questgen / backup
ViXuan's picture
Improved Inference
14cac88
raw
history blame
No virus
3.73 kB
import pke
from sense2vec import Sense2Vec
import time
import gradio as gr
from transformers import AutoTokenizer
import os
from pathlib import Path
from FastT5 import get_onnx_runtime_sessions, OnnxT5
# commands = [
# "curl -LO https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz",
# "tar -xvf s2v_reddit_2015_md.tar.gz",
# ]
# for command in commands:
# return_code = os.system(command)
# if return_code == 0:
# print(f"Command '{command}' executed successfully")
# else:
# print(f"Command '{command}' failed with return code {return_code}")
s2v = Sense2Vec().from_disk("s2v_old")
trained_model_path = './t5_squad_v1/'
pretrained_model_name = Path(trained_model_path).stem
encoder_path = os.path.join(
trained_model_path, f"{pretrained_model_name}-encoder_quantized.onnx")
decoder_path = os.path.join(
trained_model_path, f"{pretrained_model_name}-decoder_quantized.onnx")
init_decoder_path = os.path.join(
trained_model_path, f"{pretrained_model_name}-init-decoder_quantized.onnx")
model_paths = encoder_path, decoder_path, init_decoder_path
model_sessions = get_onnx_runtime_sessions(model_paths)
model = OnnxT5(trained_model_path, model_sessions)
tokenizer = AutoTokenizer.from_pretrained(trained_model_path)
def get_question(sentence, answer, mdl, tknizer):
text = f"context: {sentence} answer: {answer}"
print(text)
max_len = 256
encoding = tknizer.encode_plus(
text, max_length=max_len, pad_to_max_length=False, truncation=True, return_tensors="pt")
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
outs = mdl.generate(input_ids=input_ids,
attention_mask=attention_mask,
early_stopping=True,
num_beams=5,
num_return_sequences=1,
no_repeat_ngram_size=2,
max_length=300)
dec = [tknizer.decode(ids, skip_special_tokens=True) for ids in outs]
Question = dec[0].replace("question:", "")
Question = Question.strip()
return Question
def generate_question(context, answer):
start_time = time.time() # Record the start time
result = get_question(context, answer, model, tokenizer)
end_time = time.time() # Record the end time
latency = end_time - start_time # Calculate latency
print(f"Latency: {latency} seconds")
return result
def generate_mcq(context):
extractor = pke.unsupervised.TopicRank()
extractor.load_document(input=context, language='en')
extractor.candidate_selection(pos={"NOUN", "PROPN", "ADJ"})
extractor.candidate_weighting()
keyphrases = extractor.get_n_best(n=10)
results = []
for keyword, _ in keyphrases:
original_keyword = keyword
keyword = original_keyword.lower().replace(" ", "_")
sense = s2v.get_best_sense(keyword)
if sense is not None:
most_similar = s2v.most_similar(sense, n=2)
distractors = [word.split("|")[0].lower().replace(
"_", " ") for word, _ in most_similar]
question = generate_question(context, original_keyword)
result = {
"Question": question,
"Keyword": original_keyword,
"Distractor1": distractors[0],
"Distractor2": distractors[1]
}
results.append(result)
return results
iface = gr.Interface(
fn=generate_mcq,
inputs=gr.Textbox(label="Context", type='text'),
outputs=gr.JSON(value=list),
title="Questgen AI",
description="Enter a context to generate MCQs for keywords."
)
iface.launch()