File size: 4,293 Bytes
14e4843
 
 
794e78c
33e1c9d
14e4843
 
 
 
3020792
 
d6d7ec6
 
3020792
14e4843
 
 
 
d6d7ec6
 
14e4843
 
d6d7ec6
14e4843
 
 
 
 
 
 
794e78c
 
d6d7ec6
 
 
14e4843
 
 
 
 
 
 
 
 
 
 
 
 
33e1c9d
 
 
 
14e4843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6d7ec6
14e4843
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import os
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from moe_infinity import MoE
from typing import List, Tuple, Optional, Union

from lm_eval.api.registry import register_model

from src.backend.hflm_with_measurement import HFLMWithMeasurement


@register_model("moe-infinity")
class MoEHFLM(HFLMWithMeasurement):
    def __init__(
        self,
        pretrained: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
        moe_config: dict = None,
        offload_path=os.path.expanduser("~"),
        device_memory_ratio=0.75,
        use_chat_template=True,
        *args,
        **kwargs,
    ):
        # Initialize parent class without calling _create_model in the parent's __init__
        self.checkpoint = pretrained
        self.moe_config = moe_config if moe_config is not None else {}
        self.offload_path = offload_path
        self.device_memory_ratio = device_memory_ratio
        self.use_chat_template = use_chat_template
        if "device" in kwargs:
            kwargs.pop("device")
        super().__init__(
            *args, **kwargs, pretrained=pretrained, device_map="cuda:0"
        )  # Assuming HFLM accepts a 'pretrained' arg and handles it
        # self._create_model()

    def _create_model(self, *args, **kwargs):
        """
        Initializes the MoE model from MoE-infinity with the provided configuration.
        """
        # Ensure default configurations are set if not provided
        default_moe_config = {
            "offload_path": os.path.join(self.offload_path, "moe-infinity-offloads"),
            "device_memory_ratio": self.device_memory_ratio,  # Default value, adjust as necessary
        }
        # Update default config with any user-provided config
        final_moe_config = {**default_moe_config, **self.moe_config}
        self._model = MoE(self.checkpoint, final_moe_config)
        # self._model = AutoModelForCausalLM.from_pretrained(
        #     self.checkpoint, torch_dtype=torch.float16, device_map="auto"
        # )

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.model.config, attr):
                return getattr(self.model.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH

    def tok_batch_encode(
        self,
        strings: List[str],
        padding_side: str = "left",
        left_truncate_len: int = None,
        truncation: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        if self.use_chat_template:
            try:
                updated_strings = []
                for input_string in strings:
                    messages = [
                        {"role": "user", "content": f"{input_string}"},
                    ]
                    updated_string = self.tokenizer.apply_chat_template(messages, tokenize=False)
                    updated_strings.append(updated_string)
                strings = updated_strings[:]
            except:
                print(f"failed to update input string with chat template: {self._model}")
        # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
        old_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = padding_side

        add_special_tokens = False

        encoding = self.tokenizer(
            strings,
            truncation=truncation,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )
        if left_truncate_len:
            encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
            encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:]
        self.tokenizer.padding_side = old_padding_side

        return encoding["input_ids"], encoding["attention_mask"]