import copy import os from datetime import timedelta from time import time from pathlib import Path from typing import List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F import transformers from accelerate import ( Accelerator, DistributedType, InitProcessGroupKwargs, find_executable_batch_size, ) from packaging import version from peft import PeftModel from peft import __version__ as PEFT_VERSION from tqdm import tqdm from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, ) from transformers import TextStreamer from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import ( Collator, clear_torch_cache, get_dtype, pad_and_concat, stop_sequences_criteria, ) from lm_eval.models.huggingface import HFLM class StopWatch(TextStreamer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.start_prefilling = None self.prefilling_time = None self.start_decoding = None self.decoding_time = None self.decoding_iterations = 0 def put(self, value): if self.start_prefilling is None: self.start_prefilling = time() return elif self.prefilling_time is None: self.prefilling_time = time() - self.start_prefilling self.start_decoding = time() self.decoding_iterations += 1 return def end(self): if self.decoding_time is None and self.start_decoding is not None: self.decoding_time = time() - self.start_decoding return class HFLMWithMeasurement(HFLM): def __init__(self, **kwargs): super().__init__(**kwargs) def _model_generate(self, context, max_length, stop, **generation_kwargs): # temperature = 0.0 if not set # if do_sample is false and temp==0.0: # remove temperature, as do_sample=False takes care of this # and we don't want a warning from HF generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", None) # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies if generation_kwargs.get("temperature") == 0.0 and do_sample is None: generation_kwargs["do_sample"] = do_sample = False if do_sample is False and generation_kwargs.get("temperature") == 0.0: generation_kwargs.pop("temperature") # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, context.shape[1], context.shape[0] ) stop_watch = StopWatch(self.tokenizer) start = time() res = self.model.generate( input_ids=context, max_length=max_length, stopping_criteria=stopping_criteria, pad_token_id=self.tokenizer.pad_token_id, use_cache=True, streamer=stop_watch, **generation_kwargs, ) end = time() batch_size = context.shape[0] output_length = stop_watch.decoding_iterations end_to_end_time = (end - start) / batch_size prefilling_time = stop_watch.prefilling_time / batch_size decoding_time = stop_watch.decoding_time / batch_size token_per_sec = output_length / decoding_time return res, end_to_end_time, prefilling_time, token_per_sec def generate_until( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[str]: res = [] def _collate(req: Tuple[str, dict]): """Defines the key for the sorted method""" # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tok_encode(req[0]) return -len(toks), req[0] pbar = tqdm( total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests", ) adaptive_batch_size = None if self.batch_size == "auto": # using rolling window with maximum context print("Passed argument batch_size = auto. Detecting largest batch size") batch_size = self._detect_batch_size() print(f"Determined Largest batch size: {batch_size}") adaptive_batch_size = batch_size # for each different set of kwargs, we execute all requests, by batch. batch_size = ( self.batch_size if self.batch_size != "auto" else adaptive_batch_size if adaptive_batch_size is not None else 0 ) batch_fn = ( self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None ) # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) re_ords = Collator( [reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs", group_fn=lambda x: x[1], ) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) for chunk in chunks: contexts, all_gen_kwargs = zip(*chunk) # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0] # unpack our keyword arguments. until = None if isinstance(gen_kwargs, dict): kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "until" in kwargs.keys(): until = kwargs.pop("until") if isinstance(until, str): until = [kwargs] elif not isinstance(until, list): raise ValueError( f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" ) else: raise ValueError( f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" ) # add EOS token to stop sequences eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) if not until: until = [eos] else: until.append(eos) if "max_gen_toks" in kwargs.keys(): max_gen_toks = kwargs.pop("max_gen_toks") else: max_gen_toks = self.max_gen_toks # set the max length in tokens of inputs ("context_enc") if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # max len for inputs = max length, minus room to generate the max new tokens max_ctx_len = self.max_length - max_gen_toks elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: # max len for inputs = encoder's whole max_length max_ctx_len = self.max_length # encode, pad, and truncate contexts for this batch context_enc, attn_masks = self.tok_batch_encode( contexts, left_truncate_len=max_ctx_len, truncation=self.truncation, ) context_enc = context_enc.to(self.device) attn_masks = attn_masks.to(self.device) if "max_length" not in kwargs: kwargs["max_length"] = context_enc.shape[1] + max_gen_toks # perform batched generation cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate( context=context_enc, attention_mask=attn_masks, stop=until, **kwargs, ) cont_toks_list = cont.tolist() for cont_toks, context in zip(cont_toks_list, contexts): # discard context + left-padding toks if using causal decoder-only LM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: cont_toks = cont_toks[context_enc.shape[1] :] s = self.tok_decode(cont_toks) # use secondary stop seqs to cut off should-have-been-stopped content post-hoc for term in until: if len(term) > 0: # ignore '' separator, # for seq2seq case where self.tok_decode(self.eot_token_id) = '' s = s.split(term)[0] res.append((s, end_to_end_time, prefilling_time, token_per_sec)) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res