import os from pathlib import Path import torch import numpy as np from PIL import Image import gradio as gr from tokenizers import Tokenizer from torch.utils.data import Dataset import albumentations as A from tqdm import tqdm from fourm.vq.vqvae import VQVAE from fourm.models.fm import FM from fourm.models.generate import ( GenerationSampler, build_chained_generation_schedules, init_empty_target_modality, custom_text, ) from fourm.utils.plotting_utils import decode_dict from fourm.data.modality_info import MODALITY_INFO from fourm.data.modality_transforms import RGBTransform from torchvision.transforms.functional import center_crop # Constants and configurations DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMG_SIZE = 224 TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json" FM_MODEL_PATH = "EPFL-VILAB/4M-21_L" VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224" IMAGE_DATASET_PATH = "/home/ubuntu/GIT_REPOS/ml-4m/data/custom_data/" # Load models text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH) vqvae = VQVAE.from_pretrained(VQVAE_PATH) fm_model = FM.from_pretrained(FM_MODEL_PATH).eval().to(DEVICE) # Generation configurations cond_domains = ["caption", "metadata"] target_domains = ["tok_dinov2_global"] tokens_per_target = [16] generation_config = { "autoregression_schemes": ["roar"], "decoding_steps": [1], "token_decoding_schedules": ["linear"], "temps": [2.0], "temp_schedules": ["onex:0.5:0.5"], "cfg_scales": [1.0], "cfg_schedules": ["constant"], "cfg_grow_conditioning": True, } top_p, top_k = 0.8, 0.0 schedule = build_chained_generation_schedules( cond_domains=cond_domains, target_domains=target_domains, tokens_per_target=tokens_per_target, **generation_config, ) sampler = GenerationSampler(fm_model) class ImageDataset(Dataset): def __init__(self, path: str, img_sz=IMG_SIZE): self.path = Path(path) self.files = list(self.path.rglob("*")) self.tfms = A.Compose( [A.SmallestMaxSize(img_sz)]) def __len__(self): return len(self.files) def __getitem__(self, idx): img = Image.open(self.files[idx]).convert("RGB") img = np.array(img) img = self.tfms(image=img)["image"] return Image.fromarray(img) dataset = ImageDataset(IMAGE_DATASET_PATH) @torch.no_grad() def get_image_embeddings(dataset): cache_file = "image_emb.pt" if os.path.exists(cache_file): return torch.load(cache_file) image_embeddings = get_image_embeddings(dataset).to(DEVICE) print(image_embeddings.shape) def get_similar_images(caption, brightness, num_items): batched_sample = {} for target_mod, ntoks in zip(target_domains, tokens_per_target): batched_sample = init_empty_target_modality( batched_sample, MODALITY_INFO, target_mod, 1, ntoks, DEVICE ) metadata = f"v1=6 v0={num_items} v1=10 v0={brightness}" print(metadata) batched_sample = custom_text( batched_sample, input_text=caption, eos_token="[EOS]", key="caption", device=DEVICE, text_tokenizer=text_tokenizer, ) batched_sample = custom_text( batched_sample, input_text=metadata, eos_token="[EOS]", key="metadata", device=DEVICE, text_tokenizer=text_tokenizer, ) out_dict = sampler.generate( batched_sample, schedule, text_tokenizer=text_tokenizer, verbose=True, seed=0, top_p=top_p, top_k=top_k, ) with torch.no_grad(): dec_dict = decode_dict( out_dict, {"tok_dinov2_global": vqvae.to(DEVICE)}, text_tokenizer, image_size=IMG_SIZE, patch_size=16, decoding_steps=1, ) combined_features = dec_dict["tok_dinov2_global"] similarities = torch.nn.functional.cosine_similarity( combined_features, image_embeddings ) top_indices = similarities.argsort(descending=True)[:1] print(top_indices, similarities[top_indices]) return [dataset[i] for i in top_indices.cpu().numpy()] # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Retrieval using 4M-21: An Any-to-Any Vision Model") with gr.Row(): with gr.Column(scale=1): caption = gr.Textbox( label="Caption Description", placeholder="Enter image description..." ) brightness = gr.Slider( minimum=0, maximum=255, value=5, step=1, label="Brightness", info="Adjust image brightness (0-255)" ) num_items = gr.Slider( minimum=0, maximum=50, value=5, step=1, label="Number of Items", info="Number of COCO instances in image (0-50)" ) with gr.Column(scale=1): output_images = gr.Gallery( label="Retrieved Images", show_label=True, elem_id="gallery", columns=2, rows=2, height=512, ) submit_btn = gr.Button("Retrieve Most Similar Image") submit_btn.click( fn=get_similar_images, inputs=[caption, brightness, num_items], outputs=output_images, ) if __name__ == "__main__": demo.launch(share=True)