File size: 4,812 Bytes
6bcba58
 
 
 
 
 
 
 
 
 
5103369
6bcba58
 
08d3dac
4482a0d
08d3dac
c3a1556
6bcba58
 
 
6b02e11
 
0653671
6b02e11
08d3dac
6b02e11
 
454b0bf
 
c3a1556
 
 
454b0bf
 
 
8861375
e17f0b6
 
 
 
 
 
 
 
 
 
 
 
 
 
6bcba58
923e263
 
2932ae3
 
 
 
6bcba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2932ae3
6bcba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08d3dac
6bcba58
e17f0b6
6bcba58
96ac3aa
6bcba58
 
 
 
c3a1556
6bcba58
 
 
 
c3a1556
 
6bcba58
 
 
 
 
c3a1556
6bcba58
 
 
c3a1556
 
 
2ad25d4
 
6bcba58
 
 
 
6b02e11
 
6bcba58
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import os
import spaces
from transformers import GemmaTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)


DESCRIPTION = '''
<div>
<h1 style="text-align: center;">🤖 BotBot Cabra 🐐 Llama 3 8b</h1>
<p>Conversa com o modelo <a href="https://huggingface.co/botbot-ai/CabraLlama3-8b"><b>BotBot Cabra Llama3 8b</b></a>.</p>
<p>🐐 Conheça os nossos outros <a href="https://huggingface.co/collections/botbot-ai/models-6604c2069ceef04f834ba99b3">modelos Cabra</a>.</p>
<p></p>
</div>
'''

LICENSE = """
<p/>

---
Esse modelo pode gerar inverdades, mentiras ou ofensas. Somente para teste e validação de modelos de linguagem. Proibido para uso comercial.
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://uploads-ssl.webflow.com/65f77c0240ae1c68f8192771/66299ba8957d9bb8fb5d1d12_image.png" style="width: 70%; max-width: 550px; height: auto; opacity: 0.6;  "> 
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">BotBot Cabra</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Faça uma pergunta...</p>
</div>
"""


css = """
h1 {
  text-align: center;
  display: block;
}

#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("botbot-ai/CabraLlama3-8b")
model = AutoModelForCausalLM.from_pretrained("botbot-ai/CabraLlama3-8b", device_map="auto")  # to("cuda:0") 
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

@spaces.GPU(duration=120)
def chat_llama3_8b(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             ) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    conversation = []
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids= input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        print(outputs)
        yield "".join(outputs)
        

# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='BotBot Cabra Llama 3')

with gr.Blocks(fill_height=True, css=css) as demo:
    
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Paramentos", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.6, 
                      label="Temperatura", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max novos tokens", 
                      render=False ),
            ],
        examples=[
            ['Como cirar uma base humana em marte, em 5 passos?'],
            ['Who is Elon Musk?'],
            ['Quem desenhou e criou Brasilia?'],
            ['Traduz o seguinte texto: "The quick brown fox jumps over the lazy dog."'],
            ['Me conta um pouco sobre o rio amazonas']
            ],
        cache_examples=False,
                     )
    
    gr.Markdown(LICENSE)
    
if __name__ == "__main__":
    demo.launch()