broadwell commited on
Commit
1dce8bf
1 Parent(s): 1563ea0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -27,6 +27,8 @@ from CLIP_Explainability.vit_cam import (
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
 
 
30
  MAX_IMG_WIDTH = 500
31
  MAX_IMG_HEIGHT = 800
32
 
@@ -172,9 +174,10 @@ def init():
172
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
173
  ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
174
 
175
- st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
176
- ja_model_path, device=device, jit=False
177
- )
 
178
 
179
  st.session_state.ja_model = AutoModel.from_pretrained(
180
  ja_model_name, trust_remote_code=True
@@ -183,9 +186,10 @@ def init():
183
  ja_model_name, trust_remote_code=True
184
  )
185
 
186
- st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
187
- clip.load("RN50x4", device=device)
188
- )
 
189
 
190
  st.session_state.rn_model = legacy_multilingual_clip.load_model(
191
  "M-BERT-Base-69"
@@ -701,11 +705,12 @@ for image_id in batch:
701
  <div>""",
702
  unsafe_allow_html=True,
703
  )
704
- st.button(
705
- "Explain this",
706
- on_click=image_modal,
707
- args=[image_id],
708
- use_container_width=True,
709
- key=image_id,
710
- )
 
711
  col = (col + 1) % row_size
 
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
+ RUN_LITE = True # Load vision model for CAM viz explainability for M-CLIP only
31
+
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
34
 
 
174
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
175
  ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
176
 
177
+ if not RUN_LITE:
178
+ st.session_state.ja_image_model, st.session_state.ja_image_preprocess = (
179
+ load(ja_model_path, device=device, jit=False)
180
+ )
181
 
182
  st.session_state.ja_model = AutoModel.from_pretrained(
183
  ja_model_name, trust_remote_code=True
 
186
  ja_model_name, trust_remote_code=True
187
  )
188
 
189
+ if not RUN_LITE:
190
+ st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
191
+ clip.load("RN50x4", device=device)
192
+ )
193
 
194
  st.session_state.rn_model = legacy_multilingual_clip.load_model(
195
  "M-BERT-Base-69"
 
705
  <div>""",
706
  unsafe_allow_html=True,
707
  )
708
+ if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
709
+ st.button(
710
+ "Explain this",
711
+ on_click=image_modal,
712
+ args=[image_id],
713
+ use_container_width=True,
714
+ key=image_id,
715
+ )
716
  col = (col + 1) % row_size