# Copyright 2024 Rhymes AI. All rights reserved. # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import torch import torch.nn as nn from torch.nn.init import trunc_normal_ from transformers.activations import ACT2FN class FFN(nn.Module): """ Feed-Forward Network module. Args: embed_dim (int): Input embedding dimension. ff_dim (int): Hidden dimension of the feed-forward network. output_dim (int): Output dimension. """ def __init__(self, embed_dim, ff_dim, output_dim): super().__init__() self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False) self.linear_out = nn.Linear(ff_dim, output_dim, bias=False) self.act = ACT2FN["gelu_new"] def forward(self, hidden_states): hidden_states = self.act(self.linear_in(hidden_states)) hidden_states = self.linear_out(hidden_states) return hidden_states class CrossAttention(nn.Module): """ Cross-Attention module. Args: kv_dim (int): Dimension of key and value. embed_dim (int): Embedding dimension. num_heads (int): Number of attention heads. drop_out_rate (float): Dropout rate. Default is 0. """ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): super().__init__() self.num_heads = num_heads self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) self.linear = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(drop_out_rate) self.layer_norm = nn.LayerNorm(embed_dim) self.ln_kv = nn.LayerNorm(kv_dim) def forward(self, x, hidden_states, attn_mask=None, add_residual=False): """ Forward pass of the CrossAttention module. Args: x (torch.Tensor): Input tensor for key and value. hidden_states (torch.Tensor): Input tensor for query. attn_mask (torch.Tensor, optional): Attention mask. Default is None. add_residual (bool): Whether to add residual connection. Default is False. Returns: torch.Tensor: Output tensor after cross-attention. """ normed_hidden_states = self.layer_norm(hidden_states) query = self.q_proj(normed_hidden_states).permute(1, 0, 2) x = self.ln_kv(x) key = self.k_proj(x).permute(1, 0, 2) value = self.v_proj(x).permute(1, 0, 2) attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) attn_output = attn_output.permute(1, 0, 2) if add_residual: attn_output = hidden_states + self.dropout(self.linear(attn_output)) else: attn_output = self.dropout(self.linear(attn_output)) return attn_output class AriaProjector(nn.Module): """ A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs. Args: patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers, e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution. embed_dim (int): Embedding dimension. num_heads (int): Number of attention heads. kv_dim (int): Dimension of key and value. ff_dim (int): Hidden dimension of the feed-forward network. output_dim (int): Output dimension. norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. Outputs: A tensor with the shape of (batch_size, query_number, output_dim) """ def __init__( self, patch_to_query_dict, embed_dim, num_heads, kv_dim, ff_dim, output_dim, norm_layer=nn.LayerNorm, ): super().__init__() self.patch_to_query_dict = patch_to_query_dict self.embed_dim = embed_dim self.num_heads = num_heads self.query = nn.Parameter( torch.zeros(max(patch_to_query_dict.values()), self.embed_dim) ) trunc_normal_(self.query, std=0.02) self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) self.ln_ffn = norm_layer(embed_dim) self.ffn = FFN(embed_dim, ff_dim, output_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, attn_mask=None): """ Forward pass of the Projector module. Args: x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim). attn_mask (torch.Tensor, optional): Attention mask. Default is None. Returns: torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim). """ bs = x.shape[0] queries = self.query.unsqueeze(0).repeat(bs, 1, 1) query_num = self.patch_to_query_dict.get(x.shape[1], None) assert ( query_num is not None ), f"Query number for {x.shape[1]} patches is not provided" queries = queries[:, :query_num, :] if attn_mask is not None: attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) out = self.ffn(self.ln_ffn(attention_out)) return out