ViXuan commited on
Commit
1571051
1 Parent(s): 3e0f6bf

cleaner files

Browse files
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  venv
2
- .vscode
 
 
 
1
  venv
2
+ .vscode
3
+ s2v_reddit_2015_md.tar.gz
4
+ __pycache__
FastT5/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .huggingface_utils import set_auth_token
2
+ from .onnx_models import OnnxT5, export_and_get_onnx_model, get_onnx_model
3
+ from .ort_settings import get_onnx_runtime_sessions
4
+ from .onnx_exporter import generate_onnx_representation, quantize
FastT5/huggingface_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _auth_token = None
2
+
3
+ def set_auth_token(token):
4
+ """Set the token which allows the user to authenticate to hugginface.co for downloading private models
5
+
6
+ Args:
7
+ token (Union[str, bool]): The token value to store. One of:
8
+ - an API key (from https://huggingface.co/organizations/ORGNAME/settings/token),
9
+ - a login token obtained by running `$ transformers-cli login`
10
+ - `True`, which tells transformers to use the login token stored in ~/.huggingface/token
11
+
12
+ Returns:
13
+ None
14
+ """
15
+ global _auth_token
16
+ _auth_token = token
17
+
18
+ def get_auth_token():
19
+ """Get the user-configurable auth token, which defaults to None
20
+
21
+ Returns:
22
+ auth_token (Optional[Union[str, bool]]) for authenticating with huggingface.co
23
+ """
24
+ global _auth_token
25
+ return _auth_token
FastT5/mcq.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flashtext import KeywordProcessor
2
+ from nltk.tokenize import sent_tokenize
3
+ from similarity.normalized_levenshtein import NormalizedLevenshtein
4
+ from nltk.corpus import stopwords
5
+ import torch
6
+ from collections import OrderedDict
7
+ import string
8
+ import pke
9
+ import nltk
10
+ import random
11
+ nltk.download('brown')
12
+ nltk.download('stopwords')
13
+ nltk.download('popular')
14
+
15
+
16
+ def MCQs_available(word, s2v):
17
+ word = word.replace(" ", "_")
18
+ sense = s2v.get_best_sense(word)
19
+ if sense is not None:
20
+ return True
21
+ else:
22
+ return False
23
+
24
+
25
+ def edits(word):
26
+ "All edits that are one edit away from `word`."
27
+ letters = 'abcdefghijklmnopqrstuvwxyz '+string.punctuation
28
+ splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
29
+ deletes = [L + R[1:] for L, R in splits if R]
30
+ transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
31
+ replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
32
+ inserts = [L + c + R for L, R in splits for c in letters]
33
+ return set(deletes + transposes + replaces + inserts)
34
+
35
+
36
+ def sense2vec_get_words(word, s2v):
37
+ output = []
38
+
39
+ word_preprocessed = word.translate(
40
+ word.maketrans("", "", string.punctuation))
41
+ word_preprocessed = word_preprocessed.lower()
42
+
43
+ word_edits = edits(word_preprocessed)
44
+
45
+ word = word.replace(" ", "_")
46
+
47
+ sense = s2v.get_best_sense(word)
48
+ most_similar = s2v.most_similar(sense, n=15)
49
+
50
+ compare_list = [word_preprocessed]
51
+ for each_word in most_similar:
52
+ append_word = each_word[0].split("|")[0].replace("_", " ")
53
+ append_word = append_word.strip()
54
+ append_word_processed = append_word.lower()
55
+ append_word_processed = append_word_processed.translate(
56
+ append_word_processed.maketrans("", "", string.punctuation))
57
+ if append_word_processed not in compare_list and word_preprocessed not in append_word_processed and append_word_processed not in word_edits:
58
+ output.append(append_word.title())
59
+ compare_list.append(append_word_processed)
60
+
61
+ out = list(OrderedDict.fromkeys(output))
62
+
63
+ return out
64
+
65
+
66
+ def get_options(answer, s2v):
67
+ distractors = []
68
+
69
+ try:
70
+ distractors = sense2vec_get_words(answer, s2v)
71
+ if len(distractors) > 0:
72
+ print(" Sense2vec_distractors successful for word : ", answer)
73
+ return distractors, "sense2vec"
74
+ except:
75
+ print(" Sense2vec_distractors failed for word : ", answer)
76
+
77
+ return distractors, "None"
78
+
79
+
80
+ def tokenize_sentences(text):
81
+ sentences = [sent_tokenize(text)]
82
+ sentences = [y for x in sentences for y in x]
83
+ # Remove any short sentences less than 20 letters.
84
+ sentences = [sentence.strip()
85
+ for sentence in sentences if len(sentence) > 20]
86
+ return sentences
87
+
88
+
89
+ def get_sentences_for_keyword(keywords, sentences):
90
+ keyword_processor = KeywordProcessor()
91
+ keyword_sentences = {}
92
+ for word in keywords:
93
+ word = word.strip()
94
+ keyword_sentences[word] = []
95
+ keyword_processor.add_keyword(word)
96
+ for sentence in sentences:
97
+ keywords_found = keyword_processor.extract_keywords(sentence)
98
+ for key in keywords_found:
99
+ keyword_sentences[key].append(sentence)
100
+
101
+ for key in keyword_sentences.keys():
102
+ values = keyword_sentences[key]
103
+ values = sorted(values, key=len, reverse=True)
104
+ keyword_sentences[key] = values
105
+
106
+ delete_keys = []
107
+ for k in keyword_sentences.keys():
108
+ if len(keyword_sentences[k]) == 0:
109
+ delete_keys.append(k)
110
+ for del_key in delete_keys:
111
+ del keyword_sentences[del_key]
112
+
113
+ return keyword_sentences
114
+
115
+
116
+ def is_far(words_list, currentword, thresh, normalized_levenshtein):
117
+ threshold = thresh
118
+ score_list = []
119
+ for word in words_list:
120
+ score_list.append(normalized_levenshtein.distance(
121
+ word.lower(), currentword.lower()))
122
+ if min(score_list) >= threshold:
123
+ return True
124
+ else:
125
+ return False
126
+
127
+
128
+ def filter_phrases(phrase_keys, max, normalized_levenshtein):
129
+ filtered_phrases = []
130
+ if len(phrase_keys) > 0:
131
+ filtered_phrases.append(phrase_keys[0])
132
+ for ph in phrase_keys[1:]:
133
+ if is_far(filtered_phrases, ph, 0.7, normalized_levenshtein):
134
+ filtered_phrases.append(ph)
135
+ if len(filtered_phrases) >= max:
136
+ break
137
+ return filtered_phrases
138
+
139
+
140
+ def get_nouns_multipartite(text):
141
+ out = []
142
+
143
+ extractor = pke.unsupervised.MultipartiteRank()
144
+ extractor.load_document(input=text, language='en')
145
+ pos = {'PROPN', 'NOUN'}
146
+ stoplist = list(string.punctuation)
147
+ stoplist += stopwords.words('english')
148
+ extractor.candidate_selection(pos=pos)
149
+ # 4. build the Multipartite graph and rank candidates using random walk,
150
+ # alpha controls the weight adjustment mechanism, see TopicRank for
151
+ # threshold/method parameters.
152
+ try:
153
+ extractor.candidate_weighting(alpha=1.1,
154
+ threshold=0.75,
155
+ method='average')
156
+ except:
157
+ return out
158
+
159
+ keyphrases = extractor.get_n_best(n=10)
160
+
161
+ for key in keyphrases:
162
+ out.append(key[0])
163
+
164
+ return out
165
+
166
+
167
+ def get_phrases(doc):
168
+ phrases = {}
169
+ for np in doc.noun_chunks:
170
+ phrase = np.text
171
+ len_phrase = len(phrase.split())
172
+ if len_phrase > 1:
173
+ if phrase not in phrases:
174
+ phrases[phrase] = 1
175
+ else:
176
+ phrases[phrase] = phrases[phrase]+1
177
+
178
+ phrase_keys = list(phrases.keys())
179
+ phrase_keys = sorted(phrase_keys, key=lambda x: len(x), reverse=True)
180
+ phrase_keys = phrase_keys[:50]
181
+ return phrase_keys
182
+
183
+
184
+ def get_keywords(nlp, text, max_keywords, s2v, fdist, normalized_levenshtein, no_of_sentences):
185
+ doc = nlp(text)
186
+ max_keywords = int(max_keywords)
187
+
188
+ keywords = get_nouns_multipartite(text)
189
+ keywords = sorted(keywords, key=lambda x: fdist[x])
190
+ keywords = filter_phrases(keywords, max_keywords, normalized_levenshtein)
191
+
192
+ phrase_keys = get_phrases(doc)
193
+ filtered_phrases = filter_phrases(
194
+ phrase_keys, max_keywords, normalized_levenshtein)
195
+
196
+ total_phrases = keywords + filtered_phrases
197
+
198
+ total_phrases_filtered = filter_phrases(total_phrases, min(
199
+ max_keywords, 2*no_of_sentences), normalized_levenshtein)
200
+
201
+ answers = []
202
+ for answer in total_phrases_filtered:
203
+ if answer not in answers and MCQs_available(answer, s2v):
204
+ answers.append(answer)
205
+
206
+ answers = answers[:max_keywords]
207
+ return answers
208
+
209
+
210
+ def generate_questions_mcq(keyword_sent_mapping, device, tokenizer, model, sense2vec, normalized_levenshtein):
211
+ batch_text = []
212
+
213
+ answers = keyword_sent_mapping.keys()
214
+ for answer in answers:
215
+ txt = keyword_sent_mapping[answer]
216
+ txt_str = "\n".join(txt)
217
+ context = "context: " + txt_str
218
+ text = context + " " + "answer: " + answer + " </s>"
219
+ batch_text.append(text)
220
+ print(batch_text)
221
+
222
+ encoding = tokenizer.batch_encode_plus(
223
+ batch_text, pad_to_max_length=True, return_tensors="pt")
224
+
225
+ print("Running model for generation")
226
+ input_ids, attention_masks = encoding["input_ids"].to(
227
+ device), encoding["attention_mask"].to(device)
228
+
229
+ with torch.no_grad():
230
+ outs = model.generate(input_ids=input_ids,
231
+ attention_mask=attention_masks,
232
+ max_length=150)
233
+
234
+ output_array = {}
235
+ output_array["questions"] = []
236
+ # print(outs)
237
+ for index, val in enumerate(answers):
238
+ individual_question = {}
239
+ out = outs[index, :]
240
+ dec = tokenizer.decode(out, skip_special_tokens=True,
241
+ clean_up_tokenization_spaces=True)
242
+
243
+ Question = dec.replace("question:", "")
244
+ Question = Question.strip()
245
+ individual_question["question_statement"] = Question
246
+ individual_question["question_type"] = "MCQ"
247
+ individual_question["answer"] = val
248
+ individual_question["id"] = index+1
249
+ individual_question["options"], individual_question["options_algorithm"] = get_options(
250
+ val, sense2vec)
251
+
252
+ individual_question["options"] = filter_phrases(
253
+ individual_question["options"], 10, normalized_levenshtein)
254
+ index = 3
255
+ individual_question["extra_options"] = individual_question["options"][index:]
256
+ individual_question["options"] = individual_question["options"][:index]
257
+ individual_question["context"] = keyword_sent_mapping[val]
258
+
259
+ if len(individual_question["options"]) > 0:
260
+ output_array["questions"].append(individual_question)
261
+
262
+ return output_array
263
+
264
+
265
+ # for normal one word questions
266
+ def generate_normal_questions(keyword_sent_mapping, device, tokenizer, model):
267
+ batch_text = ""
268
+ answers = keyword_sent_mapping.keys()
269
+ for answer in answers:
270
+ txt = keyword_sent_mapping[answer]
271
+ context = "context: " + txt
272
+ text = context + " " + "answer: " + answer + " </s>"
273
+ batch_text.join(text)
274
+
275
+ encoding = tokenizer.batch_encode_plus(
276
+ batch_text, pad_to_max_length=True, return_tensors="pt")
277
+
278
+ print("Running model for generation")
279
+ input_ids, attention_masks = encoding["input_ids"].to(
280
+ device), encoding["attention_mask"].to(device)
281
+
282
+ with torch.no_grad():
283
+ outs = model.generate(input_ids=input_ids,
284
+ attention_mask=attention_masks,
285
+ max_length=150)
286
+
287
+ output_array = {}
288
+ output_array["questions"] = []
289
+
290
+ for index, val in enumerate(answers):
291
+ individual_quest = {}
292
+ out = outs[index, :]
293
+ dec = tokenizer.decode(out, skip_special_tokens=True,
294
+ clean_up_tokenization_spaces=True)
295
+
296
+ Question = dec.replace('question:', '')
297
+ Question = Question.strip()
298
+
299
+ individual_quest['Question'] = Question
300
+ individual_quest['Answer'] = val
301
+ individual_quest["id"] = index+1
302
+ individual_quest["context"] = keyword_sent_mapping[val]
303
+
304
+ output_array["questions"].append(individual_quest)
305
+
306
+ return output_array
307
+
308
+
309
+ def random_choice():
310
+ a = random.choice([0, 1])
311
+ return bool(a)
FastT5/model_testing_tools.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter as pc
2
+ from matplotlib import pyplot as plt
3
+ from transformers import AutoTokenizer
4
+
5
+ import numpy as np
6
+
7
+
8
+ def speed_test(
9
+ onnx_model,
10
+ torch_model,
11
+ beam_range: range = range(1, 10, 1),
12
+ seq_length_range: range = range(10, 500, 50),
13
+ input_text=None,
14
+ ):
15
+ """
16
+ method prints the time took for onnx and pytorch model to finish a text generation task
17
+
18
+ args:
19
+ input_text (str) : text input for the model.
20
+ onnx_model : onnx representation of the t5 model,
21
+ torch_model : torch represention of the t5 model,
22
+ beam_range (range) : provide a range, which takes starting end and steps (don't start with 0)
23
+ sequence_length-range (range) : takes the start, end and steps as a range (start with 10)
24
+ return :
25
+ onnx_model_latency : numpy array of latency for each beam number and sequence length
26
+ pytorch_model_latency : numpy array of latency for each beam number and sequence length
27
+ """
28
+
29
+ if input_text is None:
30
+ input_text = """translate English to French: A nucleus is a collection of a large number of up and down quarks, confined into triplets (neutrons and protons). According to the strange matter hypothesis, strangelets are more stable than nuclei, so nuclei are expected to decay into strangelets. But this process may be extremely slow because there is a large energy barrier to overcome:
31
+ as the weak interaction starts making a nucleus into a strangelet, the first few strange quarks form strange baryons, such as the Lambda, which are heavy. Only if many conversions occur almost simultaneously will the number of strange quarks reach the critical proportion required to achieve a lower energy state. This is very unlikely to happen, so even if the strange matter hypothesis were correct, nuclei would never be seen to decay to strangelets because their lifetime would be longer than the age of the universe.
32
+ The stability of strangelets depends on their size. This is because of (a) surface tension at the interface between quark matter and vacuum (which affects small strangelets more than big ones), and (b) screening of charges, which allows small strangelets to be charged, with a neutralizing cloud of electrons/positrons around them, but requires large strangelets, like any large piece of matter, to be electrically neutral in their interior. The charge screening distance tends to be of the order of a few femtometers, so only the outer few femtometers of a strangelet can carry charge.
33
+ The surface tension of strange matter is unknown. If it is smaller than a critical value (a few MeV per square femtometer) then large strangelets are unstable and will tend to fission into smaller strangelets (strange stars would still be stabilized by gravity). If it is larger than the critical value, then strangelets become more stable as they get bigger.
34
+ The known particles with strange quarks are unstable. Because the strange quark is heavier than the up and down quarks, it can spontaneously decay, via the weak interaction into an up quark. Consequently particles containing strange quarks, such as the Lambda particle, always lose their strangeness, by decaying into lighter particles containing only up and down quarks.
35
+ But condensed states with a larger number of quarks might not suffer from this instability. That possible stability against decay is the "strange matter hypothesis" proposed separately by Arnold Bodmer[3] and Edward Witten.[4] According to this hypothesis, when a large enough number of quarks are concentrated together, the lowest energy state is one which has roughly equal numbers of up, down, and strange quarks, namely a strangelet. This stability would occur because of the Pauli exclusion principle; having three types of quarks, rather than two as in normal nuclear matter, allows more quarks to be placed in lower energy levels
36
+ """
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(torch_model.name_or_path)
39
+
40
+ xx = []
41
+ yy = []
42
+
43
+ for j in beam_range:
44
+ x = []
45
+ y = []
46
+ prev = [1, 2]
47
+ for i in seq_length_range:
48
+
49
+ token = tokenizer(
50
+ input_text,
51
+ padding=True,
52
+ truncation=True,
53
+ max_length=i,
54
+ pad_to_max_length=i,
55
+ return_tensors="pt",
56
+ )
57
+
58
+ input_ids = token["input_ids"]
59
+ attention_mask = token["attention_mask"]
60
+
61
+ a = pc()
62
+ out = onnx_model.generate(
63
+ input_ids=input_ids,
64
+ attention_mask=attention_mask,
65
+ max_length=i,
66
+ num_beams=j,
67
+ )
68
+ b = pc()
69
+ x.append(b - a)
70
+
71
+ c = pc()
72
+ o = torch_model.generate(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask,
75
+ max_length=i,
76
+ num_beams=j,
77
+ )
78
+ d = pc()
79
+ y.append(d - c)
80
+
81
+ mean_y = np.mean(y)
82
+ mean_x = np.mean(x)
83
+ mean_ratio = mean_y / mean_x
84
+
85
+ print(f"seqL : {i}, onnx-{b-a}, pt-{d-c} .. X faster {(d-c)/(b-a)}")
86
+
87
+ # ...bleu_score-{bleu.compute(predictions=, references=[[tokenizer.decode(o.squeeze(), skip_special_tokens=True)], ])}')
88
+ # print(f'o---{tokenizer.decode(out.squeeze(), skip_special_tokens=True)}...p---{tokenizer.decode(o.squeeze(), skip_special_tokens=True)}')
89
+
90
+ if (o.shape[1] == prev[-1]) and (o.shape[1] == prev[-2]):
91
+ break
92
+
93
+ prev.append(o.shape[1])
94
+
95
+ print(f"beam no.- {j} onnx-{mean_x} pt-{mean_y} X ratio-{mean_ratio}")
96
+
97
+ xx.append(x)
98
+ yy.append(y)
99
+ plt.plot(x, "g", y, "r")
100
+ plt.pause(0.05)
101
+
102
+ plt.show()
103
+ return np.array(xx), np.array(yy)
FastT5/onnx_exporter.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .huggingface_utils import get_auth_token
2
+ from .onnx_models_structure import (
3
+ T5Encoder,
4
+ DecoderWithLMhead,
5
+ DecoderWithLMheadInitial,
6
+ )
7
+ from transformers import (
8
+ AutoConfig,
9
+ T5ForConditionalGeneration,
10
+ MT5ForConditionalGeneration,
11
+ )
12
+ import torch
13
+ import functools
14
+ import operator
15
+ from progress.bar import Bar
16
+ from pathlib import Path
17
+ import os
18
+
19
+ _folder = Path.cwd()
20
+ saved_models_path = _folder.joinpath("models")
21
+
22
+ Bar.check_tty = False
23
+
24
+
25
+ def create_t5_encoder_decoder(pretrained_version="t5-base"):
26
+ """Generates an encoder and a decoder model with a language model head from a pretrained huggingface model
27
+
28
+ Args:
29
+ pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
30
+
31
+ Returns:
32
+ simplified_encoder: pytorch t5 encoder with a wrapper to output only the hidden states
33
+ decoder_with_lm_head: pytorch t5 decoder with a language modeling head
34
+ """
35
+
36
+ if 'mt5' in pretrained_version:
37
+ model = MT5ForConditionalGeneration.from_pretrained(pretrained_version, use_auth_token=get_auth_token())
38
+ else:
39
+ model = T5ForConditionalGeneration.from_pretrained(pretrained_version, use_auth_token=get_auth_token())
40
+
41
+ return turn_model_into_encoder_decoder(model)
42
+
43
+
44
+ def turn_model_into_encoder_decoder(model):
45
+ encoder = model.encoder
46
+ decoder = model.decoder
47
+ lm_head = model.lm_head
48
+
49
+ decoder_with_lm_head = DecoderWithLMhead(decoder, lm_head, model.config)
50
+ simplified_encoder = T5Encoder(encoder)
51
+ decoder_with_lm_head_init = DecoderWithLMheadInitial(decoder, lm_head, model.config)
52
+
53
+ return simplified_encoder, decoder_with_lm_head, decoder_with_lm_head_init
54
+
55
+
56
+ def generate_onnx_representation(
57
+ pretrained_version=None,
58
+ model=None,
59
+ output_path=None,
60
+ input_sequence_length=256,
61
+ onnx_opset_version=12, # no other opset versions are tested, change at your own risk
62
+ ):
63
+ """Exports a given huggingface pretrained model, or a given model and tokenizer, to onnx
64
+
65
+ Args:
66
+ pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
67
+ output_path (Optional[str]): if missing then use ./models
68
+ input_sequence_length (Optional[int]): typical input sequence length, for use by the ORT for possible optimization
69
+ onnx_opset_version (Optional[int]): ONNX Operator Set Version, default 12 is the only tested version
70
+ """
71
+ if (pretrained_version is None) and model is None:
72
+ print(
73
+ "You need to specify pretrained_version (the pretrained model you wish to export). Alternatively you can export a model you have in memory."
74
+ )
75
+ return
76
+
77
+ if model is not None:
78
+ (
79
+ simplified_encoder,
80
+ decoder_with_lm_head,
81
+ decoder_with_lm_head_init,
82
+ ) = turn_model_into_encoder_decoder(model)
83
+ else:
84
+ (
85
+ simplified_encoder,
86
+ decoder_with_lm_head,
87
+ decoder_with_lm_head_init,
88
+ ) = create_t5_encoder_decoder(pretrained_version)
89
+
90
+ # model paths for enc, dec and dec_init
91
+ output_path = saved_models_path if output_path is None else Path(output_path)
92
+ encoder_path, decoder_path, init_decoder_path = get_model_paths(
93
+ pretrained_version, output_path, quantized=False
94
+ )
95
+
96
+ model_config = AutoConfig.from_pretrained(pretrained_version, use_auth_token=get_auth_token())
97
+
98
+ # Though these are dummy inputs, ORT optimizations do reference these values,
99
+ # so it is worth using values as close to production as possible
100
+ batch_size = 1 # not configurable since only CPU
101
+ enc_seq_length = input_sequence_length
102
+ dec_seq_length = 1 # a decoder sequence length is always one because it's just the last generated token
103
+ input_ids = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
104
+ attention_mask = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
105
+
106
+ n_heads = model_config.num_heads
107
+ d_kv = model_config.d_kv
108
+
109
+ input_ids_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
110
+ attention_mask_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
111
+ enc_out = torch.ones(
112
+ (batch_size, enc_seq_length, model_config.d_model), dtype=torch.float32
113
+ )
114
+
115
+ # self_attention_past_key_values = torch.ones(
116
+ # (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_a, d_kv), dtype=torch.float32)
117
+ # cross_attention_past_key_values = torch.ones(
118
+ # (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_b, d_kv), dtype=torch.float32)
119
+
120
+ sa = torch.ones(
121
+ (batch_size, n_heads, dec_seq_length, d_kv), dtype=torch.float32
122
+ ) # 1, 8, 1, 64
123
+ ca = torch.ones(
124
+ (batch_size, n_heads, enc_seq_length, d_kv), dtype=torch.float32
125
+ ) # 1, 8, variable, 64
126
+ t5_block = (sa, sa, ca, ca)
127
+ past_key_values = (t5_block,) * model_config.num_decoder_layers
128
+
129
+ flat_past_key_values = functools.reduce(operator.iconcat, past_key_values, [])
130
+
131
+ decoder_all_inputs = tuple(
132
+ [input_ids_dec, attention_mask_dec, enc_out] + flat_past_key_values
133
+ )
134
+
135
+ # for progress bars
136
+ bar = Bar("Exporting to onnx...", max=3)
137
+
138
+ import warnings
139
+
140
+ # ignores all the warnings during conversion
141
+ warnings.filterwarnings("ignore")
142
+
143
+ # Exports to ONNX
144
+ with torch.no_grad():
145
+
146
+ decoder_inputs = [
147
+ "input_ids",
148
+ "encoder_attention_mask",
149
+ "encoder_hidden_states",
150
+ ]
151
+
152
+ pkv_input_names = ["pkv_{}".format(i) for i in range(len(flat_past_key_values))]
153
+
154
+ decoder_input_names = decoder_inputs + pkv_input_names
155
+
156
+ decoder_output_names = ["logits", "output_past_key_values"]
157
+
158
+ dyn_axis_general = {0: "batch", 1: "sequence"}
159
+ dyn_axis_pkv = {0: "batch", 2: "seq_length"}
160
+
161
+ dyn_axis = {
162
+ "input_ids": dyn_axis_general,
163
+ "encoder_attention_mask": dyn_axis_general,
164
+ "encoder_hidden_states": dyn_axis_general,
165
+ "logits": dyn_axis_general,
166
+ "output_past_key_values": dyn_axis_general,
167
+ }
168
+
169
+ dyn_pkv = {
170
+ "pkv_{}".format(i): dyn_axis_pkv
171
+ for i in range(len(flat_past_key_values))
172
+ }
173
+
174
+ dyn_axis_params = {**dyn_axis, **dyn_pkv}
175
+
176
+ # decoder to utilize past key values:
177
+ torch.onnx.export(
178
+ decoder_with_lm_head,
179
+ decoder_all_inputs,
180
+ decoder_path.as_posix(),
181
+ export_params=True,
182
+ do_constant_folding=True,
183
+ opset_version=onnx_opset_version,
184
+ input_names=decoder_input_names,
185
+ output_names=decoder_output_names,
186
+ dynamic_axes=dyn_axis_params,
187
+ )
188
+ bar.next()
189
+
190
+ torch.onnx.export(
191
+ simplified_encoder,
192
+ args=(input_ids, attention_mask),
193
+ f=encoder_path.as_posix(),
194
+ export_params=True,
195
+ opset_version=onnx_opset_version,
196
+ do_constant_folding=True,
197
+ input_names=["input_ids", "attention_mask"],
198
+ output_names=["hidden_states"],
199
+ dynamic_axes={
200
+ "input_ids": dyn_axis_general,
201
+ "attention_mask": dyn_axis_general,
202
+ "hidden_states": dyn_axis_general,
203
+ },
204
+ )
205
+ bar.next()
206
+ # initial decoder to produce past key values
207
+ torch.onnx.export(
208
+ decoder_with_lm_head_init,
209
+ (input_ids_dec, attention_mask_dec, enc_out),
210
+ init_decoder_path.as_posix(),
211
+ export_params=True,
212
+ opset_version=onnx_opset_version,
213
+ input_names=[
214
+ "input_ids",
215
+ "encoder_attention_mask",
216
+ "encoder_hidden_states",
217
+ ],
218
+ output_names=["logits", "past_key_values"],
219
+ dynamic_axes={
220
+ # batch_size, seq_length = input_shape
221
+ "input_ids": dyn_axis_general,
222
+ "encoder_attention_mask": dyn_axis_general,
223
+ "encoder_hidden_states": dyn_axis_general,
224
+ "logits": dyn_axis_general,
225
+ "past_key_values": dyn_axis_general,
226
+ },
227
+ )
228
+ bar.next()
229
+ bar.finish()
230
+
231
+ return encoder_path, decoder_path, init_decoder_path
232
+
233
+
234
+ def get_model_paths(pretrained_model, model_path, quantized):
235
+
236
+ model_path.mkdir(parents=True, exist_ok=True)
237
+
238
+ # gets only the filename
239
+ pretrained_model_name = Path(pretrained_model).stem
240
+
241
+ if not quantized:
242
+ encoder_path = model_path.joinpath(f"{pretrained_model_name}-encoder.onnx")
243
+ decoder_path = model_path.joinpath(f"{pretrained_model_name}-decoder.onnx")
244
+ init_decoder_path = model_path.joinpath(
245
+ f"{pretrained_model_name}-init-decoder.onnx"
246
+ )
247
+ else:
248
+ encoder_path = model_path.joinpath(
249
+ f"{pretrained_model_name}-encoder-quantized.onnx"
250
+ )
251
+ decoder_path = model_path.joinpath(
252
+ f"{pretrained_model_name}-decoder-quantized.onnx"
253
+ )
254
+ init_decoder_path = model_path.joinpath(
255
+ f"{pretrained_model_name}-init-decoder-quantized.onnx"
256
+ )
257
+
258
+ return encoder_path, decoder_path, init_decoder_path
259
+
260
+
261
+ def quantize(models_name_or_path):
262
+ """
263
+ Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
264
+
265
+ Uses unsigned ints for activation values, signed ints for weights, per
266
+ https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
267
+ it is faster on most CPU architectures
268
+ Args:
269
+ onnx_model_path: Path to location the exported ONNX model is stored
270
+ Returns: The Path generated for the quantized
271
+ """
272
+ from onnxruntime.quantization import quantize_dynamic, QuantType
273
+
274
+ bar = Bar("Quantizing...", max=3)
275
+
276
+ quant_model_paths = []
277
+ for model in models_name_or_path:
278
+ model_name = model.as_posix()
279
+ output_model_name = f"{model_name[:-5]}-quantized.onnx"
280
+ quantize_dynamic(
281
+ model_input=model_name,
282
+ model_output=output_model_name,
283
+ per_channel=True,
284
+ reduce_range=True, # should be the same as per_channel
285
+ activation_type=QuantType.QUInt8,
286
+ weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
287
+ optimize_model=False,
288
+ ) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
289
+ quant_model_paths.append(output_model_name)
290
+ bar.next()
291
+
292
+ bar.finish()
293
+
294
+ return tuple(quant_model_paths)
FastT5/onnx_models.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .huggingface_utils import get_auth_token
2
+ from .ort_settings import get_onnx_runtime_sessions
3
+ from .onnx_exporter import (
4
+ generate_onnx_representation,
5
+ quantize,
6
+ get_model_paths,
7
+ saved_models_path,
8
+ )
9
+ from pathlib import Path
10
+
11
+ from transformers import (
12
+ AutoConfig,
13
+ MT5Config,
14
+ T5ForConditionalGeneration,
15
+ )
16
+ from transformers.modeling_outputs import (
17
+ Seq2SeqLMOutput,
18
+ BaseModelOutput,
19
+ )
20
+ import torch
21
+ import functools
22
+ import operator
23
+ import numpy
24
+
25
+
26
+ class T5Encoder(torch.nn.Module):
27
+ def __init__(self, encoder_sess):
28
+ super().__init__()
29
+ self.encoder = encoder_sess
30
+ self.main_input_name = "input_ids"
31
+
32
+ def forward(
33
+ self,
34
+ input_ids,
35
+ attention_mask,
36
+ inputs_embeds=None,
37
+ head_mask=None,
38
+ output_attentions=None,
39
+ output_hidden_states=None,
40
+ return_dict=None,
41
+ ):
42
+
43
+ encoder_hidden_state = torch.from_numpy(
44
+ self.encoder.run(
45
+ None,
46
+ {
47
+ "input_ids": input_ids.cpu().numpy(),
48
+ "attention_mask": attention_mask.cpu().numpy(),
49
+ },
50
+ )[0]
51
+ )
52
+
53
+ return BaseModelOutput(encoder_hidden_state)
54
+
55
+
56
+ class T5DecoderInit(torch.nn.Module):
57
+ def __init__(self, decoder_sess):
58
+ super().__init__()
59
+ self.decoder = decoder_sess
60
+
61
+ def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):
62
+
63
+ decoder_outputs = self.decoder.run(
64
+ None,
65
+ {
66
+ "input_ids": input_ids.cpu().numpy(),
67
+ "encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
68
+ "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
69
+ },
70
+ )
71
+
72
+ list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
73
+
74
+ out_past_key_values = tuple(
75
+ list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
76
+ )
77
+
78
+ return torch.from_numpy(decoder_outputs[0]), out_past_key_values
79
+
80
+
81
+ class T5Decoder(torch.nn.Module):
82
+ def __init__(self, decoder_sess):
83
+ super().__init__()
84
+ self.decoder = decoder_sess
85
+
86
+ def forward(self, input_ids, attention_mask, encoder_output, past_key_values):
87
+
88
+ decoder_inputs = {
89
+ "input_ids": input_ids.cpu().numpy(),
90
+ "encoder_attention_mask": attention_mask.cpu().numpy(),
91
+ "encoder_hidden_states": encoder_output.cpu().numpy(),
92
+ }
93
+
94
+ flat_past_key_values = functools.reduce(
95
+ operator.iconcat, past_key_values, [])
96
+
97
+ past_key_values = {
98
+ f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
99
+ }
100
+
101
+ decoder_outputs = self.decoder.run(
102
+ None, {**decoder_inputs, **past_key_values})
103
+ # converts each value of the list to tensor from numpy
104
+ list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
105
+
106
+ # creates a tuple of tuples of shape 6x4 from the above tuple
107
+ out_past_key_values = tuple(
108
+ list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
109
+ )
110
+
111
+ return torch.from_numpy(decoder_outputs[0]), out_past_key_values
112
+
113
+
114
+ class OnnxT5(T5ForConditionalGeneration):
115
+ """creates a T5 model using onnx sessions (encode, decoder & init_decoder)"""
116
+
117
+ def __init__(self, model_or_model_path, onnx_model_sessions):
118
+ config = AutoConfig.from_pretrained(
119
+ model_or_model_path, use_auth_token=get_auth_token()
120
+ )
121
+ super().__init__(config)
122
+
123
+ # monkeypatch to work for MT5
124
+ if (
125
+ isinstance(model_or_model_path, str)
126
+ and "mt5" in model_or_model_path.lower()
127
+ ) or (
128
+ hasattr(model_or_model_path, "name_or_path")
129
+ and "mt5" in model_or_model_path.name_or_path
130
+ ):
131
+ self.model_type = "mt5"
132
+ self.config_class = MT5Config
133
+ self._keys_to_ignore_on_load_missing = [
134
+ r"encoder\.embed_tokens\.weight",
135
+ ]
136
+ self._keys_to_ignore_on_save = [
137
+ r"encoder\.embed_tokens\.weight",
138
+ ]
139
+
140
+ assert len(onnx_model_sessions) == 3, "all three models should be given"
141
+
142
+ encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
143
+
144
+ self.encoder = T5Encoder(encoder_sess)
145
+ self.decoder = T5Decoder(decoder_sess)
146
+ self.decoder_init = T5DecoderInit(decoder_sess_init)
147
+
148
+ def forward(
149
+ self,
150
+ input_ids=None,
151
+ attention_mask=None,
152
+ decoder_input_ids=None,
153
+ decoder_attention_mask=None,
154
+ head_mask=None,
155
+ decoder_head_mask=None,
156
+ cross_attn_head_mask=None,
157
+ encoder_outputs=None,
158
+ past_key_values=None,
159
+ inputs_embeds=None,
160
+ decoder_inputs_embeds=None,
161
+ labels=None,
162
+ use_cache=None,
163
+ output_attentions=None,
164
+ output_hidden_states=None,
165
+ return_dict=None,
166
+ ):
167
+
168
+ if encoder_outputs is None:
169
+ # Convert encoder inputs in embeddings if needed
170
+ encoder_outputs = self.encoder(
171
+ input_ids=input_ids, attention_mask=attention_mask
172
+ )
173
+
174
+ encoder_hidden_states = encoder_outputs[0]
175
+
176
+ if past_key_values is not None:
177
+ if decoder_input_ids is not None:
178
+ decoder_input_ids = decoder_input_ids[:, -1:]
179
+ if decoder_inputs_embeds is not None:
180
+ decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
181
+
182
+ if past_key_values is None:
183
+
184
+ # runs only for the first time:
185
+ init_onnx_outputs = self.decoder_init(
186
+ decoder_input_ids, attention_mask, encoder_hidden_states
187
+ )
188
+
189
+ logits, past_key_values = init_onnx_outputs
190
+
191
+ else:
192
+
193
+ onnx_outputs = self.decoder(
194
+ decoder_input_ids,
195
+ attention_mask,
196
+ encoder_hidden_states,
197
+ past_key_values,
198
+ )
199
+
200
+ logits, past_key_values = onnx_outputs
201
+
202
+ return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
203
+
204
+
205
+ def export_and_get_onnx_model(
206
+ model_or_model_path, custom_output_path=saved_models_path, quantized=True
207
+ ):
208
+ """
209
+ Method for whole pipeline,
210
+ converts from pytorch to onnx --> quantizes model --> sets onnx runtime
211
+ --> builds whole onnx model with all sessions
212
+
213
+ """
214
+
215
+ # Step 1. convert huggingfaces t5 model to onnx
216
+ onnx_model_paths = generate_onnx_representation(
217
+ model_or_model_path, output_path=custom_output_path
218
+ )
219
+
220
+ if quantized:
221
+ # Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
222
+ quant_model_paths = quantize(onnx_model_paths)
223
+
224
+ # step 3. setup onnx runtime
225
+ print("Setting up onnx model...")
226
+ model_sessions = get_onnx_runtime_sessions(quant_model_paths)
227
+ else:
228
+ print("Setting up onnx model...")
229
+ model_sessions = get_onnx_runtime_sessions(onnx_model_paths)
230
+
231
+ # step 4. get the onnx model
232
+ model = OnnxT5(model_or_model_path, model_sessions)
233
+ print("Done!")
234
+
235
+ return model
236
+
237
+
238
+ def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True):
239
+ """
240
+ method gets the onnx model, if already converted models exists
241
+ Example:
242
+ >> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/")
243
+
244
+ """
245
+
246
+ encoder_path, decoder_path, init_decoder_path = get_model_paths(
247
+ model_name, Path(onnx_models_path), quantized
248
+ )
249
+
250
+ if quantized:
251
+ assert (
252
+ encoder_path.exists()
253
+ and decoder_path.exists()
254
+ and init_decoder_path.exists()
255
+ ), "quantized model don't exist in the model folder, first quantize the model!"
256
+ else:
257
+ assert (
258
+ encoder_path.exists()
259
+ and decoder_path.exists()
260
+ and init_decoder_path.exists()
261
+ ), "all or some models don't exists in the model folder, first convert the model! "
262
+
263
+ model_paths = encoder_path, decoder_path, init_decoder_path
264
+
265
+ model_sessions = get_onnx_runtime_sessions(model_paths)
266
+
267
+ model = OnnxT5(model_name, model_sessions)
268
+
269
+ return model
FastT5/onnx_models_structure.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class DecoderWithLMhead(torch.nn.Module):
5
+ """ Creation of a class to combine the decoder and the lm head """
6
+
7
+ def __init__(self, decoder, lm_head, config):
8
+ super().__init__()
9
+ self.decoder = decoder
10
+ self.lm_head = lm_head
11
+ self.config = config
12
+
13
+ def forward(self, *inputs):
14
+
15
+ input_ids, attention_mask, encoder_hidden_states = inputs[:3]
16
+
17
+ list_pkv = inputs[3:]
18
+ past_key_values = tuple(list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4))
19
+
20
+ decoder_output = self.decoder(
21
+ input_ids=input_ids, # decoder_input_ids
22
+ encoder_attention_mask=attention_mask,
23
+ encoder_hidden_states=encoder_hidden_states,
24
+ past_key_values=past_key_values,
25
+ )
26
+
27
+ lm_head_out = self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5))
28
+
29
+ return lm_head_out, decoder_output[1]
30
+
31
+
32
+ class T5Encoder(torch.nn.Module):
33
+ """ Creation of a class to output only the last hidden state from the encoder """
34
+
35
+ def __init__(self, encoder):
36
+ super().__init__()
37
+ self.encoder = encoder
38
+
39
+ def forward(self, *input, **kwargs):
40
+ return self.encoder(*input, **kwargs)[0]
41
+
42
+
43
+ class DecoderWithLMheadInitial(torch.nn.Module):
44
+ """ Creation of a class to combine the decoder and the lm head """
45
+
46
+ def __init__(self, decoder, lm_head, config):
47
+ super().__init__()
48
+ self.decoder = decoder
49
+ self.lm_head = lm_head
50
+ self.config = config
51
+
52
+ def forward(self, input_ids, attention_mask, encoder_hidden_states):
53
+ decoder_output = self.decoder(
54
+ input_ids=input_ids,
55
+ encoder_attention_mask=attention_mask,
56
+ encoder_hidden_states=encoder_hidden_states,
57
+ )
58
+
59
+ return (
60
+ self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
61
+ decoder_output[1],
62
+ )
FastT5/ort_settings.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, psutil
2
+
3
+ os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=True))
4
+ os.environ["OMP_WAIT_POLICY"] = "ACTIVE"
5
+
6
+
7
+ from onnxruntime import (
8
+ GraphOptimizationLevel,
9
+ InferenceSession,
10
+ SessionOptions,
11
+ ExecutionMode,
12
+ )
13
+
14
+
15
+ def get_onnx_runtime_sessions(
16
+ model_paths,
17
+ default: bool = True,
18
+ opt_level: int = 99,
19
+ parallel_exe_mode: bool = True,
20
+ n_threads: int = 0,
21
+ provider=[
22
+ "CPUExecutionProvider",
23
+ ],
24
+ ) -> InferenceSession:
25
+ """
26
+ Optimizes the model
27
+
28
+ Args:
29
+ model_paths (List or Tuple of str) : the path to, in order:
30
+ path_to_encoder (str) : the path of input onnx encoder model.
31
+ path_to_decoder (str) : the path of input onnx decoder model.
32
+ path_to_initial_decoder (str) : the path of input initial onnx decoder model.
33
+ default : set this to true, ort will choose the best settings for your hardware.
34
+ (you can test out different settings for better results.)
35
+ opt_level (int) : sess_options.GraphOptimizationLevel param if set 1 uses 'ORT_ENABLE_BASIC',
36
+ 2 for 'ORT_ENABLE_EXTENDED' and 99 for 'ORT_ENABLE_ALL',
37
+ default value is set to 99.
38
+ parallel_exe_mode (bool) : Sets the execution mode. Default is True (parallel).
39
+ n_threads (int) : Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose
40
+ provider : execution providers list.
41
+
42
+ Returns:
43
+ encoder_session : encoder onnx InferenceSession
44
+ decoder_session : decoder onnx InferenceSession
45
+ decoder_sess_init : initial decoder onnx InferenceSession
46
+
47
+ """
48
+ path_to_encoder, path_to_decoder, path_to_initial_decoder = model_paths
49
+
50
+ if default:
51
+
52
+ encoder_sess = InferenceSession(str(path_to_encoder))
53
+
54
+ decoder_sess = InferenceSession(str(path_to_decoder))
55
+
56
+ decoder_sess_init = InferenceSession(str(path_to_initial_decoder))
57
+
58
+ else:
59
+
60
+ # Few properties that might have an impact on performances
61
+ options = SessionOptions()
62
+
63
+ if opt_level == 1:
64
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
65
+ elif opt_level == 2:
66
+ options.graph_optimization_level = (
67
+ GraphOptimizationLevel.ORT_ENABLE_EXTENDED
68
+ )
69
+ else:
70
+ assert opt_level == 99
71
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
72
+
73
+ # set this true for better performance
74
+ if parallel_exe_mode == True:
75
+ options.execution_mode = ExecutionMode.ORT_PARALLEL
76
+ else:
77
+ options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
78
+
79
+ options.intra_op_num_threads = n_threads
80
+ # options.inter_op_num_threads = 10
81
+
82
+ # options.enable_profiling = True
83
+
84
+ encoder_sess = InferenceSession(
85
+ str(path_to_encoder), options, providers=provider
86
+ )
87
+
88
+ decoder_sess = InferenceSession(
89
+ str(path_to_decoder), options, providers=provider
90
+ )
91
+
92
+ decoder_sess_init = InferenceSession(
93
+ str(path_to_initial_decoder), options, providers=provider
94
+ )
95
+
96
+ return encoder_sess, decoder_sess, decoder_sess_init
app.py CHANGED
@@ -1,742 +1,9 @@
1
- import psutil
2
- from transformers import (
3
- AutoConfig,
4
- T5ForConditionalGeneration,
5
- MT5ForConditionalGeneration,
6
- )
7
- import torch
8
  import time
