BK-Lee commited on
Commit
9b5bdb0
1 Parent(s): a3e7589
Files changed (2) hide show
  1. app.py +7 -7
  2. trol/load_trol.py +2 -2
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # A100 Zero GPU
2
- import spaces
3
 
4
  # TroL Package
5
  import torch
@@ -18,8 +18,8 @@ from transformers import TextIteratorStreamer
18
  from torchvision.transforms.functional import pil_to_tensor
19
 
20
  # flash attention
21
- import subprocess
22
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
 
24
  # accel
25
  accel = Accelerator()
@@ -33,10 +33,10 @@ question="What is the troll doing? Provide the detail in the image and imagine w
33
  model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
34
 
35
  # loading model
36
- model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
37
 
38
  # loading model
39
- model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
41
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
@@ -55,7 +55,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
55
  generation_kwargs.update({'use_cache': True})
56
  return model.generate(**generation_kwargs)
57
 
58
- @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
  # model selection
@@ -135,7 +135,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
135
  yield buffer
136
 
137
  demo = gr.ChatInterface(fn=bot_streaming,
138
- additional_inputs = [gr.Radio(["1.8B", "3.8B"], label="Size", info="Select one model size", value="3.8B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
139
  additional_inputs_accordion="Generation Hyperparameters",
140
  theme=gr.themes.Soft(),
141
  title="TroL",
 
1
  # A100 Zero GPU
2
+ # import spaces
3
 
4
  # TroL Package
5
  import torch
 
18
  from torchvision.transforms.functional import pil_to_tensor
19
 
20
  # flash attention
21
+ # import subprocess
22
+ # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
23
 
24
  # accel
25
  accel = Accelerator()
 
33
  model_1_8, tokenizer_1_8 = load_trol(link='TroL-1.8B')
34
 
35
  # loading model
36
+ # model_3_8, tokenizer_3_8 = load_trol(link='TroL-3.8B')
37
 
38
  # loading model
39
+ # model_7, tokenizer_7 = load_trol(link='TroL-7B')
40
 
41
  def threading_function(inputs, image_token_number, streamer, device, model, tokenizer, temperature, new_max_token, top_p):
42
 
 
55
  generation_kwargs.update({'use_cache': True})
56
  return model.generate(**generation_kwargs)
57
 
58
+ # @spaces.GPU
59
  def bot_streaming(message, history, link, temperature, new_max_token, top_p):
60
 
61
  # model selection
 
135
  yield buffer
136
 
137
  demo = gr.ChatInterface(fn=bot_streaming,
138
+ additional_inputs = [gr.Radio(["1.8B"], label="Size", info="Select one model size", value="1.8B"), gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
139
  additional_inputs_accordion="Generation Hyperparameters",
140
  theme=gr.themes.Soft(),
141
  title="TroL",
trol/load_trol.py CHANGED
@@ -14,14 +14,14 @@ def load_trol(link):
14
  if link == 'TroL-1.8B':
15
  from .arch_internlm2.modeling_trol import TroLForCausalLM
16
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
17
- bits = 16
18
  path = TROL_1_8B
19
  bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
20
 
21
  elif link == 'TroL-3.8B':
22
  from trol.arch_phi3.modeling_trol import TroLForCausalLM
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
- bits = 16
25
  path = TROL_3_8B
26
  bit_quant_skip = ["vision_model", "vision_proj", "lm_head"]
27
 
 
14
  if link == 'TroL-1.8B':
15
  from .arch_internlm2.modeling_trol import TroLForCausalLM
16
  from .arch_internlm2.tokenization_internlm2 import InternLM2Tokenizer as TroLTokenizer
17
+ bits = 4
18
  path = TROL_1_8B
19
  bit_quant_skip = ["vit", "vision_proj", "ffn", "output"]
20
 
21
  elif link == 'TroL-3.8B':
22
  from trol.arch_phi3.modeling_trol import TroLForCausalLM
23
  from transformers import LlamaTokenizerFast as TroLTokenizer
24
+ bits = 8
25
  path = TROL_3_8B
26
  bit_quant_skip = ["vision_model", "vision_proj", "lm_head"]
27