import numpy as np import torch import torch.nn as nn from transformers import AutoConfig, T5EncoderModel from .nn import SiLU, linear, timestep_embedding class TransformerNetModel(nn.Module): def __init__( self, in_channels=32, model_channels=128, dropout=0.1, config_name="QizhiPei/biot5-base-text2mol", vocab_size=None, # 821 hidden_size=768, num_attention_heads=12, num_hidden_layers=12, ): super().__init__() config = AutoConfig.from_pretrained(config_name) config.is_decoder = True config.add_cross_attention = True config.hidden_dropout_prob = 0.1 config.num_attention_heads = num_attention_heads config.num_hidden_layers = num_hidden_layers config.max_position_embeddings = 512 config.layer_norm_eps = 1e-12 config.vocab_size = vocab_size config.d_model = hidden_size self.hidden_size = hidden_size self.in_channels = in_channels self.model_channels = model_channels self.dropout = dropout self.word_embedding = nn.Embedding(vocab_size, self.in_channels) self.lm_head = nn.Linear(self.in_channels, vocab_size) self.lm_head.weight = self.word_embedding.weight self.caption_down_proj = nn.Sequential( linear(768, self.hidden_size), SiLU(), linear(self.hidden_size, self.hidden_size), ) time_embed_dim = model_channels * 4 # 512 self.time_embed = nn.Sequential( linear(self.model_channels, time_embed_dim), SiLU(), linear(time_embed_dim, self.hidden_size), ) self.input_up_proj = nn.Sequential( nn.Linear(self.in_channels, self.hidden_size), nn.Tanh(), nn.Linear(self.hidden_size, self.hidden_size), ) self.input_transformers = T5EncoderModel(config) # self.input_transformers.eval() # for param in self.input_transformers.parameters(): # param.requires_grad = False self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, self.hidden_size ) self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.output_down_proj = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh(), nn.Linear(self.hidden_size, self.in_channels), ) def get_embeds(self, input_ids): return self.word_embedding(input_ids) def get_embeds_with_deep(self, input_ids): atom, deep = input_ids atom = self.word_embedding(atom) deep = self.deep_embedding(deep) return torch.concat([atom, deep], dim=-1) def get_logits(self, hidden_repr): return self.lm_head(hidden_repr) def forward(self, x, timesteps, caption_state, caption_mask, y=None): emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb_x = self.input_up_proj(x) seq_length = x.size(1) position_ids = self.position_ids[:, :seq_length] emb_inputs = ( self.position_embeddings(position_ids) + emb_x + emb.unsqueeze(1).expand(-1, seq_length, -1) ) emb_inputs = self.dropout(self.LayerNorm(emb_inputs)) caption_state = self.dropout( self.LayerNorm(self.caption_down_proj(caption_state)) ) input_trans_hidden_states = self.input_transformers.encoder( inputs_embeds=emb_inputs, encoder_hidden_states=caption_state, encoder_attention_mask=caption_mask, ).last_hidden_state h = self.output_down_proj(input_trans_hidden_states) h = h.type(x.dtype) return h