import json import os import random import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from torch.utils.data.dataset import Dataset class CC15M(Dataset): def __init__( self, json_path, video_folder=None, resolution=512, enable_bucket=False, ): print(f"loading annotations from {json_path} ...") self.dataset = json.load(open(json_path, 'r')) self.length = len(self.dataset) print(f"data scale: {self.length}") self.enable_bucket = enable_bucket self.video_folder = video_folder resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution) self.pixel_transforms = transforms.Compose([ transforms.Resize(resolution[0]), transforms.CenterCrop(resolution), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) def get_batch(self, idx): video_dict = self.dataset[idx] video_id, name = video_dict['file_path'], video_dict['text'] if self.video_folder is None: video_dir = video_id else: video_dir = os.path.join(self.video_folder, video_id) pixel_values = Image.open(video_dir).convert("RGB") return pixel_values, name def __len__(self): return self.length def __getitem__(self, idx): while True: try: pixel_values, name = self.get_batch(idx) break except Exception as e: print(e) idx = random.randint(0, self.length-1) if not self.enable_bucket: pixel_values = self.pixel_transforms(pixel_values) else: pixel_values = np.array(pixel_values) sample = dict(pixel_values=pixel_values, text=name) return sample if __name__ == "__main__": dataset = CC15M( csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json", resolution=512, ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) for idx, batch in enumerate(dataloader): print(batch["pixel_values"].shape, len(batch["text"]))