broadwell commited on
Commit
3e811d1
1 Parent(s): 1922759

Disable activation viz for RN model, to save memory

Browse files
Files changed (1) hide show
  1. CLIP_Explainability/app.py +17 -6
CLIP_Explainability/app.py CHANGED
@@ -27,7 +27,7 @@ from CLIP_Explainability.vit_cam import (
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
- RUN_LITE = False # Load vision model for CAM viz explainability for M-CLIP only
31
 
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
@@ -110,6 +110,10 @@ def clip_search(search_query):
110
 
111
 
112
  def string_search():
 
 
 
 
113
  if "search_field_value" in st.session_state:
114
  clip_search(st.session_state.search_field_value)
115
 
@@ -179,10 +183,9 @@ def init():
179
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
180
  ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
181
 
182
- if not RUN_LITE:
183
- st.session_state.ja_image_model, st.session_state.ja_image_preprocess = (
184
- load(ja_model_path, device=device, jit=False)
185
- )
186
 
187
  st.session_state.ja_model = AutoModel.from_pretrained(
188
  ja_model_name, trust_remote_code=True
@@ -216,6 +219,9 @@ def init():
216
  st.session_state.search_image_ids = []
217
  st.session_state.search_image_scores = {}
218
  st.session_state.text_table_df = None
 
 
 
219
 
220
  with st.spinner("Loading models and data, please wait..."):
221
  load_image_features()
@@ -745,6 +751,7 @@ with controls[3]:
745
  key="uploaded_image",
746
  label_visibility="collapsed",
747
  on_change=vis_uploaded_image,
 
748
  )
749
 
750
 
@@ -779,7 +786,9 @@ for image_id in batch:
779
  <div>""",
780
  unsafe_allow_html=True,
781
  )
782
- if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
 
 
783
  st.button(
784
  "Explain this",
785
  on_click=vis_known_image,
@@ -787,4 +796,6 @@ for image_id in batch:
787
  use_container_width=True,
788
  key=image_id,
789
  )
 
 
790
  col = (col + 1) % row_size
 
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
+ RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only
31
 
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
 
110
 
111
 
112
  def string_search():
113
+ st.session_state.disable_uploader = (
114
+ RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
115
+ )
116
+
117
  if "search_field_value" in st.session_state:
118
  clip_search(st.session_state.search_field_value)
119
 
 
183
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
184
  ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
185
 
186
+ st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
187
+ ja_model_path, device=device, jit=False
188
+ )
 
189
 
190
  st.session_state.ja_model = AutoModel.from_pretrained(
191
  ja_model_name, trust_remote_code=True
 
219
  st.session_state.search_image_ids = []
220
  st.session_state.search_image_scores = {}
221
  st.session_state.text_table_df = None
222
+ st.session_state.disable_uploader = (
223
+ RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
224
+ )
225
 
226
  with st.spinner("Loading models and data, please wait..."):
227
  load_image_features()
 
751
  key="uploaded_image",
752
  label_visibility="collapsed",
753
  on_change=vis_uploaded_image,
754
+ disabled=st.session_state.disable_uploader,
755
  )
756
 
757
 
 
786
  <div>""",
787
  unsafe_allow_html=True,
788
  )
789
+ if not (
790
+ RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
791
+ ):
792
  st.button(
793
  "Explain this",
794
  on_click=vis_known_image,
 
796
  use_container_width=True,
797
  key=image_id,
798
  )
799
+ else:
800
+ st.empty()
801
  col = (col + 1) % row_size