broadwell commited on
Commit
f5b714f
1 Parent(s): ff0ced4

Legacy ResNet CAM visualization functionality

Browse files
Files changed (2) hide show
  1. app.py +287 -204
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,7 +2,8 @@ from base64 import b64encode
2
  from io import BytesIO
3
  from math import ceil
4
 
5
- from multilingual_clip import pt_multilingual_clip
 
6
  import numpy as np
7
  import pandas as pd
8
  from PIL import Image
@@ -10,13 +11,21 @@ import requests
10
  import streamlit as st
11
  import torch
12
  from torchvision.transforms import ToPILImage
13
- from transformers import AutoTokenizer, AutoModel
14
 
15
  from CLIP_Explainability.clip_ import load, tokenize
 
 
 
 
 
16
  from CLIP_Explainability.vit_cam import (
17
- interpret_vit,
18
  vit_perword_relevance,
19
- ) # , interpret_vit_overlapped
 
 
 
20
 
21
  MAX_IMG_WIDTH = 500
22
  MAX_IMG_HEIGHT = 800
@@ -40,17 +49,20 @@ def find_best_matches(text_features, image_features, image_ids):
40
  def encode_search_query(search_query, model_type):
41
  with torch.no_grad():
42
  # Encode and normalize the search query using the multilingual model
43
- if model_type == "M-CLIP (multiple languages)":
44
  text_encoded = st.session_state.ml_model.forward(
45
  search_query, st.session_state.ml_tokenizer
46
  )
47
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
48
- else: # model_type == "J-CLIP (日本語 only)"
49
  t_text = st.session_state.ja_tokenizer(
50
  search_query, padding=True, return_tensors="pt"
51
  )
52
  text_encoded = st.session_state.ja_model.get_text_features(**t_text)
53
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
 
 
 
54
 
55
  # Retrieve the feature vector
56
  return text_encoded
@@ -69,18 +81,24 @@ def clip_search(search_query):
69
  # similarities = list((text_features @ photo_features.T).squeeze(0))
70
 
71
  # Sort the photos by their similarity score
72
- if model_type == "M-CLIP (multiple languages)":
73
  matches = find_best_matches(
74
  text_features,
75
  st.session_state.ml_image_features,
76
  st.session_state.image_ids,
77
  )
78
- else: # model_type == "J-CLIP (日本語 only)"
79
  matches = find_best_matches(
80
  text_features,
81
  st.session_state.ja_image_features,
82
  st.session_state.image_ids,
83
  )
 
 
 
 
 
 
84
 
85
  st.session_state.search_image_ids = [match[0] for match in matches]
86
  st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
@@ -96,21 +114,26 @@ def load_image_features():
96
  if st.session_state.vision_mode == "tiled":
97
  ml_image_features = np.load("./image_features/tiled_ml_features.npy")
98
  ja_image_features = np.load("./image_features/tiled_ja_features.npy")
 
99
  elif st.session_state.vision_mode == "stretched":
100
  ml_image_features = np.load("./image_features/resized_ml_features.npy")
101
  ja_image_features = np.load("./image_features/resized_ja_features.npy")
 
102
  else: # st.session_state.vision_mode == "cropped":
103
  ml_image_features = np.load("./image_features/cropped_ml_features.npy")
104
  ja_image_features = np.load("./image_features/cropped_ja_features.npy")
 
105
 
106
  # Convert features to Tensors: Float32 on CPU and Float16 on GPU
107
  device = st.session_state.device
108
  if device == "cpu":
109
  ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
110
  ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
 
111
  else:
112
  ml_image_features = torch.from_numpy(ml_image_features).to(device)
113
  ja_image_features = torch.from_numpy(ja_image_features).to(device)
 
114
 
115
  st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
116
  dim=-1, keepdim=True
@@ -118,6 +141,9 @@ def load_image_features():
118
  st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
119
  dim=-1, keepdim=True
120
  )
 
 
 
121
 
122
  string_search()
123
 
@@ -129,10 +155,11 @@ def init():
129
  st.session_state.device = device
130
 
131
  # Load the open CLIP models
132
- ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
133
- ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
134
 
135
  with st.spinner("Loading models and data, please wait..."):
 
 
 
136
  st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
137
  ml_model_path, device=device, jit=False
138
  )
