broadwell commited on
Commit
c3d8208
1 Parent(s): 3d8e28d

Don't run activations viz with ResNet model, to save memory

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -27,7 +27,7 @@ from CLIP_Explainability.vit_cam import (
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
- RUN_LITE = False # Load vision model for CAM viz explainability for M-CLIP only
31
 
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
@@ -110,6 +110,10 @@ def clip_search(search_query):
110
 
111
 
112
  def string_search():
 
 
 
 
113
  if "search_field_value" in st.session_state:
114
  clip_search(st.session_state.search_field_value)
115
 
@@ -179,10 +183,9 @@ def init():
179
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
180
  ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
181
 
182
- if not RUN_LITE:
183
- st.session_state.ja_image_model, st.session_state.ja_image_preprocess = (
184
- load(ja_model_path, device=device, jit=False)
185
- )
186
 
187
  st.session_state.ja_model = AutoModel.from_pretrained(
188
  ja_model_name, trust_remote_code=True
@@ -216,6 +219,9 @@ def init():
216
  st.session_state.search_image_ids = []
217
  st.session_state.search_image_scores = {}
218
  st.session_state.text_table_df = None
 
 
 
219
 
220
  with st.spinner("Loading models and data, please wait..."):
221
  load_image_features()
@@ -430,7 +436,7 @@ def visualize_gradcam(image):
430
 
431
  header_cols = st.columns([80, 20], vertical_alignment="bottom")
432
  with header_cols[0]:
433
- st.title("Image + query details")
434
  with header_cols[1]:
435
  if st.button("Close"):
436
  st.rerun()
@@ -457,6 +463,8 @@ def visualize_gradcam(image):
457
  st.session_state.search_field_value
458
  )
459
 
 
 
460
  with st.spinner("Calculating..."):
461
  # info_text = st.text("Calculating activation regions...")
462
 
@@ -743,6 +751,7 @@ with controls[3]:
743
  key="uploaded_image",
744
  label_visibility="collapsed",
745
  on_change=vis_uploaded_image,
 
746
  )
747
 
748
 
@@ -777,7 +786,9 @@ for image_id in batch:
777
  <div>""",
778
  unsafe_allow_html=True,
779
  )
780
- if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
 
 
781
  st.button(
782
  "Explain this",
783
  on_click=vis_known_image,
@@ -785,4 +796,6 @@ for image_id in batch:
785
  use_container_width=True,
786
  key=image_id,
787
  )
 
 
788
  col = (col + 1) % row_size
 
27
 
28
  from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
+ RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only
31
 
32
  MAX_IMG_WIDTH = 500
33
  MAX_IMG_HEIGHT = 800
 
110
 
111
 
112
  def string_search():
113
+ st.session_state.disable_uploader = (
114
+ RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
115
+ )
116
+
117
  if "search_field_value" in st.session_state:
118
  clip_search(st.session_state.search_field_value)
119
 
 
183
  ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
184
  ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
185
 
186
+ st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
187
+ ja_model_path, device=device, jit=False
188
+ )
 
189
 
190
  st.session_state.ja_model = AutoModel.from_pretrained(
191
  ja_model_name, trust_remote_code=True
 
219
  st.session_state.search_image_ids = []
220
  st.session_state.search_image_scores = {}
221
  st.session_state.text_table_df = None
222
+ st.session_state.disable_uploader = (
223
+ RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
224
+ )
225
 
226
  with st.spinner("Loading models and data, please wait..."):
227
  load_image_features()
 
436
 
437
  header_cols = st.columns([80, 20], vertical_alignment="bottom")
438
  with header_cols[0]:
439
+ st.title("Image + query activation gradients")
440
  with header_cols[1]:
441
  if st.button("Close"):
442
  st.rerun()
 
463
  st.session_state.search_field_value
464
  )
465
 
466
+ st.image(image)
467
+
468
  with st.spinner("Calculating..."):
469
  # info_text = st.text("Calculating activation regions...")
470
 
 
751
  key="uploaded_image",
752
  label_visibility="collapsed",
753
  on_change=vis_uploaded_image,
754
+ disabled=st.session_state.disable_uploader,
755
  )
756
 
757
 
 
786
  <div>""",
787
  unsafe_allow_html=True,
788
  )
789
+ if not (
790
+ RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
791
+ ):
792
  st.button(
793
  "Explain this",
794
  on_click=vis_known_image,
 
796
  use_container_width=True,
797
  key=image_id,
798
  )
799
+ else:
800
+ st.empty()
801
  col = (col + 1) % row_size