phoebeklett commited on
Commit
8123135
1 Parent(s): 43430cc

Delete modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +0 -1164
modeling_llama.py DELETED
@@ -1,1164 +0,0 @@
1
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
- #
3
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
- # and OPT implementations in this library. It has been modified from its
5
- # original forms to accommodate minor architectural differences compared
6
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
- #
8
- # Licensed under the Apache License, Version 2.0 (the "License");
9
- # you may not use this file except in compliance with the License.
10
- # You may obtain a copy of the License at
11
- #
12
- # http://www.apache.org/licenses/LICENSE-2.0
13
- #
14
- # Unless required by applicable law or agreed to in writing, software
15
- # distributed under the License is distributed on an "AS IS" BASIS,
16
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
- # See the License for the specific language governing permissions and
18
- # limitations under the License.
19
- """ PyTorch LLaMA model."""
20
- import math
21
- from typing import List, Optional, Tuple, Union
22
- import faiss
23
- from einops import rearrange
24
-
25
- import torch
26
- import torch.utils.checkpoint
27
- from torch import nn
28
- import torch.nn.functional as F
29
- from torch.linalg import vector_norm
30
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
- from transformers.activations import ACT2FN
32
- from transformers.modeling_outputs import (
33
- BaseModelOutputWithPast,
34
- CausalLMOutputWithPast,
35
- SequenceClassifierOutputWithPast,
36
- )
37
- from transformers.modeling_utils import PreTrainedModel
38
- from transformers.utils import (
39
- add_start_docstrings,
40
- add_start_docstrings_to_model_forward,
41
- logging,
42
- replace_return_docstrings,
43
- )
44
- from .configuration_llama import ExtendedLlamaConfig
45
-
46
-
47
- logger = logging.get_logger(__name__)
48
-
49
- _CONFIG_FOR_DOC = "ExtendedLlamaConfig"
50
-
51
-
52
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
53
- def _make_causal_mask(
54
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
55
- ):
56
- """
57
- Make causal mask used for bi-directional self-attention.
58
- """
59
- bsz, tgt_len = input_ids_shape
60
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
61
- mask_cond = torch.arange(mask.size(-1), device=device)
62
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
63
- mask = mask.to(dtype)
64
-
65
- if past_key_values_length > 0:
66
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
67
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
68
-
69
-
70
- # Copied from transformers.models.bart.modeling_bart._expand_mask
71
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
72
- """
73
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
74
- """
75
- bsz, src_len = mask.size()
76
- tgt_len = tgt_len if tgt_len is not None else src_len
77
-
78
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
79
-
80
- inverted_mask = 1.0 - expanded_mask
81
-
82
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
83
-
84
-
85
- class LlamaRMSNorm(nn.Module):
86
- def __init__(self, hidden_size, eps=1e-6):
87
- """
88
- LlamaRMSNorm is equivalent to T5LayerNorm
89
- """
90
- super().__init__()
91
- self.weight = nn.Parameter(torch.ones(hidden_size))
92
- self.variance_epsilon = eps
93
-
94
- def forward(self, hidden_states):
95
- input_dtype = hidden_states.dtype
96
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
97
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
98
-
99
- return (self.weight * hidden_states).to(input_dtype)
100
-
101
-
102
- class LlamaRotaryEmbedding(torch.nn.Module):
103
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
104
- super().__init__()
105
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
106
- self.register_buffer("inv_freq", inv_freq)
107
-
108
- # Build here to make `torch.jit.trace` work.
109
- self.max_seq_len_cached = max_position_embeddings
110
- t = torch.arange(
111
- self.max_seq_len_cached,
112
- device=self.inv_freq.device,
113
- dtype=self.inv_freq.dtype,
114
- )
115
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
116
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
- emb = torch.cat((freqs, freqs), dim=-1)
118
- dtype = torch.get_default_dtype()
119
- self.register_buffer(
120
- "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
121
- )
122
- self.register_buffer(
123
- "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
124
- )
125
-
126
- def forward(self, x, seq_len=None):
127
- # x: [bs, num_attention_heads, seq_len, head_size]
128
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
129
- if seq_len > self.max_seq_len_cached:
130
- self.max_seq_len_cached = seq_len
131
- t = torch.arange(
132
- self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
133
- )
134
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
135
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
136
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
137
- self.register_buffer(
138
- "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False
139
- )
140
- self.register_buffer(
141
- "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False
142
- )
143
- return (
144
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
145
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
146
- )
147
-
148
- class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
149
- def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
150
- super().__init__()
151
- self.ntk = ntk
152
- self.base = base
153
- self.dim = dim
154
- self.max_position_embeddings = max_position_embeddings
155
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
156
- self.register_buffer("inv_freq", inv_freq)
157
-
158
- # Build here to make `torch.jit.trace` work.
159
- self.max_seq_len_cached = max_position_embeddings
160
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
161
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
162
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
163
- emb = torch.cat((freqs, freqs), dim=-1)
164
- dtype = torch.get_default_dtype()
165
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
166
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
167
-
168
- def forward(self, x, seq_len=None):
169
- # x: [bs, num_attention_heads, seq_len, head_size]
170
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
171
- if seq_len > self.max_seq_len_cached:
172
- self.max_seq_len_cached = seq_len
173
- if self.ntk:
174
- base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
175
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
176
- self.register_buffer("inv_freq", inv_freq)
177
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
178
- if not self.ntk:
179
- t *= self.max_position_embeddings / seq_len
180
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
181
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
182
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
183
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
184
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
185
- return (
186
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
187
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
188
- )
189
-
190
- class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
191
- def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
192
- super().__init__()
193
- self.scale = scale
194
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
195
- self.register_buffer("inv_freq", inv_freq)
196
-
197
- # Build here to make `torch.jit.trace` work.
198
- self.max_seq_len_cached = max_position_embeddings
199
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
200
- t /= self.scale
201
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
202
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
203
- emb = torch.cat((freqs, freqs), dim=-1)
204
- dtype = torch.get_default_dtype()
205
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
206
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
207
-
208
- def forward(self, x, seq_len=None):
209
- # x: [bs, num_attention_heads, seq_len, head_size]
210
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
211
- if seq_len > self.max_seq_len_cached:
212
- self.max_seq_len_cached = seq_len
213
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
214
- t /= self.scale
215
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
216
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
217
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
218
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
219
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
220
- return (
221
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
222
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
223
- )
224
-
225
- class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):
226
- def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
227
- super().__init__()
228
- base = base * alpha ** (dim / (dim-2))
229
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
230
- self.register_buffer("inv_freq", inv_freq)
231
-
232
- # Build here to make `torch.jit.trace` work.
233
- self.max_seq_len_cached = max_position_embeddings
234
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
235
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
236
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
237
- emb = torch.cat((freqs, freqs), dim=-1)
238
- dtype = torch.get_default_dtype()
239
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
240
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
241
-
242
- def forward(self, x, seq_len=None):
243
- # x: [bs, num_attention_heads, seq_len, head_size]
244
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
245
- if seq_len > self.max_seq_len_cached:
246
- self.max_seq_len_cached = seq_len
247
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
248
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
249
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
250
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
251
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
252
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
253
- return (
254
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
255
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
256
- )
257
-
258
- def rotate_half(x):
259
- """Rotates half the hidden dims of the input."""
260
- x1 = x[..., : x.shape[-1] // 2]
261
- x2 = x[..., x.shape[-1] // 2 :]
262
- return torch.cat((-x2, x1), dim=-1)
263
-
264
-
265
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
266
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
267
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
268
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
269
-
270
- s_q = q.size(-2) #Since we apply rotary pos emb after reading from cache, queries may be shorter
271
- _q_position_ids = position_ids[:, -s_q:]
272
- _q_cos = cos[_q_position_ids].unsqueeze(1)
273
- _q_sin = sin[_q_position_ids].unsqueeze(1)
274
- q_embed = (q * _q_cos) + (rotate_half(q) * _q_sin)
275
-
276
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
277
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
278
- k_embed = (k * cos) + (rotate_half(k) * sin)
279
- return q_embed, k_embed
280
-
281
-
282
- class LlamaMLP(nn.Module):
283
- def __init__(
284
- self,
285
- hidden_size: int,
286
- intermediate_size: int,
287
- hidden_act: str,
288
- ):
289
- super().__init__()
290
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
291
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
292
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
293
- self.act_fn = ACT2FN[hidden_act]
294
-
295
- def forward(self, x):
296
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
297
-
298
-
299
- class ExtendedLlamaAttention(nn.Module):
300
- """Multi-headed attention from 'Attention Is All You Need' paper"""
301
-
302
- def __init__(self, config: ExtendedLlamaConfig):
303
- super().__init__()
304
- self.config = config
305
- self.hidden_size = config.hidden_size
306
- self.num_heads = config.num_attention_heads
307
- self.head_dim = self.hidden_size // self.num_heads
308
- self.max_position_embeddings = config.max_position_embeddings
309
- self.num_hidden_layers = config.num_hidden_layers
310
-
311
- if (self.head_dim * self.num_heads) != self.hidden_size:
312
- raise ValueError(
313
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
314
- f" and `num_heads`: {self.num_heads})."
315
- )
316
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
317
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
318
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
319
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
320
- self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
321
-
322
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
323
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
324
-
325
- def forward(
326
- self,
327
- hidden_states: torch.Tensor,
328
- attention_mask: Optional[torch.Tensor] = None,
329
- position_ids: Optional[torch.LongTensor] = None,
330
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
331
- output_attentions: bool = False,
332
- use_cache: bool = False,
333
-
334
- long_range_past_key_value=None,
335
- faiss_indexes=None,
336
- mask_by_sim=False,
337
- sim_threshold=0.0,
338
- topk=None,
339
- current_layer=None,
340
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
341
- bsz, q_len, _ = hidden_states.size()
342
-
343
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
344
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
345
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
346
- if past_key_value is not None:
347
- # reuse k, v, self_attention
348
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
349
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
350
-
351
- past_key_value = (key_states, value_states) if use_cache else None
352
-
353
- kv_seq_len = key_states.shape[-2]
354
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
355
-
356
- query_states_no_rotary = query_states.clone() # use queries wo positional info for memory retrieval
357
-
358
- query_states, key_states = apply_rotary_pos_emb(
359
- query_states, key_states, cos, sin, position_ids
360
- )
361
- # [bsz, nh, t, hd]
362
- bsz, nh, s_q, hd = query_states.shape
363
- s_k = key_states.size(-2)
364
-
365
- attn_weights = torch.matmul(
366
- query_states, key_states.transpose(2, 3)
367
- ) / math.sqrt(self.head_dim)
368
-
369
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
370
- raise ValueError(
371
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
372
- f" {attn_weights.size()}"
373
- )
374
-
375
- if long_range_past_key_value is not None or faiss_indexes is not None:
376
- if long_range_past_key_value is not None: #manual memories
377
-
378
- k_cache, v_cache = long_range_past_key_value
379
- s_cache = k_cache.size(-2)
380
-
381
- k_cache = k_cache.to(key_states.device)
382
- v_cache = v_cache.to(key_states.device)
383
-
384
- q_n = query_states_no_rotary/vector_norm(query_states_no_rotary, ord=2, dim=-1, keepdim=True)
385
- k_n = k_cache/vector_norm(k_cache, ord=2, dim=-1, keepdim=True)
386
-
387
- sim = q_n.matmul(k_n.transpose(2,3))
388
- if s_cache<topk:
389
- topk = s_cache #number of tokens in cache < topk
390
- val, idx = torch.topk(sim, k=topk, dim=-1)
391
-
392
- reshaped_idx = idx.reshape(bsz, nh, s_q * topk)
393
-
394
- cos_m, sin_m = self.rotary_emb(value_states, seq_len=self.max_position_embeddings) # use max pos emb for memories
395
- cos_m = cos_m[:,:,-1,...].repeat(1,1,s_q * topk,1)
396
- sin_m = sin_m[:,:,-1,...].repeat(1,1,s_q * topk,1)
397
-
398
- selected_k = k_cache.gather(dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, hd))
399
- _, selected_k = apply_rotary_pos_emb(
400
- torch.ones(selected_k.shape, device=key_states.device), selected_k, cos_m, sin_m, position_ids=torch.arange(s_q * topk, device=key_states.device).unsqueeze(0)
401
- ) # Apply rotary pos emb to selected memory keys, use dummy input for queries
402
-
403
- selected_v = v_cache.gather(dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, hd))
404
-
405
- sim_mask = rearrange(~ (val > sim_threshold).bool(), 'b h s i -> b h (s i)').unsqueeze(-2).expand(-1, -1, s_q, -1)
406
-
407
- elif faiss_indexes is not None: #faiss indexes
408
-
409
- kn_index, kv_index = faiss_indexes
410
- q_n = query_states_no_rotary/vector_norm(query_states_no_rotary, ord=2, dim=-1, keepdim=True)
411
-
412
- one_hot_encodings = F.one_hot(torch.arange(0, nh*self.num_hidden_layers, device=query_states.device))*10
413
- q_n = torch.concat([rearrange(q_n, 'b h s d -> b (h s) d', h=nh), one_hot_encodings[nh*current_layer:nh*(current_layer+1)].unsqueeze(0).repeat_interleave(repeats=query_states.size(-2), dim=-2)], dim=-1).squeeze()
414
-
415
- D, I = kn_index.search(q_n.to('cpu').numpy(), k=topk)
416
-
417
- selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:hd], '(h s) d -> 1 h s d', h=nh).to(query_states.device)
418
- cos_m, sin_m = self.rotary_emb(value_states, seq_len=self.max_position_embeddings) # use max pos emb for memories
419
- cos_m = cos_m[:,:,-1,...].repeat(1,1,s_q * topk,1)
420
- sin_m = sin_m[:,:,-1,...].repeat(1,1,s_q * topk,1)
421
-
422
- _, selected_k = apply_rotary_pos_emb(
423
- torch.ones(selected_k.shape, device=key_states.device), selected_k, cos_m, sin_m, position_ids=torch.arange(s_q * topk, device=key_states.device).unsqueeze(0)
424
- ) # Apply rotary pos emb to selected memory keys, use dummy input for queries
425
-
426
- selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,hd:], '(h s) d -> 1 h s d', h=nh).to(query_states.device)
427
-
428
- attn_weight_cache = torch.matmul(query_states, selected_k.transpose(2, 3)) / math.sqrt(self.head_dim)
429
- if mask_by_sim:
430
- attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, torch.finfo(selected_k.dtype).min)
431
-
432
- attn_weights = torch.cat([attn_weight_cache, attn_weights], dim=-1)
433
- value_states = torch.cat([selected_v, value_states], dim=-2)
434
-
435
- min_val = torch.finfo(attn_weights.dtype).min
436
- def _create_active_externalism_mask(k, s_q, device, min_val=min_val):
437
- mask = torch.ones(s_q, s_q * k, device=device, dtype=torch.float32)
438
- for i in range(s_q):
439
- mask[i, i * k : (i + 1) * k] = 0
440
-
441
- filled = mask.masked_fill(mask.bool(), min_val)
442
- return filled
443
-
444
- if attention_mask is not None:
445
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
446
- raise ValueError(
447
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
448
- )
449
- if long_range_past_key_value is not None:
450
- memory_mask = _create_active_externalism_mask(k=topk,s_q=s_q, device=attn_weights.device)
451
- attention_mask = torch.cat([memory_mask, attention_mask[:,:,:,-s_k:].squeeze(dim=[0,1])], dim=1).unsqueeze(dim=0).unsqueeze(dim=1)
452
-
453
- attn_weights = attn_weights + attention_mask
454
- attn_weights = torch.max(
455
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
456
- )
457
-
458
- # upcast attention to fp32
459
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
460
- attn_output = torch.matmul(attn_weights, value_states)
461
-
462
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
463
- raise ValueError(
464
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
465
- f" {attn_output.size()}"
466
- )
467
-
468
- attn_output = attn_output.transpose(1, 2)
469
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
470
-
471
- attn_output = self.o_proj(attn_output)
472
-
473
- if not output_attentions:
474
- attn_weights = None
475
-
476
- if long_range_past_key_value is None and faiss_indexes is None:
477
- reshaped_idx=None
478
-
479
- return attn_output, attn_weights, past_key_value, reshaped_idx
480
-
481
- class ExtendedLlamaDecoderLayer(nn.Module):
482
- def __init__(self, config: ExtendedLlamaConfig):
483
- super().__init__()
484
- self.hidden_size = config.hidden_size
485
- self.self_attn = ExtendedLlamaAttention(config=config)
486
- self.mlp = LlamaMLP(
487
- hidden_size=self.hidden_size,
488
- intermediate_size=config.intermediate_size,
489
- hidden_act=config.hidden_act,
490
- )
491
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
492
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
493
-
494
- def forward(
495
- self,
496
- hidden_states: torch.Tensor,
497
- attention_mask: Optional[torch.Tensor] = None,
498
- position_ids: Optional[torch.LongTensor] = None,
499
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
500
- output_attentions: Optional[bool] = False,
501
- use_cache: Optional[bool] = False,
502
-
503
- long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None,
504
- faiss_indexes:Tuple=None,
505
- mask_by_sim:bool=False,
506
- sim_threshold:float=None,
507
- topk:int=None,
508
- current_layer=None
509
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
510
- """
511
- Args:
512
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
513
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
514
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
515
- output_attentions (`bool`, *optional*):
516
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
517
- returned tensors for more detail.
518
- use_cache (`bool`, *optional*):
519
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
520
- (see `past_key_values`).
521
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
522
- """
523
-
524
- residual = hidden_states
525
-
526
- hidden_states = self.input_layernorm(hidden_states)
527
-
528
- # Self Attention
529
- hidden_states, self_attn_weights, present_key_value, selected_idx = self.self_attn(
530
- hidden_states=hidden_states,
531
- attention_mask=attention_mask,
532
- position_ids=position_ids,
533
- past_key_value=past_key_value,
534
- output_attentions=output_attentions,
535
- use_cache=use_cache,
536
-
537
- long_range_past_key_value=long_range_past_key_value,
538
- faiss_indexes=faiss_indexes,
539
- mask_by_sim=mask_by_sim,
540
- sim_threshold=sim_threshold,
541
- topk=topk,
542
- current_layer=current_layer,
543
- )
544
- hidden_states = residual + hidden_states
545
-
546
- # Fully Connected
547
- residual = hidden_states
548
- hidden_states = self.post_attention_layernorm(hidden_states)
549
- hidden_states = self.mlp(hidden_states)
550
- hidden_states = residual + hidden_states
551
-
552
- outputs = (hidden_states,)
553
-
554
- if output_attentions:
555
- outputs += (self_attn_weights,)
556
-
557
- if use_cache:
558
- outputs += (present_key_value,)
559
-
560
- if output_attentions:
561
- outputs += (selected_idx,)
562
-
563
- return outputs
564
-
565
-
566
- LLAMA_START_DOCSTRING = r"""
567
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
568
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
569
- etc.)
570
-
571
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
572
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
573
- and behavior.
574
-
575
- Parameters:
576
- config ([`LlamaConfig`]):
577
- Model configuration class with all the parameters of the model. Initializing with a config file does not
578
- load the weights associated with the model, only the configuration. Check out the
579
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
580
- """
581
-
582
-
583
- @add_start_docstrings(
584
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
585
- LLAMA_START_DOCSTRING,
586
- )
587
- class LlamaPreTrainedModel(PreTrainedModel):
588
- config_class = ExtendedLlamaConfig
589
- base_model_prefix = "model"
590
- supports_gradient_checkpointing = True
591
- _no_split_modules = ["LlamaDecoderLayer"]
592
- _skip_keys_device_placement = "past_key_values"
593
-
594
- def _init_weights(self, module):
595
- std = self.config.initializer_range
596
- if isinstance(module, nn.Linear):
597
- module.weight.data.normal_(mean=0.0, std=std)
598
- if module.bias is not None:
599
- module.bias.data.zero_()
600
- elif isinstance(module, nn.Embedding):
601
- module.weight.data.normal_(mean=0.0, std=std)
602
- if module.padding_idx is not None:
603
- module.weight.data[module.padding_idx].zero_()
604
-
605
- def _set_gradient_checkpointing(self, module, value=False):
606
- if isinstance(module, ExtendedLlamaModel):
607
- module.gradient_checkpointing = value
608
-
609
-
610
- LLAMA_INPUTS_DOCSTRING = r"""
611
- Args:
612
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
613
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
614
- it.
615
-
616
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
617
- [`PreTrainedTokenizer.__call__`] for details.
618
-
619
- [What are input IDs?](../glossary#input-ids)
620
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
621
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
622
-
623
- - 1 for tokens that are **not masked**,
624
- - 0 for tokens that are **masked**.
625
-
626
- [What are attention masks?](../glossary#attention-mask)
627
-
628
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
629
- [`PreTrainedTokenizer.__call__`] for details.
630
-
631
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
632
- `past_key_values`).
633
-
634
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
635
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
636
- information on the default strategy.
637
-
638
- - 1 indicates the head is **not masked**,
639
- - 0 indicates the head is **masked**.
640
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
641
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
642
- config.n_positions - 1]`.
643
-
644
- [What are position IDs?](../glossary#position-ids)
645
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
646
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
647
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
648
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
649
-
650
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
651
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
652
-
653
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
654
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
655
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
656
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
657
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
658
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
659
- model's internal embedding lookup matrix.
660
- use_cache (`bool`, *optional*):
661
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
662
- `past_key_values`).
663
- output_attentions (`bool`, *optional*):
664
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
665
- tensors for more detail.
666
- output_hidden_states (`bool`, *optional*):
667
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
668
- more detail.
669
- return_dict (`bool`, *optional*):
670
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
671
- """
672
-
673
-
674
- @add_start_docstrings(
675
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
676
- LLAMA_START_DOCSTRING,
677
- )
678
- class ExtendedLlamaModel(LlamaPreTrainedModel):
679
- """
680
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
681
-
682
- Args:
683
- config: LlamaConfig
684
- """
685
-
686
- def __init__(self, config: ExtendedLlamaConfig):
687
- super().__init__(config)
688
- self.padding_idx = config.pad_token_id
689
- self.vocab_size = config.vocab_size
690
-
691
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
692
- self.layers = nn.ModuleList([ExtendedLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
693
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
694
-
695
- self.gradient_checkpointing = False
696
- # Initialize weights and apply final processing
697
-
698
- self.mask_by_sim = config.mask_by_sim
699
- self.sim_threshold = config.sim_threshold
700
- self.topk = config.topk
701
- self.use_active_externalism = config.use_active_externalism
702
- self.use_active_externalism_by_layer = config.use_active_externalism_by_layer
703
-
704
- self.post_init()
705
-
706
- def get_input_embeddings(self):
707
- return self.embed_tokens
708
-
709
- def set_input_embeddings(self, value):
710
- self.embed_tokens = value
711
-
712
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
713
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
714
- # create causal mask
715
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
716
- combined_attention_mask = None
717
- if input_shape[-1] > 1:
718
- combined_attention_mask = _make_causal_mask(
719
- input_shape,
720
- inputs_embeds.dtype,
721
- device=inputs_embeds.device,
722
- past_key_values_length=past_key_values_length,
723
- )
724
-
725
- if attention_mask is not None:
726
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
727
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
728
- inputs_embeds.device
729
- )
730
- combined_attention_mask = (
731
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
732
- )
733
-
734
- return combined_attention_mask
735
-
736
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
737
- def forward(
738
- self,
739
- input_ids: torch.LongTensor = None,
740
- attention_mask: Optional[torch.Tensor] = None,
741
- position_ids: Optional[torch.LongTensor] = None,
742
- past_key_values: Optional[List[torch.FloatTensor]] = None,
743
- inputs_embeds: Optional[torch.FloatTensor] = None,
744
- use_cache: Optional[bool] = None,
745
- output_attentions: Optional[bool] = None,
746
- output_hidden_states: Optional[bool] = None,
747
- return_dict: Optional[bool] = None,
748
-
749
- use_active_externalism:Optional[bool]=None,
750
- long_range_past_key_values:Optional[List[Tuple[torch.FloatTensor]]] = None,
751
- faiss_indexes:Tuple=None,
752
- topk:int=None,
753
- ) -> Union[Tuple, BaseModelOutputWithPast]:
754
- output_attentions = (
755
- output_attentions
756
- if output_attentions is not None
757
- else self.config.output_attentions
758
- )
759
- output_hidden_states = (
760
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
761
- )
762
- use_cache = use_cache if use_cache is not None else self.config.use_cache
763
-
764
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
765
- use_active_externalism = (use_active_externalism if use_active_externalism is not None else self.use_active_externalism)
766
- topk = (topk if topk is not None else self.topk)
767
-
768
- # retrieve input_ids and inputs_embeds
769
- if input_ids is not None and inputs_embeds is not None:
770
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
771
- elif input_ids is not None:
772
- batch_size, seq_length = input_ids.shape
773
- elif inputs_embeds is not None:
774
- batch_size, seq_length, _ = inputs_embeds.shape
775
- else:
776
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
777
-
778
- seq_length_with_past = seq_length
779
- past_key_values_length = 0
780
-
781
- if past_key_values is not None:
782
- past_key_values_length = past_key_values[0][0].shape[2]
783
- seq_length_with_past = seq_length_with_past + past_key_values_length
784
-
785
- if position_ids is None:
786
- device = input_ids.device if input_ids is not None else inputs_embeds.device
787
- position_ids = torch.arange(
788
- seq_length_with_past, dtype=torch.long, device=device #range of position ids is total seq length since we apply rotary pos emb after reading from cache
789
- )
790
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length_with_past)
791
- else:
792
- position_ids = position_ids.view(-1, seq_length).long()
793
-
794
- if inputs_embeds is None:
795
- inputs_embeds = self.embed_tokens(input_ids)
796
- # embed positions
797
- if attention_mask is None:
798
- attention_mask = torch.ones(
799
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
800
- )
801
- attention_mask = self._prepare_decoder_attention_mask(
802
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
803
- )
804
-
805
- hidden_states = inputs_embeds
806
-
807
- if self.gradient_checkpointing and self.training:
808
- if use_cache:
809
- logger.warning_once(
810
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
811
- )
812
- use_cache = False
813
-
814
- # decoder layers
815
- all_hidden_states = () if output_hidden_states else None
816
- all_self_attns = () if output_attentions else None
817
- next_decoder_cache = () if use_cache else None
818
- all_idx = () if output_attentions else None
819
-
820
- for idx, decoder_layer in enumerate(self.layers):
821
- if output_hidden_states:
822
- all_hidden_states += (hidden_states,)
823
-
824
- past_key_value = past_key_values[idx] if past_key_values is not None else None
825
-
826
- long_range_past_key_value = (long_range_past_key_values[idx]
827
- if (long_range_past_key_values is not None and self.use_active_externalism_by_layer[idx] and use_active_externalism is True) else None)
828
-
829
- if long_range_past_key_value is not None and faiss_indexes is not None:
830
- raise NotImplementedError(
831
- 'Using faiss and passing key value pairs manually are mutually exclusive right now.')
832
-
833
- if self.gradient_checkpointing and self.training:
834
-
835
- def create_custom_forward(module):
836
- def custom_forward(*inputs):
837
- # None for past_key_value
838
- return module(*inputs, output_attentions, None)
839
-
840
- return custom_forward
841
-
842
- layer_outputs = torch.utils.checkpoint.checkpoint(
843
- create_custom_forward(decoder_layer),
844
- hidden_states,
845
- attention_mask,
846
- position_ids,
847
- None,
848
- )
849
- else:
850
- layer_outputs = decoder_layer(
851
- hidden_states,
852
- attention_mask=attention_mask,
853
- position_ids=position_ids,
854
- past_key_value=past_key_value,
855
- output_attentions=output_attentions,
856
- use_cache=use_cache,
857
-
858
- topk=topk,
859
- long_range_past_key_value=long_range_past_key_value,
860
- faiss_indexes=faiss_indexes,
861
- mask_by_sim=self.mask_by_sim,
862
- sim_threshold=self.sim_threshold,
863
- current_layer=idx,
864
- )
865
-
866
- hidden_states = layer_outputs[0]
867
-
868
- if use_cache:
869
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
870
-
871
- if output_attentions:
872
- all_self_attns += (layer_outputs[1],)
873
-
874
- all_idx += (layer_outputs[3],) # record which memories were retrieved
875
- hidden_states = self.norm(hidden_states)
876
-
877
- # add hidden states from the last decoder layer
878
- if output_hidden_states:
879
- all_hidden_states += (hidden_states,)
880
-
881
- next_cache = next_decoder_cache if use_cache else None
882
- if not return_dict:
883
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
884
- return BaseModelOutputWithPast(
885
- last_hidden_state=hidden_states,
886
- past_key_values=next_cache,
887
- hidden_states=all_hidden_states,
888
- attentions=(all_self_attns, all_idx)
889
- )
890
-
891
-
892
- class ExtendedLlamaForCausalLM(LlamaPreTrainedModel):
893
- _tied_weights_keys = ["lm_head.weight"]
894
-
895
- def __init__(self, config, external_memories=None, **kwargs):
896
- super().__init__(config)
897
- self.model = ExtendedLlamaModel(config)
898
-
899
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
900
-
901
- self.use_active_externalism = config.use_active_externalism
902
- self.memory_type = config.memory_type
903
- self.memory_device = config.memory_device
904
- self._memories = None
905
- if external_memories is not None:
906
- self._memories = external_memories
907
- self.memories = None
908
-
909
- # Initialize weights and apply final processing
910
- self.post_init()
911
-
912
- def get_input_embeddings(self):
913
- return self.model.embed_tokens
914
-
915
- def set_input_embeddings(self, value):
916
- self.model.embed_tokens = value
917
-
918
- def get_output_embeddings(self):
919
- return self.lm_head
920
-
921
- def set_output_embeddings(self, new_embeddings):
922
- self.lm_head = new_embeddings
923
-
924
- def set_decoder(self, decoder):
925
- self.model = decoder
926
-
927
- def get_decoder(self):
928
- return self.model
929
-
930
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
931
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
932
- def forward(
933
- self,
934
- input_ids: torch.LongTensor = None,
935
- attention_mask: Optional[torch.Tensor] = None,
936
- position_ids: Optional[torch.LongTensor] = None,
937
- past_key_values: Optional[List[torch.FloatTensor]] = None,
938
- inputs_embeds: Optional[torch.FloatTensor] = None,
939
- labels: Optional[torch.LongTensor] = None,
940
- use_cache: Optional[bool] = None,
941
- output_attentions: Optional[bool] = None,
942
- output_hidden_states: Optional[bool] = None,
943
- return_dict: Optional[bool] = None,
944
-
945
- use_active_externalism: Optional[bool]=None,
946
- topk:int=None
947
- ) -> Union[Tuple, CausalLMOutputWithPast]:
948
- r"""
949
- Args:
950
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
951
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
952
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
953
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
954
-
955
- Returns:
956
-
957
- Example:
958
-
959
- ```python
960
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
961
-
962
- >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
963
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
964
-
965
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
966
- >>> inputs = tokenizer(prompt, return_tensors="pt")
967
-
968
- >>> # Generate
969
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
970
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
971
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
972
- ```"""
973
-
974
- if self._memories is not None and self.memories is None: #init memories once on first call
975
- self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
976
-
977
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
978
- output_hidden_states = (
979
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
980
- )
981
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
982
-
983
- use_active_externalism = (use_active_externalism
984
- if use_active_externalism is not None else self.use_active_externalism
985
- )
986
- topk = topk if topk is not None else None
987
-
988
- long_range_past_key_values = None
989
- faiss_indexes = None
990
- if hasattr(self, "memories") and isinstance(self.memories, list):
991
- long_range_past_key_values = self.memories
992
- faiss_indexes = None
993
- elif hasattr(self, "memories"):
994
- long_range_past_key_values = None
995
- faiss_indexes = self.memories
996
-
997
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
998
- outputs = self.model(
999
- input_ids=input_ids,
1000
- attention_mask=attention_mask,
1001
- position_ids=position_ids,
1002
- past_key_values=past_key_values,
1003
- inputs_embeds=inputs_embeds,
1004
- use_cache=use_cache,
1005
- output_attentions=output_attentions,
1006
- output_hidden_states=output_hidden_states,
1007
- return_dict=return_dict,
1008
-
1009
- long_range_past_key_values=long_range_past_key_values,
1010
- faiss_indexes=faiss_indexes,
1011
- use_active_externalism=use_active_externalism,
1012
- topk=topk
1013
- )
1014
-
1015
- hidden_states = outputs[0]
1016
- logits = self.lm_head(hidden_states)
1017
-
1018
- loss = None
1019
- if labels is not None:
1020
- # Shift so that tokens < n predict n
1021
- shift_logits = logits[..., :-1, :].contiguous()
1022
- shift_labels = labels[..., 1:].contiguous()
1023
- # Flatten the tokens
1024
- loss_fct = CrossEntropyLoss()
1025
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1026
- shift_labels = shift_labels.view(-1)
1027
- # Enable model parallelism
1028
- shift_labels = shift_labels.to(shift_logits.device)
1029
- loss = loss_fct(shift_logits, shift_labels)
1030
-
1031
- if not return_dict:
1032
- output = (logits,) + outputs[1:]
1033
- return (loss,) + output if loss is not None else output
1034
-
1035
- return CausalLMOutputWithPast(
1036
- loss=loss,
1037
- logits=logits,
1038
- past_key_values=outputs.past_key_values,
1039
- hidden_states=outputs.hidden_states,
1040
- attentions=outputs.attentions,
1041
- )
1042
-
1043
- def generate_cache(self,
1044
- input_ids:torch.LongTensor,
1045
- stride:int=512,
1046
- max_len:int=2048,
1047
- cache_type:str='manual'):
1048
- if cache_type not in ['manual', 'faiss']:
1049
- raise NotImplementedError(f"Cache type {cache_type} not implemented.")
1050
-
1051
- prev_end_loc=0
1052
- long_range_past_key_values = None
1053
- faiss_indexes= None
1054
- for b_idx in range(0, input_ids.size(-1), stride): #generate kv-pairs using stride
1055
- end_loc = min(b_idx + max_len, input_ids.size(-1))
1056
- trg_len = end_loc - prev_end_loc
1057
- subseq = input_ids[:, b_idx:end_loc].to(self.model.device)
1058
- with torch.no_grad():
1059
- outputs = self.model(subseq, use_cache=True, use_active_externalism=False)
1060
- to_cache = [(
1061
- kv[0][:,:,-trg_len:],
1062
- kv[1][:,:,-trg_len:])
1063
- for kv in outputs.past_key_values
1064
- ]
1065
- long_range_past_key_values, faiss_indexes = self.cache(to_cache, cache_type, long_range_past_key_values=long_range_past_key_values, faiss_indexes=faiss_indexes)
1066
-
1067
- prev_end_loc = end_loc
1068
- if end_loc == input_ids.size(-1):
1069
- break
1070
- if long_range_past_key_values is not None:
1071
- return long_range_past_key_values
1072
- else:
1073
- return faiss_indexes
1074
-
1075
- def cache(self,
1076
- to_cache:List,
1077
- cache_type:str='manual',
1078
- long_range_past_key_values:List=None,
1079
- faiss_indexes:faiss.IndexFlatIP=None,
1080
- max_length_cache=100000,
1081
- verbose=False):
1082
- if long_range_past_key_values is not None and faiss_indexes is not None:
1083
- raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
1084
-
1085
- if cache_type=='faiss': #add one-hot encoding to match layer, head indices
1086
- one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.num_hidden_layers))*10
1087
- if faiss_indexes is None:
1088
- faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-1)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
1089
- kn_index, kv_index = faiss_indexes
1090
- for b_idx, (k, v) in enumerate(to_cache):
1091
- k_n = (k/vector_norm(k, ord=2, dim=-1, keepdim=True)).to('cpu')
1092
- k_n = torch.concat([rearrange(k_n, 'b h s d -> b (h s) d', h=self.config.n_heads), one_hot_encodings[self.config.n_heads*b_idx:self.config.n_heads*(b_idx+1)].unsqueeze(0).repeat_interleave(repeats=k.size(-2), dim=-2)], dim=-1)
1093
- kn_index.add(k_n.squeeze().numpy())
1094
-
1095
- k= rearrange(k, 'b h s d -> b (h s) d', h=self.config.n_heads)
1096
- v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
1097
- kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
1098
- else:
1099
- if long_range_past_key_values is None:
1100
- long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
1101
- else:
1102
- long_range_past_key_values = [
1103
- (
1104
- torch.concat([kv[0], to_cache[ind][0].to(self.memory_device)], dim=2),
1105
- torch.concat([kv[1], to_cache[ind][1].to(self.memory_device)], dim=2)
1106
- )
1107
- for ind, kv in enumerate(long_range_past_key_values)
1108
- ]
1109
- if long_range_past_key_values is not None: #set a limit on manual memory length
1110
- if long_range_past_key_values[0][0].size(-2) > max_length_cache:
1111
- long_range_past_key_values = [
1112
- (
1113
- kv[0][:, :, -max_length_cache:],
1114
- kv[1][:, :, -max_length_cache:]
1115
- )
1116
- for kv in long_range_past_key_values]
1117
- if verbose:
1118
- if cache_type == 'faiss':
1119
- print(f"{kn_index.ntotal} keys in faiss index")
1120
- else:
1121
- print(f"{long_range_past_key_values[0][0].size(-2)} cached kvs")
1122
-
1123
- return long_range_past_key_values, (kn_index, kv_index) if cache_type == 'faiss' else None
1124
-
1125
- def prepare_inputs_for_generation(
1126
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1127
- ):
1128
- if past_key_values:
1129
- input_ids = input_ids[:, -1:]
1130
-
1131
- position_ids = kwargs.get("position_ids", None)
1132
- if attention_mask is not None and position_ids is None:
1133
- # create position_ids on the fly for batch generation
1134
- position_ids = attention_mask.long().cumsum(-1) - 1
1135
- position_ids.masked_fill_(attention_mask == 0, 1)
1136
- if past_key_values:
1137
- position_ids = position_ids[:, -1].unsqueeze(-1)
1138
-
1139
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1140
- if inputs_embeds is not None and past_key_values is None:
1141
- model_inputs = {"inputs_embeds": inputs_embeds}
1142
- else:
1143
- model_inputs = {"input_ids": input_ids}
1144
-
1145
- model_inputs.update(
1146
- {
1147
- "position_ids": position_ids,
1148
- "past_key_values": past_key_values,
1149
- "use_cache": kwargs.get("use_cache"),
1150
- "attention_mask": attention_mask,
1151
- 'use_active_externalism': kwargs.get('use_active_externalism'), #add a few more kwargs for active externalism
1152
- 'topk': kwargs.get('topk', None),
1153
- }
1154
- )
1155
- return model_inputs
1156
-
1157
- @staticmethod
1158
- def _reorder_cache(past_key_values, beam_idx):
1159
- reordered_past = ()
1160
- for layer_past in past_key_values:
1161
- reordered_past += (
1162
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1163
- )
1164
- return reordered_past