from typing import Optional import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as transforms from .task_configs import task_parameters MAKE_RESCALE_0_1_NEG1_POS1 = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan) RESCALE_0_1_NEG1_POS1 = transforms.Normalize([0.5], [0.5]) # This needs to be different depending on num out chans MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0]) RESCALE_0_255_NEG1_POS1 = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255]) MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0]) STD_IMAGENET = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # For semantic segmentation transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long() # avoids normalizing # Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1] transform_8bit = transforms.Compose([ transforms.ToTensor(), ]) # Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels def transform_8bit_n_channel(n_channel=1, crop_channels=True): if crop_channels: crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x else: crop_channels_fn = lambda x: x return transforms.Compose([ transforms.ToTensor(), crop_channels_fn, ]) # Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1]. def transform_16bit_single_channel(im): im = transforms.ToTensor()(np.array(im)) im = im.float() / (2 ** 16 - 1.0) return im def make_valid_mask(mask_float, max_pool_size=4): ''' Creates a mask indicating the valid parts of the image(s). Enlargens masked area using a max pooling operation. Args: mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader. max_pool_size: Parameter to choose how much to enlarge masked area. ''' squeeze = False if len(mask_float.shape) == 3: mask_float = mask_float.unsqueeze(0) squeeze = True _, _, h, w = mask_float.shape mask_float = 1 - mask_float mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size) mask_float = F.interpolate(mask_float, (h, w), mode='nearest') mask_valid = mask_float == 0 mask_valid = mask_valid[0] if squeeze else mask_valid return mask_valid def task_transform(file, task: str, image_size=Optional[int]): transform = None if task in ['rgb']: transform = transforms.Compose([ transform_8bit, STD_IMAGENET ]) elif task in ['normal']: transform = transform_8bit elif task in ['mask_valid']: transform = transforms.Compose([ transforms.ToTensor(), make_valid_mask ]) elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture']: transform = transform_16bit_single_channel elif task in ['edge_occlusion']: transform = transforms.Compose([ transform_16bit_single_channel, transforms.GaussianBlur(3, sigma=1) ]) elif task in ['principal_curvature', 'curvature']: transform = transform_8bit_n_channel(2) elif task in ['reshading']: transform = transform_8bit_n_channel(1) elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']: # this is stored as 1 channel image (H,W) where each pixel value is a different class transform = transform_dense_labels elif task in ['class_object', 'class_scene']: transform = torch.Tensor image_size = None else: transform = None if 'threshold_min' in task_parameters[task]: threshold = task_parameters[task]['threshold_min'] transform = transforms.Compose([ transform, lambda x: torch.threshold(x, threshold, 0.0) ]) if 'clamp_to' in task_parameters[task]: minn, maxx = task_parameters[task]['clamp_to'] if minn > 0: raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task)) transform = transforms.Compose([ transform, lambda x: torch.clamp(x, minn, maxx), MAKE_RESCALE_0_MAX_0_POS1(maxx) ]) if image_size is not None: if task == 'fragments': resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0) transform = transforms.Compose([ transform, resize_frag ]) else: resize_method = transforms.InterpolationMode.BILINEAR if task in ['rgb'] else transforms.InterpolationMode.NEAREST transform = transforms.Compose([ transforms.Resize(image_size, resize_method), transform ]) if transform is not None: file = transform(file) return file