|
from .huggingface_utils import get_auth_token |
|
from .ort_settings import get_onnx_runtime_sessions |
|
from .onnx_exporter import ( |
|
generate_onnx_representation, |
|
quantize, |
|
get_model_paths, |
|
saved_models_path, |
|
) |
|
from pathlib import Path |
|
|
|
from transformers import ( |
|
AutoConfig, |
|
MT5Config, |
|
T5ForConditionalGeneration, |
|
) |
|
from transformers.modeling_outputs import ( |
|
Seq2SeqLMOutput, |
|
BaseModelOutput, |
|
) |
|
import torch |
|
import functools |
|
import operator |
|
import numpy |
|
|
|
|
|
class T5Encoder(torch.nn.Module): |
|
def __init__(self, encoder_sess): |
|
super().__init__() |
|
self.encoder = encoder_sess |
|
self.main_input_name = "input_ids" |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
inputs_embeds=None, |
|
head_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
encoder_hidden_state = torch.from_numpy( |
|
self.encoder.run( |
|
None, |
|
{ |
|
"input_ids": input_ids.cpu().numpy(), |
|
"attention_mask": attention_mask.cpu().numpy(), |
|
}, |
|
)[0] |
|
) |
|
|
|
return BaseModelOutput(encoder_hidden_state) |
|
|
|
|
|
class T5DecoderInit(torch.nn.Module): |
|
def __init__(self, decoder_sess): |
|
super().__init__() |
|
self.decoder = decoder_sess |
|
|
|
def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states): |
|
|
|
decoder_outputs = self.decoder.run( |
|
None, |
|
{ |
|
"input_ids": input_ids.cpu().numpy(), |
|
"encoder_attention_mask": encoder_attention_mask.cpu().numpy(), |
|
"encoder_hidden_states": encoder_hidden_states.cpu().numpy(), |
|
}, |
|
) |
|
|
|
list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:]) |
|
|
|
out_past_key_values = tuple( |
|
list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4) |
|
) |
|
|
|
return torch.from_numpy(decoder_outputs[0]), out_past_key_values |
|
|
|
|
|
class T5Decoder(torch.nn.Module): |
|
def __init__(self, decoder_sess): |
|
super().__init__() |
|
self.decoder = decoder_sess |
|
|
|
def forward(self, input_ids, attention_mask, encoder_output, past_key_values): |
|
|
|
decoder_inputs = { |
|
"input_ids": input_ids.cpu().numpy(), |
|
"encoder_attention_mask": attention_mask.cpu().numpy(), |
|
"encoder_hidden_states": encoder_output.cpu().numpy(), |
|
} |
|
|
|
flat_past_key_values = functools.reduce( |
|
operator.iconcat, past_key_values, []) |
|
|
|
past_key_values = { |
|
f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values) |
|
} |
|
|
|
decoder_outputs = self.decoder.run( |
|
None, {**decoder_inputs, **past_key_values}) |
|
|
|
list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:]) |
|
|
|
|
|
out_past_key_values = tuple( |
|
list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4) |
|
) |
|
|
|
return torch.from_numpy(decoder_outputs[0]), out_past_key_values |
|
|
|
|
|
class OnnxT5(T5ForConditionalGeneration): |
|
"""creates a T5 model using onnx sessions (encode, decoder & init_decoder)""" |
|
|
|
def __init__(self, model_or_model_path, onnx_model_sessions): |
|
config = AutoConfig.from_pretrained( |
|
model_or_model_path, use_auth_token=get_auth_token() |
|
) |
|
super().__init__(config) |
|
|
|
|
|
if ( |
|
isinstance(model_or_model_path, str) |
|
and "mt5" in model_or_model_path.lower() |
|
) or ( |
|
hasattr(model_or_model_path, "name_or_path") |
|
and "mt5" in model_or_model_path.name_or_path |
|
): |
|
self.model_type = "mt5" |
|
self.config_class = MT5Config |
|
self._keys_to_ignore_on_load_missing = [ |
|
r"encoder\.embed_tokens\.weight", |
|
] |
|
self._keys_to_ignore_on_save = [ |
|
r"encoder\.embed_tokens\.weight", |
|
] |
|
|
|
assert len(onnx_model_sessions) == 3, "all three models should be given" |
|
|
|
encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions |
|
|
|
self.encoder = T5Encoder(encoder_sess) |
|
self.decoder = T5Decoder(decoder_sess) |
|
self.decoder_init = T5DecoderInit(decoder_sess_init) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
cross_attn_head_mask=None, |
|
encoder_outputs=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
decoder_inputs_embeds=None, |
|
labels=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
if encoder_outputs is None: |
|
|
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
if past_key_values is not None: |
|
if decoder_input_ids is not None: |
|
decoder_input_ids = decoder_input_ids[:, -1:] |
|
if decoder_inputs_embeds is not None: |
|
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] |
|
|
|
if past_key_values is None: |
|
|
|
|
|
init_onnx_outputs = self.decoder_init( |
|
decoder_input_ids, attention_mask, encoder_hidden_states |
|
) |
|
|
|
logits, past_key_values = init_onnx_outputs |
|
|
|
else: |
|
|
|
onnx_outputs = self.decoder( |
|
decoder_input_ids, |
|
attention_mask, |
|
encoder_hidden_states, |
|
past_key_values, |
|
) |
|
|
|
logits, past_key_values = onnx_outputs |
|
|
|
return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values) |
|
|
|
|
|
def export_and_get_onnx_model( |
|
model_or_model_path, custom_output_path=saved_models_path, quantized=True |
|
): |
|
""" |
|
Method for whole pipeline, |
|
converts from pytorch to onnx --> quantizes model --> sets onnx runtime |
|
--> builds whole onnx model with all sessions |
|
|
|
""" |
|
|
|
|
|
onnx_model_paths = generate_onnx_representation( |
|
model_or_model_path, output_path=custom_output_path |
|
) |
|
|
|
if quantized: |
|
|
|
quant_model_paths = quantize(onnx_model_paths) |
|
|
|
|
|
print("Setting up onnx model...") |
|
model_sessions = get_onnx_runtime_sessions(quant_model_paths) |
|
else: |
|
print("Setting up onnx model...") |
|
model_sessions = get_onnx_runtime_sessions(onnx_model_paths) |
|
|
|
|
|
model = OnnxT5(model_or_model_path, model_sessions) |
|
print("Done!") |
|
|
|
return model |
|
|
|
|
|
def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True): |
|
""" |
|
method gets the onnx model, if already converted models exists |
|
Example: |
|
>> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/") |
|
|
|
""" |
|
|
|
encoder_path, decoder_path, init_decoder_path = get_model_paths( |
|
model_name, Path(onnx_models_path), quantized |
|
) |
|
|
|
if quantized: |
|
assert ( |
|
encoder_path.exists() |
|
and decoder_path.exists() |
|
and init_decoder_path.exists() |
|
), "quantized model don't exist in the model folder, first quantize the model!" |
|
else: |
|
assert ( |
|
encoder_path.exists() |
|
and decoder_path.exists() |
|
and init_decoder_path.exists() |
|
), "all or some models don't exists in the model folder, first convert the model! " |
|
|
|
model_paths = encoder_path, decoder_path, init_decoder_path |
|
|
|
model_sessions = get_onnx_runtime_sessions(model_paths) |
|
|
|
model = OnnxT5(model_name, model_sessions) |
|
|
|
return model |
|
|