import torch import os import shutil from transformers import AutoTokenizer from transformers import AutoModelForCausalLM from moe_infinity import MoE from typing import List, Tuple, Optional, Union from lm_eval.api.registry import register_model from src.backend.hflm_with_measurement import HFLMWithMeasurement @register_model("moe-infinity") class MoEHFLM(HFLMWithMeasurement): def __init__( self, pretrained: str = "mistralai/Mixtral-8x7B-Instruct-v0.1", moe_config: dict = None, offload_path=os.path.expanduser("~"), device_memory_ratio=0.75, use_chat_template=True, *args, **kwargs, ): # Initialize parent class without calling _create_model in the parent's __init__ self.checkpoint = pretrained self.moe_config = moe_config if moe_config is not None else {} self.offload_path = offload_path self.device_memory_ratio = device_memory_ratio self.use_chat_template = use_chat_template if "device" in kwargs: kwargs.pop("device") super().__init__( *args, **kwargs, pretrained=pretrained, device_map="cuda:0" ) # Assuming HFLM accepts a 'pretrained' arg and handles it # self._create_model() shutil.rmtree(os.path.join(self.offload_path, "moe-infinity-offloads")) def __del__(self): # Clean up offloaded models from self.offload_path shutil.rmtree(os.path.join(self.offload_path, "moe-infinity-offloads")) def _create_model(self, *args, **kwargs): """ Initializes the MoE model from MoE-infinity with the provided configuration. """ # Ensure default configurations are set if not provided default_moe_config = { "offload_path": os.path.join(self.offload_path, "moe-infinity-offloads"), "device_memory_ratio": self.device_memory_ratio, # Default value, adjust as necessary } # Update default config with any user-provided config final_moe_config = {**default_moe_config, **self.moe_config} # dirty fix, to be removed when MoE-infinity supports move input to correct device def MoEGenDecorator(func): def wrapper(*args, **kwargs): # Ensure all tensor in the input are in the same device as the model args = [arg.to("cuda:0") if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to("cuda:0") if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} return func(*args, **kwargs) return wrapper self._model = MoE(self.checkpoint, final_moe_config) self._model.generate = MoEGenDecorator(self._model.generate) # self._model = AutoModelForCausalLM.from_pretrained( # self.checkpoint, torch_dtype=torch.float16, device_map="auto" # ) @property def max_length(self): if self._max_length: # if max length manually set, return it return self._max_length seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") for attr in seqlen_config_attrs: if hasattr(self.model.model.config, attr): return getattr(self.model.model.config, attr) if hasattr(self.tokenizer, "model_max_length"): if self.tokenizer.model_max_length == 1000000000000000019884624838656: return self._DEFAULT_MAX_LENGTH return self.tokenizer.model_max_length return self._DEFAULT_MAX_LENGTH def tok_batch_encode( self, strings: List[str], padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.use_chat_template: try: updated_strings = [] for input_string in strings: messages = [ {"role": "user", "content": f"{input_string}"}, ] updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False) updated_strings.append(updated_string) strings = updated_strings[:] except: print(f"failed to update input string with chat template: {self._model}") # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. old_padding_side = self.tokenizer.padding_side self.tokenizer.padding_side = padding_side add_special_tokens = False encoding = self.tokenizer( strings, truncation=truncation, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens, ) if left_truncate_len: encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] self.tokenizer.padding_side = old_padding_side return encoding["input_ids"], encoding["attention_mask"]