Aria / projector.py
aria-dev's picture
first version
0531a03
raw
history blame
No virus
6.62 kB
# Copyright 2024 Rhymes AI. All rights reserved.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
from transformers.activations import ACT2FN
class FFN(nn.Module):
"""
Feed-Forward Network module.
Args:
embed_dim (int): Input embedding dimension.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
"""
def __init__(self, embed_dim, ff_dim, output_dim):
super().__init__()
self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False)
self.linear_out = nn.Linear(ff_dim, output_dim, bias=False)
self.act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_states = self.act(self.linear_in(hidden_states))
hidden_states = self.linear_out(hidden_states)
return hidden_states
class CrossAttention(nn.Module):
"""
Cross-Attention module.
Args:
kv_dim (int): Dimension of key and value.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
drop_out_rate (float): Dropout rate. Default is 0.
"""
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out_rate)
self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)
def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
"""
Forward pass of the CrossAttention module.
Args:
x (torch.Tensor): Input tensor for key and value.
hidden_states (torch.Tensor): Input tensor for query.
attn_mask (torch.Tensor, optional): Attention mask. Default is None.
add_residual (bool): Whether to add residual connection. Default is False.
Returns:
torch.Tensor: Output tensor after cross-attention.
"""
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
x = self.ln_kv(x)
key = self.k_proj(x).permute(1, 0, 2)
value = self.v_proj(x).permute(1, 0, 2)
attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
attn_output = attn_output.permute(1, 0, 2)
if add_residual:
attn_output = hidden_states + self.dropout(self.linear(attn_output))
else:
attn_output = self.dropout(self.linear(attn_output))
return attn_output
class AriaProjector(nn.Module):
"""
A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs.
Args:
patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
Outputs:
A tensor with the shape of (batch_size, query_number, output_dim)
"""
def __init__(
self,
patch_to_query_dict,
embed_dim,
num_heads,
kv_dim,
ff_dim,
output_dim,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(
torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)
)
trunc_normal_(self.query, std=0.02)
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
"""
Forward pass of the Projector module.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim).
attn_mask (torch.Tensor, optional): Attention mask. Default is None.
Returns:
torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim).
"""
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
query_num = self.patch_to_query_dict.get(x.shape[1], None)
assert (
query_num is not None
), f"Query number for {x.shape[1]} patches is not provided"
queries = queries[:, :query_num, :]
if attn_mask is not None:
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)
out = self.ffn(self.ln_ffn(attention_out))
return out