import webdataset as wds import glob import os from tqdm import tqdm import orjson as json import itertools from PIL import Image import numpy as np from typing import List import cv2 import random class Generator(): def __init__(self, dataset_name): self.dataset_name = dataset_name self.is_end = False class CC3MGenerator(Generator): def __init__(self, root: str, dataset_name="cc3m"): super().__init__(dataset_name=dataset_name) self.tars = glob.glob(os.path.join(root, "cc3m_*", "*.tar")) def __len__(self): return 3000000 def __iter__(self): for tar in self.tars: dataset = wds.WebDataset(tar).decode("pilrgb").to_tuple("jpg;png;jpeg", "txt") for data in dataset: yield [self.dataset_name] + list(data) self.is_end = True class CC12MGenerator(CC3MGenerator): def __init__(self, root: str): super().__init__(root, "cc12m") self.tars = glob.glob(os.path.join(root, "*.tar")) def __len__(self): return 12000000 class COCOGenerator(Generator): def __init__(self, anno: str, image_dir): super().__init__(dataset_name="coco") data = json.loads(open(anno).read()) self.annotations = data["annotations"] self.image_id_to_filename = {} for image in data["images"]: file_name = image["file_name"] image_id = image["id"] self.image_id_to_filename[image_id] = os.path.join(image_dir, file_name) def __len__(self): return len(self.annotations) def __iter__(self): for anno in self.annotations: image_id = anno["image_id"] caption = anno["caption"] try: image = Image.open(self.image_id_to_filename[image_id]) except: continue yield [self.dataset_name, image, caption] self.is_end = True class KarpathyCOCOGenerator(Generator): def __init__(self, anno="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/tools/coco_karpathy_train.json", image_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco/images"): super().__init__(dataset_name="coco") data = json.loads(open(anno).read()) self.annotations = data self.image_id_to_filename = {} for d in data: self.image_id_to_filename[d["image_id"]] = os.path.join(image_dir, d["image"]) def __len__(self): return len(self.annotations) def __iter__(self): for anno in self.annotations: image_id = anno["image_id"] caption = anno["caption"] try: image = Image.open(self.image_id_to_filename[image_id]) except: print(self.image_id_to_filename[image_id]) yield [self.dataset_name, image, caption] self.is_end = True class VisualGenomeGenerator(Generator): def __init__(self, root: str): super().__init__(dataset_name="vg") data = json.loads(open(os.path.join(root, "region_descriptions.json")).read()) image_data = json.loads(open(os.path.join(root, "image_data.json")).read()) self.image_id_to_filename = {} self.image_id_to_wh = {} for image in image_data: image_id = image["image_id"] subfolder, filename = image['url'].split("/")[-2:] self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename) self.image_id_to_wh[image_id] = (image["width"], image["height"]) self.regions = [] total = 0 total_image = 0 used_image = 0 for xx in data: total_image += 1 flag = False for region in xx['regions']: total += 1 region_w = int(region["width"]) region_h = int(region["height"]) x = int(region["x"]) y = int(region["y"]) image_w = self.image_id_to_wh[region["image_id"]][0] image_h = self.image_id_to_wh[region["image_id"]][1] region_w /= image_w region_h /= image_h x /= image_w y /= image_h if region_w * region_h < 0.1: continue if " is" in region["phrase"] or " are" in region["phrase"] or len(region["phrase"].split(" ")) <= 7: continue region["norm_xywh"] = (x, y, region_w, region_h) self.regions.append(region) flag = True if flag: used_image += 1 random.shuffle(self.regions) print("valid region", len(self.regions), total, len(self.regions) / total) print("valid image", used_image, total_image, used_image / total_image) def __len__(self): return len(self.regions) def __iter__(self): for region in self.regions: image_id = region["image_id"] phrase = region["phrase"] try: image = Image.open(self.image_id_to_filename[image_id]) except: continue image = image.resize((224, 224)) x, y, region_w, region_h = region["norm_xywh"] x1 = int(x * 224) y1 = int(y * 224) x2 = int(x1 + region_w * 224) y2 = int(y1 + region_h * 224) # open_cv_image = np.array(image) # # Convert RGB to BGR # open_cv_image = open_cv_image[:, :, ::-1].copy() # open_cv_image = cv2.rectangle(open_cv_image, (x1, y1), (x2, y2), (255, 0, 0), 2) # cv2.imwrite("vg.jpg", open_cv_image) # import pdb; pdb.set_trace() yield [self.dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id] self.is_end = True class ShuffleGenerator(): def __init__(self, generators: List[Generator], p: List[int]): self.generators = generators self.p = list(np.array(p) / sum(p)) self.ids = list(range(len(self.generators))) print("rebalance", self.ids, self.p) def __len__(self): return sum([len(g) for g in self.generators]) def __iter__(self): while True: if len(self.ids) == 0: break id = np.random.choice(self.ids, p=self.p) gen = self.generators[id] if gen.is_end: print(gen.dataset_name, "is all done") del self.ids[id] del self.p[id] self.p = list(np.array(self.p) / sum(p)) print("rebalance", self.ids, self.p) else: return iter(gen) if __name__ == "__main__": OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_withBox_L7_wds" os.makedirs(OUT_DIR, exist_ok=True) # cc3m_generator = CC3MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc3m") # cc12m_generator = CC12MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc12m/tars") # coco_generator = KarpathyCOCOGenerator() visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg") # generators = [cc3m_generator, cc12m_generator, coco_generator, visual_genome_generator] # p = [len(generator) for generator in generators] # dataset = ShuffleGenerator(generators, p) with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=8500) as sink: sink.verbose = False pbar = tqdm(visual_genome_generator) for i, data in enumerate(pbar): dataset_name, image, caption, xyxy, image_id = data sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption, "xyxy.pyd": xyxy}) if i % 200 == 0: tqdm.write(f"{caption} {xyxy}")