import cuda_ext from model import ExLlama, ExLlamaCache from tokenizer import ExLlamaTokenizer from lora import ExLlamaLora import torch import torch.nn.functional as F MAX_CACHED_STRINGS = 100 class ExLlamaAltGenerator: class Settings: temperature = 0.95 top_k = 40 # consider the most probable top_k samples, 0 to disable top_k sampling top_p = 0.65 # consider tokens up to a cumulative probabiltiy of top_p, 0.0 to disable top_p sampling min_p = 0.0 # Do not consider tokens with probability less than this typical = 0.0 # Locally typical sampling threshold, 0.0 to disable typical sampling token_repetition_penalty_max = 1.15 # Repetition penalty for most recent tokens token_repetition_penalty_sustain = -1 # No. most recent tokens to repeat penalty for, -1 to apply to whole context token_repetition_penalty_decay = 0 # Gradually decrease penalty over this many tokens disallowed_tokens: list[int] = None # List of tokens to inhibit, e.g. tokenizer.eos_token_id lora: ExLlamaLora = None # LoRA to apply when generating # Internal state model: ExLlama cache: ExLlamaCache tokenizer: ExLlamaTokenizer tokenizer_cache = {} settings: Settings stop_strings: list = [] stop_tokens: list = [] held_text: str = "" max_stop_tokens: int = 2 sequence_ids: torch.Tensor = None sequence_str: str = None remaining_tokens: int = 0 def __init__(self, model: ExLlama, tokenizer: ExLlamaTokenizer, cache: ExLlamaCache): self.model = model self.tokenizer = tokenizer self.cache = cache self.settings = ExLlamaAltGenerator.Settings() def cached_tokenize(self, text: str, encode_special_characters = False): if text in self.tokenizer_cache: return self.tokenizer_cache[text] while len(self.tokenizer_cache) >= MAX_CACHED_STRINGS: del self.tokenizer_cache[next(iter(self.tokenizer_cache))] # Always removes oldest entry, as of Python 3.7 new_enc = self.tokenizer.encode(text, encode_special_characters = encode_special_characters) self.tokenizer_cache[text] = new_enc return new_enc def get_num_tokens(self, text: str, encode_special_characters = False): return self.cached_tokenize(text, encode_special_characters = encode_special_characters).shape[-1] # Begin generating # # prompt: The input prompt. Will be tokenized and then truncated to the models max sequence length minus max_new_tokens # stop_conditions: List of strings or integer token IDs that will end the sequence # settings: ExLlamaAltGeneratorSettings # encode_special_characters: Set to true to tokenize "" etc. def begin_stream(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): assert isinstance(prompt, str), "ExLlamaAltGenerator does not support batched generation" # Tokenize prompt and limit length to allow prompt and (max) new tokens within max sequence length max_input_tokens = self.model.config.max_seq_len - max_new_tokens self.remaining_tokens = max_new_tokens input_ids = self.cached_tokenize(prompt, encode_special_characters) applied_input_ids = input_ids[:, -max_input_tokens:] self.sequence_str = self.tokenizer.decode(applied_input_ids)[0] if applied_input_ids.shape[0] < input_ids.shape[0] else prompt # Settings self.stop_strings = [] self.stop_tokens = [] for t in stop_conditions: if isinstance(t, int): self.stop_tokens += [t] elif isinstance(t, str): self.stop_strings += [t] else: raise ValueError("Unsupported type in stop_conditions") self.held_text = "" self.max_stop_tokens = 2 for ss in self.stop_strings: self.max_stop_tokens = max(self.max_stop_tokens, self.get_num_tokens(ss) + 2) self.settings = gen_settings # Start generation self.gen_begin_reuse(applied_input_ids, gen_settings) # Get the next chunk of text in the stream # # Returns stream_chunk: str, EOS: bool def stream(self): # Check total response length if self.remaining_tokens == 0: self.sequence_str += self.held_text return self.held_text, True self.remaining_tokens -= 1 # Decode the current tail end of the sequence old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.max_stop_tokens:])[0] # Generate a single token and append to the sequence next_token = self.gen_single_token(self.settings) # End immediately if it was a stop token if next_token in self.stop_tokens: self.sequence_str += self.held_text return self.held_text, True # Decode the tail end of the sequence with the added token to get (actual) characters added new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.max_stop_tokens + 1):])[0] self.held_text += new_tail[len(old_tail):] # Hold text as long as it contains part of a stop string partial_ss = False for ss in self.stop_strings: # Check if held_text fully contains stop string position = self.held_text.find(ss) if position != -1: self.sequence_str += self.held_text[:position] return self.held_text[:position], True # Check for overlap between end of held_text and start of stop string overlap = 0 for j in range(1, min(len(self.held_text), len(ss)) + 1): if self.held_text[-j:] == ss[:j]: overlap = j if overlap > 0: partial_ss = True # If holding text because of a partial stop condition, return nothing but also EOS = False if partial_ss: return "", False # No stop condition, so return whatever has been generated stream_text = self.held_text self.held_text = "" self.sequence_str += stream_text return stream_text, False # Generate and return a full prompt completion. See begin_stream() above def generate(self, prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: Settings, encode_special_characters = False): self.begin_stream(prompt, stop_conditions, max_new_tokens, gen_settings, encode_special_characters) response = "" while True: chunk, eos = self.stream() response += chunk if eos: break return response # Begin generation def gen_begin(self, in_tokens, gen_settings): self.sequence_ids = in_tokens.clone() self.cache.current_seq_len = 0 self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, lora = gen_settings.lora) def gen_begin_reuse(self, in_tokens, gen_settings): if self.sequence_ids is None or self.cache.current_seq_len == 0: self.gen_begin(in_tokens, gen_settings) return reuse = 0 while reuse < self.sequence_ids.shape[-1] and reuse < in_tokens.shape[-1] and self.sequence_ids[0, reuse] == in_tokens[0, reuse]: reuse += 1 if reuse < 2: self.gen_begin(in_tokens, gen_settings) return self.cache.current_seq_len = reuse - 1 self.sequence_ids = in_tokens[:, :reuse] if reuse < in_tokens.shape[-1]: self.gen_feed_tokens(in_tokens[:, reuse:], gen_settings) def gen_feed_tokens(self, in_tokens, gen_settings): if self.sequence_ids is None: self.gen_begin(in_tokens, gen_settings) return start = self.cache.current_seq_len self.sequence_ids = torch.cat((self.sequence_ids, in_tokens), dim = 1) self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, lora = gen_settings.lora) # Generate one token in current sequence def gen_single_token(self, gen_settings): # Simple sampling case: logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, lora = gen_settings.lora) token, _ = self.sample(logits, gen_settings) self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) return token def sample(self, logits, gen_settings): cuda_ext.ext_apply_rep_penalty_mask_cpu(self.sequence_ids, self.settings.token_repetition_penalty_max, self.settings.token_repetition_penalty_sustain, self.settings.token_repetition_penalty_decay, logits) logits[:, :, self.tokenizer.bos_token_id] = -10000.0 if logits.dim() == 3: logits = logits[0, -1, :] elif logits.dim() == 2: logits = logits[-1, :] else: raise ValueError("Bad logits dimension") # Disallow tokens if gen_settings.disallowed_tokens is not None: logits[gen_settings.disallowed_tokens] = float("-inf") # Base probabilities logits /= gen_settings.temperature logits += 1e-8 probs = torch.softmax(logits, dim = -1) # Top K if gen_settings.top_k == 0: top_probs, top_indices = torch.sort(probs, descending = True) else: top_probs, top_indices = torch.topk(probs, gen_settings.top_k) top_probs = F.normalize(top_probs, p = 1, dim = -1) # Top P if gen_settings.top_p > 0.0: num_top_p_probs = 0 cum_prob = top_probs[0].item() while True: num_top_p_probs += 1 if num_top_p_probs == top_probs.shape[-1]: break if top_probs[num_top_p_probs].item() < gen_settings.min_p: break cum_prob += top_probs[num_top_p_probs].item() if cum_prob > gen_settings.top_p: break top_probs = top_probs[:num_top_p_probs] top_probs = F.normalize(top_probs, p = 1, dim = -1) top_indices = top_indices[:num_top_p_probs] # Locally typical sampling if gen_settings.typical > 0.0: epsilon = 1e-10 log_probs = (top_probs + epsilon).log() neg_entropy = (top_probs * log_probs).sum() entropy_dev = (neg_entropy - log_probs).abs() _, entropy_dev_order = torch.sort(entropy_dev) top_probs = top_probs.gather(-1, entropy_dev_order) top_indices = top_indices.gather(-1, entropy_dev_order) num_typical_probs = 0 cum_prob = top_probs[0].item() while True: num_typical_probs += 1 if num_typical_probs == top_probs.shape[-1]: break cum_prob += top_probs[num_typical_probs].item() if cum_prob > gen_settings.typical: break top_probs = top_probs[:num_typical_probs] top_probs = F.normalize(top_probs, p = 1, dim = -1) top_indices = top_indices[:num_typical_probs] # Multinomial sampling from top_probs, kept in same order as top_indices sampled_ind = torch.multinomial(top_probs, 1) sampled_tokens = top_indices[sampled_ind] sampled_probs = top_probs[sampled_ind] # Return probs before second norm # if sampled_tokens.shape[0] > 1: # sampled_tokens, ind = sampled_tokens.sort() # sampled_probs = sampled_probs[ind] return sampled_tokens.unsqueeze(0), sampled_probs.unsqueeze(0)