hylee719 commited on
Commit
0da9196
1 Parent(s): 25ca7c8

add focusing questions

Browse files
Files changed (1) hide show
  1. handler.py +74 -15
handler.py CHANGED
@@ -13,9 +13,11 @@ from transformers import BertTokenizer, BertForSequenceClassification
13
 
14
  transformers.logging.set_verbosity_debug()
15
 
16
- UPTAKE_MODEL='ddemszky/uptake-model'
17
- REASONING_MODEL ='ddemszky/student-reasoning'
18
- QUESTION_MODEL ='ddemszky/question-detection'
 
 
19
 
20
  class Utterance:
21
  def __init__(self, speaker, text, uid=None,
@@ -31,6 +33,7 @@ class Utterance:
31
  self.uptake = None
32
  self.reasoning = None
33
  self.question = None
 
34
 
35
  def get_clean_text(self, remove_punct=False):
36
  if remove_punct:
@@ -50,6 +53,7 @@ class Utterance:
50
  'uptake': self.uptake,
51
  'reasoning': self.reasoning,
52
  'question': self.question,
 
53
  **self.props
54
  }
55
 
@@ -58,6 +62,7 @@ class Utterance:
58
  f"text='{self.text}', uid={self.uid}," \
59
  f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
60
 
 
61
  class Transcript:
62
  def __init__(self, **kwargs):
63
  self.utterances = []
@@ -90,6 +95,7 @@ class Transcript:
90
  def __repr__(self):
91
  return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
92
 
 
93
  class QuestionModel:
94
  def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
95
  print("Loading models...")
@@ -97,10 +103,10 @@ class QuestionModel:
97
  self.tokenizer = tokenizer
98
  self.input_builder = input_builder
99
  self.max_length = max_length
100
- self.model = MultiHeadModel.from_pretrained(path, head2size={"is_question": 2})
 
101
  self.model.to(self.device)
102
 
103
-
104
  def run_inference(self, transcript):
105
  self.model.eval()
106
  with torch.no_grad():
@@ -114,12 +120,14 @@ class QuestionModel:
114
  input_str=True)
115
  output = self.get_prediction(instance)
116
  print(output)
117
- utt.question = np.argmax(output["is_question_logits"][0].tolist())
 
118
 
119
  def get_prediction(self, instance):
120
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
121
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
122
- instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
 
123
  instance[key].to(self.device)
124
 
125
  output = self.model(input_ids=instance["input_ids"],
@@ -128,6 +136,7 @@ class QuestionModel:
128
  return_pooler_output=False)
129
  return output
130
 
 
131
  class ReasoningModel:
132
  def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
133
  print("Loading models...")
@@ -152,7 +161,8 @@ class ReasoningModel:
152
  def get_prediction(self, instance):
153
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
154
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
155
- instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
 
156
  instance[key].to(self.device)
157
 
158
  output = self.model(input_ids=instance["input_ids"],
@@ -160,6 +170,7 @@ class ReasoningModel:
160
  token_type_ids=instance["token_type_ids"])
161
  return output
162
 
 
163
  class UptakeModel:
164
  def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
165
  print("Loading models...")
@@ -184,14 +195,16 @@ class UptakeModel:
184
  input_str=True)
185
  output = self.get_prediction(instance)
186
 
187
- utt.uptake = int(softmax(output["nsp_logits"][0].tolist())[1] > .8)
 
188
  prev_num_words = utt.get_num_words()
189
  prev_utt = utt
190
 
191
  def get_prediction(self, instance):
192
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
193
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
194
- instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
 
195
  instance[key].to(self.device)
196
 
197
  output = self.model(input_ids=instance["input_ids"],
@@ -201,6 +214,44 @@ class UptakeModel:
201
  return output
202
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  class EndpointHandler():
205
  def __init__(self, path="."):
206
  print("Loading models...")
@@ -231,18 +282,26 @@ class EndpointHandler():
231
  transcript.add_utterance(Utterance(**utt))
232
 
233
  print("Running inference on %d examples..." % transcript.length())
234
-
235
  # Uptake
236
- uptake_model = UptakeModel(self.device, self.tokenizer, self.input_builder)
 
237
  uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
238
- uptake_speaker=params.pop("uptake_speaker", None))
239
 
240
  # Reasoning
241
- reasoning_model = ReasoningModel(self.device, self.tokenizer, self.input_builder)
 
242
  reasoning_model.run_inference(transcript)
243
 
244
  # Question
245
- question_model = QuestionModel(self.device, self.tokenizer, self.input_builder)
 
246
  question_model.run_inference(transcript)
247
 
 
 
 
 
 
248
  return transcript.to_dict()
 
13
 
14
  transformers.logging.set_verbosity_debug()
15
 
16
+ UPTAKE_MODEL = 'ddemszky/uptake-model'
17
+ REASONING_MODEL = 'ddemszky/student-reasoning'
18
+ QUESTION_MODEL = 'ddemszky/question-detection'
19
+ FOCUSING_QUESTION_MODEL = 'ddemszky/focusing-questions'
20
+
21
 
22
  class Utterance:
23
  def __init__(self, speaker, text, uid=None,
 
33
  self.uptake = None
34
  self.reasoning = None
35
  self.question = None
36
+ self.focusing_question = None
37
 
38
  def get_clean_text(self, remove_punct=False):
39
  if remove_punct:
 
53
  'uptake': self.uptake,
54
  'reasoning': self.reasoning,
55
  'question': self.question,
56
+ 'focusingquestion': self.focusing_question,
57
  **self.props
58
  }
