File size: 1,876 Bytes
cc28aad
b73e8d0
e7899d4
b73e8d0
c4fe73d
 
 
 
20cda87
 
c4fe73d
b73e8d0
cc28aad
20cda87
b73e8d0
cc28aad
c4fe73d
20cda87
b73e8d0
c4fe73d
cc28aad
c4fe73d
b73e8d0
cc28aad
b73e8d0
cc28aad
 
 
c4fe73d
20cda87
cc28aad
20cda87
 
cc28aad
c4fe73d
b73e8d0
c4fe73d
cc28aad
 
 
b73e8d0
 
 
 
 
20cda87
cc28aad
 
 
 
b73e8d0
c4fe73d
cc28aad
 
20cda87
cc28aad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
import torch
import gradio as gr

#model_name = "facebook/blenderbot-400M-distill"
model_name = "microsoft/DialoGPT-medium"
chat_token = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def converse(user_input, chat_history=[]):    
    user_input_ids = chat_token(user_input + chat_token.eos_token, return_tensors='pt').input_ids

    # keep history in the tensor
    bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)

    # get response 
    chat_history = model.generate(bot_input_ids, max_length=1000, pad_token_id=chat_token.eos_token_id).tolist()
    print (chat_history)

    response = chat_token.decode(chat_history[0]).split("<|endoftext|>")
    
    print("Starting to print response...")
    print(response)
    
    # html for display
    html = "<div class='mybot'>"
    for x, mesg in enumerate(response):
        if x%2!=0 :
           mesg="Bot: " + mesg
           clazz="bot"
        else :
           clazz="user"
        
        
        print("Value of x: ")
        print(x)
        print("Message: ")
        print (mesg)
        
        html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
    html += "</div>"
    print(html)
    return html, chat_history

css = """
.mychat {display:flex;flex-direction:column}
.mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
.mesg.user {background-color:lightblue;color:white}
.mesg.bot {background-color:orange;color:white,align-self:self-end}
.footer {display:none !important}
"""
text=gr.inputs.Textbox(placeholder="Let's start a chat...")
gr.Interface(fn=converse,
             theme="default",
             inputs=[text, "state"],
             outputs=["html", "state"],
             css=css).launch()