import sys min_version = (3, 9) if sys.version_info < min_version: print("") print(f" ## Warning: this project requires Python {min_version[0]}.{min_version[1]} or higher.") print("") import torch from torch import nn import torch.nn.functional as F from safetensors import safe_open import cuda_ext import json import math import gc from enum import Enum try: from flash_attn import flash_attn_func except: pass class ParsedEnum(Enum): def __str__(self): return self.name.lower() def __repr__(self): return str(self) @classmethod def argparse(cls, s): try: return cls[s.upper()] except KeyError: return s class ExLlamaConfig: # Load config from Llama config.json def __init__(self, model_config_path): with open(model_config_path) as f: read_config = json.load(f) # Loaded/automatic settings self.bos_token_id = read_config["bos_token_id"] if "bos_token_id" in read_config else 1 self.eos_token_id = read_config["eos_token_id"] if "eos_token_id" in read_config else 2 self.pad_token_id = read_config["pad_token_id"] if "pad_token_id" in read_config else 0 self.hidden_size = read_config["hidden_size"] self.initializer_range = read_config["initializer_range"] self.intermediate_size = read_config["intermediate_size"] self.num_attention_heads = read_config["num_attention_heads"] self.num_hidden_layers = read_config["num_hidden_layers"] self.rms_norm_eps = read_config["rms_norm_eps"] self.vocab_size = read_config["vocab_size"] if "num_key_value_heads" in read_config: self.num_key_value_heads = read_config["num_key_value_heads"] self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads else: self.num_key_value_heads = self.num_attention_heads self.num_key_value_groups = 1 self.rotary_embedding_base = read_config["rope_theta"] if "rope_theta" in read_config else 10000.0 self.head_dim = self.hidden_size // self.num_attention_heads self.groupsize = None # Autodetected self.act_order = False # Autodetected self.empty_g_idx = False # Autodetected # Required settings self.model_path = None # str or list[str] self.device_map = ExLlamaDeviceMap(self.num_hidden_layers) # Optional settings self.max_seq_len = 2048 # Reduce to save memory. Can also be increased, ideally while also using compress_pos_emn and a compatible model/LoRA self.max_input_len = 2048 # Maximum length of input IDs in a single forward pass. Sequences longer than this will be processed in multiple steps self.max_attention_size = 2048**2 # Sequences will be processed in chunks to keep the size of the attention weights matrix <= this self.compress_pos_emb = 1.0 # Increase to compress positional embeddings applied to sequence self.alpha_value = 1.0 # Alpha value for NTK RoPE scaling. Similar to compress_pos_emb, higher values increaste ctx but add Perplexity. self.gpu_peer_fix = False # Apparently Torch can have problems transferring tensors directly one GPU to another sometimes. Enable this to expliticly move tensors via system RAM instead, where needed self.auto_map = None # List of floats with memory allocation in GB, per CUDA device, overrides device_map # Tuning self.use_flash_attn_2 = False self.matmul_recons_thd = 8 self.fused_mlp_thd = 2 self.sdp_thd = 8 self.fused_attn = True self.matmul_fused_remap = False self.rmsnorm_no_half2 = False self.rope_no_half2 = False self.matmul_no_half2 = False self.silu_no_half2 = False self.concurrent_streams = False # Copy tuning params to C++ extension def set_tuning_params(self): cuda_ext.exllama_ext.set_tuning_params(self.matmul_recons_thd, self.fused_mlp_thd, self.sdp_thd, self.matmul_fused_remap, self.rmsnorm_no_half2, self.rope_no_half2, self.matmul_no_half2, self.silu_no_half2, self.concurrent_streams) # Parse and set list of GPU VRAM allocations def set_auto_map(self, map_string): if map_string is None: self.auto_map = None else: self.auto_map = [float(alloc) for alloc in map_string.split(",")] def calculate_rotary_embedding_base(self): self.rotary_embedding_base = self.rotary_embedding_base * self.alpha_value ** (self.head_dim / (self.head_dim-2)) # 4-bit linear layer implementation class Ex4bitLinear: def __init__(self, config, in_features, out_features, has_bias, tensors, key): self.config = config self.key = key self.in_features = in_features self.out_features = out_features self.qweight = tensors[key + ".qweight"] self.qzeros = tensors[key + ".qzeros"] self.scales = tensors[key + ".scales"] self.g_idx = tensors[key + ".g_idx"].cpu() if key + ".g_idx" in tensors else None self.bias = tensors[key + ".bias"] if has_bias else None if self.g_idx is not None and (self.g_idx == 0).all(): self.config.empty_g_idx = True self.g_idx = None self.device = self.qweight.device self.device_index = self.device.index self.q4 = cuda_ext.ext_make_q4(self.qweight, self.qzeros, self.scales, self.g_idx, self.device_index) self.height = tensors[key + ".qweight"].shape[0] * 8 self.width = tensors[key + ".qweight"].shape[1] # Infer groupsize from height of qzeros self.groupsize = None if self.qzeros.shape[0] > 1: self.groupsize = (self.qweight.shape[0] * 8) // self.qzeros.shape[0] if self.config.groupsize is None: self.config.groupsize = self.groupsize # Handle act-order matrix if self.g_idx is not None: if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") self.config.act_order = True def lora_applies(self, lora): if lora is None: return False return self.key + ".lora_A.weight" in lora.tensors def lora_apply(self, lora, x): lora_a = lora.tensors[self.key + ".lora_A.weight"] lora_b = lora.tensors[self.key + ".lora_B.weight"] out = torch.matmul(x, lora_a) out = torch.matmul(out, lora_b) # out = cuda_ext.ext_half_matmul(x, lora_a.contiguous(), cublas = True) # out = cuda_ext.ext_half_matmul(out, lora_b.contiguous(), cublas = True) return out def get_lora_tensors_or_meta(self, lora): if not self.lora_applies(lora): return cuda_ext.none_tensor, cuda_ext.none_tensor else: lora_a = lora.tensors[self.key + ".lora_A.weight"] lora_b = lora.tensors[self.key + ".lora_B.weight"] return lora_a, lora_b def forward(self, x, lora): if self.lora_applies(lora): lora_a = lora.tensors[self.key + ".lora_A.weight"] lora_b = lora.tensors[self.key + ".lora_B.weight"] out = cuda_ext.ext_q4_matmul(x, self.q4, self.width, lora_a, lora_b) else: out = cuda_ext.ext_q4_matmul(x, self.q4, self.width) # out = cuda_ext.ext_q4_matmul(x, self.q4, self.width) # if self.lora_applies(lora): # out += self.lora_apply(lora, x) if self.bias is not None: out.add_(self.bias) return out # Llama MLP class ExLlamaMLP: def __init__(self, config, tensors, key): self.config = config self.gate_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".gate_proj") self.up_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.intermediate_size, False, tensors, key + ".up_proj") self.down_proj = Ex4bitLinear(config, self.config.intermediate_size, self.config.hidden_size, False, tensors, key + ".down_proj") self.act_fn = nn.SiLU() def fused(self, x, buffer, post_attention_layernorm, lora): bsz, q_len, _ = x.size() gate_a, gate_b = self.gate_proj.get_lora_tensors_or_meta(lora) up_a, up_b = self.up_proj.get_lora_tensors_or_meta(lora) down_a, down_b = self.down_proj.get_lora_tensors_or_meta(lora) temp_size = 0 if not gate_a.is_meta: temp_size = max(temp_size, bsz * q_len * gate_a.shape[1]) if not up_a.is_meta: temp_size = max(temp_size, bsz * q_len * up_a.shape[1]) if not down_a.is_meta: temp_size = max(temp_size, bsz * q_len * down_a.shape[1]) if temp_size > 0: lora_temp = torch.empty((1, temp_size), dtype = torch.float16, device = x.device) else: lora_temp = cuda_ext.none_tensor cuda_ext.exllama_ext.q4_mlp(x.view(-1, x.shape[-1]), post_attention_layernorm.weight, self.config.rms_norm_eps, self.gate_proj.q4, self.up_proj.q4, self.down_proj.q4, gate_a, gate_b, up_a, up_b, down_a, down_b, lora_temp) def forward(self, x, buffer, lora): y = self.gate_proj.forward(x, lora) y = self.act_fn(y) y *= self.up_proj.forward(x, lora) y = self.down_proj.forward(y, lora) return y # RMS Layer norm. class ExLlamaRMSNorm: def __init__(self, config, tensors, key): self.config = config self.variance_epsilon = self.config.rms_norm_eps self.weight = tensors[key] def forward(self, hidden_states, buffer): hidden_states = cuda_ext.ext_rms_norm(hidden_states, self.weight, self.variance_epsilon) return hidden_states # Llama attention class ExLlamaAttention: def __init__(self, config, tensors, key, sin, cos, index): self.config = config self.sin = sin self.cos = cos self.index = index self.q_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_attention_heads * self.config.head_dim, False, tensors, key + ".q_proj") self.k_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_key_value_heads * self.config.head_dim, False, tensors, key + ".k_proj") self.v_proj = Ex4bitLinear(config, self.config.hidden_size, self.config.num_key_value_heads * self.config.head_dim, False, tensors, key + ".v_proj") self.o_proj = Ex4bitLinear(config, self.config.num_attention_heads * self.config.head_dim, self.config.hidden_size, False, tensors, key + ".o_proj") def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # TODO: This seems inefficient. It should be possible to broadcast in the attention matmul to avoid building # temporary K/V tensors like this batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def fused(self, hidden_states, cache, buffer, input_layernorm, lora): bsz, q_len, _ = hidden_states.size() past_len = cache.current_seq_len # Lora tensors q_a, q_b = self.q_proj.get_lora_tensors_or_meta(lora) k_a, k_b = self.k_proj.get_lora_tensors_or_meta(lora) v_a, v_b = self.v_proj.get_lora_tensors_or_meta(lora) o_a, o_b = self.o_proj.get_lora_tensors_or_meta(lora) temp_size = 0 if not q_a.is_meta: temp_size = max(temp_size, bsz * q_len * q_a.shape[1]) if not k_a.is_meta: temp_size = max(temp_size, bsz * q_len * k_a.shape[1]) if not v_a.is_meta: temp_size = max(temp_size, bsz * q_len * v_a.shape[1]) if not o_a.is_meta: temp_size = max(temp_size, bsz * q_len * o_a.shape[1]) if temp_size > 0: lora_temp = torch.empty((1, temp_size), dtype = torch.float16, device = hidden_states.device) else: lora_temp = cuda_ext.none_tensor # Project q, k, v, apply position embeddings to k and v, update cache query_states = torch.empty((bsz, q_len, self.config.num_attention_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device) key_states = torch.empty((bsz, q_len, self.config.num_key_value_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device) value_states = torch.empty((bsz, q_len, self.config.num_key_value_heads * self.config.head_dim), dtype = torch.float16, device = hidden_states.device) cuda_ext.exllama_ext.q4_attn(hidden_states, input_layernorm.weight, self.config.rms_norm_eps, query_states, key_states, value_states, self.q_proj.q4, self.k_proj.q4, self.v_proj.q4, self.sin, self.cos, q_len, past_len, self.config.num_attention_heads, self.config.num_key_value_heads, self.config.head_dim, cache.key_states[self.index], cache.value_states[self.index], cache.max_seq_len, q_a, q_b, k_a, k_b, v_a, v_b, lora_temp) query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim) # Get k, v with past key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz) value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz) # Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads query_states.transpose_(1, 2) key_states = self.repeat_kv(key_states, self.config.num_key_value_groups) value_states = self.repeat_kv(value_states, self.config.num_key_value_groups) # Attention # TODO: Figure out if we can use cublasHgemmStridedBatched() to do this matmul without reshaping. Torch uses # gemmStridedBatchedEx() internally, so it should be possible. # -- Flash Attention 2.0 if self.config.use_flash_attn_2 and (past_len == 0 or q_len == 1): key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) query_states = query_states.transpose(1, 2) attn_output = flash_attn_func(query_states, key_states, value_states, causal = (past_len == 0)) # -- HF Transformers regular attention, faster on shorter sequences, same VRAM usage else: key_states.transpose_(2, 3) attn_weights = torch.matmul(query_states, key_states) attn_weights /= math.sqrt(self.config.head_dim) attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size) # Output projection cuda_ext.exllama_ext.q4_attn_2(hidden_states, attn_output, self.o_proj.q4, o_a, o_b, lora_temp) # return hidden_states def forward(self, hidden_states, cache, buffer, lora): bsz, q_len, _ = hidden_states.size() past_len = cache.current_seq_len # Project q, k, v, apply position embeddings to k and v query_states = self.q_proj.forward(hidden_states, lora) key_states = self.k_proj.forward(hidden_states, lora) cuda_ext.exllama_ext.rope_(query_states, self.sin, self.cos, past_len, self.config.num_attention_heads, self.config.head_dim) cuda_ext.exllama_ext.rope_(key_states, self.sin, self.cos, past_len, self.config.num_key_value_heads, self.config.head_dim) query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.config.head_dim).transpose(1, 2) value_states = self.v_proj.forward(hidden_states, lora).view(bsz, q_len, self.config.num_key_value_heads, self.config.head_dim).transpose(1, 2) # Add keys and values to cache new_keys = cache.key_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz) new_values = cache.value_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz) new_keys.copy_(key_states) new_values.copy_(value_states) # Key/value tensors with past key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz) value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz) # Attention # -- Flash Attention 2.0 if self.config.use_flash_attn_2 and (past_len == 0 or q_len == 1): key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) query_states = query_states.transpose(1, 2) attn_output = flash_attn_func(query_states, key_states, value_states, causal = (past_len == 0)) # -- HF Transformers regular attention, faster on shorter sequences, same VRAM usage elif self.config.sdp_thd == 0 or q_len < self.config.sdp_thd: key_states = self.repeat_kv(key_states, self.config.num_key_value_groups) value_states = self.repeat_kv(value_states, self.config.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights /= math.sqrt(self.config.head_dim) if buffer.attn_mask is not None: attn_weights = attn_weights + buffer.attn_mask attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) # -- Scaled dot-product attention from PyTorch 2, should be comparable to xformers (?) else: # Torch's SDP attention has a built-in causal mask feature which we can use only when there is no past, i.e. # it can only apply a square attention mask. It saves quite a bit of VRAM but in practice Torch seems to use # the same amount of memory at peak anyway. # # TODO: Apparently flash attention is disabled when supplying an attention mask tensor. Figure out if this # is true and maybe drop SDP altogether. If causal masking in flash-attn is updated eventually there should # be no need for this anyway. key_states = self.repeat_kv(key_states, self.config.num_key_value_groups) value_states = self.repeat_kv(value_states, self.config.num_key_value_groups) if past_len > 0 or (bsz > 1 and buffer.attn_mask is not None): attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = buffer.attn_mask, is_causal = False) else: attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = None, is_causal = True) attn_output = attn_output.transpose(1, 2) # Output projection attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size) attn_output = self.o_proj.forward(attn_output, lora) return attn_output def _rows(x): xdp = 1 for y in x.shape[:-1]: xdp *= y return xdp class ExLlamaDecoderLayer: def __init__(self, config, tensors, key, index, sin, cos): self.config = config self.index = index self.self_attn = ExLlamaAttention(self.config, tensors, key + ".self_attn", sin, cos, self.index) self.mlp = ExLlamaMLP(self.config, tensors, key + ".mlp") self.input_layernorm = ExLlamaRMSNorm(self.config, tensors, key + ".input_layernorm.weight") self.post_attention_layernorm = ExLlamaRMSNorm(self.config, tensors, key + ".post_attention_layernorm.weight") def forward(self, hidden_states, cache, buffer, lora): # Self-attention if self.config.fused_attn and _rows(hidden_states) == 1: self.self_attn.fused(hidden_states, cache, buffer, self.input_layernorm, lora) else: residual = hidden_states hidden_states = self.input_layernorm.forward(hidden_states, buffer) hidden_states = self.self_attn.forward(hidden_states, cache, buffer, lora) hidden_states = residual + hidden_states # MLP if self.config.fused_mlp_thd > 0 and _rows(hidden_states) <= self.config.fused_mlp_thd: self.mlp.fused(hidden_states, buffer, self.post_attention_layernorm, lora) else: residual = hidden_states hidden_states = self.post_attention_layernorm.forward(hidden_states, buffer) hidden_states = self.mlp.forward(hidden_states, buffer, lora) hidden_states = residual + hidden_states return hidden_states # Persistent cache for inference. Allocate the whole thing up front. class ExLlamaCache: def __init__(self, model, batch_size = 1, max_seq_len = -1, copy_from = None): self.model = model self.config = self.model.config self.max_seq_len = max_seq_len if max_seq_len != -1 else self.config.max_seq_len self.batch_size = batch_size self.key_states = [] self.value_states = [] self.current_seq_len = 0 # Preallocate full-length cache for i in range(self.config.num_hidden_layers): if copy_from is None: p_key_states = torch.zeros(self.batch_size, self.config.num_key_value_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i]) p_value_states = torch.zeros(self.batch_size, self.config.num_key_value_heads, self.max_seq_len, self.config.head_dim, dtype = torch.float16, device = self.model.config.device_map.layers[i]) else: p_key_states = copy_from.key_states[i].clone() p_value_states = copy_from.value_states[i].clone() self.key_states.append(p_key_states) self.value_states.append(p_value_states) def zero(self): for i in range(self.config.num_hidden_layers): self.key_states[i].zero_() self.value_states[i].zero_() def clone(self): new = ExLlamaCache(self.model, batch_size = self.batch_size, max_seq_len = self.max_seq_len, copy_from = self) return new def roll_left(self): for i in range(self.config.num_hidden_layers): self.key_states[i] = torch.roll(self.key_states[i], shifts = -1, dims = 2) self.value_states[i] = torch.roll(self.value_states[i], shifts = -1, dims = 2) self.current_seq_len -= 1 def copy_states(self, target, from_column, from_columns, to_column, to_columns, from_row, from_rows, to_row, to_rows): assert from_rows == 1 assert from_columns == to_columns assert to_column + to_columns <= target.max_seq_len assert from_column + from_columns <= self.max_seq_len for i in range(self.config.num_hidden_layers): source_view_k = self.key_states[i].narrow(0, from_row, from_rows).narrow(2, from_column, from_columns) source_view_v = self.value_states[i].narrow(0, from_row, from_rows).narrow(2, from_column, from_columns) target_view_k = target.key_states[i].narrow(0, to_row, to_rows).narrow(2, to_column, to_columns) target_view_v = target.value_states[i].narrow(0, to_row, to_rows).narrow(2, to_column, to_columns) if to_rows > 1: source_view_k = source_view_k.expand_as(target_view_k) source_view_v = source_view_v.expand_as(target_view_v) target_view_k.copy_(source_view_k) target_view_v.copy_(source_view_v) # Device map for the model. class ExLlamaDeviceMap: def __init__(self, num_layers): self.num_layers = num_layers self.embed_tokens = "cpu" # Embedding table on CPU saves 400 MB on the 30B model with no measurable impact on performance self.lm_head = "cuda:0" self.norm = "cuda:0" self.layers = ["cuda:0"] * self.num_layers def get_layers_devs(self): return sorted(list(set(self.layers))) def get_all_devs(self): return sorted(list(set(self.layers + [self.lm_head, self.norm, self.embed_tokens]))) def map(self, key): if key.startswith("lm_head."): return self.lm_head if key.startswith("model.embed_tokens."): return self.embed_tokens if key.startswith("model.norm."): return self.norm if key.startswith("model.layers."): num = int(key.split(".")[2]) return self.layers[num] raise ValueError("Unknown key: " + key) class ExLlamaBuffer: config: ExLlamaConfig def __init__(self, config): self.config = config # Attention mask attn_mask: torch.Tensor = None # Move to device def to(self, device): new = ExLlamaBuffer(self.config) new.attn_mask = None if self.attn_mask is None else _move_tensor(self.attn_mask, device, "attn_mask", self.config) return new def _device_to_int(device): return int(device[device.find(":") + 1:]) def _skip_key(key): if key.endswith("_proj.bias"): return True if key.endswith(".rotary_emb.inv_freq"): return True return False def _move_tensor(tensor, new_device, name, config): device = str(tensor.device) if device == new_device: return tensor if config.gpu_peer_fix: if str(device).startswith("cuda:") and str(new_device).startswith("cuda:"): tensor = tensor.to("cpu") return tensor.to(new_device) def _layer_dtype_size(key): if key.endswith(".weight"): return 2 if key.endswith(".qweight"): return 4 if key.endswith(".qzeros"): return 4 if key.endswith(".scales"): return 2 if key.endswith(".g_idx"): return 0 raise ValueError("Unrecognized layer: " + key) class ExLlama: def __init__(self, config): self.config = config # Copy tuning parameters to C++ extension self.config.set_tuning_params() # Read tensor list from file(s) if isinstance(self.config.model_path, str): model_path = [self.config.model_path] else: model_path = self.config.model_path # Read tensor list from file(s), and measure layer sizes load_keys = {} decoder_size = 0 norm_size = 0 head_size = 0 for path in model_path: with safe_open(path, framework = "pt", device = "cpu") as f: for key in f.keys(): if _skip_key(key): continue load_keys[key] = path if key.startswith("model.layers.0."): tensor_slice = f.get_slice(key) shape = tensor_slice.get_shape() decoder_size += math.prod(shape) * _layer_dtype_size(key) del tensor_slice if key.startswith("model.norm."): tensor_slice = f.get_slice(key) shape = tensor_slice.get_shape() norm_size += math.prod(shape) * _layer_dtype_size(key) del tensor_slice if key.startswith("lm_head."): tensor_slice = f.get_slice(key) shape = tensor_slice.get_shape() head_size += math.prod(shape) * _layer_dtype_size(key) del tensor_slice # Begin auto mapping if enabled if self.config.auto_map is not None: self.config.device_map.embed_tokens = "cpu" self.config.device_map.layers = ["cuda:0"] + ["?"] * (self.config.num_hidden_layers - 1) # Assign layers automatically device_usage = 0 device_index = 0 layer_index_device = 0 max_usage = self.config.auto_map[device_index] * (1024 ** 3) for layer in range(self.config.num_hidden_layers + 2): this_layer_size = decoder_size if layer == self.config.num_hidden_layers + 0: this_layer_size = norm_size elif layer == self.config.num_hidden_layers + 1: this_layer_size = head_size while device_usage + this_layer_size > max_usage: device_index += 1 device_usage = 0 layer_index_device = 0 max_usage = self.config.auto_map[device_index] * (1024 ** 3) if device_index >= len(self.config.auto_map): raise ValueError("Model too large for device allocation scheme.") target = f"cuda:{device_index}" if layer == self.config.num_hidden_layers + 0: self.config.device_map.norm = target elif layer == self.config.num_hidden_layers + 1: self.config.device_map.lm_head = target else: self.config.device_map.layers[layer] = f"cuda:{device_index}" device_usage += this_layer_size layer_index_device += 1 # Load up to 1 GB of tensors at a time, closing and reopening the file in between each chunk max_dq_buffer_size = 0 tensors = {} st_mem = 0 MAX_ST_MEM = 1024**3 f = None prev_path = "" for key, path in load_keys.items(): device = self.config.device_map.map(key) if f is None or st_mem > MAX_ST_MEM or path != prev_path: if f is not None: del f f = safe_open(path, framework = "pt", device = "cpu") prev_path = path st_mem = 0 tensor = f.get_tensor(key) size = tensor.numel() * tensor.element_size() st_mem += size if key.endswith(".scales"): tensor = tensor.half() if key == "lm_head.weight": tensor = tensor.float() if device == "cpu" else tensor.half() if key == "model.norm.weight": tensor = tensor.half() if key.endswith(".embed_tokens.weight"): tensor = tensor.half() if key.endswith(".input_layernorm.weight"): tensor = tensor.half() if key.endswith(".post_attention_layernorm.weight"): tensor = tensor.half() if device == "cpu": keep_tensor = tensor.clone() else: keep_tensor = tensor.to(device) del tensor if key.endswith(".qweight"): max_dq_buffer_size = max(max_dq_buffer_size, keep_tensor.numel() * 8) tensors[key] = keep_tensor del f # Head self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = False, device = "meta") self.lm_head.weight = nn.Parameter(tensors["lm_head.weight"]) # self.lm_head_data = tensors["lm_head.weight"].transpose(0, 1).contiguous() # Token embeddings self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size, self.config.pad_token_id, device = "meta") self.embed_tokens.weight = nn.Parameter(tensors["model.embed_tokens.weight"]) with torch.no_grad(): self.embed_tokens.weight[self.config.pad_token_id] = 0 # Norm self.norm = ExLlamaRMSNorm(self.config, tensors, "model.norm.weight") # Prepare position embeddings for max seq length devs = self.config.device_map.get_layers_devs() self.sincos = {} for device in devs: inv_freq = 1.0 / (self.config.rotary_embedding_base ** (torch.arange(0, self.config.head_dim, 2, device = device).float() / self.config.head_dim)) t = torch.arange(self.config.max_seq_len, device = device, dtype = torch.float32) if self.config.compress_pos_emb != 1.0: t /= self.config.compress_pos_emb freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim = -1) sin = emb.sin()[None, None, :, :].half() cos = emb.cos()[None, None, :, :].half() self.sincos[device] = (sin, cos) # Decoder layers modules = [] device_layer_index = [0] * len(devs) for i in range(self.config.num_hidden_layers): device = self.config.device_map.layers[i] sin, cos = self.sincos[device] layer = ExLlamaDecoderLayer(self.config, tensors, f"model.layers.{i}", i, sin, cos) modules.append(layer) self.layers = modules # Prepare CUDA buffers self.buffers = [] for dev in self.config.device_map.get_layers_devs(): device_buffers = {} self.buffers.append(device_buffers) temp_state = torch.zeros((config.max_input_len, config.intermediate_size), dtype = torch.float16, device = dev) temp_mlp = torch.zeros((config.fused_mlp_thd * 2, config.intermediate_size), dtype = torch.float16, device = dev) temp_zeros_float = torch.zeros((1, 65536), dtype = torch.float32, device = dev) temp_dq = torch.zeros((1, max_dq_buffer_size), dtype = torch.float16, device = dev) device_buffers["temp_state"] = temp_state device_buffers["temp_mlp"] = temp_mlp device_buffers["temp_zeros_float"] = temp_zeros_float device_buffers["temp_dq"] = temp_dq cuda_ext.exllama_ext.prepare_buffers(torch.device(dev), temp_state, temp_mlp, temp_zeros_float, temp_dq) # Clear the cache torch.cuda.empty_cache() def forward(self, input_ids, cache, last_id_only = True, preprocess_only = False, lora = None, output_device = None, input_mask = None): q_len = input_ids.shape[-1] remaining_q_len = q_len bsz = input_ids.shape[0] assert input_mask is None or (input_mask.shape[-1] >= input_ids.shape[-1] and input_mask.shape[-2] == input_ids.shape[-2]) # The buffers can only fit max_input_len tokens, so with larger batch sizes we reduce our work size correspondingly. effective_max_input_len = self.config.max_input_len // bsz # Split sequence result = None chunk_begin = 0 while chunk_begin < q_len: # Limit chunk_size to max_input_len chunk_size = min(remaining_q_len, effective_max_input_len) # Limit chunk_size to keep size of attention operation <= max_attention_size, unless using flash-attn if not self.config.use_flash_attn_2 or chunk_begin > 0: past_len = cache.current_seq_len attn_size = (past_len + remaining_q_len) * remaining_q_len max_a = self.config.max_attention_size if attn_size > max_a: cs = (math.sqrt(past_len ** 2 + 4 * max_a) - past_len) / 2 chunk_size = min(chunk_size, math.floor(cs)) # Process chunk chunk_end = min(chunk_begin + chunk_size, q_len) _last_id_only = last_id_only _preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only) r = self._forward(input_ids[:, chunk_begin : chunk_end], cache, _last_id_only, _preprocess_only, lora, output_device, input_mask) if not _preprocess_only: result = r if result is None else torch.cat((result, r), dim = 1) chunk_begin = chunk_end remaining_q_len -= chunk_size return result def _forward(self, input_ids, cache, last_id_only = True, preprocess_only = False, lora = None, output_device = None, input_mask = None): # if torch.is_grad_enabled(): # raise ValueError("Forward pass called with gradients enabled. Back propagation is not supported yet.") with torch.no_grad(): batch_size, seq_len = input_ids.shape past_len = cache.current_seq_len if output_device is None: output_device = input_ids.device buffer = ExLlamaBuffer(self.config) # Build attention mask on first device, copy to others if necessary devs = self.config.device_map.get_layers_devs() # if not self.config.use_flash_attn_2: if seq_len > 1 or input_mask is not None: attn_mask = torch.zeros(batch_size, 1, seq_len, past_len + seq_len, dtype = torch.float16, device = devs[0]) attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), -65504.)) attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu if input_mask is not None: input_mask = input_mask[:, :past_len + seq_len] input_mask = _move_tensor(input_mask, devs[0], "input_mask", self.config) input_mask = torch.where(input_mask, 0, -65504.).half() input_mask = input_mask.unsqueeze(1).unsqueeze(2) attn_mask = torch.minimum(attn_mask, input_mask) else: attn_mask = None # attn_mask = torch.zeros(batch_size, 1, seq_len, seq_len + past_len, dtype = torch.float16, device = devs[0]) buffer.attn_mask = attn_mask # else: # # buffer.attn_mask = None # Embeddings # TODO: Allow passing input embeddings instead of IDs input_ids = _move_tensor(input_ids, self.config.device_map.embed_tokens, "input_ids", self.config) hidden_states = self.embed_tokens(input_ids) # Split buffers to devices buffers = {devs[0]: buffer} for device in devs[1:]: buffers[device] = buffer.to(device) # Decoder layers for i, decoder_layer in enumerate(self.layers): device = self.config.device_map.layers[i] hidden_states = _move_tensor(hidden_states, device, "hidden_states", self.config) hidden_states = decoder_layer.forward(hidden_states, cache, buffers[device], lora) cache.current_seq_len += seq_len # Early exit when we don't need logits if preprocess_only: return None # Norm hidden_states = _move_tensor(hidden_states, self.config.device_map.norm, "hidden_states", self.config) hidden_states = self.norm.forward(hidden_states, buffer) # Head if last_id_only: hidden_states = hidden_states[:, -1:, :].contiguous() if self.config.device_map.lm_head == "cpu": hidden_states = hidden_states.float() hidden_states = _move_tensor(hidden_states, self.config.device_map.lm_head, "hidden_states", self.config) logits = self.lm_head(hidden_states) # logits = cuda_ext.matmul_half(hidden_states, self.lm_head_data, cublas = False) logits = logits.float() logits = _move_tensor(logits, output_device, "logits", self.config) return logits # Free unmanaged resources allocated by the C++ extension. Call this before dereferencing the ExLlama object, # e.g. if you intend to create a new instance to load another model, but don't call it in a destructor that wraps # the object, since it relies on CUDA function calls and the CUDA context is one of the first things to go when # a PyTorch application terminates, before other managed objects are destroyed. def free_unmanaged(self): cuda_ext.exllama_ext.cleanup()