@@ -156,6 +183,17 @@ def init():
156
  ja_model_name, trust_remote_code=True
157
  )
158
 
 
 
 
 
 
 
 
 
 
 
 
159
  # Load the image IDs
160
  st.session_state.images_info = pd.read_csv("./metadata.csv")
161
  st.session_state.images_info.set_index("filename", inplace=True)
@@ -163,7 +201,7 @@ def init():
163
  with open("./images_list.txt", "r", encoding="utf-8") as images_list:
164
  st.session_state.image_ids = list(images_list.read().strip().split("\n"))
165
 
166
- st.session_state.active_model = "M-CLIP (multiple languages)"
167
 
168
  st.session_state.vision_mode = "tiled"
169
  st.session_state.search_image_ids = []
@@ -194,195 +232,223 @@ def visualize_gradcam(viz_image_id):
194
  f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
195
  )
196
 
197
- # with st.spinner("Calculating..."):
198
- info_text = st.text("Calculating activation regions...")
199
-
200
- image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
201
- image_response = requests.get(image_url)
202
- image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"])
203
- image = image.convert("RGB")
204
-
205
- img_dim = 224
206
- if st.session_state.active_model == "M-CLIP (multiple languages)":
207
- img_dim = 240
208
-
209
- orig_img_dims = image.size
210
-
211
- ##### If the features are based on tiled image slices
212
- tile_behavior = None
213
-
214
- if st.session_state.vision_mode == "tiled":
215
- scaled_dims = [img_dim, img_dim]
216
-
217
- if orig_img_dims[0] > orig_img_dims[1]:
218
- scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
219
- if scale_ratio > 1:
220
- scaled_dims = [scale_ratio * img_dim, img_dim]
221
- tile_behavior = "width"
222
- elif orig_img_dims[0] < orig_img_dims[1]:
223
- scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
224
- if scale_ratio > 1:
225
- scaled_dims = [img_dim, scale_ratio * img_dim]
226
- tile_behavior = "height"
227
-
228
- resized_image = image.resize(scaled_dims, Image.LANCZOS)
229
-
230
- if tile_behavior == "width":
231
- image_tiles = []
232
- for x in range(0, scale_ratio):
233
- box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
234
- image_tiles.append(resized_image.crop(box))
235
-
236
- elif tile_behavior == "height":
237
- image_tiles = []
238
- for y in range(0, scale_ratio):
239
- box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
240
- image_tiles.append(resized_image.crop(box))
241
-
242
- else:
243
- image_tiles = [resized_image]
244
-
245
- elif st.session_state.vision_mode == "stretched":
246
- image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- else: # vision_mode == "cropped"
249
- if orig_img_dims[0] > orig_img_dims[1]:
250
- scale_factor = orig_img_dims[0] / orig_img_dims[1]
251
- resized_img_dims = (round(scale_factor * img_dim), img_dim)
252
  resized_img = image.resize(resized_img_dims)
253
- elif orig_img_dims[0] < orig_img_dims[1]:
254
- scale_factor = orig_img_dims[1] / orig_img_dims[0]
255
- resized_img_dims = (img_dim, round(scale_factor * img_dim))
256
- else:
257
- resized_img_dims = (img_dim, img_dim)
258
-
259
- resized_img = image.resize(resized_img_dims)
260
-
261
- left = round((resized_img_dims[0] - img_dim) / 2)
262
- top = round((resized_img_dims[1] - img_dim) / 2)
263
- x_right = round(resized_img_dims[0] - img_dim) - left
264
- x_bottom = round(resized_img_dims[1] - img_dim) - top
265
- right = resized_img_dims[0] - x_right
266
- bottom = resized_img_dims[1] - x_bottom
267
 
268
- # Crop the center of the image
269
- image_tiles = [resized_img.crop((left, top, right, bottom))]
 
 
 
 
270
 
271
- image_visualizations = []
 
272
 
273
- if st.session_state.active_model == "M-CLIP (multiple languages)":
274
- # Sometimes used for token importance viz
275
- tokenized_text = st.session_state.ml_tokenizer.tokenize(
276
- st.session_state.search_field_value
277
- )
278
-
279
- text_features = st.session_state.ml_model.forward(
280
- st.session_state.search_field_value, st.session_state.ml_tokenizer
281
- )
282
 
