shivangibithel's picture
Update app.py
ca9d65c
raw
history blame
No virus
2.21 kB
from datasets import load_dataset
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import wget
from PIL import Image
from io import BytesIO
from sentence_transformers import SentenceTransformer
# dataset = load_dataset("imagefolder", data_files="https://huggingface.co/datasets/nlphuji/flickr30k/blob/main/flickr30k-images.zip")
# Load the pre-trained sentence encoder
model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)
# Load the FAISS index
index_name = 'index.faiss'
index_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/faiss_flickr8k.index'
wget.download(index_url, index_name)
index = faiss.read_index(index_name)
# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'
image_map_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/captions.json'
wget.download(image_map_url, image_map_name)
with open(image_map_name, 'r') as f:
caption_dict = json.load(f)
image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())
def search(query, k=5):
# Encode the query
query_tokens = tokenizer.encode(query, return_tensors='pt')
query_embedding = model.encode(query_tokens).detach().numpy()
# Search for the nearest neighbors in the FAISS index
D, I = index.search(query_embedding, k)
# Map the image ids to the corresponding image URLs
image_urls = []
for i in I[0]:
text_id = i
image_id = str(image_list[i])
image_url = "https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/Images/" + image_id
image_urls.append(image_url)
return image_urls
st.title("Image Search App")
query = st.text_input("Enter your search query here:")
if st.button("Search"):
if query:
image_urls = search(query)
# Display the images
st.image(image_urls, width=200)
if __name__ == '__main__':
st.set_page_config(page_title='Image Search App', layout='wide')
st.cache(allow_output_mutation=True)
run_app()