samla-photos / app.py
broadwell's picture
Update Python library versions and Gradio syntax
4657c90
raw
history blame contribute delete
No virus
5.4 kB
#!/usr/bin/env python
# coding: utf-8
# Norsk (Multilingual) Image Search
#
# Based on [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search)
# by [Vladimir Haltakov](https://twitter.com/haltakov).
# In[ ]:
import clip
import gradio as gr
from multilingual_clip import pt_multilingual_clip, legacy_multilingual_clip
import numpy as np
import os
import pandas as pd
from PIL import Image
import requests
import torch
from transformers import AutoTokenizer
# In[ ]:
# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# In[ ]:
# Load the image IDs
images_info = pd.read_csv("./metadata.csv")
image_ids = list(
open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n")
)
# Load the image feature vectors
image_features = np.load("./image_features.npy")
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
if device == "cpu":
image_features = torch.from_numpy(image_features).float().to(device)
else:
image_features = torch.from_numpy(image_features).to(device)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# ## Define Functions
#
# Some important functions for processing the data are defined here.
#
#
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
# In[ ]:
def encode_search_query(search_query):
with torch.no_grad():
# Encode and normalize the search query using the multilingual model
text_encoded = model.forward(search_query, tokenizer)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector
return text_encoded
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
# In[ ]:
def find_best_matches(text_features, image_features, image_ids, results_count=3):
# Compute the similarity between the search query and each image using the Cosine similarity
similarities = (image_features @ text_features.T).squeeze(1)
# Sort the images by their similarity score
best_image_idx = (-similarities).argsort()
# Return the image IDs of the best matches
return [
[image_ids[i], similarities[i].item()] for i in best_image_idx[:results_count]
]
# In[ ]:
def clip_search(search_query):
if len(search_query) >= 3:
text_features = encode_search_query(search_query)
# Compute the similarity between the descrption and each photo using the Cosine similarity
# similarities = list((text_features @ photo_features.T).squeeze(0))
# Sort the photos by their similarity score
matches = find_best_matches(
text_features, image_features, image_ids, results_count=15
)
images = []
for i in range(15):
# Retrieve the photo ID
image_id = matches[i][0]
image_url = images_info[images_info["filename"] == image_id][
"image_url"
].values[0]
# response = requests.get(image_url)
# img = PIL.open(response.raw)
images.append(
[
(image_url),
images_info[images_info["filename"] == image_id][
"permalink"
].values[0],
]
)
# print(images)
return images
css = (
"footer {display: none !important;} .gradio-container {min-height: 0px !important;}"
)
with gr.Blocks(css=css) as gr_app:
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
search_string = gr.Textbox(
label="Evocative Search",
show_label=True,
max_lines=1,
placeholder="Type something, or click a suggested search below.",
container=False,
)
btn = gr.Button("Search", variant="primary") #.style(full_width=False)
with gr.Row(variant="compact"):
suggest1 = gr.Button(
"två hundar som leker i snön", variant="secondary", size="sm"
)# .style(size="sm")
suggest2 = gr.Button(
"en fisker til sjøs i en båt", variant="secondary", size="sm"
)# .style(size="sm")
suggest3 = gr.Button(
"cold dark alone on the street", variant="secondary", size="sm"
)# .style(size="sm")
suggest4 = gr.Button("도로 위의 자동차들", variant="secondary", size="sm")
gallery = gr.Gallery(label=False, show_label=False, elem_id="gallery", height="100%", columns=6)
suggest1.click(clip_search, inputs=suggest1, outputs=gallery)
suggest2.click(clip_search, inputs=suggest2, outputs=gallery)
suggest3.click(clip_search, inputs=suggest3, outputs=gallery)
suggest4.click(clip_search, inputs=suggest4, outputs=gallery)
btn.click(clip_search, inputs=search_string, outputs=gallery)
search_string.submit(clip_search, search_string, gallery)
if __name__ == "__main__":
gr_app.launch(share=False)