BK-Lee commited on
Commit
f019fdd
1 Parent(s): 908a9d5
app.py CHANGED
@@ -18,8 +18,8 @@ from transformers import TextIteratorStreamer
18
  from torchvision.transforms.functional import pil_to_tensor
19
 
20
  # flash attention
21
- # import subprocess
22
- # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
 
24
  # accel
25
  accel = Accelerator()
 
18
  from torchvision.transforms.functional import pil_to_tensor
19
 
20
  # flash attention
21
+ import subprocess
22
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
 
24
  # accel
25
  accel = Accelerator()
trol/arch_internlm2/modeling_internlm2.py CHANGED
@@ -867,13 +867,15 @@ class InternLM2Model(InternLM2PreTrainedModel):
867
  self.norm = InternLM2RMSNorm(
868
  config.hidden_size, eps=config.rms_norm_eps)
869
 
870
- self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
871
- self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
872
-
873
  self.gradient_checkpointing = False
874
  # Initialize weights and apply final processing
875
  self.post_init()
876
 
 
 
 
 
 
877
  def get_input_embeddings(self):
878
  return self.tok_embeddings
879
 
 
867
  self.norm = InternLM2RMSNorm(
868
  config.hidden_size, eps=config.rms_norm_eps)
869
 
 
 
 
870
  self.gradient_checkpointing = False
871
  # Initialize weights and apply final processing
872
  self.post_init()
873
 
874
+ def initialize_trol_gating(self):
875
+ self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1).cuda()]*self.config.num_hidden_layers)
876
+ self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
877
+
878
+
879
  def get_input_embeddings(self):
880
  return self.tok_embeddings
881
 
trol/arch_phi3/modeling_phi3.py CHANGED
@@ -1031,13 +1031,15 @@ class Phi3Model(Phi3PreTrainedModel):
1031
  self._attn_implementation = "flash_attention_2"
1032
  self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1033
 
1034
- self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
1035
- self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
1036
-
1037
  self.gradient_checkpointing = False
1038
  # Initialize weights and apply final processing
1039
  self.post_init()
1040
 
 
 
 
 
 
1041
  def get_input_embeddings(self):
1042
  return self.embed_tokens
1043
 
 
1031
  self._attn_implementation = "flash_attention_2"
1032
  self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1033
 
 
 
 
1034
  self.gradient_checkpointing = False
1035
  # Initialize weights and apply final processing
1036
  self.post_init()
1037
 
1038
+ def initialize_trol_gating(self):
1039
+ self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1).cuda()]*self.config.num_hidden_layers)
1040
+ self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
1041
+
1042
+
1043
  def get_input_embeddings(self):
1044
  return self.embed_tokens
1045
 
trol/load_trol.py CHANGED
@@ -81,11 +81,17 @@ def load_trol(link):
81
  # setting config
82
  setting_trol_config(trol, tok_trol, image_special_token)
83
 
84
-
85
  # trol gating load
86
  from huggingface_hub import hf_hub_download
87
  try:
 
88
  trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
89
  except:
 
90
  trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
 
 
 
 
 
91
  return trol, tok_trol
 
81
  # setting config
82
  setting_trol_config(trol, tok_trol, image_special_token)
83
 
 
84
  # trol gating load
85
  from huggingface_hub import hf_hub_download
86
  try:
87
+ trol.model.initialize_trol_gating()
88
  trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
89
  except:
90
+ trol.language_model.model.initialize_trol_gating()
91
  trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
92
+
93
+ # X -> float16 conversion
94
+ for param in trol.parameters():
95
+ if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
96
+ param.data = param.data.to(torch.float16)
97
  return trol, tok_trol