283
- image_model = st.session_state.ml_image_model
284
- # tokenize = st.session_state.ml_tokenizer.tokenize
285
- image_model.eval()
286
-
287
- for altered_image in image_tiles:
288
- image_model.zero_grad()
289
-
290
- p_image = (
291
- st.session_state.ml_image_preprocess(altered_image)
292
- .unsqueeze(0)
293
- .to(st.session_state.device)
294
  )
295
 
296
- vis_t = interpret_vit(
297
- p_image.type(st.session_state.ml_image_model.dtype),
298
- text_features,
299
- image_model.visual,
300
- st.session_state.device,
301
- img_dim=img_dim,
302
  )
303
 
304
- image_visualizations.append(vis_t)
305
-
306
- else:
307
- # Sometimes used for token importance viz
308
- tokenized_text = st.session_state.ja_tokenizer.tokenize(
309
- st.session_state.search_field_value
310
- )
311
-
312
- t_text = st.session_state.ja_tokenizer(
313
- st.session_state.search_field_value, return_tensors="pt"
314
- )
315
- text_features = st.session_state.ja_model.get_text_features(**t_text)
316
-
317
- image_model = st.session_state.ja_image_model
318
- image_model.eval()
319
-
320
- for altered_image in image_tiles:
321
- image_model.zero_grad()
 
 
 
 
 
 
322
 
