broadwell commited on
Commit
3e19435
1 Parent(s): 6f2e4d3

Delete CLIP_Explainability/app.py

Browse files
Files changed (1) hide show
  1. CLIP_Explainability/app.py +0 -711
CLIP_Explainability/app.py DELETED
@@ -1,711 +0,0 @@
1
- from base64 import b64encode
2
- from io import BytesIO
3
- from math import ceil
4
-
5
- import clip
6
- from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip
7
- import numpy as np
8
- import pandas as pd
9
- from PIL import Image
10
- import requests
11
- import streamlit as st
12
- import torch
13
- from torchvision.transforms import ToPILImage
14
- from transformers import AutoTokenizer, AutoModel, BertTokenizer
15
-
16
- from CLIP_Explainability.clip_ import load, tokenize
17
- from CLIP_Explainability.rn_cam import (
18
- # interpret_rn,
19
- interpret_rn_overlapped,
20
- rn_perword_relevance,
21
- )
22
- from CLIP_Explainability.vit_cam import (
23
- # interpret_vit,
24
- vit_perword_relevance,
25
- interpret_vit_overlapped,
26
- )
27
-
28
- from pytorch_grad_cam.grad_cam import GradCAM
29
-
30
- MAX_IMG_WIDTH = 500
31
- MAX_IMG_HEIGHT = 800
32
-
33
- st.set_page_config(layout="wide")
34
-
35
-
36
- # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
37
- def find_best_matches(text_features, image_features, image_ids):
38
- # Compute the similarity between the search query and each image using the Cosine similarity
39
- similarities = (image_features @ text_features.T).squeeze(1)
40
-
41
- # Sort the images by their similarity score
42
- best_image_idx = (-similarities).argsort()
43
-
44
- # Return the image IDs of the best matches
45
- return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
46
-
47
-
48
- # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
49
- def encode_search_query(search_query, model_type):
50
- with torch.no_grad():
51
- # Encode and normalize the search query using the multilingual model
52
- if model_type == "M-CLIP (multilingual ViT)":
53
- text_encoded = st.session_state.ml_model.forward(
54
- search_query, st.session_state.ml_tokenizer
55
- )
56
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
57
- elif model_type == "J-CLIP (日本語 ViT)":
58
- t_text = st.session_state.ja_tokenizer(
59
- search_query, padding=True, return_tensors="pt"
60
- )
61
- text_encoded = st.session_state.ja_model.get_text_features(**t_text)
62
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
63
- else: # model_type == legacy
64
- text_encoded = st.session_state.rn_model(search_query)
65
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
66
-
67
- # Retrieve the feature vector
68
- return text_encoded
69
-
70
-
71
- def clip_search(search_query):
72
- if st.session_state.search_field_value != search_query:
73
- st.session_state.search_field_value = search_query
74
-
75
- model_type = st.session_state.active_model
76
-
77
- if len(search_query) >= 1:
78
- text_features = encode_search_query(search_query, model_type)
79
-
80
- # Compute the similarity between the descrption and each photo using the Cosine similarity
81
- # similarities = list((text_features @ photo_features.T).squeeze(0))
82
-
83
- # Sort the photos by their similarity score
84
- if model_type == "M-CLIP (multilingual ViT)":
85
- matches = find_best_matches(
86
- text_features,
87
- st.session_state.ml_image_features,
88
- st.session_state.image_ids,
89
- )
90
- elif model_type == "J-CLIP (日本語 ViT)":
91
- matches = find_best_matches(
92
- text_features,
93
- st.session_state.ja_image_features,
94
- st.session_state.image_ids,
95
- )
96
- else: # model_type == legacy
97
- matches = find_best_matches(
98
- text_features,
99
- st.session_state.rn_image_features,
100
- st.session_state.image_ids,
101
- )
102
-
103
- st.session_state.search_image_ids = [match[0] for match in matches]
104
- st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
105
-
106
-
107
- def string_search():
108
- if "search_field_value" in st.session_state:
109
- clip_search(st.session_state.search_field_value)
110
-
111
-
112
- def load_image_features():
113
- # Load the image feature vectors
114
- if st.session_state.vision_mode == "tiled":
115
- ml_image_features = np.load("./image_features/tiled_ml_features.npy")
116
- ja_image_features = np.load("./image_features/tiled_ja_features.npy")
117
- rn_image_features = np.load("./image_features/tiled_rn_features.npy")
118
- elif st.session_state.vision_mode == "stretched":
119
- ml_image_features = np.load("./image_features/resized_ml_features.npy")
120
- ja_image_features = np.load("./image_features/resized_ja_features.npy")
121
- rn_image_features = np.load("./image_features/resized_rn_features.npy")
122
- else: # st.session_state.vision_mode == "cropped":
123
- ml_image_features = np.load("./image_features/cropped_ml_features.npy")
124
- ja_image_features = np.load("./image_features/cropped_ja_features.npy")
125
- rn_image_features = np.load("./image_features/cropped_rn_features.npy")
126
-
127
- # Convert features to Tensors: Float32 on CPU and Float16 on GPU
128
- device = st.session_state.device
129
- if device == "cpu":
130
- ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
131
- ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
132
- rn_image_features = torch.from_numpy(rn_image_features).float().to(device)
133
- else:
134
- ml_image_features = torch.from_numpy(ml_image_features).to(device)
135
- ja_image_features = torch.from_numpy(ja_image_features).to(device)
136
- rn_image_features = torch.from_numpy(rn_image_features).to(device)
137
-
138
- st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
139
- dim=-1, keepdim=True
140
- )
141
- st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
142
- dim=-1, keepdim=True
143
- )
144
- st.session_state.rn_image_features = rn_image_features / rn_image_features.norm(
145
- dim=-1, keepdim=True
146
- )
147
-
148
- string_search()
149
-
150
-
151
- def init():
152
- st.session_state.current_page = 1
153
-
154
- device = "cuda" if torch.cuda.is_available() else "cpu"
155
- st.session_state.device = device
156
-
157
- # Load the open CLIP models
158
-
159
- with st.spinner("Loading models and data, please wait..."):
160
- ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
161
- ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
162
-
163
- st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
164
- ml_model_path, device=device, jit=False
165
- )
166
-
167
- st.session_state.ml_model = (
168
- pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
169
- )
170
- st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
171
-
172
- ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
173
- ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
174
-
175
- st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
176
- ja_model_path, device=device, jit=False
177
- )
178
-
179
- st.session_state.ja_model = AutoModel.from_pretrained(
180
- ja_model_name, trust_remote_code=True
181
- ).to(device)
182
- st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
183
- ja_model_name, trust_remote_code=True
184
- )
185
-
186
- st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
187
- clip.load("RN50x4", device=device)
188
- )
189
-
190
- st.session_state.rn_model = legacy_multilingual_clip.load_model(
191
- "M-BERT-Base-69"
192
- )
193
- st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
194
- "bert-base-multilingual-cased"
195
- )
196
-
197
- # Load the image IDs
198
- st.session_state.images_info = pd.read_csv("./metadata.csv")
199
- st.session_state.images_info.set_index("filename", inplace=True)
200
-
201
- with open("./images_list.txt", "r", encoding="utf-8") as images_list:
202
- st.session_state.image_ids = list(images_list.read().strip().split("\n"))
203
-
204
- st.session_state.active_model = "M-CLIP (multilingual ViT)"
205
-
206
- st.session_state.vision_mode = "tiled"
207
- st.session_state.search_image_ids = []
208
- st.session_state.search_image_scores = {}
209
- st.session_state.activations_image = None
210
- st.session_state.text_table_df = None
211
-
212
- with st.spinner("Loading models and data, please wait..."):
213
- load_image_features()
214
-
215
-
216
- if "images_info" not in st.session_state:
217
- init()
218
-
219
-
220
- def visualize_gradcam(viz_image_id):
221
- if "search_field_value" not in st.session_state:
222
- return
223
-
224
- header_cols = st.columns([80, 20], vertical_alignment="bottom")
225
- with header_cols[0]:
226
- st.title("Image + query details")
227
- with header_cols[1]:
228
- if st.button("Close"):
229
- st.rerun()
230
-
231
- st.markdown(
232
- f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
233
- )
234
-
235
- with st.spinner("Calculating..."):
236
- # info_text = st.text("Calculating activation regions...")
237
-
238
- image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
239
- image_response = requests.get(image_url)
240
- image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"])
241
- image = image.convert("RGB")
242
-
243
- img_dim = 224
244
- if st.session_state.active_model == "M-CLIP (multilingual ViT)":
245
- img_dim = 240
246
- elif st.session_state.active_model == "Legacy (multilingual ResNet)":
247
- img_dim = 288
248
-
249
- orig_img_dims = image.size
250
-
251
- ##### If the features are based on tiled image slices
252
- tile_behavior = None
253
-
254
- if st.session_state.vision_mode == "tiled":
255
- scaled_dims = [img_dim, img_dim]
256
-
257
- if orig_img_dims[0] > orig_img_dims[1]:
258
- scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
259
- if scale_ratio > 1:
260
- scaled_dims = [scale_ratio * img_dim, img_dim]
261
- tile_behavior = "width"
262
- elif orig_img_dims[0] < orig_img_dims[1]:
263
- scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
264
- if scale_ratio > 1:
265
- scaled_dims = [img_dim, scale_ratio * img_dim]
266
- tile_behavior = "height"
267
-
268
- resized_image = image.resize(scaled_dims, Image.LANCZOS)
269
-
270
- if tile_behavior == "width":
271
- image_tiles = []
272
- for x in range(0, scale_ratio):
273
- box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
274
- image_tiles.append(resized_image.crop(box))
275
-
276
- elif tile_behavior == "height":
277
- image_tiles = []
278
- for y in range(0, scale_ratio):
279
- box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
280
- image_tiles.append(resized_image.crop(box))
281
-
282
- else:
283
- image_tiles = [resized_image]
284
-
285
- elif st.session_state.vision_mode == "stretched":
286
- image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
287
-
288
- else: # vision_mode == "cropped"
289
- if orig_img_dims[0] > orig_img_dims[1]:
290
- scale_factor = orig_img_dims[0] / orig_img_dims[1]
291
- resized_img_dims = (round(scale_factor * img_dim), img_dim)
292
- resized_img = image.resize(resized_img_dims)
293
- elif orig_img_dims[0] < orig_img_dims[1]:
294
- scale_factor = orig_img_dims[1] / orig_img_dims[0]
295
- resized_img_dims = (img_dim, round(scale_factor * img_dim))
296
- else:
297
- resized_img_dims = (img_dim, img_dim)
298
-
299
- resized_img = image.resize(resized_img_dims)
300
-
301
- left = round((resized_img_dims[0] - img_dim) / 2)
302
- top = round((resized_img_dims[1] - img_dim) / 2)
303
- x_right = round(resized_img_dims[0] - img_dim) - left
304
- x_bottom = round(resized_img_dims[1] - img_dim) - top
305
- right = resized_img_dims[0] - x_right
306
- bottom = resized_img_dims[1] - x_bottom
307
-
308
- # Crop the center of the image
309
- image_tiles = [resized_img.crop((left, top, right, bottom))]
310
-
311
- image_visualizations = []
312
-
313
- if st.session_state.active_model == "M-CLIP (multilingual ViT)":
314
- # Sometimes used for token importance viz
315
- tokenized_text = st.session_state.ml_tokenizer.tokenize(
316
- st.session_state.search_field_value
317
- )
318
-
319
- text_features = st.session_state.ml_model.forward(
320
- st.session_state.search_field_value, st.session_state.ml_tokenizer
321
- )
322
-
323
- image_model = st.session_state.ml_image_model
324
-
325
- for altered_image in image_tiles:
326
- p_image = (
327
- st.session_state.ml_image_preprocess(altered_image)
328
- .unsqueeze(0)
329
- .to(st.session_state.device)
330
- )
331
-
332
- vis_t = interpret_vit_overlapped(
333
- p_image.type(st.session_state.ml_image_model.dtype),
334
- text_features,
335
- image_model.visual,
336
- st.session_state.device,
337
- img_dim=img_dim,
338
- )
339
-
340
- image_visualizations.append(vis_t)
341
-
342
- elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
343
- # Sometimes used for token importance viz
344
- tokenized_text = st.session_state.ja_tokenizer.tokenize(
345
- st.session_state.search_field_value
346
- )
347
-
348
- t_text = st.session_state.ja_tokenizer(
349
- st.session_state.search_field_value, return_tensors="pt"
350
- )
351
- text_features = st.session_state.ja_model.get_text_features(**t_text)
352
-
353
- image_model = st.session_state.ja_image_model
354
-
355
- for altered_image in image_tiles:
356
- p_image = (
357
- st.session_state.ja_image_preprocess(altered_image)
358
- .unsqueeze(0)
359
- .to(st.session_state.device)
360
- )
361
-
362
- vis_t = interpret_vit_overlapped(
363
- p_image.type(st.session_state.ja_image_model.dtype),
364
- text_features,
365
- image_model.visual,
366
- st.session_state.device,
367
- img_dim=img_dim,
368
- )
369
-
370
- image_visualizations.append(vis_t)
371
-
372
- else: # st.session_state.active_model == Legacy
373
- # Sometimes used for token importance viz
374
- tokenized_text = st.session_state.rn_tokenizer.tokenize(
375
- st.session_state.search_field_value
376
- )
377
-
378
- text_features = st.session_state.rn_model(
379
- st.session_state.search_field_value
380
- )
381
-
382
- image_model = st.session_state.rn_image_model
383
-
384
- for altered_image in image_tiles:
385
- p_image = (
386
- st.session_state.rn_image_preprocess(altered_image)
387
- .unsqueeze(0)
388
- .to(st.session_state.device)
389
- )
390
-
391
- vis_t = interpret_rn_overlapped(
392
- p_image.type(st.session_state.rn_image_model.dtype),
393
- text_features,
394
- image_model.visual,
395
- GradCAM,
396
- st.session_state.device,
397
- img_dim=img_dim,
398
- )
399
-
400
- image_visualizations.append(vis_t)
401
-
402
- transform = ToPILImage()
403
-
404
- vis_images = [transform(vis_t) for vis_t in image_visualizations]
405
-
406
- if st.session_state.vision_mode == "cropped":
407
- resized_img.paste(vis_images[0], (left, top))
408
- vis_images = [resized_img]
409
-
410
- if orig_img_dims[0] > orig_img_dims[1]:
411
- scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
412
- scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
413
- else:
414
- scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
415
- scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
416
-
417
- if tile_behavior == "width":
418
- vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
419
- for x, v_img in enumerate(vis_images):
420
- vis_image.paste(v_img, (x * img_dim, 0))
421
- st.session_state.activations_image = vis_image.resize(scaled_dims)
422
-
423
- elif tile_behavior == "height":
424
- vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
425
- for y, v_img in enumerate(vis_images):
426
- vis_image.paste(v_img, (0, y * img_dim))
427
- st.session_state.activations_image = vis_image.resize(scaled_dims)
428
-
429
- else:
430
- st.session_state.activations_image = vis_images[0].resize(scaled_dims)
431
-
432
- image_io = BytesIO()
433
- st.session_state.activations_image.save(image_io, "PNG")
434
- dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
435
- "ascii"
436
- )
437
-
438
- st.html(
439
- f"""<div style="display: flex; flex-direction: column; align-items: center;">
440
- <img src="{dataurl}" />
441
- </div>"""
442
- )
443
-
444
- tokenized_text = [tok.replace("▁", "") for tok in tokenized_text if tok != "▁"]
445
- tokenized_text = [
446
- tok for tok in tokenized_text if tok not in ["s", "ed", "a", "the", "an", "ing"]
447
- ]
448
-
449
- if (
450
- len(tokenized_text) > 1
451
- and len(tokenized_text) < 25
452
- and st.button(
453
- "Calculate text importance (may take some time)",
454
- )
455
- ):
456
- search_tokens = []
457
- token_scores = []
458
-
459
- progress_text = f"Processing {len(tokenized_text)} text tokens"
460
- progress_bar = st.progress(0.0, text=progress_text)
461
-
462
- for t, tok in enumerate(tokenized_text):
463
- token = tok
464
-
465
- if st.session_state.active_model == "Legacy (multilingual ResNet)":
466
- word_rel = rn_perword_relevance(
467
- p_image,
468
- st.session_state.search_field_value,
469
- image_model,
470
- tokenize,
471
- GradCAM,
472
- st.session_state.device,
473
- token,
474
- data_only=True,
475
- img_dim=img_dim,
476
- )
477
- else:
478
- word_rel = vit_perword_relevance(
479
- p_image,
480
- st.session_state.search_field_value,
481
- image_model,
482
- tokenize,
483
- st.session_state.device,
484
- token,
485
- data_only=True,
486
- img_dim=img_dim,
487
- )
488
- avg_score = np.mean(word_rel)
489
- if avg_score == 0 or np.isnan(avg_score):
490
- continue
491
- search_tokens.append(token)
492
- token_scores.append(1 / avg_score)
493
-
494
- progress_bar.progress(
495
- (t + 1) / len(tokenized_text),
496
- text=f"Processing token {t+1} of {len(tokenized_text)}",
497
- )
498
- progress_bar.empty()
499
-
500
- normed_scores = torch.softmax(torch.tensor(token_scores), dim=0)
501
-
502
- token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
503
- st.session_state.text_table_df = pd.DataFrame(
504
- {"token": search_tokens, "importance": token_scores}
505
- )
506
-
507
- st.markdown("**Importance of each text token to relevance score**")
508
- st.table(st.session_state.text_table_df)
509
-
510
-
511
- def format_vision_mode(mode_stub):
512
- return mode_stub.capitalize()
513
-
514
-
515
- @st.dialog(" ", width="large")
516
- def image_modal(vis_image_id):
517
- visualize_gradcam(vis_image_id)
518
-
519
-
520
- st.title("Explore Japanese visual aesthetics with CLIP models")
521
-
522
- st.markdown(
523
- """
524
- <style>
525
- [data-testid=stImageCaption] {
526
- padding: 0 0 0 0;
527
- }
528
- [data-testid=stVerticalBlockBorderWrapper] {
529
- line-height: 1.2;
530
- }
531
- [data-testid=stVerticalBlock] {
532
- gap: .75rem;
533
- }
534
- [data-testid=baseButton-secondary] {
535
- min-height: 1rem;
536
- padding: 0 0.75rem;
537
- margin: 0 0 1rem 0;
538
- }
539
- div[aria-label="dialog"]>button[aria-label="Close"] {
540
- display: none;
541
- }
542
- [data-testid=stFullScreenFrame] {
543
- display: flex;
544
- flex-direction: column;
545
- align-items: center;
546
- }
547
- </style>
548
- """,
549
- unsafe_allow_html=True,
550
- )
551
-
552
- search_row = st.columns([45, 8, 8, 10, 1, 8, 20], vertical_alignment="center")
553
- with search_row[0]:
554
- search_field = st.text_input(
555
- label="search",
556
- label_visibility="collapsed",
557
- placeholder="Type something, or click a suggested search below.",
558
- on_change=string_search,
559
- key="search_field_value",
560
- )
561
- with search_row[1]:
562
- st.button(
563
- "Search", on_click=string_search, use_container_width=True, type="primary"
564
- )
565
- with search_row[2]:
566
- st.markdown("**Vision mode:**")
567
- with search_row[3]:
568
- st.selectbox(
569
- "Vision mode",
570
- options=["tiled", "stretched", "cropped"],
571
- key="vision_mode",
572
- help="How to consider images that aren't square",
573
- on_change=load_image_features,
574
- format_func=format_vision_mode,
575
- label_visibility="collapsed",
576
- )
577
- with search_row[4]:
578
- st.empty()
579
- with search_row[5]:
580
- st.markdown("**CLIP model:**")
581
- with search_row[6]:
582
- st.selectbox(
583
- "CLIP Model:",
584
- options=[
585
- "M-CLIP (multilingual ViT)",
586
- "J-CLIP (日本語 ViT)",
587
- "Legacy (multilingual ResNet)",
588
- ],
589
- key="active_model",
590
- on_change=string_search,
591
- label_visibility="collapsed",
592
- )
593
-
594
- canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
595
- with canned_searches[0]:
596
- st.markdown("**Suggested searches:**")
597
- if st.session_state.active_model == "J-CLIP (日本語 ViT)":
598
- with canned_searches[1]:
599
- st.button(
600
- "間",
601
- on_click=clip_search,
602
- args=["間"],
603
- use_container_width=True,
604
- )
605
- with canned_searches[2]:
606
- st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True)
607
- with canned_searches[3]:
608
- st.button("山", on_click=clip_search, args=["山"], use_container_width=True)
609
- with canned_searches[4]:
610
- st.button(
611
- "花に酔えり 羽織着て刀 さす女",
612
- on_click=clip_search,
613
- args=["花に酔えり 羽織着て刀 さす女"],
614
- use_container_width=True,
615
- )
616
- else:
617
- with canned_searches[1]:
618
- st.button(
619
- "negative space",
620
- on_click=clip_search,
621
- args=["negative space"],
622
- use_container_width=True,
623
- )
624
- with canned_searches[2]:
625
- st.button("間", on_click=clip_search, args=["間"], use_container_width=True)
626
- with canned_searches[3]:
627
- st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
628
- with canned_searches[4]:
629
- st.button(
630
- "αρνητικός χώρος",
631
- on_click=clip_search,
632
- args=["αρνητικός χώρος"],
633
- use_container_width=True,
634
- )
635
-
636
- controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center")
637
- with controls[0]:
638
- im_per_pg = st.columns([30, 70], vertical_alignment="center")
639
- with im_per_pg[0]:
640
- st.markdown("**Images/page:**")
641
- with im_per_pg[1]:
642
- batch_size = st.select_slider(
643
- "Images/page:", range(10, 50, 10), label_visibility="collapsed"
644
- )
645
- with controls[1]:
646
- st.empty()
647
- with controls[2]:
648
- im_per_row = st.columns([30, 70], vertical_alignment="center")
649
- with im_per_row[0]:
650
- st.markdown("**Images/row:**")
651
- with im_per_row[1]:
652
- row_size = st.select_slider(
653
- "Images/row:", range(1, 6), value=5, label_visibility="collapsed"
654
- )
655
- num_batches = ceil(len(st.session_state.image_ids) / batch_size)
656
- with controls[3]:
657
- st.empty()
658
- with controls[4]:
659
- pager = st.columns([40, 60], vertical_alignment="center")
660
- with pager[0]:
661
- st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
662
- with pager[1]:
663
- st.number_input(
664
- "Page",
665
- min_value=1,
666
- max_value=num_batches,
667
- step=1,
668
- label_visibility="collapsed",
669
- key="current_page",
670
- )
671
-
672
-
673
- if len(st.session_state.search_image_ids) == 0:
674
- batch = []
675
- else:
676
- batch = st.session_state.search_image_ids[
677
- (st.session_state.current_page - 1) * batch_size : st.session_state.current_page
678
- * batch_size
679
- ]
680
-
681
- grid = st.columns(row_size)
682
- col = 0
683
- for image_id in batch:
684
- with grid[col]:
685
- link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
686
- 2
687
- ]
688
- # st.image(
689
- # st.session_state.images_info.loc[image_id]["image_url"],
690
- # caption=st.session_state.images_info.loc[image_id]["caption"],
691
- # )
692
- st.html(
693
- f"""<div style="display: flex; flex-direction: column; align-items: center">
694
- <img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
695
- <div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
696
- </div>"""
697
- )
698
- st.caption(
699
- f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
700
- <a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
701
- <div>""",
702
- unsafe_allow_html=True,
703
- )
704
- st.button(
705
- "Explain this",
706
- on_click=image_modal,
707
- args=[image_id],
708
- use_container_width=True,
709
- key=image_id,
710
- )
711
- col = (col + 1) % row_size