del gradient_checkpointing_enable()

#60
Files changed (1) hide show
  1. modeling_chatglm.py +0 -3
modeling_chatglm.py CHANGED
@@ -797,9 +797,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
800
- def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
801
- if not self.supports_gradient_checkpointing:
802
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
803
 
804
 
805
  class Embedding(torch.nn.Module):
 
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
 
 
 
800
 
801
 
802
  class Embedding(torch.nn.Module):