323
- p_image = (
324
- st.session_state.ja_image_preprocess(altered_image)
325
- .unsqueeze(0)
326
- .to(st.session_state.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  )
328
 
329
- vis_t = interpret_vit(
330
- p_image.type(st.session_state.ja_image_model.dtype),
331
- text_features,
332
- image_model.visual,
333
- st.session_state.device,
334
- img_dim=img_dim,
335
  )
336
 
337
- image_visualizations.append(vis_t)
338
 
339
- transform = ToPILImage()
 
 
 
 
 
340
 
341
- vis_images = [transform(vis_t) for vis_t in image_visualizations]
 
 
 
 
 
 
 
342
 
343
- if st.session_state.vision_mode == "cropped":
344
- resized_img.paste(vis_images[0], (left, top))
345
- vis_images = [resized_img]
346
 
347
- if orig_img_dims[0] > orig_img_dims[1]:
348
- scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
349
- scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
350
- else:
351
- scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
352
- scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
353
 
354
- if tile_behavior == "width":
355
- vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
356
- for x, v_img in enumerate(vis_images):
357
- vis_image.paste(v_img, (x * img_dim, 0))
358
- st.session_state.activations_image = vis_image.resize(scaled_dims)
359
 
360
- elif tile_behavior == "height":
361
- vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
362
- for y, v_img in enumerate(vis_images):
363
- vis_image.paste(v_img, (0, y * img_dim))
364
- st.session_state.activations_image = vis_image.resize(scaled_dims)
365
 
366
- else:
367
- st.session_state.activations_image = vis_images[0].resize(scaled_dims)
 
 
 
 
368
 
369
- image_io = BytesIO()
370
- st.session_state.activations_image.save(image_io, "PNG")
371
- dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii")
 
 
372
 
373
- st.html(
374
- f"""<div style="display: flex; flex-direction: column; align-items: center;">
375
- <img src="{dataurl}" />
376
- </div>"""
377
- )
378
 
379
- info_text.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
- tokenized_text = [tok for tok in tokenized_text if tok != "▁"]
 
 
 
382
 
383
  if (
384
  len(tokenized_text) > 1
385
- and len(tokenized_text) < 15
386
  and st.button(
387
  "Calculate text importance (may take some time)",
388
  )
@@ -394,17 +460,31 @@ def visualize_gradcam(viz_image_id):
394
  progress_bar = st.progress(0.0, text=progress_text)
395
 
396
  for t, tok in enumerate(tokenized_text):
397
- token = tok.replace("▁", "")
398
- word_rel = vit_perword_relevance(
399
- p_image,
400
- st.session_state.search_field_value,
401
- image_model,
402
- tokenize,
403
- st.session_state.device,
404
- token,
405
- data_only=True,
406
- img_dim=img_dim,
407
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  avg_score = np.mean(word_rel)
409
  if avg_score == 0 or np.isnan(avg_score):
410
  continue
@@ -429,7 +509,7 @@ def visualize_gradcam(viz_image_id):
429
 
430
 
431
  def format_vision_mode(mode_stub):
432
- return f"Vision mode: {mode_stub.capitalize()}"
433
 
434
 
435
  @st.dialog(" ", width="large")
@@ -469,7 +549,7 @@ st.markdown(
469
  unsafe_allow_html=True,
470
  )
471
 
472
- search_row = st.columns([45, 5, 1, 15, 1, 8, 25], vertical_alignment="center")
473
  with search_row[0]:
474
  search_field = st.text_input(
475
  label="search",
@@ -483,10 +563,10 @@ with search_row[1]:
483
  "Search", on_click=string_search, use_container_width=True, type="primary"
484
  )
485
  with search_row[2]:
486
- st.empty()
487
  with search_row[3]:
488
  st.selectbox(
489
- "Vision mode:",
490
  options=["tiled", "stretched", "cropped"],
491
  key="vision_mode",
492
  help="How to consider images that aren't square",
@@ -497,56 +577,59 @@ with search_row[3]:
497
  with search_row[4]:
498
  st.empty()
499
  with search_row[5]:
500
- st.markdown("**CLIP Model:**")
501
  with search_row[6]:
502
- st.radio(
503
- "CLIP Model",
504
- options=["M-CLIP (multiple languages)", "J-CLIP (日本語)"],
 
 
 
 
505
  key="active_model",
506
  on_change=string_search,
507
- horizontal=True,
508
  label_visibility="collapsed",
509
  )
510
 
511
  canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
512
  with canned_searches[0]:
513
  st.markdown("**Suggested searches:**")
514
- if st.session_state.active_model == "M-CLIP (multiple languages)":
515
  with canned_searches[1]:
516
  st.button(
517
- "negative space",
518
  on_click=clip_search,
519
- args=["negative space"],
520
  use_container_width=True,
521
  )
522
  with canned_searches[2]:
523
- st.button("", on_click=clip_search, args=[""], use_container_width=True)
524
  with canned_searches[3]:
525
- st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
526
  with canned_searches[4]:
527
  st.button(
528
- "αρνητικός χώρος",
529
  on_click=clip_search,
530
- args=["αρνητικός χώρος"],
531
  use_container_width=True,
532
  )
533
  else:
534
  with canned_searches[1]:
535
  st.button(
536
- "",
537
  on_click=clip_search,
538
- args=[""],
539
  use_container_width=True,
540
  )
541
  with canned_searches[2]:
542
- st.button("", on_click=clip_search, args=[""], use_container_width=True)
543
  with canned_searches[3]:
544
- st.button("", on_click=clip_search, args=[""], use_container_width=True)
545
  with canned_searches[4]:
546
  st.button(
547
- "花に酔えり 羽織着て刀 さす女",
548
  on_click=clip_search,
549
- args=["花に酔えり 羽織着て刀 さす女"],
550
  use_container_width=True,
551
  )
552
 
 
2
  from io import BytesIO
3
  from math import ceil
4
 
5
+ import clip
6
+ from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip
7
  import numpy as np
8
  import pandas as pd
9
  from PIL import Image
 
11
  import streamlit as st
12
  import torch
13
  from torchvision.transforms import ToPILImage
14
+ from transformers import AutoTokenizer, AutoModel, BertTokenizer
15
 
16
  from CLIP_Explainability.clip_ import load, tokenize
17
+ from CLIP_Explainability.rn_cam import (
18
+ # interpret_rn,
19
+ interpret_rn_overlapped,
20
+ rn_perword_relevance,
21
+ )
22
  from CLIP_Explainability.vit_cam import (
23
+ # interpret_vit,
24
  vit_perword_relevance,
25
+ interpret_vit_overlapped,
26
+ )
27
+
28
+ from pytorch_grad_cam.grad_cam import GradCAM
29
 
30
  MAX_IMG_WIDTH = 500
31
  MAX_IMG_HEIGHT = 800
 
49
  def encode_search_query(search_query, model_type):
50
  with torch.no_grad():
51
  # Encode and normalize the search query using the multilingual model
52
+ if model_type == "M-CLIP (multilingual ViT)":
53
  text_encoded = st.session_state.ml_model.forward(
54
  search_query, st.session_state.ml_tokenizer
55
  )
56
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
57
+ elif model_type == "J-CLIP (日本語 ViT)":
58
  t_text = st.session_state.ja_tokenizer(
59
  search_query, padding=True, return_tensors="pt"
60
  )
61
  text_encoded = st.session_state.ja_model.get_text_features(**t_text)
62
  text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
63
+ else: # model_type == legacy
64
+ text_encoded = st.session_state.rn_model(search_query)
65
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
66
 
67
  # Retrieve the feature vector
68
  return text_encoded
 
81
  # similarities = list((text_features @ photo_features.T).squeeze(0))
82
 
83
  # Sort the photos by their similarity score
84
+ if model_type == "M-CLIP (multilingual ViT)":
85
  matches = find_best_matches(
86
  text_features,
87
  st.session_state.ml_image_features,
88
  st.session_state.image_ids,
89
  )
90
+ elif model_type == "J-CLIP (日本語 ViT)":
91
  matches = find_best_matches(
92
  text_features,
93
  st.session_state.ja_image_features,
94
  st.session_state.image_ids,
95
  )
96
+ else: # model_type == legacy
97
+ matches = find_best_matches(
98
+ text_features,
99
+ st.session_state.rn_image_features,
100
+ st.session_state.image_ids,
101
+ )
102
 
103
  st.session_state.search_image_ids = [match[0] for match in matches]
104
  st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
 
114
  if st.session_state.vision_mode == "tiled":
115
  ml_image_features = np.load("./image_features/tiled_ml_features.npy")
116
  ja_image_features = np.load("./image_features/tiled_ja_features.npy")
117
+ rn_image_features = np.load("./image_features/tiled_rn_features.npy")
118
  elif st.session_state.vision_mode == "stretched":
119
  ml_image_features = np.load("./image_features/resized_ml_features.npy")
120
  ja_image_features = np.load("./image_features/resized_ja_features.npy")
121
+ rn_image_features = np.load("./image_features/resized_rn_features.npy")
122
  else: # st.session_state.vision_mode == "cropped":
123
  ml_image_features = np.load("./image_features/cropped_ml_features.npy")
124
  ja_image_features = np.load("./image_features/cropped_ja_features.npy")
125
+ rn_image_features = np.load("./image_features/cropped_rn_features.npy")
126
 
127
  # Convert features to Tensors: Float32 on CPU and Float16 on GPU
128
  device = st.session_state.device
129
  if device == "cpu":
130
  ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
131
  ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
132
+ rn_image_features = torch.from_numpy(rn_image_features).float().to(device)
133
  else:
134
  ml_image_features = torch.from_numpy(ml_image_features).to(device)
135
  ja_image_features = torch.from_numpy(ja_image_features).to(device)
136
+ rn_image_features = torch.from_numpy(rn_image_features).to(device)
137
 
138
  st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
139
  dim=-1, keepdim=True
 
141
  st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
142
  dim=-1, keepdim=True
143
  )
144
+ st.session_state.rn_image_features = rn_image_features / rn_image_features.norm(
145
+ dim=-1, keepdim=True
146
+ )
147
 
148
  string_search()
149
 
 
155
  st.session_state.device = device
156
 
157
  # Load the open CLIP models
 
 
158
 
159
  with st.spinner("Loading models and data, please wait..."):
160
+ ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
161
+ ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
162
+
163
  st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
164
  ml_model_path, device=device, jit=False
165
  )
 
183
  ja_model_name, trust_remote_code=True
184
  )
185
 
186
+ st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
187
+ clip.load("RN50x4", device=device)
188
+ )
189
+
190
+ st.session_state.rn_model = legacy_multilingual_clip.load_model(
191
+ "M-BERT-Base-69"
192
+ )
193
+ st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
194
+ "bert-base-multilingual-cased"
195
+ )
196
+
197
  # Load the image IDs
