gustproof's picture
Update app.py
accc6c7 verified
raw
history blame contribute delete
No virus
2.47 kB
import numpy as np
import faiss
import torch
from torchvision.transforms import (
Compose,
Resize,
ToTensor,
Normalize,
InterpolationMode,
CenterCrop,
)
from PIL import Image
import gradio as gr
print("starting...")
(ys,) = np.load("embs.npz").values()
print("loaded embs")
model = torch.load(
"style-extractor-v0.3.0.ckpt",
map_location="cpu",
)
print("loaded extractor")
with open("urls.txt") as f:
urls = f.read().splitlines()
print("loaded urls")
assert len(urls) == len(ys)
d = ys.shape[1]
index = faiss.IndexHNSWFlat(d, 32)
print("building index")
index.add(ys)
print("index built")
def MyResize(area, d):
def f(im: Image):
w, h = im.size
s = (area / w / h) ** 0.5
wd, hd = int(s * w / d), int(s * h / d)
e = lambda a, b: 1 - min(a, b) / max(a, b)
wd, hd = min(
(
(ww * d, hh * d)
for ww, hh in [(wd + i, hd + j) for i in (0, 1) for j in (0, 1)]
if ww * d * hh * d <= area
),
key=lambda wh: e(wh[0] / wh[1], w / h),
)
return Compose(
[
Resize(
(int(h * wd / w), wd) if wd / w > hd / h else (hd, int(w * hd / h)),
InterpolationMode.BICUBIC,
),
CenterCrop((hd, wd)),
]
)(im)
return f
tf = Compose(
[
MyResize((518 * 1.3) ** 2, 14),
ToTensor(),
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
]
)
def get_emb(im: Image):
model.eval()
with torch.no_grad():
return model(tf(im).unsqueeze(0))
n_outputs = 50
row_size = 5
def f(im):
D, I = index.search(get_emb(im), n_outputs)
return [f"Distance: {d:.1f}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]
print("preparing gradio")
with gr.Blocks() as demo:
gr.Markdown(
"# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)"
)
img = gr.Image(type="pil", label="Query", height=500)
btn = gr.Button(variant="primary", value="search")
outputs = []
for i in range(-(n_outputs // (-row_size))):
with gr.Row():
for _ in range(min(row_size, n_outputs - i * row_size)):
outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
btn.click(f, img, outputs)
print("starting gradio")
demo.launch()