File size: 5,202 Bytes
3b49518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from .task_configs import task_parameters

MAKE_RESCALE_0_1_NEG1_POS1   = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan)
RESCALE_0_1_NEG1_POS1        = transforms.Normalize([0.5], [0.5])  # This needs to be different depending on num out chans
MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0])
RESCALE_0_255_NEG1_POS1      = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255])
MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0])
STD_IMAGENET = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    

# For semantic segmentation
transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long()  # avoids normalizing

# Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1]
transform_8bit = transforms.Compose([
        transforms.ToTensor(),
    ])
    
# Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels
def transform_8bit_n_channel(n_channel=1, crop_channels=True):
    if crop_channels:
        crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x
    else: 
        crop_channels_fn = lambda x: x
    return transforms.Compose([
            transforms.ToTensor(),
            crop_channels_fn,
        ])

# Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1].
def transform_16bit_single_channel(im):
    im = transforms.ToTensor()(np.array(im))
    im = im.float() / (2 ** 16 - 1.0) 
    return im

def make_valid_mask(mask_float, max_pool_size=4):
    '''
        Creates a mask indicating the valid parts of the image(s).
        Enlargens masked area using a max pooling operation.

        Args:
            mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader.
            max_pool_size: Parameter to choose how much to enlarge masked area.
    '''
    squeeze = False
    if len(mask_float.shape) == 3:
        mask_float = mask_float.unsqueeze(0)
        squeeze = True
    _, _, h, w = mask_float.shape
    mask_float = 1 - mask_float
    mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size)
    mask_float = F.interpolate(mask_float, (h, w), mode='nearest')
    mask_valid = mask_float == 0
    mask_valid = mask_valid[0] if squeeze else mask_valid
    return mask_valid


def task_transform(file, task: str, image_size=Optional[int]):
    transform = None

    if task in ['rgb']:
        transform = transforms.Compose([
            transform_8bit,
            STD_IMAGENET
        ])
    elif task in ['normal']:
        transform = transform_8bit
    elif task in ['mask_valid']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            make_valid_mask
        ])
    elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture']:
        transform = transform_16bit_single_channel
    elif task in ['edge_occlusion']:
        transform = transforms.Compose([
            transform_16bit_single_channel,
            transforms.GaussianBlur(3, sigma=1)
        ])
    elif task in ['principal_curvature', 'curvature']:
        transform = transform_8bit_n_channel(2)
    elif task in ['reshading']:
        transform = transform_8bit_n_channel(1)
    elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']:  # this is stored as 1 channel image (H,W) where each pixel value is a different class
        transform = transform_dense_labels
    elif task in ['class_object', 'class_scene']:
        transform = torch.Tensor
        image_size = None
    else:
        transform = None
    
    if 'threshold_min' in task_parameters[task]:
        threshold = task_parameters[task]['threshold_min']
        transform = transforms.Compose([
            transform,
            lambda x: torch.threshold(x, threshold, 0.0)
        ])
    if 'clamp_to' in task_parameters[task]:
        minn, maxx = task_parameters[task]['clamp_to']
        if minn > 0:
            raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task))
        transform = transforms.Compose([
            transform,
            lambda x: torch.clamp(x, minn, maxx),
            MAKE_RESCALE_0_MAX_0_POS1(maxx)
        ])
    

    if image_size is not None:
        if task == 'fragments':
            resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0)
            transform = transforms.Compose([
                transform,
                resize_frag
            ])
        else:
            resize_method = transforms.InterpolationMode.BILINEAR if task in ['rgb'] else transforms.InterpolationMode.NEAREST
            transform = transforms.Compose([
                transforms.Resize(image_size, resize_method),
                transform
            ])

    if transform is not None:
        file = transform(file)
        
    return file