File size: 5,193 Bytes
3b49518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) EPFL VILAB.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on BEiT, timm, DINO, DeiT and MAE-priv code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/BUPT-PRIV/MAE-priv
# --------------------------------------------------------

import numpy as np
import torch

try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
except:
    print('albumentations not installed')
# import cv2
import torch.nn.functional as F

from utils import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, NYU_MEAN,
                   NYU_STD, PAD_MASK_VALUE)
from utils.dataset_folder import ImageFolder, MultiTaskImageFolder


def nyu_transform(train, additional_targets, input_size=512, color_aug=False):
    if train:
        augs = [
            A.SmallestMaxSize(max_size=input_size, p=1),
            A.HorizontalFlip(p=0.5),
        ]
        if color_aug: augs += [
                # Color jittering from BYOL https://arxiv.org/pdf/2006.07733.pdf
                A.ColorJitter(
                    brightness=0.1255,
                    contrast=0.4,
                    saturation=[0.5, 1.5],
                    hue=[-0.2, 0.2],
                    p=0.5
                ),
                A.ToGray(p=0.3),
            ]
        augs += [
            A.RandomCrop(height=input_size, width=input_size, p=1),
            A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
            ToTensorV2(),
        ]

        transform = A.Compose(augs, additional_targets=additional_targets)

    else:
        transform = A.Compose([
            A.SmallestMaxSize(max_size=input_size, p=1),
            A.CenterCrop(height=input_size, width=input_size),
            A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
            ToTensorV2(),
        ], additional_targets=additional_targets)

    return transform


def simple_regression_transform(train, additional_targets, input_size=512, pad_value=(128, 128, 128), pad_mask_value=PAD_MASK_VALUE):

    if train:
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.LongestMaxSize(max_size=input_size, p=1),
            A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.5),  # Color jittering from MoCo-v3 / DINO
            A.RandomScale(scale_limit=(0.1 - 1, 2.0 - 1), p=1),  # This is LSJ (0.1, 2.0)
            A.PadIfNeeded(min_height=input_size, min_width=input_size,
                          position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT,
                          border_mode=cv2.BORDER_CONSTANT,
                          value=pad_value, mask_value=pad_mask_value),
            A.RandomCrop(height=input_size, width=input_size, p=1),
            A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
            ToTensorV2(),
        ], additional_targets=additional_targets)

    else:
        transform = A.Compose([
            A.LongestMaxSize(max_size=input_size, p=1),
            A.PadIfNeeded(min_height=input_size, min_width=input_size,
                          position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT,
                          border_mode=cv2.BORDER_CONSTANT,
                          value=pad_value, mask_value=pad_mask_value),
            A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
            ToTensorV2(),
        ], additional_targets=additional_targets)

    return transform


class DataAugmentationForRegression(object):

    def __init__(self, transform, mask_value=0.0):
        self.transform = transform
        self.mask_value = mask_value

    def __call__(self, task_dict):

        # Need to replace rgb key to image
        task_dict['image'] = task_dict.pop('rgb')
        # Convert to np.array
        task_dict = {k: np.array(v) for k, v in task_dict.items()}

        task_dict = self.transform(**task_dict)

        task_dict['depth'] = (task_dict['depth'].float() - NYU_MEAN)/NYU_STD

        # And then replace it back to rgb
        task_dict['rgb'] = task_dict.pop('image')

        task_dict['mask_valid'] = (task_dict['mask_valid'] == 255)[None]

        for task in task_dict:
            if task in ['depth']:
                img = task_dict[task]
                if 'mask_valid' in task_dict:
                    mask_valid = task_dict['mask_valid'].squeeze()
                    img[~mask_valid] = self.mask_value
                task_dict[task] = img.unsqueeze(0)
            elif task in ['rgb']:
                task_dict[task] = task_dict[task].to(torch.float)

        return task_dict


def build_regression_dataset(args, data_path, transform, max_images=None):
    transform = DataAugmentationForRegression(transform=transform)

    return MultiTaskImageFolder(data_path, args.all_domains, transform=transform, prefixes=None, max_images=max_images)