File size: 2,058 Bytes
70766ea
 
 
cdc36a5
 
 
 
 
70766ea
 
 
902588d
70766ea
902588d
 
 
 
95b6c1d
 
 
e990869
83c961c
70766ea
 
 
 
d14b77d
70766ea
 
e990869
70766ea
 
 
31bd0a0
70766ea
 
 
e990869
70766ea
 
e990869
70766ea
 
 
 
 
 
588fda0
70766ea
 
 
 
 
 
4e4c6d3
70766ea
 
 
 
 
e990869
70766ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from threading import Thread
import gradio as gr
import torch
from transformers import (
    pipeline,
    AutoTokenizer,
    TextIteratorStreamer,
)


def chat_history(history) -> str:
    messages = []

    for dialog in history:
        for i, message in enumerate(dialog):
            role = "user" if i % 2 == 0 else "assistant"
            messages.append({"role": role, "content": message})
            
    messages.pop(-1)
    
    return pipe.tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )


def model_loading_pipeline():
    model_id = "vilm/vinallama-2.7b"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
    pipe = pipeline(
        "text-generation",
        model=model_id,
        model_kwargs={
            "torch_dtype": torch.bfloat16,
        },
        streamer=streamer,
    )
    return pipe, streamer


def launch_app(pipe, streamer):
    with gr.Blocks() as demo:
        chat = gr.Chatbot()
        msg = gr.Textbox()
        clear = gr.Button("Clear")

        def user(user_message, history):
            return "", history + [[user_message, None]]

        def bot(history):
            prompt = chat_history(history)
            history[-1][1] = ""
            kwargs = {
                "text_inputs": prompt,
                "max_new_tokens": 64,
                "do_sample": True,
                "temperature": 0.7,
                "top_k": 50,
                "top_p": 0.95,
            }
            thread = Thread(target=pipe, kwargs=kwargs)
            thread.start()

            for token in streamer:
                history[-1][1] += token
                yield history

        msg.submit(user, [msg, chat], [msg, chat], queue=False).then(bot, chat, chat)
        clear.click(lambda: None, None, chat, queue=False)

    demo.queue()
    demo.launch(share=True, debug=True)


if __name__ == "__main__":
    pipe, streamer = model_loading_pipeline()
    launch_app(pipe, streamer)