from typing import Dict, List, Any from scipy.special import softmax import numpy as np import weakref import re from utils import clean_str, clean_str_nopunct import torch from utils import MultiHeadModel, BertInputBuilder, get_num_words, MATH_PREFIXES, MATH_WORDS import transformers from transformers import BertTokenizer, BertForSequenceClassification transformers.logging.set_verbosity_debug() UPTAKE_MODEL = 'ddemszky/uptake-model' REASONING_MODEL = 'ddemszky/student-reasoning' QUESTION_MODEL = 'ddemszky/question-detection' FOCUSING_QUESTION_MODEL = 'ddemszky/focusing-questions' class Utterance: def __init__(self, speaker, text, uid=None, transcript=None, starttime=None, endtime=None, **kwargs): self.speaker = speaker self.text = text self.uid = uid self.starttime = starttime self.endtime = endtime self.transcript = weakref.ref(transcript) if transcript else None self.props = kwargs self.num_math_terms = None self.math_terms = None self.uptake = None self.reasoning = None self.question = None self.focusing_question = None def get_clean_text(self, remove_punct=False): if remove_punct: return clean_str_nopunct(self.text) return clean_str(self.text) def get_num_words(self): return get_num_words(self.text) def to_dict(self): return { 'speaker': self.speaker, 'text': self.text, 'uid': self.uid, 'starttime': self.starttime, 'endtime': self.endtime, 'uptake': self.uptake, 'reasoning': self.reasoning, 'question': self.question, 'focusingQuestion': self.focusing_question, 'numMathTerms': self.num_math_terms, 'mathTerms': self.math_terms, **self.props } def __repr__(self): return f"Utterance(speaker='{self.speaker}'," \ f"text='{self.text}', uid={self.uid}," \ f"starttime={self.starttime}, endtime={self.endtime}, props={self.props})" class Transcript: def __init__(self, **kwargs): self.utterances = [] self.params = kwargs def add_utterance(self, utterance): utterance.transcript = weakref.ref(self) self.utterances.append(utterance) def get_idx(self, idx): if idx >= len(self.utterances): return None return self.utterances[idx] def get_uid(self, uid): for utt in self.utterances: if utt.uid == uid: return utt return None def length(self): return len(self.utterances) def to_dict(self): return { 'utterances': [utterance.to_dict() for utterance in self.utterances], **self.params } def __repr__(self): return f"Transcript(utterances={self.utterances}, custom_params={self.params})" class QuestionModel: def __init__(self, device, tokenizer, input_builder, max_length=300, path=QUESTION_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.max_length = max_length self.model = MultiHeadModel.from_pretrained( path, head2size={"is_question": 2}) self.model.to(self.device) def run_inference(self, transcript): self.model.eval() with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if "?" in utt.text: utt.question = 1 else: text = utt.get_clean_text(remove_punct=True) instance = self.input_builder.build_inputs([], text, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) print(output) utt.question = np.argmax( output["is_question_logits"][0].tolist()) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor( instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"], return_pooler_output=False) return output class ReasoningModel: def __init__(self, device, tokenizer, input_builder, max_length=128, path=REASONING_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.max_length = max_length self.model = BertForSequenceClassification.from_pretrained(path) self.model.to(self.device) def run_inference(self, transcript, min_num_words=8): self.model.eval() with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if utt.get_num_words() >= min_num_words: instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) utt.reasoning = np.argmax(output["logits"][0].tolist()) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor( instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"]) return output class UptakeModel: def __init__(self, device, tokenizer, input_builder, max_length=120, path=UPTAKE_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.max_length = max_length self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2}) self.model.to(self.device) def run_inference(self, transcript, min_prev_words, uptake_speaker=None): self.model.eval() prev_num_words = 0 prev_utt = None with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if ((uptake_speaker is None) or (utt.speaker == uptake_speaker)) and (prev_num_words >= min_prev_words): textA = prev_utt.get_clean_text(remove_punct=False) textB = utt.get_clean_text(remove_punct=False) instance = self.input_builder.build_inputs([textA], textB, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) utt.uptake = int( softmax(output["nsp_logits"][0].tolist())[1] > .8) prev_num_words = utt.get_num_words() prev_utt = utt def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor( instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"], return_pooler_output=False) return output class FocusingQuestionModel: def __init__(self, device, tokenizer, input_builder, max_length=128, path=FOCUSING_QUESTION_MODEL): print("Loading models...") self.device = device self.tokenizer = tokenizer self.input_builder = input_builder self.model = BertForSequenceClassification.from_pretrained(path) self.model.to(self.device) self.max_length = max_length def run_inference(self, transcript, min_focusing_words=0, uptake_speaker=None): self.model.eval() with torch.no_grad(): for i, utt in enumerate(transcript.utterances): if utt.speaker != uptake_speaker or uptake_speaker is None: utt.focusing_question = None continue if utt.get_num_words() < min_focusing_words: utt.focusing_question = None continue instance = self.input_builder.build_inputs([], utt.text, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) utt.focusing_question = np.argmax(output["logits"][0].tolist()) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor( instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"]) return output def load_math_terms(): math_terms = [] math_terms_dict = {} for term in MATH_WORDS: if term in MATH_PREFIXES: math_terms_dict[f"(^|[^a-zA-Z]){term}(s|es)?([^a-zA-Z]|$)"] = term math_terms.append(f"(^|[^a-zA-Z]){term}(s|es)?([^a-zA-Z]|$)") else: math_terms.append(term) math_terms_dict[term] = term return math_terms, math_terms_dict def run_math_density(transcript): math_terms, math_terms_dict = load_math_terms() sorted_terms = sorted(math_terms, key=len, reverse=True) for i, utt in enumerate(transcript.utterances): text = utt.get_clean_text(remove_punct=False) num_matches = 0 matched_positions = set() match_list = [] for term in sorted_terms: matches = list(re.finditer(term, text, re.IGNORECASE)) # Filter out matches that share positions with longer terms matches = [match for match in matches if not any(match.start() in range(existing[0], existing[1]) for existing in matched_positions)] if len(matches) > 0: match_list.append(math_terms_dict[term]) # Update existing match positions matched_positions.update((match.start(), match.end()) for match in matches) num_matches += len(matches) utt.num_math_terms = num_matches utt.math_terms = match_list class EndpointHandler(): def __init__(self, path="."): print("Loading models...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.input_builder = BertInputBuilder(tokenizer=self.tokenizer) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `list`): List of dicts, where each dict represents an utterance; each utterance object must have a `speaker`, `text` and `uid`and can include list of custom properties parameters (:obj: `dict`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs utterances = data.pop("inputs", data) params = data.pop("parameters", None) print("EXAMPLES") for utt in utterances[:3]: print("speaker %s: %s" % (utt["speaker"], utt["text"])) transcript = Transcript(filename=params.pop("filename", None)) for utt in utterances: transcript.add_utterance(Utterance(**utt)) print("Running inference on %d examples..." % transcript.length()) uptake_speaker = params.pop("uptake_speaker", None) # Uptake uptake_model = UptakeModel( self.device, self.tokenizer, self.input_builder) uptake_model.run_inference(transcript, min_prev_words=params['uptake_min_num_words'], uptake_speaker=uptake_speaker) # Reasoning reasoning_model = ReasoningModel( self.device, self.tokenizer, self.input_builder) reasoning_model.run_inference(transcript) # Question question_model = QuestionModel( self.device, self.tokenizer, self.input_builder) question_model.run_inference(transcript) # Focusing Question focusing_question_model = FocusingQuestionModel( self.device, self.tokenizer, self.input_builder) focusing_question_model.run_inference(transcript, uptake_speaker=uptake_speaker) run_math_density(transcript) return transcript.to_dict()