BK-Lee commited on
Commit
908a9d5
1 Parent(s): 9b5bdb0
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # A100 Zero GPU
2
- # import spaces
3
 
4
  # TroL Package
5
  import torch
@@ -33,10 +33,10 @@ question="What is the troll doing? Provide the detail in the image and imagine w
33
  model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
34
 
35
  # loading model
36
- # model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
37
 
38
  # loading model
39
- # model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
41
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
@@ -55,7 +55,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
55
  generation_kwargs.update({'use_cache': True})
56
  return model.generate(**generation_kwargs)
57
 
58
- # @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
  # model selection
@@ -70,9 +70,9 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
70
  tokenizer = tokenizer_7
71
 
72
  # cpu -> gpu
73
- # for param in model.parameters():
74
- # if not param.is_cuda:
75
- # param.data = param.to(accel.device)
76
 
77
  # prompt type -> input prompt
78
  image_token_number = None
@@ -131,11 +131,11 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
131
  buffer = ""
132
  for character in response:
133
  buffer += character
134
- time.sleep(0.015)
135
  yield buffer
136
 
137
  demo = gr.ChatInterface(fn=bot_streaming,
138
- additional_inputs = [gr.Radio(["1.8B"], label="Size", info="Select one model size", value="1.8B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
139
  additional_inputs_accordion="Generation Hyperparameters",
140
  theme=gr.themes.Soft(),
141
  title="TroL",
 
1
  # A100 Zero GPU
2
+ import spaces
3
 
4
  # TroL Package
5
  import torch
 
33
  model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
34
 
35
  # loading model
36
+ model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
37
 
38
  # loading model
39
+ model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
41
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
 
55
  generation_kwargs.update({'use_cache': True})
56
  return model.generate(**generation_kwargs)
57
 
58
+ @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
  # model selection
 
70
  tokenizer = tokenizer_7
71
 
72
  # cpu -> gpu
73
+ for param in model.parameters():
74
+ if not param.is_cuda:
75
+ param.data = param.to(accel.device)
76
 
77
  # prompt type -> input prompt
78
  image_token_number = None
 
131
  buffer = ""
132
  for character in response:
133
  buffer += character
134
+ time.sleep(0.012)
135
  yield buffer
136
 
137
  demo = gr.ChatInterface(fn=bot_streaming,
138
+ additional_inputs = [gr.Radio(["1.8B", "3.8B", "7B"], label="Size", info="Select one model size", value="7B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
139
  additional_inputs_accordion="Generation Hyperparameters",
140
  theme=gr.themes.Soft(),
141
  title="TroL",
trol/arch_internlm2/modeling_internlm2.py CHANGED
@@ -857,7 +857,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
857
  self.vocab_size = config.vocab_size
858
  self.config = config
859
 
860
- self.tok_embeddings = nn.Embedding(config.vocab_size,
861
  config.hidden_size,
862
  self.padding_idx)
863
  self.layers = nn.ModuleList([
 
857
  self.vocab_size = config.vocab_size
858
  self.config = config
859
 
860
+ self.tok_embeddings = nn.Embedding(config.vocab_size+1,
861
  config.hidden_size,
862
  self.padding_idx)
863
  self.layers = nn.ModuleList([
trol/arch_internlm2/modeling_trol.py CHANGED
@@ -30,7 +30,7 @@ class TroLForCausalLM(InternLM2PreTrainedModel):
30
  # Model
31
  self.model = InternLM2Model(config)
32
  self.vocab_size = config.vocab_size
33
- self.output = nn.Linear(config.hidden_size, config.vocab_size-1, bias=False)
34
  self.max_length = config.max_length
35
 
36
  # Initialize weights and apply final processing
 
30
  # Model
31
  self.model = InternLM2Model(config)
32
  self.vocab_size = config.vocab_size
33
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
  self.max_length = config.max_length
35
 
36
  # Initialize weights and apply final processing
trol/load_trol.py CHANGED
@@ -1,11 +1,16 @@
1
  import torch
2
  import warnings
3
  from config import *
4
- from peft import LoraConfig
5
  from transformers import BitsAndBytesConfig
6
 
7
  warnings.filterwarnings(action='ignore')
8
 
 
 
 
 
 
 
9
  def load_trol(link):
10
 
11
  """
@@ -16,21 +21,24 @@ def load_trol(link):
16
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
17
  bits = 4
18
  path = TROL_1_8B
19
- bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
 
20
 
21
  elif link == 'TroL-3.8B':
22
  from trol.arch_phi3.modeling_trol import TroLForCausalLM
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
  bits = 8
25
  path = TROL_3_8B
26
- bit_quant_skip = ["vision_model", "vision_proj", "lm_head"]
 
27
 
28
  elif link == 'TroL-7B':
29
  from .arch_internlm2.modeling_trol import TroLForCausalLM
30
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
31
  bits = 4
32
  path = TROL_7B
33
- bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
 
34
  else:
35
  raise Exception("Unsupported Link")
36
 
@@ -68,10 +76,16 @@ def load_trol(link):
68
  except:
69
  del huggingface_config["attn_implementation"]
70
  trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
 
 
 
 
 
71
 
72
- # wrapping
 
73
  try:
74
- trol = trol.cuda()
75
  except:
76
- pass
77
  return trol, tok_trol
 
1
  import torch
2
  import warnings
3
  from config import *
 
4
  from transformers import BitsAndBytesConfig
5
 
6
  warnings.filterwarnings(action='ignore')
7
 
8
+ def setting_trol_config(trol, tok_trol, image_special_token):
9
+ trol.config.image_token_index = tok_trol.convert_tokens_to_ids(image_special_token)
10
+ trol.config.ignore_index = -100
11
+ trol.config.pad_token_id = tok_trol.eos_token_id
12
+ trol.config.eos_token_id = tok_trol.eos_token_id
13
+
14
  def load_trol(link):
15
 
16
  """
 
21
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
22
  bits = 4
23
  path = TROL_1_8B
24
+ image_special_token = "<image>"
25
+ bit_quant_skip = ["vit", "vision_proj", "ffn", "output", "trol_gating"]
26
 
27
  elif link == 'TroL-3.8B':
28
  from trol.arch_phi3.modeling_trol import TroLForCausalLM
29
  from transformers import LlamaTokenizerFast as TroLTokenizer
30
  bits = 8
31
  path = TROL_3_8B
32
+ image_special_token = "<IMG_CONTEXT>"
33
+ bit_quant_skip = ["vision_model", "vision_proj", "lm_head", "trol_gating"]
34
 
35
  elif link == 'TroL-7B':
36
  from .arch_internlm2.modeling_trol import TroLForCausalLM
37
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
38
  bits = 4
39
  path = TROL_7B
40
+ image_special_token = "<image>"
41
+ bit_quant_skip = ["vit", "vision_proj", "ffn", "output", "trol_gating"]
42
  else:
43
  raise Exception("Unsupported Link")
44
 
 
76
  except:
77
  del huggingface_config["attn_implementation"]
78
  trol = TroLForCausalLM.from_pretrained(path, **huggingface_config)
79
+ trol.config.llm_config.use_cache = False
80
+
81
+ # setting config
82
+ setting_trol_config(trol, tok_trol, image_special_token)
83
+
84
 
85
+ # trol gating load
86
+ from huggingface_hub import hf_hub_download
87
  try:
88
+ trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
89
  except:
90
+ trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
91
  return trol, tok_trol