gustproof commited on
Commit
a819f61
1 Parent(s): dec8f2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -11,18 +11,23 @@ from torchvision.transforms import (
11
  from PIL import Image
12
  import gradio as gr
13
 
 
14
  (ys,) = np.load("embs.npz").values()
 
15
  model = torch.load(
16
  "style-extractor-v0.2.0.ckpt",
17
  map_location="cpu",
18
  )
 
19
  with open("urls.txt") as f:
20
  urls = f.read().splitlines()
 
21
  assert len(urls) == len(ys)
22
  d = ys.shape[1]
23
  index = faiss.IndexFlatL2(d)
24
- index.is_trained
25
  index.add(ys)
 
26
  tf = Compose(
27
  [
28
  Resize(
@@ -51,7 +56,7 @@ def f(im):
51
  D, I = index.search(get_emb(im), n_outputs)
52
  return [f"Distance: {d}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]
53
 
54
-
55
  with gr.Blocks() as demo:
56
  gr.Markdown(
57
  "# Style Similarity Search\n\nFind artworks with a similar style from a small database (10k artists * 6img/artist)"
@@ -64,4 +69,5 @@ with gr.Blocks() as demo:
64
  for _ in range(min(row_size, n_outputs - i * row_size)):
65
  outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
66
  btn.click(f, img, outputs)
 
67
  demo.launch()
 
11
  from PIL import Image
12
  import gradio as gr
13
 
14
+ print("starting...")
15
  (ys,) = np.load("embs.npz").values()
16
+ print("loaded embs")
17
  model = torch.load(
18
  "style-extractor-v0.2.0.ckpt",
19
  map_location="cpu",
20
  )
21
+ print("loaded extractor")
22
  with open("urls.txt") as f:
23
  urls = f.read().splitlines()
24
+ print("loaded urls")
25
  assert len(urls) == len(ys)
26
  d = ys.shape[1]
27
  index = faiss.IndexFlatL2(d)
28
+ print("building index")
29
  index.add(ys)
30
+ print('index built')
31
  tf = Compose(
32
  [
33
  Resize(
 
56
  D, I = index.search(get_emb(im), n_outputs)
57
  return [f"Distance: {d}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]
58
 
59
+ print("preparing gradio")
60
  with gr.Blocks() as demo:
61
  gr.Markdown(
62
  "# Style Similarity Search\n\nFind artworks with a similar style from a small database (10k artists * 6img/artist)"
 
69
  for _ in range(min(row_size, n_outputs - i * row_size)):
70
  outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
71
  btn.click(f, img, outputs)
72
+ print("starting gradio")
73
  demo.launch()