Aytaj commited on
Commit
168ff57
1 Parent(s): b2c1dfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -1,21 +1,36 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
- pipe = pipeline('text-generation', model='Pclanglais/MonadGPT', device='cuda')
4
 
 
 
 
5
 
6
- def main():
7
- st.title("MonadGPT Streamlit App")
8
- st.markdown("This is a streamlit app that allows you to have a conversation abbout any topic")
 
 
9
 
10
- user_input = st.text_input("Your question:")
 
11
 
12
- if user_input:
13
- response = pipe(user_input, max_length=256, do_sample=True, top_k=50, top_p=0.95, early_stopping=True)
14
- st.write(response[0]['generated_text'])
15
 
16
- if st.button("Restart"):
17
- st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- if __name__ == "__main__":
20
- main()
21
 
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Load DialoGPT model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
7
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
8
 
9
+ # Function to generate a response
10
+ def generate_response(chat_history_ids, new_user_input):
11
+ new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
12
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
13
+ return model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
14
 
15
+ # Streamlit app
16
+ st.title("DialoGPT Chat")
17
 
18
+ # Initialize chat history
19
+ chat_history_ids = None
 
20
 
21
+ # Conversation loop
22
+ while st.button("Restart Conversation"):
23
+ # Get user input
24
+ user_input = st.text_input("You:", "")
25
+
26
+ if user_input:
27
+ # Generate response
28
+ chat_history_ids = generate_response(chat_history_ids, user_input)
29
+
30
+ # Display DialoGPT response
31
+ st.text("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, -1][0], skip_special_tokens=True)))
32
+
33
+ # Inform the user that the conversation has ended
34
+ st.text("Conversation Ended. Press the 'Restart Conversation' button to start a new conversation.")
35
 
 
 
36