Guanzheng commited on
Commit
1f5d87f
1 Parent(s): 22df4da

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +8 -8
modeling_llama.py CHANGED
@@ -646,14 +646,14 @@ class LlamaModel(LlamaPreTrainedModel):
646
  if inputs_embeds is None:
647
  inputs_embeds = self.embed_tokens(input_ids)
648
  # embed positions
649
- # if attention_mask is None:
650
- # attention_mask = torch.ones(
651
- # (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
652
- # )
653
- # attention_mask = self._prepare_decoder_attention_mask(
654
- # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
655
- # )
656
- attention_mask = None
657
 
658
 
659
  hidden_states = inputs_embeds
 
646
  if inputs_embeds is None:
647
  inputs_embeds = self.embed_tokens(input_ids)
648
  # embed positions
649
+ if attention_mask is None:
650
+ attention_mask = torch.ones(
651
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
652
+ )
653
+ attention_mask = self._prepare_decoder_attention_mask(
654
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
655
+ )
656
+ # attention_mask = None
657
 
658
 
659
  hidden_states = inputs_embeds