CabraLlama3 / app.py
nicolasdec's picture
Update app.py
4482a0d verified
raw
history blame contribute delete
No virus
4.81 kB
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">🤖 BotBot Cabra 🐐 Llama 3 8b</h1>
<p>Conversa com o modelo <a href="https://huggingface.co/botbot-ai/CabraLlama3-8b"><b>BotBot Cabra Llama3 8b</b></a>.</p>
<p>🐐 Conheça os nossos outros <a href="https://huggingface.co/collections/botbot-ai/models-6604c2069ceef04f834ba99b3">modelos Cabra</a>.</p>
<p></p>
</div>
'''
LICENSE = """
<p/>
---
Esse modelo pode gerar inverdades, mentiras ou ofensas. Somente para teste e validação de modelos de linguagem. Proibido para uso comercial.
"""
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://uploads-ssl.webflow.com/65f77c0240ae1c68f8192771/66299ba8957d9bb8fb5d1d12_image.png" style="width: 70%; max-width: 550px; height: auto; opacity: 0.6; ">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">BotBot Cabra</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Faça uma pergunta...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("botbot-ai/CabraLlama3-8b")
model = AutoModelForCausalLM.from_pretrained("botbot-ai/CabraLlama3-8b", device_map="auto") # to("cuda:0")
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
@spaces.GPU(duration=120)
def chat_llama3_8b(message: str,
history: list,
temperature: float,
max_new_tokens: int
) -> str:
"""
Generate a streaming response using the llama3-8b model.
Args:
message (str): The input message.
history (list): The conversation history used by ChatInterface.
temperature (float): The temperature for generating the response.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated response.
"""
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids= input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=terminators,
)
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
print(outputs)
yield "".join(outputs)
# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='BotBot Cabra Llama 3')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Paramentos", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.6,
label="Temperatura",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max novos tokens",
render=False ),
],
examples=[
['Como cirar uma base humana em marte, em 5 passos?'],
['Who is Elon Musk?'],
['Quem desenhou e criou Brasilia?'],
['Traduz o seguinte texto: "The quick brown fox jumps over the lazy dog."'],
['Me conta um pouco sobre o rio amazonas']
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()