import numpy as np import faiss import torch from torchvision.transforms import ( Compose, Resize, ToTensor, Normalize, InterpolationMode, CenterCrop, ) from PIL import Image import gradio as gr print("starting...") (ys,) = np.load("embs.npz").values() print("loaded embs") model = torch.load( "style-extractor-v0.3.0.ckpt", map_location="cpu", ) print("loaded extractor") with open("urls.txt") as f: urls = f.read().splitlines() print("loaded urls") assert len(urls) == len(ys) d = ys.shape[1] index = faiss.IndexHNSWFlat(d, 32) print("building index") index.add(ys) print("index built") def MyResize(area, d): def f(im: Image): w, h = im.size s = (area / w / h) ** 0.5 wd, hd = int(s * w / d), int(s * h / d) e = lambda a, b: 1 - min(a, b) / max(a, b) wd, hd = min( ( (ww * d, hh * d) for ww, hh in [(wd + i, hd + j) for i in (0, 1) for j in (0, 1)] if ww * d * hh * d <= area ), key=lambda wh: e(wh[0] / wh[1], w / h), ) return Compose( [ Resize( (int(h * wd / w), wd) if wd / w > hd / h else (hd, int(w * hd / h)), InterpolationMode.BICUBIC, ), CenterCrop((hd, wd)), ] )(im) return f tf = Compose( [ MyResize((518 * 1.3) ** 2, 14), ToTensor(), Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), ] ) def get_emb(im: Image): model.eval() with torch.no_grad(): return model(tf(im).unsqueeze(0)) n_outputs = 50 row_size = 5 def f(im): D, I = index.search(get_emb(im), n_outputs) return [f"Distance: {d:.1f}\n![]({urls[i]})" for d, i in zip(D[0], I[0])] print("preparing gradio") with gr.Blocks() as demo: gr.Markdown( "# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)" ) img = gr.Image(type="pil", label="Query", height=500) btn = gr.Button(variant="primary", value="search") outputs = [] for i in range(-(n_outputs // (-row_size))): with gr.Row(): for _ in range(min(row_size, n_outputs - i * row_size)): outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}")) btn.click(f, img, outputs) print("starting gradio") demo.launch()