198
  st.session_state.images_info = pd.read_csv("./metadata.csv")
199
  st.session_state.images_info.set_index("filename", inplace=True)
 
201
  with open("./images_list.txt", "r", encoding="utf-8") as images_list:
202
  st.session_state.image_ids = list(images_list.read().strip().split("\n"))
203
 
204
+ st.session_state.active_model = "M-CLIP (multilingual ViT)"
205
 
206
  st.session_state.vision_mode = "tiled"
207
  st.session_state.search_image_ids = []
 
232
  f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
233
  )
234
 
235
+ with st.spinner("Calculating..."):
236
+ # info_text = st.text("Calculating activation regions...")
237
+
238
+ image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
239
+ image_response = requests.get(image_url)
240
+ image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"])
241
+ image = image.convert("RGB")
242
+
243
+ img_dim = 224
244
+ if st.session_state.active_model == "M-CLIP (multilingual ViT)":
245
+ img_dim = 240
246
+ elif st.session_state.active_model == "Legacy (multilingual ResNet)":
247
+ img_dim = 288
248
+
249
+ orig_img_dims = image.size
250
+
251
+ ##### If the features are based on tiled image slices
252
+ tile_behavior = None
253
+
254
+ if st.session_state.vision_mode == "tiled":
255
+ scaled_dims = [img_dim, img_dim]
256
+
257
+ if orig_img_dims[0] > orig_img_dims[1]:
258
+ scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
259
+ if scale_ratio > 1:
260
+ scaled_dims = [scale_ratio * img_dim, img_dim]
261
+ tile_behavior = "width"
262
+ elif orig_img_dims[0] < orig_img_dims[1]:
263
+ scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
264
+ if scale_ratio > 1:
265
+ scaled_dims = [img_dim, scale_ratio * img_dim]
266
+ tile_behavior = "height"
267
+
268
+ resized_image = image.resize(scaled_dims, Image.LANCZOS)
269
+
270
+ if tile_behavior == "width":
271
+ image_tiles = []
272
+ for x in range(0, scale_ratio):
273
+ box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
274
+ image_tiles.append(resized_image.crop(box))
275
+
276
+ elif tile_behavior == "height":
277
+ image_tiles = []
278
+ for y in range(0, scale_ratio):
279
+ box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
280
+ image_tiles.append(resized_image.crop(box))
281
+
282
+ else:
283
+ image_tiles = [resized_image]
284
+
285
+ elif st.session_state.vision_mode == "stretched":
286
+ image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
287
+
288
+ else: # vision_mode == "cropped"
289
+ if orig_img_dims[0] > orig_img_dims[1]:
290
+ scale_factor = orig_img_dims[0] / orig_img_dims[1]
291
+ resized_img_dims = (round(scale_factor * img_dim), img_dim)
292
+ resized_img = image.resize(resized_img_dims)
293
+ elif orig_img_dims[0] < orig_img_dims[1]:
294
+ scale_factor = orig_img_dims[1] / orig_img_dims[0]
295
+ resized_img_dims = (img_dim, round(scale_factor * img_dim))
296
+ else:
297
+ resized_img_dims = (img_dim, img_dim)
298
 
 
 
 
 
