# Copyright 2024 EPFL and Apple Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random from abc import ABC, abstractmethod import numpy as np import torchvision from fourm.utils import to_2tuple class AbstractImageAugmenter(ABC): """Abstract class for image augmenters. """ @abstractmethod def __call__(self, mod_dict, crop_settings): pass class RandomCropImageAugmenter(AbstractImageAugmenter): def __init__(self, target_size=224, hflip=0.5, crop_scale=(0.2, 1.0), crop_ratio=(0.75, 1.3333), main_domain='rgb'): self.target_size = to_2tuple(target_size) self.hflip = hflip self.crop_scale = crop_scale self.crop_ratio = crop_ratio self.main_domain = main_domain def __call__(self, mod_dict, crop_settings): if crop_settings is not None: raise ValueError("Crop settings are provided but not used by this augmenter.") image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) orig_width, orig_height = image.size orig_size = (orig_height, orig_width) top, left, h, w = torchvision.transforms.RandomResizedCrop.get_params( image, scale=self.crop_scale, ratio=self.crop_ratio ) crop_coords = top, left, h, w flip = random.random() < self.hflip rand_aug_idx = None return crop_coords, flip, orig_size, self.target_size, rand_aug_idx class NoImageAugmenter(AbstractImageAugmenter): # this is for non-image modalities like poses where we don't do any augs, e.g. during tokenization def __init__(self, no_aug=True, main_domain='human_poses'): self.target_size = None #to_2tuple(target_size) self.no_aug = no_aug self.main_domain = main_domain def __call__(self, mod_dict, crop_settings): # # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) orig_size = (224, 224) rand_aug_idx = 0 top, left, h, w, flip = 0, 0, 224, 224, 0 crop_coords = (top, left, h, w) return crop_coords, flip, orig_size, self.target_size, rand_aug_idx class PreTokenizedImageAugmenter(AbstractImageAugmenter): def __init__(self, target_size, no_aug=False, main_domain='rgb'): self.target_size = to_2tuple(target_size) self.no_aug = no_aug self.main_domain = main_domain def __call__(self, mod_dict, crop_settings): # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) if self.main_domain in mod_dict and 'tok' not in self.main_domain: image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] orig_width, orig_height = image.size orig_size = (orig_height, orig_width) else: orig_size = None rand_aug_idx = 0 if self.no_aug else np.random.randint(len(crop_settings)) top, left, h, w, flip = crop_settings[rand_aug_idx] crop_coords = (top, left, h, w) return crop_coords, flip, orig_size, self.target_size, rand_aug_idx class CenterCropImageAugmenter(AbstractImageAugmenter): def __init__(self, target_size, hflip=0.0, main_domain='rgb'): self.target_size = to_2tuple(target_size) self.hflip = hflip self.main_domain = main_domain def __call__(self, mod_dict, crop_settings=None): image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] orig_width, orig_height = image.size orig_size = (orig_height, orig_width) if orig_height > orig_width: h = w = orig_width top = (orig_height - orig_width) // 2 left = 0 else: h = w = orig_height top = 0 left = (orig_width - orig_height) // 2 crop_coords = (top, left, h, w) flip = random.random() < self.hflip rand_aug_idx = None return crop_coords, flip, orig_size, self.target_size, rand_aug_idx class PaddingImageAugmenter(AbstractImageAugmenter): def __init__(self, target_size, hflip=0.0, main_domain='rgb'): self.target_size = to_2tuple(target_size) self.hflip = hflip self.main_domain = main_domain def __call__(self, mod_dict, crop_settings): image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] orig_width, orig_height = image.size orig_size = (orig_height, orig_width) h = w = max(orig_width, orig_height) top = left = 0 crop_coords = (top, left, h, w) flip = random.random() < self.hflip rand_aug_idx = None return crop_coords, flip, orig_size, self.target_size, rand_aug_idx class ScaleJitteringImageAugmenter(AbstractImageAugmenter): def __init__(self, target_size, hflip=0.0, scale=(0.1, 2.0), main_domain='rgb'): self.target_size = to_2tuple(target_size) self.hflip = hflip self.scale = scale self.main_domain = main_domain def scale_jitter(self, orig_height, orig_width): rand_scale = np.random.uniform(self.scale[0], self.scale[1]) max_hw = max(orig_height, orig_width) h = w = round(max_hw / rand_scale) top = round(max(0, np.random.uniform(0, orig_height - h))) left = round(max(0, np.random.uniform(0, orig_width - w))) return top, left, h, w def __call__(self, mod_dict, crop_settings): if crop_settings is not None: raise ValueError("Crop settings are provided but not used by this augmenter.") image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) orig_width, orig_height = image.size orig_size = (orig_height, orig_width) crop_coords = self.scale_jitter(orig_height, orig_width) flip = random.random() < self.hflip rand_aug_idx = None return crop_coords, flip, orig_size, self.target_size, rand_aug_idx class EmptyAugmenter(AbstractImageAugmenter): def __init__(self): pass def __call__(self, mod_dict, crop_settings): return None, None, None, None, None