from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") class BaseStreamer: """ Base class from which `.generate()` streamers should inherit. """ def put(self, value): """Function that is called by `.generate()` to push new tokens""" raise NotImplementedError() def end(self): """Function that is called by `.generate()` to signal the end of generation""" raise NotImplementedError() class TextStreamer(BaseStreamer): """ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. The API for the streamer classes is still under development and may change in the future. Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. decode_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the tokenizer's `decode` method. Examples: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> streamer = TextStreamer(tok) >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, ``` """ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs # variables used in the streaming process self.token_cache = [] self.print_len = 0 self.next_tokens_are_prompt = True def put(self, value): """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TextStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return # Add the new token to the cache and decodes the entire thing. self.token_cache.extend(value.tolist()) text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 # If the last token is a CJK character, we print the characters. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): printable_text = text[self.print_len :] self.print_len += len(printable_text) # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, # which may change with the subsequent token -- there are probably smarter ways to do this!) else: printable_text = text[self.print_len : text.rfind(" ") + 1] self.print_len += len(printable_text) self.on_finalized_text(printable_text) def end(self): """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 else: printable_text = "" self.next_tokens_are_prompt = True self.on_finalized_text(printable_text, stream_end=True) def on_finalized_text(self, text: str, stream_end: bool = False): """Prints the new text to stdout. If the stream is ending, also prints a newline.""" # print(text, flush=True, end="" if not stream_end else None) messages.value = [ *messages.value[:-1], { "role": "assistant", "content": messages.value[-1]["content"] + text, }, ] def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False streamer = TextStreamer(tokenizer, skip_prompt=True) import re import solara from typing import List from typing_extensions import TypedDict class MessageDict(TypedDict): role: str content: str messages: solara.Reactive[List[MessageDict]] = solara.reactive([]) @solara.component def Page(): solara.lab.theme.themes.light.primary = "#0000ff" solara.lab.theme.themes.light.secondary = "#0000ff" solara.lab.theme.themes.dark.primary = "#0000ff" solara.lab.theme.themes.dark.secondary = "#0000ff" title = "Qwen2-0.5B" with solara.Head(): solara.Title(f"{title}") with solara.Column(align="center"): user_message_count = len([m for m in messages.value if m["role"] == "user"]) def send(message): messages.value = [*messages.value, {"role": "user", "content": message}] def response(message): messages.value = [*messages.value, {"role": "assistant", "content": ""}] text = tokenizer.apply_chat_template( [{"role": "user", "content": message}], tokenize=False, add_generation_prompt=True ) inputs = tokenizer(text, return_tensors="pt") _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512) def result(): if messages.value != []: response(messages.value[-1]["content"]) result = solara.lab.use_task(result, dependencies=[user_message_count]) with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}): for item in messages.value: with solara.lab.ChatMessage( user=item["role"] == "user", name="User" if item["role"] == "user" else "Qwen2-0.5B-Instruct", avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f", border_radius="20px", style="background-color:lightgrey!important;" ): item["content"] = re.sub('<\|im_end\|>', '', item["content"]) solara.Markdown(item["content"]) solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"})