299
  resized_img = image.resize(resized_img_dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ left = round((resized_img_dims[0] - img_dim) / 2)
302
+ top = round((resized_img_dims[1] - img_dim) / 2)
303
+ x_right = round(resized_img_dims[0] - img_dim) - left
304
+ x_bottom = round(resized_img_dims[1] - img_dim) - top
305
+ right = resized_img_dims[0] - x_right
306
+ bottom = resized_img_dims[1] - x_bottom
307
 
308
+ # Crop the center of the image
309
+ image_tiles = [resized_img.crop((left, top, right, bottom))]
310
 
311
+ image_visualizations = []
 
 
 
 
 
 
 
 
312
 
313
+ if st.session_state.active_model == "M-CLIP (multilingual ViT)":
314
+ # Sometimes used for token importance viz
315
+ tokenized_text = st.session_state.ml_tokenizer.tokenize(
316
+ st.session_state.search_field_value
 
 
 
 
 
 
 
317
  )
318
 
319
+ text_features = st.session_state.ml_model.forward(
320
+ st.session_state.search_field_value, st.session_state.ml_tokenizer
 
 
 
 
321
  )
322
 
323
+ image_model = st.session_state.ml_image_model
324
+
325
+ for altered_image in image_tiles:
326
+ p_image = (
327
+ st.session_state.ml_image_preprocess(altered_image)
328
+ .unsqueeze(0)
329
+ .to(st.session_state.device)
330
+ )
331
+
332
+ vis_t = interpret_vit_overlapped(
333
+ p_image.type(st.session_state.ml_image_model.dtype),
334
+ text_features,
335
+ image_model.visual,
336
+ st.session_state.device,
337
+ img_dim=img_dim,
338
+ )
339
+
340
+ image_visualizations.append(vis_t)
341
+
342
+ elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
343
+ # Sometimes used for token importance viz
344
+ tokenized_text = st.session_state.ja_tokenizer.tokenize(
345
+ st.session_state.search_field_value
346
+ )
347
 
348
+ t_text = st.session_state.ja_tokenizer(
349
+ st.session_state.search_field_value, return_tensors="pt"
350
+ )
351
+ text_features = st.session_state.ja_model.get_text_features(**t_text)
352
+
353
+ image_model = st.session_state.ja_image_model
354
+
355
+ for altered_image in image_tiles:
356
+ p_image = (
357
+ st.session_state.ja_image_preprocess(altered_image)
358
+ .unsqueeze(0)
359
+ .to(st.session_state.device)
360
+ )
361
+
362
+ vis_t = interpret_vit_overlapped(
363
+ p_image.type(st.session_state.ja_image_model.dtype),
364
+ text_features,
365
+ image_model.visual,
366
+ st.session_state.device,
367
+ img_dim=img_dim,
368
+ )
369
+
370
+ image_visualizations.append(vis_t)
371
+
372
+ else: # st.session_state.active_model == Legacy
373
+ # Sometimes used for token importance viz
374
+ tokenized_text = st.session_state.rn_tokenizer.tokenize(
375
+ st.session_state.search_field_value
376
  )
377
 
378
+ text_features = st.session_state.rn_model(
379
+ st.session_state.search_field_value
 
 
 
 
380
  )
381
 
382
+ image_model = st.session_state.rn_image_model
383
 
384
+ for altered_image in image_tiles:
385
+ p_image = (
386
+ st.session_state.rn_image_preprocess(altered_image)
387
+ .unsqueeze(0)
388
+ .to(st.session_state.device)
389
+ )
390
 
391
+ vis_t = interpret_rn_overlapped(
392
+ p_image.type(st.session_state.rn_image_model.dtype),
393
+ text_features,
394
+ image_model.visual,
395
+ GradCAM,
396
+ st.session_state.device,
397
+ img_dim=img_dim,
398
+ )
399
 
400
+ image_visualizations.append(vis_t)
 
 
401
 
402
+ transform = ToPILImage()
 
 
 
 
 
403
 
404
+ vis_images = [transform(vis_t) for vis_t in image_visualizations]
 
 
 
 
405
 
406
+ if st.session_state.vision_mode == "cropped":
407
+ resized_img.paste(vis_images[0], (left, top))
408
+ vis_images = [resized_img]
 
 
409
 
410
+ if orig_img_dims[0] > orig_img_dims[1]:
411
+ scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
412
+ scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
413
+ else:
414
+ scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
415
+ scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
416
 
417
+ if tile_behavior == "width":
418
+ vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
419
+ for x, v_img in enumerate(vis_images):
420
+ vis_image.paste(v_img, (x * img_dim, 0))
421
+ st.session_state.activations_image = vis_image.resize(scaled_dims)
422
 
423
+ elif tile_behavior == "height":
424
+ vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
425
+ for y, v_img in enumerate(vis_images):
426
+ vis_image.paste(v_img, (0, y * img_dim))
427
+ st.session_state.activations_image = vis_image.resize(scaled_dims)
428
 
429
+ else:
430
+ st.session_state.activations_image = vis_images[0].resize(scaled_dims)
431
+
432
+ image_io = BytesIO()
433
+ st.session_state.activations_image.save(image_io, "PNG")
434
+ dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
435
+ "ascii"
436
+ )
437
+
438
+ st.html(
439
+ f"""<div style="display: flex; flex-direction: column; align-items: center;">
440
+ <img src="{dataurl}" />
441
+ </div>"""
442
+ )
443
 
