pappymu commited on
Commit
8cadfac
1 Parent(s): fd65c79

Create new file

Browse files
Files changed (1) hide show
  1. app.py +255 -0
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ # modeling
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning import Trainer, seed_everything
10
+ from pytorch_lightning.callbacks import ModelCheckpoint
11
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
12
+ from transformers import (
13
+ T5ForConditionalGeneration,
14
+ T5TokenizerFast as T5Tokenizer,
15
+ )
16
+ from transformers.optimization import Adafactor
17
+
18
+ # aesthetics
19
+ from IPython.display import Markdown, display, clear_output
20
+ import re
21
+ import warnings
22
+ warnings.filterwarnings(
23
+ "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*"
24
+ )
25
+ seed_everything(25429)
26
+
27
+ # scoring
28
+ import spacy
29
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
30
+
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ print(device)
33
+
34
+ # loading the model
35
+ hug = 't5-small'
36
+ t5tokenizer = T5Tokenizer.from_pretrained(hug)
37
+ t5model = T5ForConditionalGeneration.from_pretrained(hug, return_dict=True)
38
+
39
+ # defining tokens
40
+ SEP_TOKEN = '<sep>'
41
+ MASK_TOKEN = '[MASK]'
42
+ MASKING_CHANCE = 0.1
43
+
44
+ class DataEncodings(Dataset):
45
+ '''
46
+ tokenizes, pads, and adds special tokens
47
+ '''
48
+ def __init__(
49
+ self,
50
+ data: pd.DataFrame,
51
+ tokenizer,
52
+ source_max_token_len: int,
53
+ target_max_token_len: int
54
+ ):
55
+ self.tokenizer = t5tokenizer
56
+ self.data = data
57
+ self.source_max_token_len = source_max_token_len
58
+ self.target_max_token_len = target_max_token_len
59
+
60
+ def __len__(self):
61
+ return len(self.data)
62
+
63
+ def __getitem__(self, index:int):
64
+ data_row = self.data.iloc[index]
65
+ # adds a random mask for answer-agnostic qg
66
+ if np.random.rand() > MASKING_CHANCE:
67
+ answer = data_row['answer']
68
+ else:
69
+ answer = MASK_TOKEN
70
+
71
+ source_encoding = t5tokenizer(
72
+ f"{answer} {SEP_TOKEN} {data_row['context']}",
73
+ max_length= self.source_max_token_len,
74
+ padding='max_length',
75
+ truncation= True,
76
+ return_attention_mask=True,
77
+ return_tensors='pt'
78
+ )
79
+
80
+ target_encoding = t5tokenizer(
81
+ f"{data_row['answer']} {SEP_TOKEN} {data_row['question']}",
82
+ max_length=self.target_max_token_len,
83
+ padding='max_length',
84
+ truncation = True,
85
+ return_attention_mask=True,
86
+ return_tensors='pt'
87
+ )
88
+
89
+ labels = target_encoding['input_ids']
90
+ labels[labels == 0] = -100 # masked
91
+
92
+ encodings = dict(
93
+ answer = data_row['answer'],
94
+ context = data_row['context'],
95
+ question = data_row['question'],
96
+ input_ids = source_encoding['input_ids'].flatten(),
97
+ attention_mask = source_encoding['attention_mask'].flatten(),
98
+ labels=labels.flatten()
99
+ )
100
+
101
+ return encodings
102
+
103
+ class DataModule(pl.LightningDataModule):
104
+
105
+ def __init__(
106
+ self,
107
+ train: pd.DataFrame,
108
+ val: pd.DataFrame,
109
+ tokenizer,
110
+ batch_size,
111
+ source_max_token_len: int,
112
+ target_max_token_len: int
113
+ ):
114
+ super().__init__()
115
+ self.batch_size = batch_size
116
+ self.train = train
117
+ self.val = val
118
+ self.tokenizer = t5tokenizer
119
+ self.source_max_token_len = source_max_token_len
120
+ self.target_max_token_len = target_max_token_len
121
+
122
+ def setup(self):
123
+ self.train_dataset = DataEncodings(self.train, self.tokenizer, self.source_max_token_len, self.target_max_token_len)
124
+ self.val_dataset = DataEncodings(self.val, self.tokenizer, self.source_max_token_len, self.target_max_token_len)
125
+
126
+ def train_dataloader(self):
127
+ return DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
128
+
129
+ def val_dataloader(self):
130
+ return DataLoader(self.val_dataset, batch_size=batch_size, num_workers=0)
131
+
132
+ # hyperparameters
133
+ num_epochs = 16
134
+ batch_size = 32
135
+ learning_rate = 0.001
136
+
137
+ # model
138
+ class T5Model(pl.LightningModule):
139
+ def __init__(self):
140
+ super().__init__()
141
+ self.model = t5model
142
+ self.model.resize_token_embeddings(len(t5tokenizer)) # resizing after adding new tokens to the tokenizer
143
+
144
+ # feed forward pass
145
+ def forward(self, input_ids, attention_mask, labels=None):
146
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
147
+ return output.loss, output.logits
148
+
149
+ # train model and compute loss
150
+ def training_step(self, batch, batch_idx):
151
+ input_ids = batch['input_ids']
152
+ attention_mask = batch['attention_mask']
153
+ labels = batch['labels']
154
+ loss, output = self(input_ids, attention_mask, labels)
155
+ self.log('train_loss', loss, prog_bar=True, logger=True, batch_size=batch_size)
156
+ return loss
157
+
158
+ # gets model predictions, returns loss
159
+ def validation_step(self, batch, batch_idx):
160
+ input_ids = batch['input_ids']
161
+ attention_mask = batch['attention_mask']
162
+ labels = batch['labels']
163
+ loss, output = self(input_ids, attention_mask, labels)
164
+ self.log('val_loss', loss, prog_bar=True, logger=True, batch_size=batch_size)
165
+ return {'val loss': loss}
166
+
167
+ # def validation_epoch_end(self, outputs):
168
+ # # outputs = list of dictionaries to print loss
169
+ # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
170
+ # tensorboard_logs = {'avg_val_loss': avg_loss}
171
+ # return {'val_loss': avg_loss, 'log': tensorboard_logs}
172
+
173
+ def configure_optimizers(self):
174
+ return Adafactor(model.parameters(), scale_parameter=False, relative_step=False, lr=learning_rate)
175
+
176
+ def generate(model: T5Model, answer:str, context:str, beams, length, temper) -> str:
177
+ source_encoding = t5tokenizer(
178
+ f"{answer} {SEP_TOKEN} {context}",
179
+ max_length=512,
180
+ padding='max_length',
181
+ truncation=True,
182
+ return_attention_mask=True,
183
+ add_special_tokens=True,
184
+ return_tensors='pt'
185
+ )
186
+
187
+ generated_ids=model.model.generate(
188
+ input_ids=source_encoding['input_ids'],
189
+ attention_mask=source_encoding['attention_mask'],
190
+ num_beams=beams,
191
+ max_length=length,
192
+ repetition_penalty=2.5,
193
+ length_penalty=0.8,
194
+ temperature=temper,
195
+ early_stopping=True,
196
+ use_cache=True
197
+ )
198
+
199
+ preds = {
200
+ t5tokenizer.decode(generated_id, skip_special_tokens=False, clean_up_tokenization_spaces=True)
201
+ for generated_id in generated_ids
202
+ }
203
+
204
+ return ''.join(preds)
205
+
206
+ def show_result(generated:str, answer:str, context:str, original_question:str=''):
207
+
208
+ regex = r"(?<=>)(.*?)(?=<)"
209
+ matches = re.findall(regex, generated)
210
+ matches[1] = matches[1][5:]
211
+ final = {cat: match.strip() for cat, match in zip(['Answer', 'Question'], matches)}
212
+ st.title('Context')
213
+ st.write(context)
214
+ st.title('Answer')
215
+ st.write(answer)
216
+ st.title('Generated')
217
+ st.write(final)
218
+ # if original_question:
219
+ # printBold('Original Question')
220
+ # print(original_question)
221
+ # gen = nlp(matches[1])
222
+ # ori = nlp(original_question)
223
+ # bleu_score = sentence_bleu(matches[1], original_question, smoothing_function=SmoothingFunction().method5)
224
+ # cs_score = ori.similarity(gen)
225
+ # printBold('Scores')
226
+ # print(f"BLEU: {bleu_score}")
227
+ # print(f'Cosine Similarity: {cs_score}')
228
+ # return bleu_score, cs_score
229
+
230
+ # streamlit app
231
+ st.title('Question Generation From Text')
232
+
233
+ with st.form('my_form'):
234
+ context = st.text_input('Enter a context passage for question generation:', 'The capital of France is Paris.')
235
+ answer = st.text_input('Give a correct answer, or [MASK] for unsupervised generation:', 'Paris')
236
+ # question = st.text_input('Question', 'What is the capital of France?')
237
+ # original_question = st.text_input('Original Question', 'What is the capital of France?')
238
+ beams = st.sidebar.slider('Beams', min_value=1, max_value=20)
239
+ length = st.sidebar.slider('Maximum length of generated question', min_value=50, max_value=200)
240
+ temper = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
241
+ submitted = st.form_submit_button('Generate')
242
+
243
+ with st.spinner('Loading Model...'):
244
+ model = T5Model
245
+ best_model_dir = 't5-chkpt-v2.ckpt'
246
+ best_model = model.load_from_checkpoint(best_model_dir)
247
+ # best_model = model.load_from_checkpoint(callback.best_model_path)
248
+ best_model.freeze()
249
+
250
+ with st.spinner('Generating...'):
251
+ if submitted:
252
+
253
+ generated = generate(best_model, answer, context, beams, length, temper)
254
+ show_result(generated, answer, context)
255
+