File size: 10,566 Bytes
3b49518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Copyright (c) EPFL VILAB.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on BEiT, timm, DINO, DeiT and MAE-priv code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/BUPT-PRIV/MAE-priv
# --------------------------------------------------------
from typing import Dict, Tuple

import numpy as np
import torch

try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
except:
    print('albumentations not installed')
import cv2
import torch.nn.functional as F

from utils import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PAD_MASK_VALUE,
                   SEG_IGNORE_INDEX)

from .dataset_folder import ImageFolder, MultiTaskImageFolder


def simple_transform(train: bool,
                     additional_targets: Dict[str, str],
                     input_size: int =512,
                     pad_value: Tuple[int, int, int] = (128, 128, 128),
                     pad_mask_value: int =PAD_MASK_VALUE):
    """Default transform for semantic segmentation, applied on all modalities

    During training:
        1. Random horizontal Flip
        2. Rescaling so that longest side matches input size
        3. Color jitter (for RGB-modality only)
        4. Large scale jitter (LSJ)
        5. Padding
        6. Random crop to given size
        7. Normalization with ImageNet mean and std dev

    During validation / test:
        1. Rescaling so that longest side matches given size
        2. Padding
        3. Normalization with ImageNet mean and std dev
     """

    if train:
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.LongestMaxSize(max_size=input_size, p=1),
            A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.5),  # Color jittering from MoCo-v3 / DINO
            A.RandomScale(scale_limit=(0.1 - 1, 2.0 - 1), p=1),  # This is LSJ (0.1, 2.0)
            A.PadIfNeeded(min_height=input_size, min_width=input_size,
                          position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT,
                          border_mode=cv2.BORDER_CONSTANT,
                          value=pad_value, mask_value=pad_mask_value),
            A.RandomCrop(height=input_size, width=input_size, p=1),
            A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
            ToTensorV2(),
        ], additional_targets=additional_targets)

    else:
        transform = A.Compose([
            A.LongestMaxSize(max_size=input_size, p=1),
            A.PadIfNeeded(min_height=input_size, min_width=input_size,
                          position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT,
                          border_mode=cv2.BORDER_CONSTANT,
                          value=pad_value, mask_value=pad_mask_value),
            A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
            ToTensorV2(),
        ], additional_targets=additional_targets)

    return transform


class DataAugmentationForSemSeg(object):
    """Data transform / augmentation for semantic segmentation downstream tasks.
    """

    def __init__(self, transform, seg_num_classes, seg_ignore_index=SEG_IGNORE_INDEX, standardize_depth=True,
                 seg_reduce_zero_label=False, seg_use_void_label=False):

        self.transform = transform
        self.seg_num_classes = seg_num_classes
        self.seg_ignore_index = seg_ignore_index
        self.standardize_depth = standardize_depth
        self.seg_reduce_zero_label = seg_reduce_zero_label
        self.seg_use_void_label = seg_use_void_label

    @staticmethod
    def standardize_depth_map(img, mask_valid=None, trunc_value=0.1):
        img[img == PAD_MASK_VALUE] = torch.nan
        if mask_valid is not None:
            # This is if we want to apply masking before standardization
            img[~mask_valid] = torch.nan
        sorted_img = torch.sort(torch.flatten(img))[0]
        # Remove nan, nan at the end of sort
        num_nan = sorted_img.isnan().sum()
        if num_nan > 0:
            sorted_img = sorted_img[:-num_nan]
        # Remove outliers
        trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))]
        trunc_mean = trunc_img.mean()
        trunc_var = trunc_img.var()
        eps = 1e-6
        # Replace nan by mean
        img = torch.nan_to_num(img, nan=trunc_mean)
        # Standardize
        img = (img - trunc_mean) / torch.sqrt(trunc_var + eps)
        return img

    def seg_adapt_labels(self, img):
        if self.seg_use_void_label:
            # Set void label to num_classes
            if self.seg_reduce_zero_label:
                pad_replace = self.seg_num_classes + 1
            else:
                pad_replace = self.seg_num_classes
        else:
            pad_replace = self.seg_ignore_index
        img[img == PAD_MASK_VALUE] = pad_replace

        if self.seg_reduce_zero_label:
            img[img == 0] = self.seg_ignore_index
            img = img - 1
            img[img == self.seg_ignore_index - 1] = self.seg_ignore_index

        return img

    def __call__(self, task_dict):

        # Need to replace rgb key to image
        task_dict['image'] = task_dict.pop('rgb')
        # Convert to np.array
        task_dict = {k: np.array(v) for k, v in task_dict.items()}

        task_dict = self.transform(**task_dict)

        # And then replace it back to rgb
        task_dict['rgb'] = task_dict.pop('image')

        for task in task_dict:
            if task in ['depth']:
                img = task_dict[task].to(torch.float)
                if self.standardize_depth:
                    # Mask valid set to None here, as masking is applied after standardization
                    img = self.standardize_depth_map(img, mask_valid=None)
                if 'mask_valid' in task_dict:
                    mask_valid = (task_dict['mask_valid'] == 255).squeeze()
                    img[~mask_valid] = 0.0
                task_dict[task] = img.unsqueeze(0)
            elif task in ['rgb']:
                task_dict[task] = task_dict[task].to(torch.float)
            elif task in ['semseg']:
                img = task_dict[task].to(torch.long)
                img = self.seg_adapt_labels(img)
                task_dict[task] = img
            elif task in ['pseudo_semseg']:
                # If it's pseudo-semseg, then it's an input modality and should therefore be resized
                img = task_dict[task]
                img = F.interpolate(img[None,None,:,:], scale_factor=0.25, mode='nearest').long()[0,0]
                task_dict[task] = img

        return task_dict


def build_semseg_dataset(args, data_path, transform, max_images=None):
    transform = DataAugmentationForSemSeg(transform=transform, seg_num_classes=args.num_classes,
                                          standardize_depth=args.standardize_depth,
                                          seg_reduce_zero_label=args.seg_reduce_zero_label,
                                          seg_use_void_label=args.seg_use_void_label)
    prefixes = {'depth': 'pseudo_'} if args.load_pseudo_depth else None
    return MultiTaskImageFolder(data_path, args.all_domains, transform=transform, prefixes=prefixes, max_images=max_images)


def ade_classes():
    """ADE20K class names for external use."""
    return [
        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
        'clock', 'flag'
    ]


def hypersim_classes():
    """Hypersim class names for external use."""
    return [
        'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 
        'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 
        'curtain', 'dresser', 'pillow', 'mirror', 'floor-mat', 'clothes', 
        'ceiling', 'books', 'fridge', 'TV', 'paper', 'towel', 'shower-curtain', 
        'box', 'white-board', 'person', 'night-stand', 'toilet', 'sink', 'lamp',
        'bathtub', 'bag', 'other-struct', 'other-furntr', 'other-prop'
    ]


def nyu_v2_40_classes():
    """NYUv2 40 class names for external use."""
    return [
        'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 
        'window', 'bookshelf', 'picture', 'counter', 'blinds', 'desk', 'shelves', 
        'curtain', 'dresser', 'pillow', 'mirror', 'floor-mat', 'clothes', 
        'ceiling', 'books', 'fridge', 'TV', 'paper', 'towel', 'shower-curtain', 
        'box', 'white-board', 'person', 'night-stand', 'toilet', 'sink', 'lamp',
        'bathtub', 'bag', 'other-struct', 'other-furntr', 'other-prop'
    ]