File size: 2,052 Bytes
4eb87d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import torch
import random

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelWithLMHead
from sentence_splitter import SentenceSplitter, split_text_into_sentences
splitter = SentenceSplitter(language='en')

if torch.cuda.is_available():
  torch_device="cuda:0"
else:
  torch_device="cpu"

ptokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
pmodel = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase").to(torch_device)

def get_answer(input_text,num_return_sequences,num_beams):
  batch = ptokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
  translated = pmodel.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
  tgt_text = ptokenizer.batch_decode(translated, skip_special_tokens=True)
  return tgt_text

qtokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
qmodel = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap").to(torch_device)

def get_question(answer, context, max_length=64):
  input_text = "answer: %s  context: %s </s>" % (answer, context)
  features = qtokenizer([input_text], return_tensors='pt').to(torch_device)

  output = qmodel.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'],
               max_length=max_length)

  return qtokenizer.decode(output[0])

def getqna(input):
  input=split_text_into_sentences(text=input, language='en')
  if len(input)==0:
    answer= get_answer(input,10,10)[random.randint(0, 9)]
  else:
    sentences=[get_answer(sentence,10,10)[random.randint(0, 9)] for sentence in input]
    answer= " ".join(sentences)
  answer= get_answer(answer,10,10)[random.randint(0, 9)]
  question= get_question(answer, input).replace("<pad>","").replace("</s>","")
  return "%s \n answer:%s" % (question, answer)

app = gr.Interface(fn=getqna, inputs="text", outputs="text")
app.launch()