Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,13 @@
|
|
1 |
-
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
-
import torch
|
4 |
-
|
5 |
-
# Load DialoGPT model and tokenizer
|
6 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
7 |
-
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
8 |
-
|
9 |
-
# Function to generate a response
|
10 |
-
def generate_response(chat_history_ids, new_user_input):
|
11 |
-
new_user_input_ids = tokenizer.encode(new_user_input + tokenizer.eos_token, return_tensors='pt')
|
12 |
-
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids
|
13 |
-
return model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# Initialize chat history using st.session_state
|
19 |
-
if 'chat_history_ids' not in st.session_state:
|
20 |
-
st.session_state.chat_history_ids = None
|
21 |
-
|
22 |
-
# User input text box
|
23 |
-
user_input = st.text_input("You:", "")
|
24 |
|
25 |
-
|
26 |
-
if st.button("Send"):
|
27 |
-
if user_input:
|
28 |
-
# Generate response using st.session_state
|
29 |
-
st.session_state.chat_history_ids = generate_response(st.session_state.chat_history_ids, user_input)
|
30 |
|
31 |
-
|
32 |
-
st.text("DialoGPT: {}".format(tokenizer.decode(st.session_state.chat_history_ids[:, -1][0], skip_special_tokens=True)))
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
1 |
+
# prompt: write a streamlit app that converts english text into french text when translate button is pressed. The title of page should be "Translate to French"."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
import streamlit as st
|
4 |
+
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
st.title("Translate to French")
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
input_text = st.text_input("Enter text to translate")
|
|
|
9 |
|
10 |
+
if st.button("Translate"):
|
11 |
+
pipe = pipeline(model="facebook/mbart-large-cc25")
|
12 |
+
translation = pipe(input_text, target_lang="fr")[0]["translation_text"]
|
13 |
+
st.write(translation)
|