Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
# Load DialoGPT model and tokenizer
|
6 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
7 |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
|
|
8 |
# Function to generate a response
|
9 |
def generate_response(chat_history_ids, new_user_input):
|
10 |
new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
|
@@ -14,8 +15,9 @@ def generate_response(chat_history_ids, new_user_input):
|
|
14 |
# Streamlit app
|
15 |
st.title("DialoGPT Chat")
|
16 |
|
17 |
-
# Initialize chat history
|
18 |
-
chat_history_ids
|
|
|
19 |
|
20 |
# User input text box
|
21 |
user_input = st.text_input("You:", "")
|
@@ -23,11 +25,11 @@ user_input = st.text_input("You:", "")
|
|
23 |
# Check if the user pressed Enter
|
24 |
if st.button("Send"):
|
25 |
if user_input:
|
26 |
-
# Generate response
|
27 |
-
chat_history_ids = generate_response(chat_history_ids, user_input)
|
28 |
|
29 |
# Display DialoGPT response
|
30 |
-
st.text("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, -1][0], skip_special_tokens=True)))
|
31 |
|
32 |
# Inform the user that the conversation has ended
|
33 |
st.text("Press 'Send' to continue the conversation.")
|
|
|
5 |
# Load DialoGPT model and tokenizer
|
6 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
7 |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
8 |
+
|
9 |
# Function to generate a response
|
10 |
def generate_response(chat_history_ids, new_user_input):
|
11 |
new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
|
|
|
15 |
# Streamlit app
|
16 |
st.title("DialoGPT Chat")
|
17 |
|
18 |
+
# Initialize chat history using st.session_state
|
19 |
+
if 'chat_history_ids' not in st.session_state:
|
20 |
+
st.session_state.chat_history_ids = None
|
21 |
|
22 |
# User input text box
|
23 |
user_input = st.text_input("You:", "")
|
|
|
25 |
# Check if the user pressed Enter
|
26 |
if st.button("Send"):
|
27 |
if user_input:
|
28 |
+
# Generate response using st.session_state
|
29 |
+
st.session_state.chat_history_ids = generate_response(st.session_state.chat_history_ids, user_input)
|
30 |
|
31 |
# Display DialoGPT response
|
32 |
+
st.text("DialoGPT: {}".format(tokenizer.decode(st.session_state.chat_history_ids[:, -1][0], skip_special_tokens=True)))
|
33 |
|
34 |
# Inform the user that the conversation has ended
|
35 |
st.text("Press 'Send' to continue the conversation.")
|