questgen / FastT5 /onnx_models_structure.py
ViXuan's picture
cleaner files
1571051
raw
history blame
No virus
1.91 kB
import torch
class DecoderWithLMhead(torch.nn.Module):
""" Creation of a class to combine the decoder and the lm head """
def __init__(self, decoder, lm_head, config):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
def forward(self, *inputs):
input_ids, attention_mask, encoder_hidden_states = inputs[:3]
list_pkv = inputs[3:]
past_key_values = tuple(list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4))
decoder_output = self.decoder(
input_ids=input_ids, # decoder_input_ids
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
)
lm_head_out = self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5))
return lm_head_out, decoder_output[1]
class T5Encoder(torch.nn.Module):
""" Creation of a class to output only the last hidden state from the encoder """
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, *input, **kwargs):
return self.encoder(*input, **kwargs)[0]
class DecoderWithLMheadInitial(torch.nn.Module):
""" Creation of a class to combine the decoder and the lm head """
def __init__(self, decoder, lm_head, config):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
def forward(self, input_ids, attention_mask, encoder_hidden_states):
decoder_output = self.decoder(
input_ids=input_ids,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
)
return (
self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
decoder_output[1],
)