Jaehan commited on
Commit
c4fe73d
1 Parent(s): 20cda87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -2,40 +2,39 @@ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForCondit
2
  import torch
3
  import gradio as gr
4
 
5
- chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
6
- mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
7
-
8
- #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
9
- #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
10
 
11
  def converse(user_input, chat_history=[]):
12
- user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
13
 
14
  # keep history in the tensor
15
  bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
16
 
17
  # get response
18
- chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
19
  print (chat_history)
20
 
21
- response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
22
 
23
- print("starting to print response")
24
  print(response)
25
 
26
  # html for display
27
  html = "<div class='mybot'>"
28
  for x, mesg in enumerate(response):
29
  if x%2!=0 :
30
- mesg="Bot:"+mesg
31
  clazz="bot"
32
  else :
33
  clazz="user"
34
 
35
 
36
- print("value of x")
37
  print(x)
38
- print("message")
39
  print (mesg)
40
 
41
  html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
@@ -50,7 +49,7 @@ css = """
50
  .mesg.bot {background-color:orange;color:white,align-self:self-end}
51
  .footer {display:none !important}
52
  """
53
- text=gr.inputs.Textbox(placeholder="Lets chat")
54
  gr.Interface(fn=converse,
55
  theme="default",
56
  inputs=[text, "state"],
 
2
  import torch
3
  import gradio as gr
4
 
5
+ #model_name = "facebook/blenderbot-400M-distill"
6
+ model_name = "microsoft/DialoGPT-medium"
7
+ chat_token = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
9
 
10
  def converse(user_input, chat_history=[]):
11
+ user_input_ids = chat_token(user_input + chat_token.eos_token, return_tensors='pt').input_ids
12
 
13
  # keep history in the tensor
14
  bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
  # get response
17
+ chat_history = model.generate(bot_input_ids, max_length=1000, pad_token_id=chat_token.eos_token_id).tolist()
18
  print (chat_history)
19
 
20
+ response = chat_token.decode(chat_history[0]).split("<|endoftext|>")
21
 
22
+ print("Starting to print response...")
23
  print(response)
24
 
25
  # html for display
26
  html = "<div class='mybot'>"
27
  for x, mesg in enumerate(response):
28
  if x%2!=0 :
29
+ mesg="Bot: " + mesg
30
  clazz="bot"
31
  else :
32
  clazz="user"
33
 
34
 
35
+ print("Value of x: ")
36
  print(x)
37
+ print("Message: ")
38
  print (mesg)
39
 
40
  html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
 
49
  .mesg.bot {background-color:orange;color:white,align-self:self-end}
50
  .footer {display:none !important}
51
  """
52
+ text=gr.inputs.Textbox(placeholder="Let's start a chat...")
53
  gr.Interface(fn=converse,
54
  theme="default",
55
  inputs=[text, "state"],