Guanzheng commited on
Commit
e8f1578
1 Parent(s): 481b512

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +55 -21
modeling_llama.py CHANGED
@@ -32,19 +32,52 @@ from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from .configuration_clex import CLEXLlamaConfig
34
  from .clex_layer import LlamaCLEXScalingRotaryEmbedding
35
-
36
-
37
  from einops import rearrange
38
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
39
- # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
40
- from flash_attn.bert_padding import unpad_input, pad_input
41
 
42
 
43
  logger = logging.get_logger(__name__)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  _CONFIG_FOR_DOC = "CLEXLlamaConfig"
46
 
47
 
 
 
 
48
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
49
  def _make_causal_mask(
50
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@@ -137,13 +170,13 @@ def rotate_half(x):
137
  return torch.cat((-x2, x1), dim=-1)
138
 
139
 
140
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
141
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
142
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
143
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
144
  cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
145
  sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
146
- q_embed = (q * cos) + (rotate_half(q) * sin)
147
  k_embed = (k * cos) + (rotate_half(k) * sin)
148
  return q_embed, k_embed
149
 
@@ -247,19 +280,17 @@ class LlamaAttention(nn.Module):
247
  value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
248
 
249
  kv_seq_len = key_states.shape[-2]
250
- if past_key_value is not None:
251
- kv_seq_len += past_key_value[0].shape[-2]
252
- # [bsz, nh, t, hd]
253
 
254
  if past_key_value is not None:
255
  kv_seq_len += past_key_value[0].shape[-2]
256
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
257
 
258
  if pack_cos_sin is not None:
259
- cos, sin = pack_cos_sin
260
  else:
261
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
262
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
263
 
264
  if past_key_value is not None:
265
  # reuse k, v, self_attention
@@ -267,12 +298,13 @@ class LlamaAttention(nn.Module):
267
 
268
  past_key_value = (key_states, value_states) if use_cache else None
269
 
 
270
 
271
  if self.log_scale:
272
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
273
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
274
  query_states = query_states * log_n
275
- if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or not self.config.use_flashattn:
276
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
277
 
278
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@@ -308,6 +340,7 @@ class LlamaAttention(nn.Module):
308
  attn_weights = None
309
 
310
  return attn_output, attn_weights, past_key_value
 
311
  elif past_key_value is not None:
312
  output = flash_attn_with_kvcache(
313
  query_states.transpose(1, 2),
@@ -614,13 +647,15 @@ class LlamaModel(LlamaPreTrainedModel):
614
  if inputs_embeds is None:
615
  inputs_embeds = self.embed_tokens(input_ids)
616
  # embed positions
617
- if attention_mask is None:
618
- attention_mask = torch.ones(
619
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
620
- )
621
- attention_mask = self._prepare_decoder_attention_mask(
622
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
623
- )
 
 
624
 
625
  hidden_states = inputs_embeds
626
 
@@ -802,7 +837,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
802
  # Enable model parallelism
803
  shift_labels = shift_labels.to(shift_logits.device)
804
  loss = loss_fct(shift_logits, shift_labels)
805
-
806
  if not return_dict:
807
  output = (logits,) + outputs[1:]
808
  return (loss,) + output if loss is not None else output
 
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from .configuration_clex import CLEXLlamaConfig
34
  from .clex_layer import LlamaCLEXScalingRotaryEmbedding
 
 
35
  from einops import rearrange
36
+ import importlib.metadata
37
+ import importlib.util
 
38
 
39
 
40
  logger = logging.get_logger(__name__)
41
 
42
+ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
43
+ # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
44
+ package_exists = importlib.util.find_spec(pkg_name) is not None
45
+ package_version = "N/A"
46
+ if package_exists:
47
+ try:
48
+ package_version = importlib.metadata.version(pkg_name)
49
+ package_exists = True
50
+ except importlib.metadata.PackageNotFoundError:
51
+ package_exists = False
52
+ logger.info(f"Detected {pkg_name} version {package_version}")
53
+ if return_version:
54
+ return package_exists, package_version
55
+ else:
56
+ return package_exists
57
+
58
+ def is_flash_attn_available():
59
+ if not _is_package_available("torch", return_version=True):
60
+ return False
61
+
62
+ # Let's add an extra check to see if cuda is available
63
+ import torch
64
+
65
+ return _is_package_available("flash_attn") and torch.cuda.is_available()
66
+
67
+ if is_flash_attn_available():
68
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
69
+ # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
70
+ from flash_attn.bert_padding import unpad_input, pad_input
71
+
72
+
73
+
74
+
75
  _CONFIG_FOR_DOC = "CLEXLlamaConfig"
76
 
77
 
78
+
79
+
80
+
81
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
82
  def _make_causal_mask(
83
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
170
  return torch.cat((-x2, x1), dim=-1)
171
 
172
 
173
+ def apply_rotary_pos_emb(q, k, cos, sin, q_len, position_ids):
174
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
175
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
176
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
177
  cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
178
  sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
179
+ q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :])
180
  k_embed = (k * cos) + (rotate_half(k) * sin)
181
  return q_embed, k_embed
182
 
 
280
  value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
281
 
282
  kv_seq_len = key_states.shape[-2]
 
 
 
283
 
284
  if past_key_value is not None:
285
  kv_seq_len += past_key_value[0].shape[-2]
286
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
287
 
288
  if pack_cos_sin is not None:
289
+ cos, sin = pack_cos_sin.to(query_states.device)
290
  else:
291
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
292
+ key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
293
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, q_len, key_position_ids)
294
 
295
  if past_key_value is not None:
296
  # reuse k, v, self_attention
 
298
 
299
  past_key_value = (key_states, value_states) if use_cache else None
300
 
301
+ use_flashatn = self.config.use_flashattn and is_flash_attn_available()
302
 
303
  if self.log_scale:
304
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
305
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
306
  query_states = query_states * log_n
307
+ if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or use_flashatn:
308
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
309
 
310
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
340
  attn_weights = None
341
 
342
  return attn_output, attn_weights, past_key_value
343
+ # use flash attention
344
  elif past_key_value is not None:
345
  output = flash_attn_with_kvcache(
346
  query_states.transpose(1, 2),
 
647
  if inputs_embeds is None:
648
  inputs_embeds = self.embed_tokens(input_ids)
649
  # embed positions
650
+ # if attention_mask is None:
651
+ # attention_mask = torch.ones(
652
+ # (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
653
+ # )
654
+ # attention_mask = self._prepare_decoder_attention_mask(
655
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
656
+ # )
657
+ attention_mask = None
658
+
659
 
660
  hidden_states = inputs_embeds
661
 
 
837
  # Enable model parallelism
838
  shift_labels = shift_labels.to(shift_logits.device)
839
  loss = loss_fct(shift_logits, shift_labels)
 
840
  if not return_dict:
841
  output = (logits,) + outputs[1:]
842
  return (loss,) + output if loss is not None else output