esgClassifier / app.py
dammy's picture
Update app.py
6296560
raw
history blame
No virus
971 Bytes
import gradio as gr
import transformers
# Load a pre-trained model.
model = transformers.AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
# Define a function to generate text.
def generate_text(text):
"""Generates text based on a given prompt."""
# Tokenize the input text.
input_ids = model.tokenizer.encode(text, return_tensors="pt")
# Generate text.
output_ids = model.generate(input_ids=input_ids, max_length=100, num_beams=5)
# Decode the output text.
output_text = model.tokenizer.decode(output_ids[0])
return output_text
# Define the Gradio interface.
chat_box = gr.inputs.Textbox(label="Chat Box")
chat_button = gr.Button("Send")
chat_response = gr.outputs.Textbox(label="Chat Response")
# Connect the inputs and outputs to the generate_text function.
chat_button.click(generate_text, chat_box, chat_response)
# Launch the Gradio interface.
interface = gr.Interface([chat_box, chat_button], [chat_response])
interface.launch()