|
from math import ceil |
|
|
|
from multilingual_clip import pt_multilingual_clip |
|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
def init(): |
|
st.session_state.current_page = 1 |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" |
|
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider" |
|
|
|
st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained( |
|
ml_model_name |
|
) |
|
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name) |
|
|
|
st.session_state.ja_model = AutoModel.from_pretrained( |
|
ja_model_name, trust_remote_code=True |
|
).to(device) |
|
st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained( |
|
ja_model_name, trust_remote_code=True |
|
) |
|
|
|
st.session_state.search_image_ids = [] |
|
|
|
|
|
st.session_state.images_info = pd.read_csv("./metadata.csv") |
|
st.session_state.images_info.set_index("filename", inplace=True) |
|
|
|
st.session_state.image_ids = list( |
|
open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n") |
|
) |
|
|
|
|
|
ml_image_features = np.load("./multilingual_features.npy") |
|
ja_image_features = np.load("./hakuhodo_features.npy") |
|
|
|
|
|
if device == "cpu": |
|
ml_image_features = torch.from_numpy(ml_image_features).float().to(device) |
|
ja_image_features = torch.from_numpy(ja_image_features).float().to(device) |
|
else: |
|
ml_image_features = torch.from_numpy(ml_image_features).to(device) |
|
ja_image_features = torch.from_numpy(ja_image_features).to(device) |
|
|
|
st.session_state.ml_image_features = ml_image_features / ml_image_features.norm( |
|
dim=-1, keepdim=True |
|
) |
|
st.session_state.ja_image_features = ja_image_features / ja_image_features.norm( |
|
dim=-1, keepdim=True |
|
) |
|
|
|
|
|
if ( |
|
"ml_image_features" not in st.session_state |
|
or "ja_image_features" not in st.session_state |
|
): |
|
with st.spinner("Loading models and data, please wait..."): |
|
init() |
|
|
|
|
|
|
|
def encode_search_query(search_query, model_type): |
|
with torch.no_grad(): |
|
|
|
if model_type == "M-CLIP (multiple languages)": |
|
text_encoded = st.session_state.ml_model.forward( |
|
search_query, st.session_state.ml_tokenizer |
|
) |
|
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
|
else: |
|
t_text = st.session_state.ja_tokenizer( |
|
search_query, padding=True, return_tensors="pt" |
|
) |
|
text_encoded = st.session_state.ja_model.get_text_features(**t_text) |
|
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
|
|
|
|
|
return text_encoded |
|
|
|
|
|
|
|
def find_best_matches(text_features, image_features, image_ids): |
|
|
|
similarities = (image_features @ text_features.T).squeeze(1) |
|
|
|
|
|
best_image_idx = (-similarities).argsort() |
|
|
|
|
|
return [[image_ids[i], similarities[i].item()] for i in best_image_idx] |
|
|
|
|
|
def clip_search(search_query): |
|
if st.session_state.search_field_value != search_query: |
|
st.session_state.search_field_value = search_query |
|
|
|
model_type = st.session_state.active_model |
|
|
|
if len(search_query) >= 1: |
|
text_features = encode_search_query(search_query, model_type) |
|
|
|
|
|
|
|
|
|
|
|
if model_type == "M-CLIP (multiple languages)": |
|
matches = find_best_matches( |
|
text_features, |
|
st.session_state.ml_image_features, |
|
st.session_state.image_ids, |
|
) |
|
else: |
|
matches = find_best_matches( |
|
text_features, |
|
st.session_state.ja_image_features, |
|
st.session_state.image_ids, |
|
) |
|
|
|
result_image_ids = [match[0] for match in matches] |
|
st.session_state.search_image_ids = result_image_ids |
|
|
|
|
|
def string_search(): |
|
clip_search(st.session_state.search_field_value) |
|
|
|
|
|
st.title("Explore Japanese visual aesthetics with CLIP models") |
|
|
|
search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center") |
|
with search_row[0]: |
|
search_field = st.text_input( |
|
label="search", |
|
label_visibility="collapsed", |
|
placeholder="Type something, or click a suggested search below.", |
|
on_change=string_search, |
|
key="search_field_value", |
|
) |
|
with search_row[1]: |
|
st.button("Search", on_click=string_search, use_container_width=True) |
|
with search_row[2]: |
|
st.empty() |
|
with search_row[3]: |
|
st.markdown("**CLIP Model:**") |
|
with search_row[4]: |
|
st.radio( |
|
"CLIP Model", |
|
options=["M-CLIP (multiple languages)", "J-CLIP (日本語 only)"], |
|
key="active_model", |
|
on_change=string_search, |
|
horizontal=True, |
|
label_visibility="collapsed", |
|
) |
|
|
|
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="center") |
|
with canned_searches[0]: |
|
st.markdown("**Suggested searches:**") |
|
if st.session_state.active_model == "M-CLIP (multiple languages)": |
|
with canned_searches[1]: |
|
st.button( |
|
"negative space", |
|
on_click=clip_search, |
|
args=["negative space"], |
|
use_container_width=True, |
|
) |
|
with canned_searches[2]: |
|
st.button("間", on_click=clip_search, args=["間"], use_container_width=True) |
|
with canned_searches[3]: |
|
st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True) |
|
with canned_searches[4]: |
|
st.button( |
|
"αρνητικός χώρος", |
|
on_click=clip_search, |
|
args=["αρνητικός χώρος"], |
|
use_container_width=True, |
|
) |
|
else: |
|
with canned_searches[1]: |
|
st.button( |
|
"間", |
|
on_click=clip_search, |
|
args=["間"], |
|
use_container_width=True, |
|
) |
|
with canned_searches[2]: |
|
st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True) |
|
with canned_searches[3]: |
|
st.button("山", on_click=clip_search, args=["山"], use_container_width=True) |
|
with canned_searches[4]: |
|
st.button( |
|
"花に酔えり 羽織着て刀 さす女", |
|
on_click=clip_search, |
|
args=["花に酔えり 羽織着て刀 さす女"], |
|
use_container_width=True, |
|
) |
|
|
|
controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center") |
|
with controls[0]: |
|
im_per_pg = st.columns([30, 70], vertical_alignment="center") |
|
with im_per_pg[0]: |
|
st.markdown("**Images/page:**") |
|
with im_per_pg[1]: |
|
batch_size = st.select_slider( |
|
"Images/page:", range(10, 50, 10), label_visibility="collapsed" |
|
) |
|
with controls[1]: |
|
st.empty() |
|
with controls[2]: |
|
im_per_row = st.columns([30, 70], vertical_alignment="center") |
|
with im_per_row[0]: |
|
st.markdown("**Images/row:**") |
|
with im_per_row[1]: |
|
row_size = st.select_slider( |
|
"Images/row:", range(1, 6), value=5, label_visibility="collapsed" |
|
) |
|
num_batches = ceil(len(st.session_state.image_ids) / batch_size) |
|
with controls[3]: |
|
st.empty() |
|
with controls[4]: |
|
pager = st.columns([40, 60], vertical_alignment="center") |
|
with pager[0]: |
|
st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ") |
|
with pager[1]: |
|
st.number_input( |
|
"Page", |
|
min_value=1, |
|
max_value=num_batches, |
|
step=1, |
|
label_visibility="collapsed", |
|
key="current_page", |
|
) |
|
|
|
|
|
if len(st.session_state.search_image_ids) == 0: |
|
batch = [] |
|
else: |
|
batch = st.session_state.search_image_ids[ |
|
(st.session_state.current_page - 1) * batch_size : st.session_state.current_page |
|
* batch_size |
|
] |
|
|
|
grid = st.columns(row_size) |
|
col = 0 |
|
for image_id in batch: |
|
with grid[col]: |
|
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[ |
|
2 |
|
] |
|
st.html( |
|
f"""<div style="display: flex; flex-direction: column; align-items: center"> |
|
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: 800px" /> |
|
<div>{st.session_state.images_info.loc[image_id]['caption']}</div> |
|
</div>""" |
|
) |
|
st.caption( |
|
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -20px"> |
|
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a> |
|
<div>""", |
|
unsafe_allow_html=True, |
|
) |
|
col = (col + 1) % row_size |
|
|