# Copyright 2024 EPFL and Apple Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import random import copy from functools import partial from typing import Any, Dict, Optional, Tuple, Union import torch from einops import rearrange, repeat from torch import nn import torch.nn.functional as F from fourm.utils.timm.registry import register_model from huggingface_hub import PyTorchModelHubMixin from .fm_utils import Block, DecoderBlock, LayerNorm from fourm.data.modality_info import MODALITY_INFO # Model definitions __all__ = [ # GELU models 'fm_tiny_6e_6d_gelu', 'fm_small_8e_8d_gelu', 'fm_base_12e_12d_gelu', 'fm_large_24e_24d_gelu', 'fm_xlarge_24e_24d_gelu', # SwiGLU models 'fm_tiny_6e_6d_swiglu_nobias', 'fm_small_8e_8d_swiglu_nobias', 'fm_base_12e_12d_swiglu_nobias', 'fm_large_24e_24d_swiglu_nobias', 'fm_xlarge_24e_24d_swiglu_nobias', # SwiGLU + QKNorm models 'fm_base_12e_12d_swiglu_qknorm_nobias', 'fm_large_24e_24d_swiglu_qknorm_nobias', 'fm_xlarge_24e_24d_swiglu_qknorm_nobias', ] class FourM(nn.Module): """4M model. Args: encoder_embeddings: Dict of encoder embedding modules. decoder_embeddings: Dict of decoder embedding modules. modality_info: Dict containing modality information. dim: Embedding dimension. encoder_depth: Number of encoder blocks. decoder_depth: Number of decoder blocks. num_heads: Number of attention heads. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value projections. proj_bias: If True, add a learnable bias to the last projection of the attention block. mlp_bias: If True, add a learnable bias to linear layers in the MLP / feed-forward. drop_path_rate_encoder: Stochastic depth rate for encoder. drop_path_rate_decoder: Stochastic depth rate for decoder. shared_drop_path: If True, shares drop path between encoder and decoder. act_layer: Activation layer to be used. norm_layer: Normalization layer to be used. gated_mlp: If True, make the feedforward gated (e.g., SwiGLU). qk_norm: If True, applies normalization to queries and keys (QKNorm). decoder_causal_mask: If True, decoder will use a causal mask for all tokens. decoder_sep_mask: If True, decoder attention is restricted to within each modality only. num_register_tokens: Number of register tokens. use_act_checkpoint: If True, use activation checkpoint for each block. """ def __init__(self, encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], modality_info: Dict[str, Any], dim: int = 768, encoder_depth: int = 12, decoder_depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, proj_bias: bool = True, mlp_bias: bool = True, drop_path_rate_encoder: float = 0.0, drop_path_rate_decoder: float = 0.0, shared_drop_path: bool = False, act_layer: nn.Module = nn.GELU, norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6), gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU qk_norm: bool = False, decoder_causal_mask: bool = False, decoder_sep_mask: bool = True, num_register_tokens: int = 0, use_act_checkpoint: bool = False, share_modality_embeddings: bool = True, ): super().__init__() self.modality_info = modality_info self.dim = dim self.decoder_causal_mask = decoder_causal_mask self.decoder_sep_mask = decoder_sep_mask self.init_std = 0.02 self.use_act_checkpoint = use_act_checkpoint self.num_register_tokens = num_register_tokens # Encoder embeddings & init self.encoder_modalities = set(encoder_embeddings.keys()) for emb in encoder_embeddings.values(): emb.init(dim_tokens=dim, init_std=self.init_std) self.encoder_embeddings = nn.ModuleDict(encoder_embeddings) # Decoder embeddings & init self.decoder_modalities = set(decoder_embeddings.keys()) for emb in decoder_embeddings.values(): emb.init(dim_tokens=dim, init_std=self.init_std) self.decoder_embeddings = nn.ModuleDict(decoder_embeddings) # Share modality embeddings across the encoder and decoder embedding modules if share_modality_embeddings: self.share_modality_embeddings() ## Transformer encoder if shared_drop_path: dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth + decoder_depth)][:encoder_depth] else: dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth)] # stochastic depth decay rule self.encoder = nn.ModuleList([ Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, drop_path=dpr_encoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm) for i in range(encoder_depth) ]) self.encoder_norm = norm_layer(dim) ## Transformer decoder if shared_drop_path: dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, encoder_depth + decoder_depth)][encoder_depth:] else: dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, decoder_depth)] # stochastic depth decay rule # Projection of encoder tokens before adding the embeddings again self.decoder_proj_context = nn.Linear(dim, dim) self.decoder = nn.ModuleList([ DecoderBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias, drop_path=dpr_decoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm) for i in range(decoder_depth) ]) self.decoder_norm = norm_layer(dim) self.mask_token = nn.Parameter(torch.zeros(1, 1, dim)) nn.init.normal_(self.mask_token, std=self.init_std) # Additional register tokens that can be used by the encoder during fine-tuning if self.num_register_tokens > 0: self.register_tokens = nn.Parameter(torch.zeros(1, self.num_register_tokens, dim)) nn.init.normal_(self.register_tokens, std=self.init_std) else: self.register_tokens = None # Weight init self.init_weights() def share_modality_embeddings(self): """Share modality embeddings across the encoder and decoder embedding modules.""" shared_modalities = self.encoder_modalities & self.decoder_modalities for mod in shared_modalities: self.decoder_embeddings[mod].mod_emb = self.encoder_embeddings[mod].mod_emb def init_weights(self): """Weight initialization following MAE's initialization scheme""" for name, m in self.named_modules(): # Skipping tokenizers to avoid reinitializing them if "tokenizer" in name: continue # Linear elif isinstance(m, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) nn.init.uniform_(m.weight, -val, val) elif 'kv' in name: # treat the weights of K, V separately val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) nn.init.uniform_(m.weight, -val, val) else: nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) # LayerNorm elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm): nn.init.constant_(m.weight, 1.0) if m.bias is not None: nn.init.constant_(m.bias, 0) # Embedding elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=self.init_std) # Conv2d elif isinstance(m, nn.Conv2d): if '.proj' in name: # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) w = m.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) def get_num_layers_encoder(self): return len(self.encoder) def get_num_layers_decoder(self): return len(self.decoder) def get_num_layers(self): return self.get_num_layers_encoder() + self.get_num_layers_decoder() @torch.jit.ignore def no_weight_decay(self): no_wd_set = set() for mod, emb_module in self.encoder_embeddings.items(): if hasattr(emb_module, 'no_weight_decay'): to_skip = emb_module.no_weight_decay() to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip]) no_wd_set = no_wd_set | to_skip for mod, emb_module in self.decoder_embeddings.items(): if hasattr(emb_module, 'no_weight_decay'): to_skip = emb_module.no_weight_decay() to_skip = set([f'decoder_embeddings.{mod}.{name}' for name in to_skip]) no_wd_set = no_wd_set | to_skip return no_wd_set def cat_encoder_tensors(self, mod_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]: """Concatenate encoder tensors from different modalities. Args: mod_dict (dict): A dictionary containing information for each modality. Expected keys for each modality are 'x' (input tokens), 'emb' (embeddings), 'input_mask', etc. Returns: tuple: - encoder_tokens_all (torch.Tensor): Concatenated encoder tokens from all modalities. Shape (B, O, D) where O is the total number of all encoder tokens. - emb_all (torch.Tensor): Concatenated encoder embeddings from all modalities. Shape (B, O, D) - encoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the encoder input (set to 0 for valid tokens, 1 otherwise). Shape (B, O) - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each encoder token. Shape (B, O) """ encoder_tokens_all = [] emb_all = [] encoder_mask_all = [] mod_mask_all = [] for mod, d in mod_dict.items(): encoder_tokens_all.append(d['x']) emb_all.append(d['emb']) encoder_mask_all.append(d['input_mask']) mod_mask_all.append(torch.full_like(d['input_mask'], self.modality_info[mod]['id'], dtype=torch.int16)) encoder_tokens_all = torch.cat(encoder_tokens_all, dim=1) emb_all = torch.cat(emb_all, dim=1) encoder_mask_all = torch.cat(encoder_mask_all, dim=1) mod_mask_all = torch.cat(mod_mask_all, dim=1) return encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all def cat_decoder_tensors(self, mod_dict: Dict[str, Dict[str, torch.Tensor]]) -> Tuple[torch.Tensor]: """Concatenate decoder tensors from different modalities. Args: mod_dict (dict): A dictionary containing information for each modality. Expected keys for each modality include 'x' (input tokens), 'ids' (target IDs), 'emb' (embeddings), 'target_mask', 'decoder_attention_mask', etc. Returns: tuple: - decoder_tokens_all (torch.Tensor): Concatenated decoder tokens from all modalities. Shape (B, P, D) where P is the total number of all decoder tokens. - emb_all (torch.Tensor): Concatenated decoder embeddings from all modalities. Shape (B, P, D) - decoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the decoder input / target (set to 0 for valid tokens, 1 otherwise). Shape (B, P) - target_ids_all (torch.Tensor): Concatenated target IDs from all modalities. Shape (B, P) - attention_mask_all (torch.Tensor): Concatenated attention masks in compressed format, needs to be passed to adapt_decoder_attention_mask() to obtain the final attention mask. Shape (B, P) - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each decoder token. Shape (B, P) """ decoder_tokens_all = [] target_ids_all = [] emb_all = [] decoder_mask_all = [] attention_mask_all = [] mod_mask_all = [] # Shuffle order in which modalities are provided (useful for modality causal mask) mod_dict = {mod: d for mod, d in random.sample(mod_dict.items(), len(mod_dict))} for mod, d in mod_dict.items(): if self.modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']: # Important: This makes the assumption that the target sequence appears sequentially # before sorting / gathering decoder_tokens_all.append(d['x'][:, :-1]) target_ids_all.append(d['ids'][:, 1:]) # Shifted left emb_all.append(d['emb'][:, :-1]) # Logical or with left shifting removes the last unmasked position decoder_mask_all.append(torch.logical_or(d['target_mask'][:, 1:], d['target_mask'][:, :-1])) # Add attention mask ids attention_mask_all.append(d['decoder_attention_mask'][:, :-1]) mod_mask_all.append(torch.full_like(d['ids'][:, :-1], self.modality_info[mod]['id'], dtype=torch.int16)) else: # Important: For 2d / image modalities, the decoder input tokens are replaced by the mask token decoder_tokens_all.append(torch.zeros_like(d['x']) + self.mask_token) # Replace x by mask token target_ids_all.append(d['ids']) emb_all.append(d['emb']) decoder_mask_all.append(d['target_mask']) attention_mask_all.append(d['decoder_attention_mask']) mod_mask_all.append(torch.full_like(d['ids'], self.modality_info[mod]['id'], dtype=torch.int16)) decoder_tokens_all = torch.cat(decoder_tokens_all, dim=1) emb_all = torch.cat(emb_all, dim=1) decoder_mask_all = torch.cat(decoder_mask_all, dim=1) target_ids_all = torch.cat(target_ids_all, dim=1) attention_mask_all = torch.cat(attention_mask_all, dim=1) mod_mask_all = torch.cat(mod_mask_all, dim=1) return decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, attention_mask_all, mod_mask_all def forward_mask_encoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_encoder_tokens: int) -> Tuple[torch.Tensor]: """Concatenates and mask encoder tensors based on provided modality information. This function consolidates encoder tokens from multiple modalities, then selects a specified number of them based on modality information (i.e. masking). Args: mod_dict (dict): Dictionary containing tensors for different modalities. It is expected to have keys for each modality and values containing the modalities' associated tensors. num_encoder_tokens (int): Number of encoder tokens to retain after masking. Returns: tuple: - encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. Shape (B, N, D) where N is the number of selected encoder tokens. - encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. Shape (B, N, D) - encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N) - mod_mask (torch.Tensor): An integer mask marking the modality type for each encoder token (with -1 indicating unassigned pad tokens). Shape (B, N) Notes: - If `num_register_tokens` is set and greater than 0, register tokens are added at the beginning of the sequence. """ B = list(mod_dict.values())[0]['tensor'].shape[0] encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.cat_encoder_tensors(mod_dict) # Add arange multiplied by small constant to mask so they get sorted in a deterministic way mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6 ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1) # ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :num_encoder_tokens] encoder_tokens = torch.gather(encoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2])) encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep) mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) if self.num_register_tokens > 0: register_tokens = repeat(self.register_tokens, '() n d -> b n d', b=B) # We add register tokens at the beginning of the sequence encoder_tokens = torch.cat([register_tokens, encoder_tokens], dim=1) encoder_emb = torch.cat([torch.zeros_like(register_tokens), encoder_emb], dim=1) encoder_mask = torch.cat([torch.zeros((B, register_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1) mod_mask = torch.cat([torch.full((B, register_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1) encoder_tokens[encoder_mask] = 0. encoder_emb[encoder_mask] = 0. mod_mask[encoder_mask] = -1 # Mask could be of shape 'b n1 n2' but not needed for masked_fill # This means this mask can then be re-used for decoder cross-attention encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2') return encoder_tokens, encoder_emb, encoder_mask, mod_mask def forward_mask_decoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_decoder_tokens: int) -> Tuple[torch.Tensor]: """Concatenates and mask decoder tensors based on provided modality information. This function consolidates decoder tokens from multiple modalities, selects a specified number of them based on modality information, and applies appropriate masking. Args: mod_dict (dict): Dictionary containing tensors for different modalities. It is expected to have keys for each modality and values containing the modalities' associated tensors. num_decoder_tokens (int): Number of decoder tokens to retain after masking. Returns: tuple: - decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens. - decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D) - decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M) - target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M) - decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M) - mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M) """ # decoder_mask and target_mask are equivalent, we rename it here to harmonize with forward_mask_encoder decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, decoder_attention_mask_all, mod_mask_all = self.cat_decoder_tensors(mod_dict) # Add arange multiplied by small constant to mask so they get sorted in a deterministic way mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1) # ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :num_decoder_tokens] decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2])) decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) target_ids = torch.gather(target_ids_all, dim=1, index=ids_keep) decoder_attention_mask = torch.gather(decoder_attention_mask_all, dim=1, index=ids_keep) mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) decoder_tokens[decoder_mask] = 0. decoder_emb[decoder_mask] = 0. target_ids[decoder_mask] = 0 decoder_attention_mask = self.adapt_decoder_attention_mask(decoder_attention_mask, mod_mask) mod_mask[decoder_mask] = -1 # This means this mask can then be re-used for decoder cross-attention decoder_mask = rearrange(decoder_mask, 'b n2 -> b 1 n2') return decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, mod_mask def adapt_decoder_attention_mask(self, decoder_attention_mask: torch.Tensor, mod_mask=Optional[torch.Tensor]) -> torch.Tensor: """ Transforms the compressed decoder attention mask to a full attention mask based on the specified constraints. Args: decoder_attention_mask (torch.Tensor): Initial attention mask indicating attention constraints. Shape (B, M) where M is the number of the decoder tokens. mod_mask (torch.Tensor, optional): Modality mask to separate attention masks per modality. Shape (B, M) Returns: torch.Tensor: Adapted attention mask. Shape (B, M, M) where M is the number of the decoder tokens. """ B, N = decoder_attention_mask.shape if self.decoder_causal_mask: # For causal mode, tokens can only attend to preceding tokens and themselves. causal_mask = torch.ones((N, N), dtype=torch.bool, device=decoder_attention_mask.device).triu(1) causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B) adapted_attention_mask = causal_mask else: # Cumulatively sum the attention mask to determine token-wise attention behavior. # Examples: # Mask [4, 0, 0, 0] -> Cumsum: [4, 4, 4, 4] -> All tokens attend to each other. # Mask [1, 1, 1, 1] -> Cumsum: [1, 2, 3, 4] -> Strict autoregressive behavior. # Mask [2, 0, 1, 1] -> Cumsum: [2, 2, 3, 4] -> Tokens 1 and 2 attend to each other, token 3 attends to tokens 1-3, and token 4 to all. attention_arange = torch.arange(N, device=decoder_attention_mask.device) attention_arange = repeat(attention_arange, "n2 -> b n1 n2", b=B, n1=N) cumsum_mask = torch.cumsum(decoder_attention_mask, dim=-1) cumsum_mask = rearrange(cumsum_mask, "b n -> b n 1") adapted_attention_mask = (attention_arange >= cumsum_mask) if self.decoder_sep_mask: # Separate attention between tokens based on their modality using mod_mask. sep_mask = repeat(mod_mask, "b n2 -> b n1 n2", n1=N) != repeat(mod_mask, "b n1 -> b n1 n2", n2=N) adapted_attention_mask = adapted_attention_mask | sep_mask return adapted_attention_mask def forward_encoder(self, x: torch.Tensor, encoder_mask: torch.Tensor) -> torch.Tensor: """Forward pass for the encoder. Args: x (torch.Tensor): Encoder input tokens. Shape (B, N, D) where N is the number of encoder tokens. encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N) Returns: torch.Tensor: Encoder output. Shape (B, N, D) """ for blk in self.encoder: x = blk(x, mask=encoder_mask) x = self.encoder_norm(x) return x def forward_decoder(self, y: torch.Tensor, context: torch.Tensor, encoder_mask: torch.Tensor, decoder_attention_mask: torch.Tensor) -> torch.Tensor: """Forward pass for the decoder. Args: y (torch.Tensor): Decoder input tokens. Shape (B, M, D). context (torch.Tensor): Context for the decoder (i.e. encoder output). Shape (B, N, D). encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N). decoder_attention_mask (torch.Tensor): Decoder attention mask. Shape (B, M, M). Returns: torch.Tensor: Decoder output. Shape (B, M, D). """ for blk in self.decoder: y = blk(y, context, sa_mask=decoder_attention_mask, xa_mask=encoder_mask) y = self.decoder_norm(y) return y def forward_logits(self, y: torch.Tensor, decoder_mod_dict: Dict[str, Dict[str, torch.Tensor]], decoder_mod_mask: torch.Tensor, return_all_logits: bool = False) -> Dict[str, torch.Tensor]: """Forward computation of logits for each modality. Args: y (torch.Tensor): Decoder output. Shape (B, M, D). decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M). Returns: Dict[str, torch.Tensor]: Dictionary of logits for each modality. """ mod_logits = {} for mod, d in decoder_mod_dict.items(): idx = self.modality_info[mod]["id"] if return_all_logits: logits = self.decoder_embeddings[mod].forward_logits(y) else: logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx]) mod_logits[mod] = logits return mod_logits def forward_loss(self, y: torch.Tensor, target_ids: torch.Tensor, decoder_mod_dict: Dict[str, Any], decoder_mod_mask: torch.Tensor, loss_type: str) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Computes the loss based on the specified loss type. Args: y (torch.Tensor): Decoder output. Shape (B, M, D). target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M). decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M). loss_type (str): The type of loss to compute. Either 'mod' or 'token'. Returns: Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total loss and dictionary of loss for each modality. """ if loss_type in ['mod', 'modality']: loss, mod_loss = self.forward_mod_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask) elif loss_type == 'token': loss, mod_loss = self.forward_token_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask) else: raise ValueError("Invalid loss type") return loss, mod_loss def forward_mod_loss(self, y: torch.Tensor, target_ids: torch.Tensor, decoder_mod_dict: Dict[str, Any], decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Computes the modality-wise loss. Args: y (torch.Tensor): Decoder tokens. Shape (B, M, D). target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M). decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M). Returns: Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total modality loss and dictionary of loss for each modality. """ mod_loss = {} for mod, d in decoder_mod_dict.items(): idx = self.modality_info[mod]["id"] logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx]) if logits.numel() == 0: # If there are no logits / targets, set mod_loss to 0 mod_loss[mod] = torch.zeros(1, device=logits.device) else: loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean') mod_loss[mod] = loss loss = sum(mod_loss.values()) / len(mod_loss) return loss, mod_loss def forward_token_loss(self, y: torch.Tensor, target_ids: torch.Tensor, decoder_mod_dict: Dict[str, Any], decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Computes the token-wise loss. Args: y (torch.Tensor): Decoder tokens. Shape (B, M, D). target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M). decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder. decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M). Returns: Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total token loss and dictionary of loss for each modality. """ mod_loss = {} mod_count = {} for mod, d in decoder_mod_dict.items(): idx = self.modality_info[mod]["id"] logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx]) if logits.numel() == 0: # If there are no logits / targets, set mod_loss to 0 mod_loss[mod] = torch.zeros(1, device=logits.device) mod_count[mod] = 0 else: loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean') mod_loss[mod] = loss mod_count[mod] = logits.numel() loss = sum([mod_loss[mod] * mod_count[mod] for mod in mod_loss.keys()]) / sum(mod_count.values()) return loss, mod_loss def forward(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_encoder_tokens: int, num_decoder_tokens: int, loss_type: str = 'mod', return_logits: bool = False) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: """ Forward pass for the model. Args: mod_dict (Dict[str, Dict[str, torch.Tensor]]): Dictionary containing the tensors, masks, and other info for each modality. - mod_dict[modality_name]["tensor_name"]: Shape can vary based on tensor_name and modality. num_encoder_tokens (int): Number of tokens to keep for the encoder. num_decoder_tokens (int): Number of tokens to keep for the decoder. loss_type (str, optional): The type of loss to compute. Can be 'mod' (average of loss per modality) or 'token' (average loss per token). Default is 'mod'. return_logits (bool, optional): If True, return the logits. Default is False. Returns: Union[dict, tuple]: - If return_logits is True: Dictionary of logits for each modality. - Otherwise: Tuple containing the total loss and dictionary of loss for each modality. """ # Mod dicts encoder_mod_dict = {mod: self.encoder_embeddings[mod](d) for mod, d in mod_dict.items() if mod in self.encoder_embeddings} encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens) decoder_mod_dict = {mod: self.decoder_embeddings[mod].forward_embed(d) for mod, d in mod_dict.items() if mod in self.decoder_embeddings} decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens) # Encoder x = encoder_tokens + encoder_emb x = self.forward_encoder(x, encoder_mask=encoder_mask) # Decoder context = self.decoder_proj_context(x) + encoder_emb y = decoder_tokens + decoder_emb y = self.forward_decoder(y, context, encoder_mask=encoder_mask, decoder_attention_mask=decoder_attention_mask) # Logits if return_logits: mod_logits = self.forward_logits(y, decoder_mod_dict, decoder_mod_mask, return_all_logits=True) return mod_logits # Loss loss, mod_loss = self.forward_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask, loss_type) return loss, mod_loss def freeze_encoder(self, freeze_embeddings=True): for param in self.encoder.parameters(): param.requires_grad = False for param in self.encoder_norm.parameters(): param.requires_grad = False if freeze_embeddings: for param in self.encoder_embeddings.parameters(): param.requires_grad = False def freeze_encoder_except_specific_embeddings(self, frozen_embedding_domain): frozen_embedding_domain = frozen_embedding_domain.split('-') for param in self.encoder.parameters(): param.requires_grad = False for param in self.encoder_norm.parameters(): param.requires_grad = False for name, param in self.encoder_embeddings.named_parameters(): if name.split('.')[0] in frozen_embedding_domain: param.requires_grad = False def unfreeze_encoder(self, unfreeze_embeddings=True): for param in self.encoder.parameters(): param.requires_grad = True for param in self.encoder_norm.parameters(): param.requires_grad = True if unfreeze_embeddings: for param in self.encoder_embeddings.parameters(): param.requires_grad = True def freeze_decoder(self, freeze_embeddings=True): for param in self.decoder.parameters(): param.requires_grad = False for param in self.decoder_norm.parameters(): param.requires_grad = False if freeze_embeddings: for param in self.decoder_embeddings.parameters(): param.requires_grad = False def freeze_decoder_except_specific_embeddings(self, frozen_embedding_domain): frozen_embedding_domain = frozen_embedding_domain.split('-') for param in self.decoder.parameters(): param.requires_grad = False for param in self.decoder_norm.parameters(): param.requires_grad = False for name, param in self.decoder_embeddings.named_parameters(): if name.split('.')[0] in frozen_embedding_domain: param.requires_grad = False def unfreeze_decoder(self, unfreeze_embeddings=True): for param in self.decoder.parameters(): param.requires_grad = True for param in self.decoder_norm.parameters(): param.requires_grad = True if unfreeze_embeddings: for param in self.decoder_embeddings.parameters(): param.requires_grad = True def freeze_shared_params(self): self.freeze_encoder(freeze_embeddings=False) self.freeze_decoder(freeze_embeddings=False) def freeze_params_except_specific_embeddings(self, frozen_embedding_domain): self.freeze_encoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain) self.freeze_decoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain) def unfreeze_shared_params(self): self.unfreeze_encoder(unfreeze_embeddings=False) self.unfreeze_decoder(unfreeze_embeddings=False) def unfreeze_all(self): self.unfreeze_encoder(unfreeze_embeddings=True) self.unfreeze_decoder(unfreeze_embeddings=True) ################################################ # Wrapper for easy loading with Huggingface Hub class FM(FourM, PyTorchModelHubMixin): """Wrapper around FourM for easy loading with Huggingface Hub. Args: config (dict): Dictionary containing the model and modality configuration, used for loading from Huggingface Hub. """ def __init__(self, config: dict): config = copy.deepcopy(config) all_domains = sorted(list(set(config['domains_in']) | set(config['domains_out']))) modality_info = {mod: MODALITY_INFO[mod] for mod in all_domains} encoder_embeddings = {} for mod in config['domains_in']: info = modality_info[mod] if info.get("encoder_embedding", None) is not None: if info["type"] == "img": image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size']) encoder_embeddings[mod] = info["encoder_embedding"](patch_size=patch_size, image_size=image_size) else: encoder_embeddings[mod] = info["encoder_embedding"]() decoder_embeddings = {} for mod in config['domains_out']: info = modality_info[mod] if info.get("decoder_embedding", None) is not None: if info["type"] == "img": image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size']) decoder_embeddings[mod] = info["decoder_embedding"](patch_size=patch_size, image_size=image_size, share_embedding=False) else: decoder_embeddings[mod] = info["decoder_embedding"](share_embedding=False) config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias']) config['act_layer'] = getattr(torch.nn, config['act_layer']) del config['norm_bias'] del config['domains_in'] del config['domains_out'] del config['image_size'] del config['patch_size'] super().__init__( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, modality_info=modality_info, **config ) ################################################ # Model definitions # GELU variants @register_model def fm_tiny_6e_6d_gelu( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=6, decoder_depth=6, dim=384, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model @register_model def fm_small_8e_8d_gelu( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=8, decoder_depth=8, dim=512, num_heads=8, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model @register_model def fm_base_12e_12d_gelu( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=12, decoder_depth=12, dim=768, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model @register_model def fm_large_24e_24d_gelu( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=24, decoder_depth=24, dim=1024, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model @register_model def fm_xlarge_24e_24d_gelu( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=24, decoder_depth=24, dim=2048, num_heads=32, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) return model # SwiGLU variants @register_model def fm_tiny_6e_6d_swiglu_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=6, decoder_depth=6, dim=384, num_heads=6, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, **kwargs ) return model @register_model def fm_small_8e_8d_swiglu_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=8, decoder_depth=8, dim=512, num_heads=8, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, **kwargs ) return model @register_model def fm_base_12e_12d_swiglu_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=12, decoder_depth=12, dim=768, num_heads=12, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, **kwargs ) return model @register_model def fm_large_24e_24d_swiglu_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=24, decoder_depth=24, dim=1024, num_heads=16, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, **kwargs ) return model @register_model def fm_xlarge_24e_24d_swiglu_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=24, decoder_depth=24, dim=2048, num_heads=32, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, **kwargs ) return model # SwiGLU + QKNorm variants @register_model def fm_base_12e_12d_swiglu_qknorm_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=12, decoder_depth=12, dim=768, num_heads=12, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, qk_norm=True, **kwargs ) return model @register_model def fm_large_24e_24d_swiglu_qknorm_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=24, decoder_depth=24, dim=1024, num_heads=16, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, qk_norm=True, **kwargs ) return model @register_model def fm_xlarge_24e_24d_swiglu_qknorm_nobias( encoder_embeddings: Dict[str, nn.Module], decoder_embeddings: Dict[str, nn.Module], **kwargs): model = FourM( encoder_embeddings=encoder_embeddings, decoder_embeddings=decoder_embeddings, encoder_depth=24, decoder_depth=24, dim=2048, num_heads=32, mlp_ratio=4, qkv_bias=False, proj_bias=False, mlp_bias=False, norm_layer=partial(LayerNorm, eps=1e-6, bias=False), act_layer=nn.SiLU, gated_mlp=True, qk_norm=True, **kwargs ) return model