bankholdup commited on
Commit
b4f9c39
1 Parent(s): 48db4ac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import argparse
4
+ import logging
5
+
6
+ import numpy as np
7
+ import torch
8
+ import datetime
9
+
10
+ from transformers import (
11
+ CTRLLMHeadModel,
12
+ CTRLTokenizer,
13
+ GPT2LMHeadModel,
14
+ GPT2Tokenizer,
15
+ OpenAIGPTLMHeadModel,
16
+ OpenAIGPTTokenizer,
17
+ TransfoXLLMHeadModel,
18
+ TransfoXLTokenizer,
19
+ XLMTokenizer,
20
+ XLMWithLMHeadModel,
21
+ XLNetLMHeadModel,
22
+ XLNetTokenizer,
23
+ )
24
+
25
+
26
+ logging.basicConfig(
27
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
32
+
33
+ MODEL_CLASSES = {
34
+ "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
35
+ "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
36
+ "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
37
+ "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
38
+ "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
39
+ "xlm": (XLMWithLMHeadModel, XLMTokenizer),
40
+ }
41
+
42
+ def set_seed(args):
43
+ rd = np.random.randint(100000)
44
+ print('seed =', rd)
45
+ np.random.seed(rd)
46
+ torch.manual_seed(rd)
47
+ if args.n_gpu > 0:
48
+ torch.cuda.manual_seed_all(rd)
49
+
50
+ #
51
+ # Functions to prepare models' input
52
+ #
53
+
54
+
55
+ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
56
+ if args.temperature > 0.7:
57
+ logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
58
+
59
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
60
+ if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
61
+ logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
62
+ return prompt_text
63
+
64
+
65
+ def prepare_xlm_input(args, model, tokenizer, prompt_text):
66
+ # kwargs = {"language": None, "mask_token_id": None}
67
+
68
+ # Set the language
69
+ use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
70
+ if hasattr(model.config, "lang2id") and use_lang_emb:
71
+ available_languages = model.config.lang2id.keys()
72
+ if args.xlm_language in available_languages:
73
+ language = args.xlm_language
74
+ else:
75
+ language = None
76
+ while language not in available_languages:
77
+ language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
78
+
79
+ model.config.lang_id = model.config.lang2id[language]
80
+ # kwargs["language"] = tokenizer.lang2id[language]
81
+
82
+ # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
83
+ # XLM masked-language modeling (MLM) models need masked token
84
+ # is_xlm_mlm = "mlm" in args.model_name_or_path
85
+ # if is_xlm_mlm:
86
+ # kwargs["mask_token_id"] = tokenizer.mask_token_id
87
+
88
+ return prompt_text
89
+
90
+
91
+ def prepare_xlnet_input(args, _, tokenizer, prompt_text):
92
+ prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
93
+ return prompt_text
94
+
95
+
96
+ def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
97
+ prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
98
+ return prompt_text
99
+
100
+
101
+ PREPROCESSING_FUNCTIONS = {
102
+ "ctrl": prepare_ctrl_input,
103
+ "xlm": prepare_xlm_input,
104
+ "xlnet": prepare_xlnet_input,
105
+ "transfo-xl": prepare_transfoxl_input,
106
+ }
107
+
108
+
109
+ def adjust_length_to_model(length, max_sequence_length):
110
+ if length < 0 and max_sequence_length > 0:
111
+ length = max_sequence_length
112
+ elif 0 < max_sequence_length < length:
113
+ length = max_sequence_length # No generation bigger than model size
114
+ elif length < 0:
115
+ length = MAX_LENGTH # avoid infinite loop
116
+ return length
117
+
118
+
119
+ def main():
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument(
122
+ "--model_type",
123
+ default=None,
124
+ type=str,
125
+ required=True,
126
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
127
+ )
128
+ parser.add_argument(
129
+ "--model_name_or_path",
130
+ default=None,
131
+ type=str,
132
+ required=True,
133
+ help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
134
+ )
135
+
136
+ parser.add_argument("--prompt", type=str, default="")
137
+ parser.add_argument("--length", type=int, default=20)
138
+ parser.add_argument("--stop_token", type=str, default="</s>", help="Token at which lyrics generation is stopped")
139
+
140
+ parser.add_argument(
141
+ "--temperature",
142
+ type=float,
143
+ default=1.0,
144
+ help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
145
+ )
146
+ parser.add_argument(
147
+ "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
148
+ )
149
+ parser.add_argument("--k", type=int, default=0)
150
+ parser.add_argument("--p", type=float, default=0.9)
151
+
152
+ parser.add_argument("--padding_text", type=str, default="", help="Padding lyrics for Transfo-XL and XLNet.")
153
+ parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
154
+
155
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
156
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
157
+ parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
158
+ args = parser.parse_args()
159
+
160
+ args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
161
+ args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
162
+
163
+ # Initialize the model and tokenizer
164
+ try:
165
+ args.model_type = args.model_type.lower()
166
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
167
+ except KeyError:
168
+ raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
169
+
170
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
171
+ model = model_class.from_pretrained(args.model_name_or_path)
172
+ model.to(args.device)
173
+
174
+ args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
175
+ logger.info(args)
176
+ generated_sequences = []
177
+ prompt_text = ""
178
+ while prompt_text != "stop":
179
+ set_seed(args)
180
+ while not len(prompt_text):
181
+ prompt_text = args.prompt if args.prompt else input("Context >>> ")
182
+
183
+ # Different models need different input formatting and/or extra arguments
184
+ requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
185
+ if requires_preprocessing:
186
+ prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
187
+ preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
188
+ encoded_prompt = tokenizer.encode(
189
+ preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
190
+ )
191
+ else:
192
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
193
+ encoded_prompt = encoded_prompt.to(args.device)
194
+
195
+ output_sequences = model.generate(
196
+ input_ids=encoded_prompt,
197
+ max_length=args.length + len(encoded_prompt[0]),
198
+ temperature=args.temperature,
199
+ top_k=args.k,
200
+ top_p=args.p,
201
+ repetition_penalty=args.repetition_penalty,
202
+ do_sample=True,
203
+ num_return_sequences=args.num_return_sequences,
204
+ )
205
+
206
+ # Remove the batch dimension when returning multiple sequences
207
+ if len(output_sequences.shape) > 2:
208
+ output_sequences.squeeze_()
209
+
210
+ now = datetime.datetime.now()
211
+ date_time = now.strftime('%Y%m%d_%H%M%S%f')
212
+
213
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
214
+ print("ruGPT:".format(generated_sequence_idx + 1))
215
+ generated_sequence = generated_sequence.tolist()
216
+
217
+ # Decode lyrics
218
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
219
+
220
+ # Remove all lyrics after the stop token
221
+ text = text[: text.find(args.stop_token) if args.stop_token else None]
222
+
223
+ # Add the prompt at the beginning of the sequence. Remove the excess lyrics that was used for pre-processing
224
+ total_sequence = (
225
+ prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
226
+ )
227
+
228
+ generated_sequences.append(total_sequence)
229
+ # os.system('clear')
230
+ print(total_sequence)
231
+
232
+ fileName = '{}/{}.txt'.format('./songs', date_time)
233
+ with open(fileName, 'w') as f:
234
+ f.write(total_sequence)
235
+
236
+ prompt_text = ""
237
+ if args.prompt:
238
+ break
239
+
240
+ return generated_sequences
241
+
242
+
243
+ if __name__ == "__main__":
244
+ main()