444
+ tokenized_text = [tok.replace("▁", "") for tok in tokenized_text if tok != "▁"]
445
+ tokenized_text = [
446
+ tok for tok in tokenized_text if tok not in ["s", "ed", "a", "the", "an", "ing"]
447
+ ]
448
 
449
  if (
450
  len(tokenized_text) > 1
451
+ and len(tokenized_text) < 25
452
  and st.button(
453
  "Calculate text importance (may take some time)",
454
  )
 
460
  progress_bar = st.progress(0.0, text=progress_text)
461
 
462
  for t, tok in enumerate(tokenized_text):
463
+ token = tok
464
+
465
+ if st.session_state.active_model == "Legacy (multilingual ResNet)":
466
+ word_rel = rn_perword_relevance(
467
+ p_image,
468
+ st.session_state.search_field_value,
469
+ image_model,
470
+ tokenize,
471
+ GradCAM,
472
+ st.session_state.device,
473
+ token,
474
+ data_only=True,
475
+ img_dim=img_dim,
476
+ )
477
+ else:
478
+ word_rel = vit_perword_relevance(
479
+ p_image,
480
+ st.session_state.search_field_value,
481
+ image_model,
482
+ tokenize,
483
+ st.session_state.device,
484
+ token,
485
+ data_only=True,
486
+ img_dim=img_dim,
487
+ )
488
  avg_score = np.mean(word_rel)
489
  if avg_score == 0 or np.isnan(avg_score):
490
  continue
 
509
 
510
 
511
  def format_vision_mode(mode_stub):
512
+ return mode_stub.capitalize()
513
 
514
 
515
  @st.dialog(" ", width="large")
 
549
  unsafe_allow_html=True,
550
  )
