broadwell commited on
Commit
81fa03b
1 Parent(s): 0f31c51

Load all models

Browse files
Files changed (1) hide show
  1. CLIP_Explainability/app.py +788 -0
CLIP_Explainability/app.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+ from io import BytesIO
3
+ from math import ceil
4
+
5
+ import clip
6
+ from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip
7
+ import numpy as np
8
+ import pandas as pd
9
+ from PIL import Image
10
+ import requests
11
+ import streamlit as st
12
+ import torch
13
+ from torchvision.transforms import ToPILImage
14
+ from transformers import AutoTokenizer, AutoModel, BertTokenizer
15
+
16
+ from CLIP_Explainability.clip_ import load, tokenize
17
+ from CLIP_Explainability.rn_cam import (
18
+ # interpret_rn,
19
+ interpret_rn_overlapped,
20
+ rn_perword_relevance,
21
+ )
22
+ from CLIP_Explainability.vit_cam import (
23
+ # interpret_vit,
24
+ vit_perword_relevance,
25
+ interpret_vit_overlapped,
26
+ )
27
+
28
+ from pytorch_grad_cam.grad_cam import GradCAM
29
+
30
+ RUN_LITE = False # Load vision model for CAM viz explainability for M-CLIP only
31
+
32
+ MAX_IMG_WIDTH = 500
33
+ MAX_IMG_HEIGHT = 800
34
+
35
+ st.set_page_config(layout="wide")
36
+
37
+
38
+ # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
39
+ def find_best_matches(text_features, image_features, image_ids):
40
+ # Compute the similarity between the search query and each image using the Cosine similarity
41
+ similarities = (image_features @ text_features.T).squeeze(1)
42
+
43
+ # Sort the images by their similarity score
44
+ best_image_idx = (-similarities).argsort()
45
+
46
+ # Return the image IDs of the best matches
47
+ return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
48
+
49
+
50
+ # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
51
+ def encode_search_query(search_query, model_type):
52
+ with torch.no_grad():
53
+ # Encode and normalize the search query using the multilingual model
54
+ if model_type == "M-CLIP (multilingual ViT)":
55
+ text_encoded = st.session_state.ml_model.forward(
56
+ search_query, st.session_state.ml_tokenizer
57
+ )
58
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
59
+ elif model_type == "J-CLIP (日本語 ViT)":
60
+ t_text = st.session_state.ja_tokenizer(
61
+ search_query,
62
+ padding=True,
63
+ return_tensors="pt",
64
+ device=st.session_state.device,
65
+ )
66
+ text_encoded = st.session_state.ja_model.get_text_features(**t_text)
67
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
68
+ else: # model_type == legacy
69
+ text_encoded = st.session_state.rn_model(search_query)
70
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
71
+
72
+ # Retrieve the feature vector
73
+ return text_encoded.to(st.session_state.device)
74
+
75
+
76
+ def clip_search(search_query):
77
+ if st.session_state.search_field_value != search_query:
78
+ st.session_state.search_field_value = search_query
79
+
80
+ model_type = st.session_state.active_model
81
+
82
+ if len(search_query) >= 1:
83
+ text_features = encode_search_query(search_query, model_type)
84
+
85
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
86
+ # similarities = list((text_features @ photo_features.T).squeeze(0))
87
+
88
+ # Sort the photos by their similarity score
89
+ if model_type == "M-CLIP (multilingual ViT)":
90
+ matches = find_best_matches(
91
+ text_features,
92
+ st.session_state.ml_image_features,
93
+ st.session_state.image_ids,
94
+ )
95
+ elif model_type == "J-CLIP (日本語 ViT)":
96
+ matches = find_best_matches(
97
+ text_features,
98
+ st.session_state.ja_image_features,
99
+ st.session_state.image_ids,
100
+ )
101
+ else: # model_type == legacy
102
+ matches = find_best_matches(
103
+ text_features,
104
+ st.session_state.rn_image_features,
105
+ st.session_state.image_ids,
106
+ )
107
+
108
+ st.session_state.search_image_ids = [match[0] for match in matches]
109
+ st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
110
+
111
+
112
+ def string_search():
113
+ if "search_field_value" in st.session_state:
114
+ clip_search(st.session_state.search_field_value)
115
+
116
+
117
+ def load_image_features():
118
+ # Load the image feature vectors
119
+ if st.session_state.vision_mode == "tiled":
120
+ ml_image_features = np.load("./image_features/tiled_ml_features.npy")
121
+ ja_image_features = np.load("./image_features/tiled_ja_features.npy")
122
+ rn_image_features = np.load("./image_features/tiled_rn_features.npy")
123
+ elif st.session_state.vision_mode == "stretched":
124
+ ml_image_features = np.load("./image_features/resized_ml_features.npy")
125
+ ja_image_features = np.load("./image_features/resized_ja_features.npy")
126
+ rn_image_features = np.load("./image_features/resized_rn_features.npy")
127
+ else: # st.session_state.vision_mode == "cropped":
128
+ ml_image_features = np.load("./image_features/cropped_ml_features.npy")
129
+ ja_image_features = np.load("./image_features/cropped_ja_features.npy")
130
+ rn_image_features = np.load("./image_features/cropped_rn_features.npy")
131
+
132
+ # Convert features to Tensors: Float32 on CPU and Float16 on GPU
133
+ device = st.session_state.device
134
+ if device == "cpu":
135
+ ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
136
+ ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
137
+ rn_image_features = torch.from_numpy(rn_image_features).float().to(device)
138
+ else:
139
+ ml_image_features = torch.from_numpy(ml_image_features).to(device)
140
+ ja_image_features = torch.from_numpy(ja_image_features).to(device)
141
+ rn_image_features = torch.from_numpy(rn_image_features).to(device)
142
+
143
+ st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
144
+ dim=-1, keepdim=True
145
+ )
146
+ st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
147
+ dim=-1, keepdim=True
148
+ )
149
+ st.session_state.rn_image_features = rn_image_features / rn_image_features.norm(
150
+ dim=-1, keepdim=True
151
+ )
152
+
153
+ string_search()
154
+
155
+
156
+ def init():
157
+ st.session_state.current_page = 1
158
+
159
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
160
+ device = "cpu"
161
+
162
+ st.session_state.device = device
163
+
164
+ # Load the open CLIP models
165
+
166
+ with st.spinner("Loading models and data, please wait..."):
167
+ ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
168
+ ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
169
+
170
+ st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
171
+ ml_model_path, device=device, jit=False
172
+ )
173
+
174
+ st.session_state.ml_model = (
175
+ pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
176
+ ).to(device)
177
+ st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
178
+
179
+ ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
180
+ ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
181
+
182
+ if not RUN_LITE:
183
+ st.session_state.ja_image_model, st.session_state.ja_image_preprocess = (
184
+ load(ja_model_path, device=device, jit=False)
185
+ )
186
+
187
+ st.session_state.ja_model = AutoModel.from_pretrained(
188
+ ja_model_name, trust_remote_code=True
189
+ ).to(device)
190
+ st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
191
+ ja_model_name, trust_remote_code=True
192
+ )
193
+
194
+ if not RUN_LITE:
195
+ st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
196
+ clip.load("RN50x4", device=device)
197
+ )
198
+
199
+ st.session_state.rn_model = legacy_multilingual_clip.load_model(
200
+ "M-BERT-Base-69"
201
+ ).to(device)
202
+ st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
203
+ "bert-base-multilingual-cased"
204
+ )
205
+
206
+ # Load the image IDs
207
+ st.session_state.images_info = pd.read_csv("./metadata.csv")
208
+ st.session_state.images_info.set_index("filename", inplace=True)
209
+
210
+ with open("./images_list.txt", "r", encoding="utf-8") as images_list:
211
+ st.session_state.image_ids = list(images_list.read().strip().split("\n"))
212
+
213
+ st.session_state.active_model = "M-CLIP (multilingual ViT)"
214
+
215
+ st.session_state.vision_mode = "tiled"
216
+ st.session_state.search_image_ids = []
217
+ st.session_state.search_image_scores = {}
218
+ st.session_state.text_table_df = None
219
+
220
+ with st.spinner("Loading models and data, please wait..."):
221
+ load_image_features()
222
+
223
+
224
+ if "images_info" not in st.session_state:
225
+ init()
226
+
227
+
228
+ def get_overlay_vis(image, img_dim, image_model):
229
+ orig_img_dims = image.size
230
+
231
+ ##### If the features are based on tiled image slices
232
+ tile_behavior = None
233
+
234
+ if st.session_state.vision_mode == "tiled":
235
+ scaled_dims = [img_dim, img_dim]
236
+
237
+ if orig_img_dims[0] > orig_img_dims[1]:
238
+ scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
239
+ if scale_ratio > 1:
240
+ scaled_dims = [scale_ratio * img_dim, img_dim]
241
+ tile_behavior = "width"
242
+ elif orig_img_dims[0] < orig_img_dims[1]:
243
+ scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
244
+ if scale_ratio > 1:
245
+ scaled_dims = [img_dim, scale_ratio * img_dim]
246
+ tile_behavior = "height"
247
+
248
+ resized_image = image.resize(scaled_dims, Image.LANCZOS)
249
+
250
+ if tile_behavior == "width":
251
+ image_tiles = []
252
+ for x in range(0, scale_ratio):
253
+ box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
254
+ image_tiles.append(resized_image.crop(box))
255
+
256
+ elif tile_behavior == "height":
257
+ image_tiles = []
258
+ for y in range(0, scale_ratio):
259
+ box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
260
+ image_tiles.append(resized_image.crop(box))
261
+
262
+ else:
263
+ image_tiles = [resized_image]
264
+
265
+ elif st.session_state.vision_mode == "stretched":
266
+ image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
267
+
268
+ else: # vision_mode == "cropped"
269
+ if orig_img_dims[0] > orig_img_dims[1]:
270
+ scale_factor = orig_img_dims[0] / orig_img_dims[1]
271
+ resized_img_dims = (round(scale_factor * img_dim), img_dim)
272
+ resized_img = image.resize(resized_img_dims)
273
+ elif orig_img_dims[0] < orig_img_dims[1]:
274
+ scale_factor = orig_img_dims[1] / orig_img_dims[0]
275
+ resized_img_dims = (img_dim, round(scale_factor * img_dim))
276
+ else:
277
+ resized_img_dims = (img_dim, img_dim)
278
+
279
+ resized_img = image.resize(resized_img_dims)
280
+
281
+ left = round((resized_img_dims[0] - img_dim) / 2)
282
+ top = round((resized_img_dims[1] - img_dim) / 2)
283
+ x_right = round(resized_img_dims[0] - img_dim) - left
284
+ x_bottom = round(resized_img_dims[1] - img_dim) - top
285
+ right = resized_img_dims[0] - x_right
286
+ bottom = resized_img_dims[1] - x_bottom
287
+
288
+ # Crop the center of the image
289
+ image_tiles = [resized_img.crop((left, top, right, bottom))]
290
+
291
+ image_visualizations = []
292
+ image_features = []
293
+ image_similarities = []
294
+
295
+ if st.session_state.active_model == "M-CLIP (multilingual ViT)":
296
+ text_features = st.session_state.ml_model.forward(
297
+ st.session_state.search_field_value, st.session_state.ml_tokenizer
298
+ )
299
+
300
+ if st.session_state.device == "cpu":
301
+ text_features = text_features.float().to(st.session_state.device)
302
+ else:
303
+ text_features = text_features.to(st.session_state.device)
304
+
305
+ for altered_image in image_tiles:
306
+ p_image = (
307
+ st.session_state.ml_image_preprocess(altered_image)
308
+ .unsqueeze(0)
309
+ .to(st.session_state.device)
310
+ )
311
+
312
+ vis_t, img_feats, similarity = interpret_vit_overlapped(
313
+ p_image.type(image_model.dtype),
314
+ text_features.type(image_model.dtype),
315
+ image_model.visual,
316
+ st.session_state.device,
317
+ img_dim=img_dim,
318
+ )
319
+
320
+ image_visualizations.append(vis_t)
321
+ image_features.append(img_feats)
322
+ image_similarities.append(similarity.item())
323
+
324
+ elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
325
+ t_text = st.session_state.ja_tokenizer(
326
+ st.session_state.search_field_value,
327
+ return_tensors="pt",
328
+ device=st.session_state.device,
329
+ )
330
+
331
+ text_features = st.session_state.ja_model.get_text_features(**t_text)
332
+
333
+ if st.session_state.device == "cpu":
334
+ text_features = text_features.float().to(st.session_state.device)
335
+ else:
336
+ text_features = text_features.to(st.session_state.device)
337
+
338
+ for altered_image in image_tiles:
339
+ p_image = (
340
+ st.session_state.ja_image_preprocess(altered_image)
341
+ .unsqueeze(0)
342
+ .to(st.session_state.device)
343
+ )
344
+
345
+ vis_t, img_feats, similarity = interpret_vit_overlapped(
346
+ p_image.type(image_model.dtype),
347
+ text_features.type(image_model.dtype),
348
+ image_model.visual,
349
+ st.session_state.device,
350
+ img_dim=img_dim,
351
+ )
352
+
353
+ image_visualizations.append(vis_t)
354
+ image_features.append(img_feats)
355
+ image_similarities.append(similarity.item())
356
+
357
+ else: # st.session_state.active_model == Legacy
358
+ text_features = st.session_state.rn_model(st.session_state.search_field_value)
359
+
360
+ if st.session_state.device == "cpu":
361
+ text_features = text_features.float().to(st.session_state.device)
362
+ else:
363
+ text_features = text_features.to(st.session_state.device)
364
+
365
+ for altered_image in image_tiles:
366
+ p_image = (
367
+ st.session_state.rn_image_preprocess(altered_image)
368
+ .unsqueeze(0)
369
+ .to(st.session_state.device)
370
+ )
371
+
372
+ vis_t = interpret_rn_overlapped(
373
+ p_image.type(image_model.dtype),
374
+ text_features.type(image_model.dtype),
375
+ image_model.visual,
376
+ GradCAM,
377
+ st.session_state.device,
378
+ img_dim=img_dim,
379
+ )
380
+
381
+ text_features_norm = text_features.norm(dim=-1, keepdim=True)
382
+ text_features_new = text_features / text_features_norm
383
+
384
+ image_feats = image_model.encode_image(p_image.type(image_model.dtype))
385
+ image_feats_norm = image_feats.norm(dim=-1, keepdim=True)
386
+ image_feats_new = image_feats / image_feats_norm
387
+
388
+ similarity = image_feats_new[0].dot(text_features_new[0])
389
+
390
+ image_visualizations.append(vis_t)
391
+ image_features.append(p_image)
392
+ image_similarities.append(similarity.item())
393
+
394
+ transform = ToPILImage()
395
+
396
+ vis_images = [transform(vis_t) for vis_t in image_visualizations]
397
+
398
+ if st.session_state.vision_mode == "cropped":
399
+ resized_img.paste(vis_images[0], (left, top))
400
+ vis_images = [resized_img]
401
+
402
+ if orig_img_dims[0] > orig_img_dims[1]:
403
+ scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
404
+ scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
405
+ else:
406
+ scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
407
+ scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
408
+
409
+ if tile_behavior == "width":
410
+ vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
411
+ for x, v_img in enumerate(vis_images):
412
+ vis_image.paste(v_img, (x * img_dim, 0))
413
+ activations_image = vis_image.resize(scaled_dims)
414
+
415
+ elif tile_behavior == "height":
416
+ vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
417
+ for y, v_img in enumerate(vis_images):
418
+ vis_image.paste(v_img, (0, y * img_dim))
419
+ activations_image = vis_image.resize(scaled_dims)
420
+
421
+ else:
422
+ activations_image = vis_images[0].resize(scaled_dims)
423
+
424
+ return activations_image, image_features, np.mean(image_similarities)
425
+
426
+
427
+ def visualize_gradcam(image):
428
+ if "search_field_value" not in st.session_state:
429
+ return
430
+
431
+ header_cols = st.columns([80, 20], vertical_alignment="bottom")
432
+ with header_cols[0]:
433
+ st.title("Image + query details")
434
+ with header_cols[1]:
435
+ if st.button("Close"):
436
+ st.rerun()
437
+
438
+ if st.session_state.active_model == "M-CLIP (multilingual ViT)":
439
+ img_dim = 240
440
+ image_model = st.session_state.ml_image_model
441
+ # Sometimes used for token importance viz
442
+ tokenized_text = st.session_state.ml_tokenizer.tokenize(
443
+ st.session_state.search_field_value
444
+ )
445
+ elif st.session_state.active_model == "Legacy (multilingual ResNet)":
446
+ img_dim = 288
447
+ image_model = st.session_state.rn_image_model
448
+ # Sometimes used for token importance viz
449
+ tokenized_text = st.session_state.rn_tokenizer.tokenize(
450
+ st.session_state.search_field_value
451
+ )
452
+ else: # J-CLIP
453
+ img_dim = 224
454
+ image_model = st.session_state.ja_image_model
455
+ # Sometimes used for token importance viz
456
+ tokenized_text = st.session_state.ja_tokenizer.tokenize(
457
+ st.session_state.search_field_value
458
+ )
459
+
460
+ with st.spinner("Calculating..."):
461
+ # info_text = st.text("Calculating activation regions...")
462
+
463
+ activations_image, image_features, similarity_score = get_overlay_vis(
464
+ image, img_dim, image_model
465
+ )
466
+
467
+ st.markdown(
468
+ f"**Query text:** {st.session_state.search_field_value} | **Approx. image relevance:** {round(similarity_score.item(), 3)}"
469
+ )
470
+
471
+ st.image(activations_image)
472
+
473
+ # image_io = BytesIO()
474
+ # activations_image.save(image_io, "PNG")
475
+ # dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
476
+ # "ascii"
477
+ # )
478
+
479
+ # st.html(
480
+ # f"""<div style="display: flex; flex-direction: column; align-items: center;">
481
+ # <img src="{dataurl}" />
482
+ # </div>"""
483
+ # )
484
+
485
+ tokenized_text = [
486
+ tok.replace("▁", "").replace("#", "") for tok in tokenized_text if tok != "▁"
487
+ ]
488
+ tokenized_text = [
489
+ tok
490
+ for tok in tokenized_text
491
+ if tok
492
+ not in ["s", "ed", "a", "the", "an", "ing", "て", "に", "の", "は", "と", "た"]
493
+ ]
494
+
495
+ if (
496
+ len(tokenized_text) > 1
497
+ and len(tokenized_text) < 25
498
+ and st.button(
499
+ "Calculate text importance (may take some time)",
500
+ )
501
+ ):
502
+ scores_per_token = {}
503
+
504
+ progress_text = f"Processing {len(tokenized_text)} text tokens"
505
+ progress_bar = st.progress(0.0, text=progress_text)
506
+
507
+ for t, tok in enumerate(tokenized_text):
508
+ token = tok
509
+
510
+ for img_feats in image_features:
511
+ if st.session_state.active_model == "Legacy (multilingual ResNet)":
512
+ word_rel = rn_perword_relevance(
513
+ img_feats,
514
+ st.session_state.search_field_value,
515
+ image_model,
516
+ tokenize,
517
+ GradCAM,
518
+ st.session_state.device,
519
+ token,
520
+ data_only=True,
521
+ img_dim=img_dim,
522
+ )
523
+ else:
524
+ word_rel = vit_perword_relevance(
525
+ img_feats,
526
+ st.session_state.search_field_value,
527
+ image_model,
528
+ tokenize,
529
+ st.session_state.device,
530
+ token,
531
+ img_dim=img_dim,
532
+ )
533
+ avg_score = np.mean(word_rel)
534
+ if avg_score == 0 or np.isnan(avg_score):
535
+ continue
536
+
537
+ if token not in scores_per_token:
538
+ scores_per_token[token] = [1 / avg_score]
539
+ else:
540
+ scores_per_token[token].append(1 / avg_score)
541
+
542
+ progress_bar.progress(
543
+ (t + 1) / len(tokenized_text),
544
+ text=f"Processing token {t+1} of {len(tokenized_text)}",
545
+ )
546
+ progress_bar.empty()
547
+
548
+ avg_scores_per_token = [
549
+ np.mean(scores_per_token[tok]) for tok in list(scores_per_token.keys())
550
+ ]
551
+
552
+ normed_scores = torch.softmax(torch.tensor(avg_scores_per_token), dim=0)
553
+
554
+ token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
555
+ st.session_state.text_table_df = pd.DataFrame(
556
+ {"token": list(scores_per_token.keys()), "importance": token_scores}
557
+ )
558
+
559
+ st.markdown("**Importance of each text token to relevance score**")
560
+ st.table(st.session_state.text_table_df)
561
+
562
+
563
+ @st.dialog(" ", width="large")
564
+ def image_modal(image):
565
+ visualize_gradcam(image)
566
+
567
+
568
+ def vis_known_image(vis_image_id):
569
+ image_url = st.session_state.images_info.loc[vis_image_id]["image_url"]
570
+ image_response = requests.get(image_url)
571
+ image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF", "PNG"])
572
+ image = image.convert("RGB")
573
+
574
+ image_modal(image)
575
+
576
+
577
+ def vis_uploaded_image():
578
+ uploaded_file = st.session_state.uploaded_image
579
+ if uploaded_file is not None:
580
+ # To read file as bytes:
581
+ bytes_data = uploaded_file.getvalue()
582
+ image = Image.open(BytesIO(bytes_data), formats=["JPEG", "GIF", "PNG"])
583
+ image = image.convert("RGB")
584
+
585
+ image_modal(image)
586
+
587
+
588
+ def format_vision_mode(mode_stub):
589
+ return mode_stub.capitalize()
590
+
591
+
592
+ st.title("Explore Japanese visual aesthetics with CLIP models")
593
+
594
+ st.markdown(
595
+ """
596
+ <style>
597
+ [data-testid=stImageCaption] {
598
+ padding: 0 0 0 0;
599
+ }
600
+ [data-testid=stVerticalBlockBorderWrapper] {
601
+ line-height: 1.2;
602
+ }
603
+ [data-testid=stVerticalBlock] {
604
+ gap: .75rem;
605
+ }
606
+ [data-testid=baseButton-secondary] {
607
+ min-height: 1rem;
608
+ padding: 0 0.75rem;
609
+ margin: 0 0 1rem 0;
610
+ }
611
+ div[aria-label="dialog"]>button[aria-label="Close"] {
612
+ display: none;
613
+ }
614
+ [data-testid=stFullScreenFrame] {
615
+ display: flex;
616
+ flex-direction: column;
617
+ align-items: center;
618
+ }
619
+ </style>
620
+ """,
621
+ unsafe_allow_html=True,
622
+ )
623
+
624
+ search_row = st.columns([45, 8, 8, 10, 1, 8, 20], vertical_alignment="center")
625
+ with search_row[0]:
626
+ search_field = st.text_input(
627
+ label="search",
628
+ label_visibility="collapsed",
629
+ placeholder="Type something, or click a suggested search below.",
630
+ on_change=string_search,
631
+ key="search_field_value",
632
+ )
633
+ with search_row[1]:
634
+ st.button(
635
+ "Search", on_click=string_search, use_container_width=True, type="primary"
636
+ )
637
+ with search_row[2]:
638
+ st.markdown("**Vision mode:**")
639
+ with search_row[3]:
640
+ st.selectbox(
641
+ "Vision mode",
642
+ options=["tiled", "stretched", "cropped"],
643
+ key="vision_mode",
644
+ help="How to consider images that aren't square",
645
+ on_change=load_image_features,
646
+ format_func=format_vision_mode,
647
+ label_visibility="collapsed",
648
+ )
649
+ with search_row[4]:
650
+ st.empty()
651
+ with search_row[5]:
652
+ st.markdown("**CLIP model:**")
653
+ with search_row[6]:
654
+ st.selectbox(
655
+ "CLIP Model:",
656
+ options=[
657
+ "M-CLIP (multilingual ViT)",
658
+ "J-CLIP (日本語 ViT)",
659
+ "Legacy (multilingual ResNet)",
660
+ ],
661
+ key="active_model",
662
+ on_change=string_search,
663
+ label_visibility="collapsed",
664
+ )
665
+
666
+ canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
667
+ with canned_searches[0]:
668
+ st.markdown("**Suggested searches:**")
669
+ if st.session_state.active_model == "J-CLIP (日本語 ViT)":
670
+ with canned_searches[1]:
671
+ st.button(
672
+ "間",
673
+ on_click=clip_search,
674
+ args=["間"],
675
+ use_container_width=True,
676
+ )
677
+ with canned_searches[2]:
678
+ st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True)
679
+ with canned_searches[3]:
680
+ st.button("山", on_click=clip_search, args=["山"], use_container_width=True)
681
+ with canned_searches[4]:
682
+ st.button(
683
+ "花に酔えり 羽織着て刀 さす女",
684
+ on_click=clip_search,
685
+ args=["花に酔えり 羽織着て刀 さす女"],
686
+ use_container_width=True,
687
+ )
688
+ else:
689
+ with canned_searches[1]:
690
+ st.button(
691
+ "negative space",
692
+ on_click=clip_search,
693
+ args=["negative space"],
694
+ use_container_width=True,
695
+ )
696
+ with canned_searches[2]:
697
+ st.button("間", on_click=clip_search, args=["間"], use_container_width=True)
698
+ with canned_searches[3]:
699
+ st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
700
+ with canned_searches[4]:
701
+ st.button(
702
+ "αρνητικός χώρος",
703
+ on_click=clip_search,
704
+ args=["αρνητικός χώρος"],
705
+ use_container_width=True,
706
+ )
707
+
708
+ controls = st.columns([25, 25, 20, 35], gap="large", vertical_alignment="center")
709
+ with controls[0]:
710
+ im_per_pg = st.columns([30, 70], vertical_alignment="center")
711
+ with im_per_pg[0]:
712
+ st.markdown("**Images/page:**")
713
+ with im_per_pg[1]:
714
+ batch_size = st.select_slider(
715
+ "Images/page:", range(10, 50, 10), label_visibility="collapsed"
716
+ )
717
+ with controls[1]:
718
+ im_per_row = st.columns([30, 70], vertical_alignment="center")
719
+ with im_per_row[0]:
720
+ st.markdown("**Images/row:**")
721
+ with im_per_row[1]:
722
+ row_size = st.select_slider(
723
+ "Images/row:", range(1, 6), value=5, label_visibility="collapsed"
724
+ )
725
+ num_batches = ceil(len(st.session_state.image_ids) / batch_size)
726
+ with controls[2]:
727
+ pager = st.columns([40, 60], vertical_alignment="center")
728
+ with pager[0]:
729
+ st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
730
+ with pager[1]:
731
+ st.number_input(
732
+ "Page",
733
+ min_value=1,
734
+ max_value=num_batches,
735
+ step=1,
736
+ label_visibility="collapsed",
737
+ key="current_page",
738
+ )
739
+ with controls[3]:
740
+ st.file_uploader(
741
+ "Upload an image",
742
+ type=["jpg", "jpeg", "gif", "png"],
743
+ key="uploaded_image",
744
+ label_visibility="collapsed",
745
+ on_change=vis_uploaded_image,
746
+ )
747
+
748
+
749
+ if len(st.session_state.search_image_ids) == 0:
750
+ batch = []
751
+ else:
752
+ batch = st.session_state.search_image_ids[
753
+ (st.session_state.current_page - 1) * batch_size : st.session_state.current_page
754
+ * batch_size
755
+ ]
756
+
757
+ grid = st.columns(row_size)
758
+ col = 0
759
+ for image_id in batch:
760
+ with grid[col]:
761
+ link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
762
+ 2
763
+ ]
764
+ # st.image(
765
+ # st.session_state.images_info.loc[image_id]["image_url"],
766
+ # caption=st.session_state.images_info.loc[image_id]["caption"],
767
+ # )
768
+ st.html(
769
+ f"""<div style="display: flex; flex-direction: column; align-items: center">
770
+ <img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
771
+ <div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
772
+ </div>"""
773
+ )
774
+ st.caption(
775
+ f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
776
+ <a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
777
+ <div>""",
778
+ unsafe_allow_html=True,
779
+ )
780
+ if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
781
+ st.button(
782
+ "Explain this",
783
+ on_click=vis_known_image,
784
+ args=[image_id],
785
+ use_container_width=True,
786
+ key=image_id,
787
+ )
788
+ col = (col + 1) % row_size