Aytaj commited on
Commit
9aa604e
1 Parent(s): 21ef1b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  # Load DialoGPT model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
8
  # Function to generate a response
9
  def generate_response(chat_history_ids, new_user_input):
10
  new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
@@ -14,8 +15,9 @@ def generate_response(chat_history_ids, new_user_input):
14
  # Streamlit app
15
  st.title("DialoGPT Chat")
16
 
17
- # Initialize chat history
18
- chat_history_ids = None
 
19
 
20
  # User input text box
21
  user_input = st.text_input("You:", "")
@@ -23,11 +25,11 @@ user_input = st.text_input("You:", "")
23
  # Check if the user pressed Enter
24
  if st.button("Send"):
25
  if user_input:
26
- # Generate response
27
- chat_history_ids = generate_response(chat_history_ids, user_input)
28
 
29
  # Display DialoGPT response
30
- st.text("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, -1][0], skip_special_tokens=True)))
31
 
32
  # Inform the user that the conversation has ended
33
  st.text("Press 'Send' to continue the conversation.")
 
5
  # Load DialoGPT model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
8
+
9
  # Function to generate a response
10
  def generate_response(chat_history_ids, new_user_input):
11
  new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
 
15
  # Streamlit app
16
  st.title("DialoGPT Chat")
17
 
18
+ # Initialize chat history using st.session_state
19
+ if 'chat_history_ids' not in st.session_state:
20
+ st.session_state.chat_history_ids = None
21
 
22
  # User input text box
23
  user_input = st.text_input("You:", "")
 
25
  # Check if the user pressed Enter
26
  if st.button("Send"):
27
  if user_input:
28
+ # Generate response using st.session_state
29
+ st.session_state.chat_history_ids = generate_response(st.session_state.chat_history_ids, user_input)
30
 
31
  # Display DialoGPT response
32
+ st.text("DialoGPT: {}".format(tokenizer.decode(st.session_state.chat_history_ids[:, -1][0], skip_special_tokens=True)))
33
 
34
  # Inform the user that the conversation has ended
35
  st.text("Press 'Send' to continue the conversation.")