import os from pathlib import Path import torch import numpy as np from PIL import Image import gradio as gr from tokenizers import Tokenizer from torch.utils.data import Dataset import albumentations as A from tqdm import tqdm from huggingface_hub import hf_hub_download from datasets import load_dataset from fourm.vq.vqvae import VQVAE from fourm.models.fm import FM from fourm.models.generate import ( GenerationSampler, build_chained_generation_schedules, init_empty_target_modality, custom_text, ) from fourm.utils.plotting_utils import decode_dict from fourm.data.modality_info import MODALITY_INFO from fourm.data.modality_transforms import RGBTransform from torchvision.transforms.functional import center_crop # Constants and configurations DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMG_SIZE = 224 TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json" FM_MODEL_PATH = "EPFL-VILAB/4M-21_L" VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224" IMAGE_DATASET_PATH = "./data" # Load models text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH) vqvae = VQVAE.from_pretrained(VQVAE_PATH) fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE) # Generation configurations cond_domains = ["caption", "metadata"] target_domains = ["tok_dinov2_global"] tokens_per_target = [16] generation_config = { "autoregression_schemes": ["roar"], "decoding_steps": [1], "token_decoding_schedules": ["linear"], "temps": [2.0], "temp_schedules": ["onex:0.5:0.5"], "cfg_scales": [1.0], "cfg_schedules": ["constant"], "cfg_grow_conditioning": True, } top_p, top_k = 0.8, 0.0 schedule = build_chained_generation_schedules( cond_domains=cond_domains, target_domains=target_domains, tokens_per_target=tokens_per_target, **generation_config, ) sampler = GenerationSampler(fm_model) class HuggingFaceImageDataset(Dataset): def __init__(self, dataset_name, split="train", img_sz=224): self.dataset = load_dataset(dataset_name, split=split) self.tfms = A.Compose([ A.SmallestMaxSize(img_sz) ]) def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.dataset[idx]['image'] img = np.array(img) img = self.tfms(image=img)["image"] return Image.fromarray(img) # Usage dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo") def load_image_embeddings(): # Download the file file_path = hf_hub_download(repo_id="aroraaman/img-tensor", filename="image_emb.pt") # Load the tensor image_embeddings = torch.load(file_path) return image_embeddings # Use the embeddings in your app image_embeddings = load_image_embeddings() image_embeddings = image_embeddings.to(DEVICE) image_embeddings.shape print(image_embeddings.shape) def get_similar_images(caption, brightness, num_items): batched_sample = {} for target_mod, ntoks in zip(target_domains, tokens_per_target): batched_sample = init_empty_target_modality( batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE ) metadata = f"v1=6 v0={num_items} v1=10 v0={brightness}" print(metadata) batched_sample = custom_text( batched_sample, input_text=caption, eos_token="[EOS]", key="caption", device=DEVICE, text_tokenizer=text_tokenizer, ) batched_sample = custom_text( batched_sample, input_text=metadata, eos_token="[EOS]", key="metadata", device=DEVICE, text_tokenizer=text_tokenizer, ) out_dict = sampler.generate( batched_sample, schedule, text_tokenizer=text_tokenizer, verbose=True, seed=0, top_p=top_p, top_k=top_k, ) with torch.no_grad(): dec_dict = decode_dict( out_dict, {"tok_dinov2_global": vqvae.to(DEVICE)}, text_tokenizer, image_size=IMG_SIZE, patch_size=16, decoding_steps=1, ) combined_features = dec_dict["tok_dinov2_global"] similarities = torch.nn.functional.cosine_similarity( combined_features, image_embeddings ) top_indices = similarities.argsort(descending=True)[:1] print(top_indices, similarities[top_indices]) return [dataset[int(i)] for i in top_indices.cpu().numpy()] # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Retrieval using 4M-21: An Any-to-Any Vision Model") gr.Markdown(""" This app demonstrates image retrieval using the 4M-21 model, an any-to-any vision model. Enter a caption description, adjust the brightness, and specify the number of items to retrieve similar images. The retrieval dataset for this demo is available at: https://huggingface.co/datasets/aroraaman/4m-21-demo """) with gr.Row(): with gr.Column(scale=1): caption = gr.Textbox( label="Caption Description", placeholder="Enter image description..." ) brightness = gr.Slider( minimum=0, maximum=255, value=5, step=1, label="Brightness", info="Adjust image brightness (0-255)" ) num_items = gr.Slider( minimum=0, maximum=50, value=5, step=1, label="Number of Items", info="Number of COCO instances in image (0-50)" ) with gr.Column(scale=1): output_images = gr.Gallery( label="Retrieved Images", show_label=True, elem_id="gallery", columns=2, rows=2, height=512, ) submit_btn = gr.Button("Retrieve Most Similar Image") submit_btn.click( fn=get_similar_images, inputs=[caption, brightness, num_items], outputs=output_images, ) # Add examples gr.Examples( examples=[ ["swimming pool", 27, 7], ["swimming pool", 255, 7], ["dining room", 22, 7], ["dining room", 5, 7], ["dining room", 5, 46] ], inputs=[caption, brightness, num_items] ) if __name__ == "__main__": demo.launch()