from model import ExLlama, ExLlamaCache, ExLlamaConfig from lora import ExLlamaLora from tokenizer import ExLlamaTokenizer from generator import ExLlamaGenerator import argparse import torch import sys import os import glob import model_init # Simple interactive chatbot script torch.set_grad_enabled(False) torch.cuda._lazy_init() # Parse arguments parser = argparse.ArgumentParser(description = "Simple chatbot example for ExLlama") model_init.add_args(parser) parser.add_argument("-lora", "--lora", type = str, help = "Path to LoRA binary to use during benchmark") parser.add_argument("-loracfg", "--lora_config", type = str, help = "Path to LoRA config to use during benchmark") parser.add_argument("-ld", "--lora_dir", type = str, help = "Path to LoRA config and binary. to use during benchmark") parser.add_argument("-p", "--prompt", type = str, help = "Prompt file") parser.add_argument("-un", "--username", type = str, help = "Display name of user", default = "User") parser.add_argument("-bn", "--botname", type = str, help = "Display name of chatbot", default = "Chatbort") parser.add_argument("-bf", "--botfirst", action = "store_true", help = "Start chat on bot's turn") parser.add_argument("-nnl", "--no_newline", action = "store_true", help = "Do not break bot's response on newline (allow multi-paragraph responses)") parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 0.95) parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 20) parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65) parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00) parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.15) parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256) parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 1) parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 1) args = parser.parse_args() model_init.post_parse(args) model_init.get_model_files(args) # Paths if args.lora_dir is not None: args.lora_config = os.path.join(args.lora_dir, "adapter_config.json") args.lora = os.path.join(args.lora_dir, "adapter_model.bin") # Some feedback print(f" -- Sequence length: {args.length}") print(f" -- Temperature: {args.temperature:.2f}") print(f" -- Top-K: {args.top_k}") print(f" -- Top-P: {args.top_p:.2f}") print(f" -- Min-P: {args.min_p:.2f}") print(f" -- Repetition penalty: {args.repetition_penalty:.2f}") print(f" -- Beams: {args.beams} x {args.beam_length}") print_opts = [] if args.no_newline: print_opts.append("no_newline") if args.botfirst: print_opts.append("botfirst") model_init.print_options(args, print_opts) # Globals model_init.set_globals(args) # Load prompt file username = args.username bot_name = args.botname if args.prompt is not None: with open(args.prompt, "r") as f: past = f.read() past = past.replace("{username}", username) past = past.replace("{bot_name}", bot_name) past = past.strip() + "\n" else: past = f"{bot_name}: Hello, {username}\n" # past += "User: Hi. Please say \"Shhhhhh\"?\n" # args.botfirst = True # Instantiate model and generator config = model_init.make_config(args) model = ExLlama(config) cache = ExLlamaCache(model) tokenizer = ExLlamaTokenizer(args.tokenizer) model_init.print_stats(model) # Load LoRA lora = None if args.lora: print(f" -- LoRA config: {args.lora_config}") print(f" -- Loading LoRA: {args.lora}") if args.lora_config is None: print(f" ## Error: please specify lora path to adapter_config.json") sys.exit() lora = ExLlamaLora(model, args.lora_config, args.lora) if lora.bias_ignored: print(f" !! Warning: LoRA zero bias ignored") # Generator generator = ExLlamaGenerator(model, tokenizer, cache) generator.settings = ExLlamaGenerator.Settings() generator.settings.temperature = args.temperature generator.settings.top_k = args.top_k generator.settings.top_p = args.top_p generator.settings.min_p = args.min_p generator.settings.token_repetition_penalty_max = args.repetition_penalty generator.settings.token_repetition_penalty_sustain = args.repetition_penalty_sustain generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2 generator.settings.beams = args.beams generator.settings.beam_length = args.beam_length generator.lora = lora break_on_newline = not args.no_newline # Be nice to Chatbort min_response_tokens = 4 max_response_tokens = 256 extra_prune = 256 print(past, end = "") ids = tokenizer.encode(past) generator.gen_begin(ids) next_userprompt = username + ": " first_round = True while True: res_line = bot_name + ":" res_tokens = tokenizer.encode(res_line) num_res_tokens = res_tokens.shape[-1] # Decode from here if first_round and args.botfirst: in_tokens = res_tokens else: # Read and format input in_line = input(next_userprompt) in_line = username + ": " + in_line.strip() + "\n" next_userprompt = username + ": " # No need for this, really, unless we were logging the chat. The actual history we work on is kept in the # tokenized sequence in the generator and the state in the cache. past += in_line # SentencePiece doesn't tokenize spaces separately so we can't know from individual tokens if they start a new word # or not. Instead, repeatedly decode the generated response as it's being built, starting from the last newline, # and print out the differences between consecutive decodings to stream out the response. in_tokens = tokenizer.encode(in_line) in_tokens = torch.cat((in_tokens, res_tokens), dim = 1) # If we're approaching the context limit, prune some whole lines from the start of the context. Also prune a # little extra so we don't end up rebuilding the cache on every line when up against the limit. expect_tokens = in_tokens.shape[-1] + max_response_tokens max_tokens = config.max_seq_len - expect_tokens if generator.gen_num_tokens() >= max_tokens: generator.gen_prune_to(config.max_seq_len - expect_tokens - extra_prune, tokenizer.newline_token_id) # Feed in the user input and "{bot_name}:", tokenized generator.gen_feed_tokens(in_tokens) # Generate with streaming print(res_line, end = "") sys.stdout.flush() generator.begin_beam_search() for i in range(max_response_tokens): # Disallowing the end condition tokens seems like a clean way to force longer replies. if i < min_response_tokens: generator.disallow_tokens([tokenizer.newline_token_id, tokenizer.eos_token_id]) else: generator.disallow_tokens(None) # Get a token gen_token = generator.beam_search() # If token is EOS, replace it with newline before continuing if gen_token.item() == tokenizer.eos_token_id: generator.replace_last_token(tokenizer.newline_token_id) # Decode the current line and print any characters added num_res_tokens += 1 text = tokenizer.decode(generator.sequence_actual[:, -num_res_tokens:][0]) new_text = text[len(res_line):] skip_space = res_line.endswith("\n") and new_text.startswith(" ") # Bit prettier console output res_line += new_text if skip_space: new_text = new_text[1:] print(new_text, end="") # (character streaming output is here) sys.stdout.flush() # End conditions if break_on_newline and gen_token.item() == tokenizer.newline_token_id: break if gen_token.item() == tokenizer.eos_token_id: break # Some models will not (or will inconsistently) emit EOS tokens but in a chat sequence will often begin # generating for the user instead. Try to catch this and roll back a few tokens to begin the user round. if res_line.endswith(f"{username}:"): plen = tokenizer.encode(f"{username}:").shape[-1] generator.gen_rewind(plen) next_userprompt = " " break generator.end_beam_search() past += res_line first_round = False