59
 
 
62
  f"text='{self.text}', uid={self.uid}," \
63
  f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})"
64
 
65
+
66
  class Transcript:
67
  def __init__(self, **kwargs):
68
  self.utterances = []
 
95
  def __repr__(self):
96
  return f"Transcript(utterances={self.utterances}, custom_params={self.params})"
97
 
98
+
99
  class QuestionModel:
100
  def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL):
101
  print("Loading models...")
 
103
  self.tokenizer = tokenizer
104
  self.input_builder = input_builder
105
  self.max_length = max_length
106
+ self.model = MultiHeadModel.from_pretrained(
107
+ path, head2size={"is_question": 2})
108
  self.model.to(self.device)
109
 
 
110
  def run_inference(self, transcript):
111
  self.model.eval()
112
  with torch.no_grad():
 
120
  input_str=True)
121
  output = self.get_prediction(instance)
122
  print(output)
123
+ utt.question = np.argmax(
124
+ output["is_question_logits"][0].tolist())
125
 
126
  def get_prediction(self, instance):
127
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
128
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
129
+ instance[key] = torch.tensor(
130
+ instance[key]).unsqueeze(0) # Batch size = 1
131
  instance[key].to(self.device)
132
 
133
  output = self.model(input_ids=instance["input_ids"],
 
136
  return_pooler_output=False)
137
  return output
138
 
139
+
140
  class ReasoningModel:
141
  def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL):
142
  print("Loading models...")
 
161
  def get_prediction(self, instance):
162
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
163
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
164
+ instance[key] = torch.tensor(
165
+ instance[key]).unsqueeze(0) # Batch size = 1
166
  instance[key].to(self.device)
167
 
168
  output = self.model(input_ids=instance["input_ids"],
 
170
  token_type_ids=instance["token_type_ids"])
171
  return output
172
 
173
+
174
  class UptakeModel:
175
  def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL):
176
  print("Loading models...")
 
195
  input_str=True)
196
  output = self.get_prediction(instance)
197
 
198
+ utt.uptake = int(
199
+ softmax(output["nsp_logits"][0].tolist())[1] > .8)
200
  prev_num_words = utt.get_num_words()
201
  prev_utt = utt
202
 
203
  def get_prediction(self, instance):
204
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
205
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
206
+ instance[key] = torch.tensor(
207
+ instance[key]).unsqueeze(0) # Batch size = 1
208
  instance[key].to(self.device)
209
 
210
  output = self.model(input_ids=instance["input_ids"],
 
214
  return output
215
 
216
 
217
+
218
+ class FocusingQuestionModel:
219
+ def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL):
220
+ print("Loading models...")
221
+ self.device = device
222
+ self.tokenizer = tokenizer
223
+ self.input_builder = input_builder
224
+ self.model = BertForSequenceClassification.from_pretrained(path)
225
+ self.model.to(self.device)
226
+ self.max_length = max_length
227
+
228
+ def run_inference(self, transcript, min_focusing_words=0, uptake_speaker=None):
229
+ self.model.eval()
230
+ with torch.no_grad():
231
+ for i, utt in enumerate(transcript.utterances):
232
+ if utt.speaker != uptake_speaker or uptake_speaker is None:
233
+ utt.focusing_question = None
234
+ continue
235
+ if utt.get_num_words() < min_focusing_words:
236
+ utt.focusing_question = None
237
+ continue
238
+ instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True)
239
+ output = self.get_prediction(instance)
240
+ utt.focusing_question = np.argmax(output["logits"][0].tolist())
241
+
242
+ def get_prediction(self, instance):
243
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
244
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
245
+ instance[key] = torch.tensor(
246
+ instance[key]).unsqueeze(0) # Batch size = 1
247
+ instance[key].to(self.device)
248
+
249
+ output = self.model(input_ids=instance["input_ids"],
250
+ attention_mask=instance["attention_mask"],
251
+ token_type_ids=instance["token_type_ids"])
252
+ return output
253
+
254
+
255
  class EndpointHandler():
256
  def __init__(self, path="."):
257
  print("Loading models...")
 
282
  transcript.add_utterance(Utterance(**utt))
283
 
284
  print("Running inference on %d examples..." % transcript.length())
285
+ uptake_speaker = params.pop("uptake_speaker", None)
286
  # Uptake
287
+ uptake_model = UptakeModel(
288
+ self.device, self.tokenizer, self.input_builder)
289
  uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'],
290
+ uptake_speaker=uptake_speaker)
291
 
292
  # Reasoning
293
+ reasoning_model = ReasoningModel(
294
+ self.device, self.tokenizer, self.input_builder)
295
  reasoning_model.run_inference(transcript)
296
 
297
  # Question
298
+ question_model = QuestionModel(
299
+ self.device, self.tokenizer, self.input_builder)
300
  question_model.run_inference(transcript)
301
 
302
+ # Focusing Question
303
+ focusing_question_model = FocusingQuestionModel(
304
+ self.device, self.tokenizer, self.input_builder)
305
+ focusing_question_model.run_inference(transcript, uptake_speaker=uptake_speaker)
306
+
307
  return transcript.to_dict()