broadwell commited on
Commit
e547128
1 Parent(s): b56a0e7

Upload 6 files

Browse files
Files changed (5) hide show
  1. app.py +173 -0
  2. image_features.npy +3 -0
  3. images_list.txt +0 -0
  4. metadata.csv +0 -0
  5. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # Norsk (Multilingual) Image Search
5
+ #
6
+ # Based on [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search)
7
+ # by [Vladimir Haltakov](https://twitter.com/haltakov).
8
+
9
+ # In[ ]:
10
+
11
+
12
+ import clip
13
+ import gradio as gr
14
+ from multilingual_clip import pt_multilingual_clip, legacy_multilingual_clip
15
+ import numpy as np
16
+ import os
17
+ import pandas as pd
18
+ from PIL import Image
19
+ import requests
20
+ import torch
21
+ from transformers import AutoTokenizer
22
+
23
+
24
+ # In[ ]:
25
+
26
+
27
+ # Load the open CLIP model
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
31
+
32
+ model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+
35
+
36
+ # In[ ]:
37
+
38
+
39
+ # Load the image IDs
40
+ images_info = pd.read_csv("./metadata.csv")
41
+ image_ids = list(
42
+ open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n")
43
+ )
44
+
45
+ # Load the image feature vectors
46
+ image_features = np.load("./image_features.npy")
47
+
48
+ # Convert features to Tensors: Float32 on CPU and Float16 on GPU
49
+ if device == "cpu":
50
+ image_features = torch.from_numpy(image_features).float().to(device)
51
+ else:
52
+ image_features = torch.from_numpy(image_features).to(device)
53
+
54
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
55
+
56
+ # ## Define Functions
57
+ #
58
+ # Some important functions for processing the data are defined here.
59
+ #
60
+ #
61
+
62
+ # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
63
+
64
+ # In[ ]:
65
+
66
+
67
+ def encode_search_query(search_query):
68
+ with torch.no_grad():
69
+ # Encode and normalize the search query using the multilingual model
70
+ text_encoded = model.forward(search_query, tokenizer)
71
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
72
+
73
+ # Retrieve the feature vector
74
+ return text_encoded
75
+
76
+
77
+ # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
78
+
79
+ # In[ ]:
80
+
81
+
82
+ def find_best_matches(text_features, image_features, image_ids, results_count=3):
83
+ # Compute the similarity between the search query and each image using the Cosine similarity
84
+ similarities = (image_features @ text_features.T).squeeze(1)
85
+
86
+ # Sort the images by their similarity score
87
+ best_image_idx = (-similarities).argsort()
88
+
89
+ # Return the image IDs of the best matches
90
+ return [
91
+ [image_ids[i], similarities[i].item()] for i in best_image_idx[:results_count]
92
+ ]
93
+
94
+
95
+ # In[ ]:
96
+
97
+
98
+ def clip_search(search_query):
99
+ if len(search_query) >= 3:
100
+ text_features = encode_search_query(search_query)
101
+
102
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
103
+ # similarities = list((text_features @ photo_features.T).squeeze(0))
104
+
105
+ # Sort the photos by their similarity score
106
+ matches = find_best_matches(
107
+ text_features, image_features, image_ids, results_count=15
108
+ )
109
+
110
+ images = []
111
+ for i in range(15):
112
+ # Retrieve the photo ID
113
+ image_id = matches[i][0]
114
+ image_url = images_info[images_info["filename"] == image_id][
115
+ "image_url"
116
+ ].values[0]
117
+
118
+ # response = requests.get(image_url)
119
+ # img = PIL.open(response.raw)
120
+
121
+ images.append(
122
+ [
123
+ (image_url),
124
+ images_info[images_info["filename"] == image_id][
125
+ "permalink"
126
+ ].values[0],
127
+ ]
128
+ )
129
+
130
+ # print(images)
131
+ return images
132
+
133
+
134
+ css = (
135
+ "footer {display: none !important;} .gradio-container {min-height: 0px !important;}"
136
+ )
137
+ with gr.Blocks(css=css) as gr_app:
138
+ with gr.Column(variant="panel"):
139
+ with gr.Row(variant="compact"):
140
+ search_string = gr.Textbox(
141
+ label="Evocative Search",
142
+ show_label=True,
143
+ max_lines=1,
144
+ placeholder="Type something, or click a suggested search below.",
145
+ ).style(
146
+ container=False,
147
+ )
148
+ btn = gr.Button("Search", variant="primary").style(full_width=False)
149
+ with gr.Row(variant="compact"):
150
+ suggest1 = gr.Button(
151
+ "två hundar som leker i snön", variant="secondary"
152
+ ).style(size="sm")
153
+ suggest2 = gr.Button(
154
+ "en fisker til sjøs i en båt", variant="secondary"
155
+ ).style(size="sm")
156
+ suggest3 = gr.Button(
157
+ "cold dark alone on the street", variant="secondary"
158
+ ).style(size="sm")
159
+ suggest4 = gr.Button("도로 위의 자동차들", variant="secondary").style(size="sm")
160
+ gallery = gr.Gallery(label=False, show_label=False, elem_id="gallery").style(
161
+ grid=[6],
162
+ height="100%",
163
+ )
164
+
165
+ suggest1.click(clip_search, inputs=suggest1, outputs=gallery)
166
+ suggest2.click(clip_search, inputs=suggest2, outputs=gallery)
167
+ suggest3.click(clip_search, inputs=suggest3, outputs=gallery)
168
+ suggest4.click(clip_search, inputs=suggest4, outputs=gallery)
169
+ btn.click(clip_search, inputs=search_string, outputs=gallery)
170
+ search_string.submit(clip_search, search_string, gallery)
171
+
172
+ if __name__ == "__main__":
173
+ gr_app.launch(share=True)
image_features.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:374756fa5293ce32e31d9b9681c53b374d3024adbfb1c6aaa4791aa3937fec40
3
+ size 51210368
images_list.txt ADDED
The diff for this file is too large to render. See raw diff
 
metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip==1.0
2
+ gradio==3.27.0
3
+ multilingual_clip==1.0.10
4
+ numpy==1.23.5
5
+ pandas==2.0.1
6
+ Pillow==9.4.0
7
+ Pillow==9.1.1
8
+ Pillow==9.5.0
9
+ Requests==2.28.2
10
+ torch==1.11.0
11
+ torch==1.13.1
12
+ transformers==4.19.2