551
 
552
+ search_row = st.columns([45, 8, 8, 10, 1, 8, 20], vertical_alignment="center")
553
  with search_row[0]:
554
  search_field = st.text_input(
555
  label="search",
 
563
  "Search", on_click=string_search, use_container_width=True, type="primary"
564
  )
565
  with search_row[2]:
566
+ st.markdown("**Vision mode:**")
567
  with search_row[3]:
568
  st.selectbox(
569
+ "Vision mode",
570
  options=["tiled", "stretched", "cropped"],
571
  key="vision_mode",
572
  help="How to consider images that aren't square",
 
577
  with search_row[4]:
578
  st.empty()
579
  with search_row[5]:
580
+ st.markdown("**CLIP model:**")
581
  with search_row[6]:
582
+ st.selectbox(
583
+ "CLIP Model:",
584
+ options=[
585
+ "M-CLIP (multilingual ViT)",
586
+ "J-CLIP (日本語 ViT)",
587
+ "Legacy (multilingual ResNet)",
588
+ ],
589
  key="active_model",
590
  on_change=string_search,
 
591
  label_visibility="collapsed",
592
  )
593
 
594
  canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
595
  with canned_searches[0]:
596
  st.markdown("**Suggested searches:**")
597
+ if st.session_state.active_model == "J-CLIP (日本語 ViT)":
598
  with canned_searches[1]:
599
  st.button(
600
+ "",
601
  on_click=clip_search,
602
+ args=[""],
603
  use_container_width=True,
604
  )
605
  with canned_searches[2]:
606
+ st.button("", on_click=clip_search, args=[""], use_container_width=True)
607
  with canned_searches[3]:
608
+ st.button("", on_click=clip_search, args=[""], use_container_width=True)
609
  with canned_searches[4]:
610
  st.button(
611
+ "花に酔えり 羽織着て刀 さす女",
612
  on_click=clip_search,
613
+ args=["花に酔えり 羽織着て刀 さす女"],
614
  use_container_width=True,
615
  )
616
  else:
617
  with canned_searches[1]:
618
  st.button(
619
+ "negative space",
620
  on_click=clip_search,
621
+ args=["negative space"],
622
  use_container_width=True,
623
  )
624
  with canned_searches[2]:
625
+ st.button("", on_click=clip_search, args=[""], use_container_width=True)
626
  with canned_searches[3]:
627
+ st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
628
  with canned_searches[4]:
629
  st.button(
630
+ "αρνητικός χώρος",
631
  on_click=clip_search,
632
+ args=["αρνητικός χώρος"],
633
  use_container_width=True,
634
  )
635
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  ftfy==6.2.0
2
  multilingual_clip==1.0.10
3
  numpy==1.26
 
1
+ clip==1.0
2
  ftfy==6.2.0
3
  multilingual_clip==1.0.10
4
  numpy==1.26