9
  import gradio as gr
10
  from transformers import AutoTokenizer
11
- import onnxruntime as ort
12
- from transformers.modeling_outputs import (
13
- Seq2SeqLMOutput,
14
- BaseModelOutput,
15
- )
16
  import os
17
  from pathlib import Path
18
- from progress.bar import Bar
19
- import operator
20
- import functools
21
- from onnxruntime import (
22
- GraphOptimizationLevel,
23
- InferenceSession,
24
- SessionOptions,
25
- ExecutionMode,
26
- )
27
- _auth_token = None
28
-
29
-
30
- def set_auth_token(token):
31
- """Set the token which allows the user to authenticate to hugginface.co for downloading private models
32
-
33
- Args:
34
- token (Union[str, bool]): The token value to store. One of:
35
- - an API key (from https://huggingface.co/organizations/ORGNAME/settings/token),
36
- - a login token obtained by running `$ transformers-cli login`
37
- - `True`, which tells transformers to use the login token stored in ~/.huggingface/token
38
-
39
- Returns:
40
- None
41
- """
42
- global _auth_token
43
- _auth_token = token
44
-
45
-
46
- def get_auth_token():
47
- """Get the user-configurable auth token, which defaults to None
48
-
49
- Returns:
50
- auth_token (Optional[Union[str, bool]]) for authenticating with huggingface.co
51
- """
52
- global _auth_token
53
- return _auth_token
54
-
55
-
56
- os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=True))
57
- os.environ["OMP_WAIT_POLICY"] = "ACTIVE"
58
-
59
-
60
- def get_onnx_runtime_sessions(
61
- model_paths,
62
- default: bool = True,
63
- opt_level: int = 99,
64
- parallel_exe_mode: bool = True,
65
- n_threads: int = 0,
66
- provider=[
67
- "CPUExecutionProvider",
68
- ],
69
- ) -> InferenceSession:
70
- """
71
- Optimizes the model
72
-
73
- Args:
74
- model_paths (List or Tuple of str) : the path to, in order:
75
- path_to_encoder (str) : the path of input onnx encoder model.
76
- path_to_decoder (str) : the path of input onnx decoder model.
77
- path_to_initial_decoder (str) : the path of input initial onnx decoder model.
78
- default : set this to true, ort will choose the best settings for your hardware.
79
- (you can test out different settings for better results.)
80
- opt_level (int) : sess_options.GraphOptimizationLevel param if set 1 uses 'ORT_ENABLE_BASIC',
81
- 2 for 'ORT_ENABLE_EXTENDED' and 99 for 'ORT_ENABLE_ALL',
82
- default value is set to 99.
83
- parallel_exe_mode (bool) : Sets the execution mode. Default is True (parallel).
84
- n_threads (int) : Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose
85
- provider : execution providers list.
86
-
87
- Returns:
88
- encoder_session : encoder onnx InferenceSession
89
- decoder_session : decoder onnx InferenceSession
90
- decoder_sess_init : initial decoder onnx InferenceSession
91
-
92
- """
93
- path_to_encoder, path_to_decoder, path_to_initial_decoder = model_paths
94
-
95
- if default:
96
-
97
- encoder_sess = InferenceSession(str(path_to_encoder))
98
-
99
- decoder_sess = InferenceSession(str(path_to_decoder))
100
-
101
- decoder_sess_init = InferenceSession(str(path_to_initial_decoder))
102
-
103
- else:
104
-
105
- # Few properties that might have an impact on performances
106
- options = SessionOptions()
107
-
108
- if opt_level == 1:
109
- options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
110
- elif opt_level == 2:
111
- options.graph_optimization_level = (
112
- GraphOptimizationLevel.ORT_ENABLE_EXTENDED
113
- )
114
- else:
115
- assert opt_level == 99
116
- options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
117
-
118
- # set this true for better performance
119
- if parallel_exe_mode == True:
120
- options.execution_mode = ExecutionMode.ORT_PARALLEL
121
- else:
122
- options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
123
-
124
- options.intra_op_num_threads = n_threads
125
- # options.inter_op_num_threads = 10
126
-
127
- # options.enable_profiling = True
128
-
129
- encoder_sess = InferenceSession(
130
- str(path_to_encoder), options, providers=provider
131
- )
132
-
133
- decoder_sess = InferenceSession(
134
- str(path_to_decoder), options, providers=provider
135
- )
136
-
137
- decoder_sess_init = InferenceSession(
138
- str(path_to_initial_decoder), options, providers=provider
139
- )
140
-
141
- return encoder_sess, decoder_sess, decoder_sess_init
142
-
143
-
144
- class DecoderWithLMhead(torch.nn.Module):
145
- """ Creation of a class to combine the decoder and the lm head """
146
-
147
- def __init__(self, decoder, lm_head, config):
148
- super().__init__()
149
- self.decoder = decoder
150
- self.lm_head = lm_head
151
- self.config = config
152
-
153
- def forward(self, *inputs):
154
-
155
- input_ids, attention_mask, encoder_hidden_states = inputs[:3]
156
-
157
- list_pkv = inputs[3:]
158
- past_key_values = tuple(list_pkv[i: i + 4]
159
- for i in range(0, len(list_pkv), 4))
160
-
161
- decoder_output = self.decoder(
162
- input_ids=input_ids, # decoder_input_ids
163
- encoder_attention_mask=attention_mask,
164
- encoder_hidden_states=encoder_hidden_states,
165
- past_key_values=past_key_values,
166
- )
167
-
168
- lm_head_out = self.lm_head(
169
- decoder_output[0] * (self.config.d_model ** -0.5))
170
-
171
- return lm_head_out, decoder_output[1]
172
-
173
-
174
- class T5Encoder(torch.nn.Module):
175
- """ Creation of a class to output only the last hidden state from the encoder """
176
-
177
- def __init__(self, encoder):
178
- super().__init__()
179
- self.encoder = encoder
180
-
181
- def forward(self, *input, **kwargs):
182
- return self.encoder(*input, **kwargs)[0]
183
-
184
-
185
- class DecoderWithLMheadInitial(torch.nn.Module):
186
- """ Creation of a class to combine the decoder and the lm head """
187
-
188
- def __init__(self, decoder, lm_head, config):
189
- super().__init__()
190
- self.decoder = decoder
191
- self.lm_head = lm_head
192
- self.config = config
193
-
194
- def forward(self, input_ids, attention_mask, encoder_hidden_states):
195
- decoder_output = self.decoder(
196
- input_ids=input_ids,
197
- encoder_attention_mask=attention_mask,
198
- encoder_hidden_states=encoder_hidden_states,
199
- )
200
-
201
- return (
202
- self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
203
- decoder_output[1],
204
- )
205
-
206
-
207
- _folder = Path.cwd()
208
- saved_models_path = _folder.joinpath("models")
209
-
210
- Bar.check_tty = False
211
-
212
-
213
- def create_t5_encoder_decoder(pretrained_version="t5-base"):
214
- """Generates an encoder and a decoder model with a language model head from a pretrained huggingface model
215
-
216
- Args:
217
- pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
218
-
219
- Returns:
220
- simplified_encoder: pytorch t5 encoder with a wrapper to output only the hidden states
221
- decoder_with_lm_head: pytorch t5 decoder with a language modeling head
222
- """
223
-
224
- if 'mt5' in pretrained_version:
225
- model = MT5ForConditionalGeneration.from_pretrained(
226
- pretrained_version, use_auth_token=get_auth_token())
227
- else:
228
- model = T5ForConditionalGeneration.from_pretrained(
229
- pretrained_version, use_auth_token=get_auth_token())
230
-
231
- return turn_model_into_encoder_decoder(model)
232
-
233
-
234
- def turn_model_into_encoder_decoder(model):
235
- encoder = model.encoder
236
- decoder = model.decoder
237
- lm_head = model.lm_head
238
-
239
- decoder_with_lm_head = DecoderWithLMhead(decoder, lm_head, model.config)
240
- simplified_encoder = T5Encoder(encoder)
241
- decoder_with_lm_head_init = DecoderWithLMheadInitial(
242
- decoder, lm_head, model.config)
243
-
244
- return simplified_encoder, decoder_with_lm_head, decoder_with_lm_head_init
245
-
246
-
247
- def generate_onnx_representation(
248
- pretrained_version=None,
249
- model=None,
250
- output_path=None,
251
- input_sequence_length=256,
252
- onnx_opset_version=12, # no other opset versions are tested, change at your own risk
253
- ):
254
- """Exports a given huggingface pretrained model, or a given model and tokenizer, to onnx
255
-
256
- Args:
257
- pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
258
- output_path (Optional[str]): if missing then use ./models
259
- input_sequence_length (Optional[int]): typical input sequence length, for use by the ORT for possible optimization
260
- onnx_opset_version (Optional[int]): ONNX Operator Set Version, default 12 is the only tested version
261
- """
262
- if (pretrained_version is None) and model is None:
263
- print(
264
- "You need to specify pretrained_version (the pretrained model you wish to export). Alternatively you can export a model you have in memory."
265
- )
266
- return
267
-
268
- if model is not None:
269
- (
270
- simplified_encoder,
271
- decoder_with_lm_head,
272
- decoder_with_lm_head_init,
273
- ) = turn_model_into_encoder_decoder(model)
274
- else:
275
- (
276
- simplified_encoder,
277
- decoder_with_lm_head,
278
- decoder_with_lm_head_init,
279
- ) = create_t5_encoder_decoder(pretrained_version)
280
-
281
- # model paths for enc, dec and dec_init
282
- output_path = saved_models_path if output_path is None else Path(
283
- output_path)
284
- encoder_path, decoder_path, init_decoder_path = get_model_paths(
285
- pretrained_version, output_path, quantized=False
286
- )
287
-
288
- model_config = AutoConfig.from_pretrained(
289
- pretrained_version, use_auth_token=get_auth_token())
290
-
291
- # Though these are dummy inputs, ORT optimizations do reference these values,
292
- # so it is worth using values as close to production as possible
293
- batch_size = 1 # not configurable since only CPU
294
- enc_seq_length = input_sequence_length
295
- # a decoder sequence length is always one because it's just the last generated token
296
- dec_seq_length = 1
297
- input_ids = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
298
- attention_mask = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
299
-
300
- n_heads = model_config.num_heads
301
- d_kv = model_config.d_kv
302
-
303
- input_ids_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
304
- attention_mask_dec = torch.ones(
305
- batch_size, dec_seq_length, dtype=torch.int64)
306
- enc_out = torch.ones(
307
- (batch_size, enc_seq_length, model_config.d_model), dtype=torch.float32
308
- )
309
-
310
- # self_attention_past_key_values = torch.ones(
311
- # (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_a, d_kv), dtype=torch.float32)
312
- # cross_attention_past_key_values = torch.ones(
313
- # (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_b, d_kv), dtype=torch.float32)
314
-
315
- sa = torch.ones(
316
- (batch_size, n_heads, dec_seq_length, d_kv), dtype=torch.float32
317
- ) # 1, 8, 1, 64
318
- ca = torch.ones(
319
- (batch_size, n_heads, enc_seq_length, d_kv), dtype=torch.float32
320
- ) # 1, 8, variable, 64
321
- t5_block = (sa, sa, ca, ca)
322
- past_key_values = (t5_block,) * model_config.num_decoder_layers
323
-
324
- flat_past_key_values = functools.reduce(
325
- operator.iconcat, past_key_values, [])
326
-
327
- decoder_all_inputs = tuple(
328
- [input_ids_dec, attention_mask_dec, enc_out] + flat_past_key_values
329
- )
330
-
331
- # for progress bars
332
- bar = Bar("Exporting to onnx...", max=3)
333
-
334
- import warnings
335
-
336
- # ignores all the warnings during conversion
337
- warnings.filterwarnings("ignore")
338
-
339
- # Exports to ONNX
340
- with torch.no_grad():
341
-
342
- decoder_inputs = [
343
- "input_ids",
344
- "encoder_attention_mask",
345
- "encoder_hidden_states",
346
- ]
347
-
348
- pkv_input_names = ["pkv_{}".format(
349
- i) for i in range(len(flat_past_key_values))]
350
-
351
- decoder_input_names = decoder_inputs + pkv_input_names
352
-
353
- decoder_output_names = ["logits", "output_past_key_values"]
354
-
355
- dyn_axis_general = {0: "batch", 1: "sequence"}
356
- dyn_axis_pkv = {0: "batch", 2: "seq_length"}
357
-
358
- dyn_axis = {
359
- "input_ids": dyn_axis_general,
360
- "encoder_attention_mask": dyn_axis_general,
361
- "encoder_hidden_states": dyn_axis_general,
362
- "logits": dyn_axis_general,
363
- "output_past_key_values": dyn_axis_general,
364
- }
365
-
366
- dyn_pkv = {
367
- "pkv_{}".format(i): dyn_axis_pkv
368
- for i in range(len(flat_past_key_values))
369
- }
370
-
371
- dyn_axis_params = {**dyn_axis, **dyn_pkv}
372
-
373
- # decoder to utilize past key values:
374
- torch.onnx.export(
375
- decoder_with_lm_head,
376
- decoder_all_inputs,
377
- decoder_path.as_posix(),
378
- export_params=True,
379
- do_constant_folding=True,
380
- opset_version=onnx_opset_version,
381
- input_names=decoder_input_names,
382
- output_names=decoder_output_names,
383
- dynamic_axes=dyn_axis_params,
384
- )
385
- bar.next()
386
-
387
- torch.onnx.export(
388
- simplified_encoder,
389
- args=(input_ids, attention_mask),
390
- f=encoder_path.as_posix(),
391
- export_params=True,
392
- opset_version=onnx_opset_version,
393
- do_constant_folding=True,
394
- input_names=["input_ids", "attention_mask"],
395
- output_names=["hidden_states"],
396
- dynamic_axes={
397
- "input_ids": dyn_axis_general,
398
- "attention_mask": dyn_axis_general,
399
- "hidden_states": dyn_axis_general,
400
- },
401
- )
402
- bar.next()
403
- # initial decoder to produce past key values
404
- torch.onnx.export(
405
- decoder_with_lm_head_init,
406
- (input_ids_dec, attention_mask_dec, enc_out),
407
- init_decoder_path.as_posix(),
408
- export_params=True,
409
- opset_version=onnx_opset_version,
410
- input_names=[
411
- "input_ids",
412
- "encoder_attention_mask",
413
- "encoder_hidden_states",
414
- ],
415
- output_names=["logits", "past_key_values"],
416
- dynamic_axes={
417
- # batch_size, seq_length = input_shape
418
- "input_ids": dyn_axis_general,
419
- "encoder_attention_mask": dyn_axis_general,
420
- "encoder_hidden_states": dyn_axis_general,
421
- "logits": dyn_axis_general,
422
- "past_key_values": dyn_axis_general,
423
- },
424
- )
425
- bar.next()
426
- bar.finish()
427
-
428
- return encoder_path, decoder_path, init_decoder_path
429
-
430
-
431
- def get_model_paths(pretrained_model, model_path, quantized):
432
-
433
- model_path.mkdir(parents=True, exist_ok=True)
434
-
435
- # gets only the filename
436
- pretrained_model_name = Path(pretrained_model).stem
437
-
438
- if not quantized:
439
- encoder_path = model_path.joinpath(
440
- f"{pretrained_model_name}-encoder.onnx")
441
- decoder_path = model_path.joinpath(
442
- f"{pretrained_model_name}-decoder.onnx")
443
- init_decoder_path = model_path.joinpath(
444
- f"{pretrained_model_name}-init-decoder.onnx"
445
- )
446
- else:
447
- encoder_path = model_path.joinpath(
448
- f"{pretrained_model_name}-encoder-quantized.onnx"
449
- )
450
- decoder_path = model_path.joinpath(
451
- f"{pretrained_model_name}-decoder-quantized.onnx"
452
- )
453
- init_decoder_path = model_path.joinpath(
454
- f"{pretrained_model_name}-init-decoder-quantized.onnx"
455
- )
456
-
457
- return encoder_path, decoder_path, init_decoder_path
458
-
459
-
460
- def quantize(models_name_or_path):
461
- """
462
- Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
463
-
464
- Uses unsigned ints for activation values, signed ints for weights, per
465
- https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
466
- it is faster on most CPU architectures
467
- Args:
468
- onnx_model_path: Path to location the exported ONNX model is stored
469
- Returns: The Path generated for the quantized
470
- """
471
- from onnxruntime.quantization import quantize_dynamic, QuantType
472
-
473
- bar = Bar("Quantizing...", max=3)
474
-
475
- quant_model_paths = []
476
- for model in models_name_or_path:
477
- model_name = model.as_posix()
478
- output_model_name = f"{model_name[:-5]}-quantized.onnx"
479
- quantize_dynamic(
480
- model_input=model_name,
481
- model_output=output_model_name,
482
- per_channel=True,
483
- reduce_range=True, # should be the same as per_channel
484
- activation_type=QuantType.QUInt8,
485
- weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
486
- optimize_model=False,
487
- ) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
488
- quant_model_paths.append(output_model_name)
489
- bar.next()
490
-
491
- bar.finish()
492
-
493
- return tuple(quant_model_paths)
494
-
495
-
496
- class T5Encoder(torch.nn.Module):
497
- def __init__(self, encoder_sess):
498
- super().__init__()
499
- self.encoder = encoder_sess
500
- self.main_input_name = "input_ids"
501
-
502
- def forward(
503
- self,
504
- input_ids,
505
- attention_mask,
506
- inputs_embeds=None,
507
- head_mask=None,
508
- output_attentions=None,
509
- output_hidden_states=None,
510
- return_dict=None,
511
- ):
512
-
513
- encoder_hidden_state = torch.from_numpy(
514
- self.encoder.run(
515
- None,
516
- {
517
- "input_ids": input_ids.cpu().numpy(),
518
- "attention_mask": attention_mask.cpu().numpy(),
519
- },
520
- )[0]
521
- )
522
-
523
- return BaseModelOutput(encoder_hidden_state)
524
-
525
-
526
- class T5DecoderInit(torch.nn.Module):
527
- def __init__(self, decoder_sess):
528
- super().__init__()
529
- self.decoder = decoder_sess
530
-
531
- def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):
532
-
533
- decoder_outputs = self.decoder.run(
534
- None,
535
- {
536
- "input_ids": input_ids.cpu().numpy(),
537
- "encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
538
- "encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
539
- },
540
- )
541
-
542
- list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
543
-
544
- out_past_key_values = tuple(
545
- list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
546
- )
547
-
548
- return torch.from_numpy(decoder_outputs[0]), out_past_key_values
549
-
550
-
551
- class T5Decoder(torch.nn.Module):
552
- def __init__(self, decoder_sess):
553
- super().__init__()
554
- self.decoder = decoder_sess
555
-
556
- def forward(self, input_ids, attention_mask, encoder_output, past_key_values):
557
-
558
- decoder_inputs = {
559
- "input_ids": input_ids.cpu().numpy(),
560
- "encoder_attention_mask": attention_mask.cpu().numpy(),
561
- "encoder_hidden_states": encoder_output.cpu().numpy(),
562
- }
563
-
564
- flat_past_key_values = functools.reduce(
565
- operator.iconcat, past_key_values, [])
566
-
567
- past_key_values = {
568
- f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
569
- }
570
-
571
- decoder_outputs = self.decoder.run(
572
- None, {**decoder_inputs, **past_key_values})
573
- # converts each value of the list to tensor from numpy
574
- list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
575
-
576
- # creates a tuple of tuples of shape 6x4 from the above tuple
577
- out_past_key_values = tuple(
578
- list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
579
- )
580
-
581
- return torch.from_numpy(decoder_outputs[0]), out_past_key_values
582
-
583
-
584
- class OnnxT5(T5ForConditionalGeneration):
585
- """creates a T5 model using onnx sessions (encode, decoder & init_decoder)"""
586
-
587
- def __init__(self, model_or_model_path, onnx_model_sessions):
588
- config = AutoConfig.from_pretrained(
589
- model_or_model_path, use_auth_token=get_auth_token()
590
- )
591
- super().__init__(config)
592
-
593
- # monkeypatch to work for MT5
594
- if (
595
- isinstance(model_or_model_path, str)
596
- and "mt5" in model_or_model_path.lower()
597
- ) or (
598
- hasattr(model_or_model_path, "name_or_path")
599
- and "mt5" in model_or_model_path.name_or_path
600
- ):
601
- self.model_type = "mt5"
602
- self.config_class = MT5Config
603
- self._keys_to_ignore_on_load_missing = [
604
- r"encoder\.embed_tokens\.weight",
605
- ]
606
- self._keys_to_ignore_on_save = [
607
- r"encoder\.embed_tokens\.weight",
608
- ]
609
-
610
- assert len(onnx_model_sessions) == 3, "all three models should be given"
611
-
612
- encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
613
-
614
- self.encoder = T5Encoder(encoder_sess)
615
- self.decoder = T5Decoder(decoder_sess)
616
- self.decoder_init = T5DecoderInit(decoder_sess_init)
617
-
618
- def forward(
619
- self,
620
- input_ids=None,
621
- attention_mask=None,
622
- decoder_input_ids=None,
623
- decoder_attention_mask=None,
624
- head_mask=None,
625
- decoder_head_mask=None,
626
- cross_attn_head_mask=None,
627
- encoder_outputs=None,
628
- past_key_values=None,
629
- inputs_embeds=None,
630
- decoder_inputs_embeds=None,
631
- labels=None,
632
- use_cache=None,
633
- output_attentions=None,
634
- output_hidden_states=None,
635
- return_dict=None,
636
- ):
637
-
638
- if encoder_outputs is None:
639
- # Convert encoder inputs in embeddings if needed
640
- encoder_outputs = self.encoder(
641
- input_ids=input_ids, attention_mask=attention_mask
642
- )
643
-
644
- encoder_hidden_states = encoder_outputs[0]
645
-
646
- if past_key_values is not None:
647
- if decoder_input_ids is not None:
648
- decoder_input_ids = decoder_input_ids[:, -1:]
649
- if decoder_inputs_embeds is not None:
650
- decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
651
-
652
- if past_key_values is None:
653
-
654
- # runs only for the first time:
655
- init_onnx_outputs = self.decoder_init(
656
- decoder_input_ids, attention_mask, encoder_hidden_states
657
- )
658
-
659
- logits, past_key_values = init_onnx_outputs
660
-
661
- else:
662
-
663
- onnx_outputs = self.decoder(
664
- decoder_input_ids,
665
- attention_mask,
666
- encoder_hidden_states,
667
- past_key_values,
668
- )
669
-
670
- logits, past_key_values = onnx_outputs
671
-
672
- return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
673
-
674
-
675
- def export_and_get_onnx_model(
676
- model_or_model_path, custom_output_path=saved_models_path, quantized=True
677
- ):
678
- """
679
- Method for whole pipeline,
680
- converts from pytorch to onnx --> quantizes model --> sets onnx runtime
681
- --> builds whole onnx model with all sessions
682
-
683
- """
684
-
685
- # Step 1. convert huggingfaces t5 model to onnx
686
- onnx_model_paths = generate_onnx_representation(
687
- model_or_model_path, output_path=custom_output_path
688
- )
689
-
690
- if quantized:
691
- # Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
692
- quant_model_paths = quantize(onnx_model_paths)
693
-
694
- # step 3. setup onnx runtime
695
- print("Setting up onnx model...")
696
- model_sessions = get_onnx_runtime_sessions(quant_model_paths)
697
- else:
698
- print("Setting up onnx model...")
699
- model_sessions = get_onnx_runtime_sessions(onnx_model_paths)
700
-
701
- # step 4. get the onnx model
702
- model = OnnxT5(model_or_model_path, model_sessions)
703
- print("Done!")
704
-
705
- return model
706
-
707
-
708
- def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True):
709
- """
710
- method gets the onnx model, if already converted models exists
711
- Example:
712
- >> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/")
713
-
714
- """
715
-
716
- encoder_path, decoder_path, init_decoder_path = get_model_paths(
717
- model_name, Path(onnx_models_path), quantized
718
- )
719
-
720
- if quantized:
721
- assert (
722
- encoder_path.exists()
723
- and decoder_path.exists()
724
- and init_decoder_path.exists()
725
- ), "quantized model don't exist in the model folder, first quantize the model!"
726
- else:
727
- assert (
728
- encoder_path.exists()
729
- and decoder_path.exists()
730
- and init_decoder_path.exists()
731
- ), "all or some models don't exists in the model folder, first convert the model! "
732
-
733
- model_paths = encoder_path, decoder_path, init_decoder_path
734
-
735
- model_sessions = get_onnx_runtime_sessions(model_paths)
736
-
737
- model = OnnxT5(model_name, model_sessions)
738
-
739
- return model
740
 
741
 
742
  trained_model_path = './t5_squad_v1/'
 
 
 
 
 
 
 
 
1
  import time
2
  import gradio as gr
3
  from transformers import AutoTokenizer
 
 
 
 
 
4
  import os
5
  from pathlib import Path
6
+ from FastT5 import get_onnx_runtime_sessions, OnnxT5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  trained_model_path = './t5_squad_v1/'