Spaces:
Sleeping
Sleeping
File size: 1,345 Bytes
a9046df 168ff57 9216764 168ff57 b2c1dfc 168ff57 9216764 168ff57 9216764 168ff57 9216764 168ff57 62b4437 9216764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load DialoGPT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
# Function to generate a response
def generate_response(chat_history_ids, new_user_input):
new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
return model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
# Streamlit app
st.title("DialoGPT Chat")
# Initialize chat history
chat_history_ids = None
# Conversation loop
while st.button("Restart Conversation"):
# Get user input
user_input = st.text_input("You:", "")
if user_input:
# Generate response
chat_history_ids = generate_response(chat_history_ids, user_input)
# Display DialoGPT response
st.text("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, -1][0], skip_special_tokens=True)))
# Inform the user that the conversation has ended
st.text("Conversation Ended. Press the 'Restart Conversation' button to start a new conversation.")
|