import numpy as np import faiss import torch from torchvision.transforms import ( Compose, Resize, ToTensor, Normalize, InterpolationMode, ) from PIL import Image import gradio as gr print("starting...") (ys,) = np.load("embs.npz").values() print("loaded embs") model = torch.load( "style-extractor-v0.2.0.ckpt", map_location="cpu", ) print("loaded extractor") with open("urls.txt") as f: urls = f.read().splitlines() print("loaded urls") assert len(urls) == len(ys) d = ys.shape[1] index = faiss.IndexHNSWFlat(d, 32) print("building index") index.add(ys) print('index built') tf = Compose( [ Resize( size=336, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True, ), ToTensor(), Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), ] ) def get_emb(im: Image): model.eval() with torch.no_grad(): return model(tf(im).unsqueeze(0)) n_outputs = 50 row_size = 5 def f(im): D, I = index.search(get_emb(im), n_outputs) return [f"Distance: {d:.1f}\n![]({urls[i]})" for d, i in zip(D[0], I[0])] print("preparing gradio") with gr.Blocks() as demo: gr.Markdown( "# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)" ) img = gr.Image(type="pil", label="Query", height=500) btn = gr.Button(variant="primary", value="search") outputs = [] for i in range(-(n_outputs // (-row_size))): with gr.Row(): for _ in range(min(row_size, n_outputs - i * row_size)): outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}")) btn.click(f, img, outputs) print("starting gradio") demo.launch()