text_generation / app.py
azizmma's picture
Update app.py
f6a954e
raw
history blame
No virus
2.17 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_name = "gpt2-large"
@st.cache
def load_pipeline(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,pad_token_id=tokenizer.eos_token_id)
return pipeline('text-generation', model=model, tokenizer=tokenizer)
pipeline = load_pipeline(model_name)
default_value = "But not just any roof cleaning will do."
#prompts
st.title("Text Extension or Generation")
st.write("Placeholder for some other texts, like instructions...")
sent = st.text_area("Text", default_value, height = 250)
max_length = st.sidebar.slider("Max Length", value = 50, min_value = 30, max_value=150)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
num_return_sequences = st.sidebar.slider("Num Return Sequences", min_value = 1, max_value=4, value = 1)
num_beams = st.sidebar.slider("Num Beams", min_value = 4, max_value=6, value = 4)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 90)
top_p = st.sidebar.slider("Top-p", min_value = 0.4, max_value=1.0, step = 0.05, value = 0.9)
if len(sent)<40:
print ("Input prompt is too small to generate")
def infer(input_text, **generator_args):
output_sequences = pipeline(
input_text, **generator_args
)
return output_sequences
output_sequences = infer(sent,
max_length=max_length,
num_return_sequences=num_return_sequences,
num_beams=num_beams,
temperature=temperature,
top_k=top_k,
early_stopping=False,
top_p=top_p)
generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = list(generated_sequence.values())[0]
generated_sequences.append(generated_sequence.replace('\n',' '))
st.write('\n'.join(generated_sequences))