#!/usr/bin/env python # coding: utf-8 # Norsk (Multilingual) Image Search # # Based on [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) # by [Vladimir Haltakov](https://twitter.com/haltakov). # In[ ]: import clip import gradio as gr from multilingual_clip import pt_multilingual_clip, legacy_multilingual_clip import numpy as np import os import pandas as pd from PIL import Image import requests import torch from transformers import AutoTokenizer # In[ ]: # Load the open CLIP model device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # In[ ]: # Load the image IDs images_info = pd.read_csv("./metadata.csv") image_ids = list( open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n") ) # Load the image feature vectors image_features = np.load("./image_features.npy") # Convert features to Tensors: Float32 on CPU and Float16 on GPU if device == "cpu": image_features = torch.from_numpy(image_features).float().to(device) else: image_features = torch.from_numpy(image_features).to(device) image_features = image_features / image_features.norm(dim=-1, keepdim=True) # ## Define Functions # # Some important functions for processing the data are defined here. # # # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model. # In[ ]: def encode_search_query(search_query): with torch.no_grad(): # Encode and normalize the search query using the multilingual model text_encoded = model.forward(search_query, tokenizer) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) # Retrieve the feature vector return text_encoded # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images. # In[ ]: def find_best_matches(text_features, image_features, image_ids, results_count=3): # Compute the similarity between the search query and each image using the Cosine similarity similarities = (image_features @ text_features.T).squeeze(1) # Sort the images by their similarity score best_image_idx = (-similarities).argsort() # Return the image IDs of the best matches return [ [image_ids[i], similarities[i].item()] for i in best_image_idx[:results_count] ] # In[ ]: def clip_search(search_query): if len(search_query) >= 3: text_features = encode_search_query(search_query) # Compute the similarity between the descrption and each photo using the Cosine similarity # similarities = list((text_features @ photo_features.T).squeeze(0)) # Sort the photos by their similarity score matches = find_best_matches( text_features, image_features, image_ids, results_count=15 ) images = [] for i in range(15): # Retrieve the photo ID image_id = matches[i][0] image_url = images_info[images_info["filename"] == image_id][ "image_url" ].values[0] # response = requests.get(image_url) # img = PIL.open(response.raw) images.append( [ (image_url), images_info[images_info["filename"] == image_id][ "permalink" ].values[0], ] ) # print(images) return images css = ( "footer {display: none !important;} .gradio-container {min-height: 0px !important;}" ) with gr.Blocks(css=css) as gr_app: with gr.Column(variant="panel"): with gr.Row(variant="compact"): search_string = gr.Textbox( label="Evocative Search", show_label=True, max_lines=1, placeholder="Type something, or click a suggested search below.", container=False, ) btn = gr.Button("Search", variant="primary") #.style(full_width=False) with gr.Row(variant="compact"): suggest1 = gr.Button( "två hundar som leker i snön", variant="secondary", size="sm" )# .style(size="sm") suggest2 = gr.Button( "en fisker til sjøs i en båt", variant="secondary", size="sm" )# .style(size="sm") suggest3 = gr.Button( "cold dark alone on the street", variant="secondary", size="sm" )# .style(size="sm") suggest4 = gr.Button("도로 위의 자동차들", variant="secondary", size="sm") gallery = gr.Gallery(label=False, show_label=False, elem_id="gallery", height="100%", columns=6) suggest1.click(clip_search, inputs=suggest1, outputs=gallery) suggest2.click(clip_search, inputs=suggest2, outputs=gallery) suggest3.click(clip_search, inputs=suggest3, outputs=gallery) suggest4.click(clip_search, inputs=suggest4, outputs=gallery) btn.click(clip_search, inputs=search_string, outputs=gallery) search_string.submit(clip_search, search_string, gallery) if __name__ == "__main__": gr_app.launch(share=False)