gustproof commited on
Commit
04573a7
1 Parent(s): 71c674f

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -0
main.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import faiss
3
+ import torch
4
+ from torchvision.transforms import (
5
+ Compose,
6
+ Resize,
7
+ ToTensor,
8
+ Normalize,
9
+ InterpolationMode,
10
+ )
11
+ from PIL import Image
12
+ import gradio as gr
13
+
14
+ (ys,) = np.load("embs.npz").values()
15
+ model = torch.load(
16
+ "style-extractor-v0.2.0.ckpt",
17
+ map_location="cpu",
18
+ )
19
+ with open("urls.txt") as f:
20
+ urls = f.read().splitlines()
21
+ assert len(urls) == len(ys)
22
+ d = ys.shape[1]
23
+ index = faiss.IndexFlatL2(d)
24
+ index.is_trained
25
+ index.add(ys)
26
+ tf = Compose(
27
+ [
28
+ Resize(
29
+ size=336,
30
+ interpolation=InterpolationMode.BICUBIC,
31
+ max_size=None,
32
+ antialias=True,
33
+ ),
34
+ ToTensor(),
35
+ Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
36
+ ]
37
+ )
38
+
39
+
40
+ def get_emb(im: Image):
41
+ model.eval()
42
+ with torch.no_grad():
43
+ return model(tf(im).unsqueeze(0))
44
+
45
+
46
+ n_outputs = 50
47
+ row_size = 5
48
+
49
+
50
+ def f(im):
51
+ D, I = index.search(get_emb(im), n_outputs)
52
+ return [f"Distance: {d}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]
53
+
54
+
55
+ with gr.Blocks() as demo:
56
+ gr.Markdown(
57
+ "# Style Similarity Search\n\nFind artworks with a similar style from a small database (10k artists * 6img/artist)"
58
+ )
59
+ img = gr.Image(type="pil", label="Query", height=500)
60
+ btn = gr.Button(variant="primary", value="search")
61
+ outputs = []
62
+ for i in range(-(n_outputs // (-row_size))):
63
+ with gr.Row():
64
+ for _ in range(min(row_size, n_outputs - i * row_size)):
65
+ outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
66
+ btn.click(f, img, outputs)
67
+ demo.launch()