nugentc commited on
Commit
3e149d5
1 Parent(s): db0e2dc

try wiring up feedback element

Browse files
Files changed (1) hide show
  1. app.py +35 -4
app.py CHANGED
@@ -1,10 +1,14 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
2
  import torch
3
  import gradio as gr
4
 
5
 
6
  def chat(message, history):
7
- history = history or [('hi', 'hello'), ('what ya doing', 'nothing')]
8
  if message.startswith("How many"):
9
  response = random.randint(1, 10)
10
  elif message.startswith("How"):
@@ -14,13 +18,40 @@ def chat(message, history):
14
  else:
15
  response = "I don't know"
16
  history.append((message, response))
17
- return history, history
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  iface = gr.Interface(
20
  chat,
21
  ["text", "state"],
22
- ["chatbot", "state"],
23
  allow_screenshot=False,
24
  allow_flagging="never",
25
  )
26
  iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
2
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
3
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
4
+ grammar_tokenizer = AutoTokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
5
+ grammar_model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/grammar_error_correcter_v1")
6
  import torch
7
  import gradio as gr
8
 
9
 
10
  def chat(message, history):
11
+ history = history or []
12
  if message.startswith("How many"):
13
  response = random.randint(1, 10)
14
  elif message.startswith("How"):
 
18
  else:
19
  response = "I don't know"
20
  history.append((message, response))
21
+ return history, feedback(message)
22
+
23
+
24
+ def feedback(text):
25
+ tokenized_phrases = grammar_tokenizer([text], return_tensors='pt', padding=True)
26
+ corrections = grammar_model.generate(**tokenized_phrases)
27
+ corrections = grammar_tokenizer.batch_decode(corrections, skip_special_tokens=True)
28
+ print("The corrections are: ", corrections)
29
+ if corrections[0] == text:
30
+ feedback = f'Looks good! Keep up the good work'
31
+ else:
32
+ feedback = f'\'{corrections[0]}\' might be a little better'
33
+ return f'FEEDBACK: {feedback}'
34
 
35
  iface = gr.Interface(
36
  chat,
37
  ["text", "state"],
38
+ ["chatbot", "text"],
39
  allow_screenshot=False,
40
  allow_flagging="never",
41
  )
42
  iface.launch()
43
+
44
+
45
+
46
+ new_user_input_ids = tokenizer.encode(text+tokenizer.eos_token, return_tensors='pt')
47
+ # append the new user input tokens to the chat history
48
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
49
+
50
+ # generated a response while limiting the total chat history to 1000 tokens,
51
+ chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
52
+ print("The text is ", [text])
53
+
54
+ # pretty print last ouput tokens from bot
55
+ output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
56
+ print("The outout is :", output)
57
+ text_session.append(output)