Spaces:
Running
Running
# Copyright (c) EPFL VILAB. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit | |
# https://github.com/facebookresearch/dino | |
# https://github.com/facebookresearch/moco-v3 | |
# https://github.com/microsoft/unilm/tree/master/beit | |
# https://github.com/BUPT-PRIV/MAE-priv | |
# https://github.com/facebookresearch/mae | |
# -------------------------------------------------------- | |
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from .multimae_utils import build_2d_sincos_posemb, pair, trunc_normal_ | |
class PatchedInputAdapter(nn.Module): | |
"""Adapter for spatial inputs, like images or feature maps. | |
Creates tokens from patches over the image. | |
:param num_channels: Number of input channels of the image/feature map | |
:param stride_level: Stride level compared to the full-sized image. | |
E.g. 4 for 1/4th the size of the image. | |
:param patch_size_full: Int or tuple of the patch size over the full image size. | |
Patch size for smaller inputs will be computed accordingly. | |
:param dim_tokens: Dimension of output tokens. Can be set using init method. | |
:param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
:param learnable_pos_emb: Set to True to learn positional embeddings instead | |
:param image_size: Default image size. Used to initialize size of positional embeddings. | |
""" | |
def __init__(self, | |
num_channels: int, | |
stride_level: int, | |
patch_size_full: Union[int, Tuple[int,int]], | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: bool = True, | |
learnable_pos_emb: bool = False, | |
image_size: Union[int, Tuple[int]] = 224): | |
super().__init__() | |
self.num_channels = num_channels | |
self.stride_level = stride_level | |
self.patch_size_full = pair(patch_size_full) | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.learnable_pos_emb = learnable_pos_emb | |
self.image_size = pair(image_size) | |
self.num_patches = (self.image_size[0] // patch_size_full) * (self.image_size[1] // patch_size_full) | |
# Actual patch height and width, taking into account stride of input | |
self.P_H = max(1, self.patch_size_full[0] // stride_level) | |
self.P_W = max(1, self.patch_size_full[1] // stride_level) | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768): | |
""" | |
Initialize parts of encoder that are dependent on dimension of tokens. | |
Should be called when setting up MultiMAE. | |
:param dim_tokens: Dimension of tokens | |
""" | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
h_posemb = self.image_size[0] // (self.stride_level * self.P_H) | |
w_posemb = self.image_size[1] // (self.stride_level * self.P_W) | |
if self.sincos_pos_emb: | |
self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=self.learnable_pos_emb) | |
else: | |
self.pos_emb = nn.Parameter(torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)) | |
trunc_normal_(self.pos_emb, std=0.02) | |
# Image -> tokens projection | |
self.proj = nn.Conv2d( | |
in_channels=self.num_channels, out_channels=self.dim_tokens, | |
kernel_size=(self.P_H, self.P_W), stride=(self.P_H, self.P_W) | |
) | |
def no_weight_decay(self): | |
return {'pos_emb'} | |
def forward(self, x): | |
""" | |
Forward pass through input adapter, transforming image to sequence of tokens. | |
Adds task and positional encodings. | |
:param x: Input image tensor | |
""" | |
B, C, H, W = x.shape | |
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
assert (H % self.P_H == 0) and (W % self.P_W == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}' | |
N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width | |
# Create patches [B, C, H, W] -> [B, (H*W), C] | |
x_patch = rearrange(self.proj(x), 'b d nh nw -> b (nh nw) d') | |
# Create positional embedding | |
x_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bicubic', align_corners=False) | |
x_pos_emb = rearrange(x_pos_emb, 'b d nh nw -> b (nh nw) d') | |
# Add patches and positional embeddings | |
x = x_patch + x_pos_emb | |
return x | |
class SemSegInputAdapter(nn.Module): | |
""" | |
Adapter for spatial inputs, like images or feature maps. | |
Creates tokens from patches over the image. | |
:param num_classes: Number of input semantic classes | |
:param stride_level: Stride level compared to the full-sized image. | |
E.g. 4 for 1/4th the size of the image. | |
:param patch_size_full: Int or tuple of the patch size over the full image size. | |
Patch size for smaller inputs will be computed accordingly. | |
:param dim_tokens: Dimension of output tokens. Can be set using init method. | |
:param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
:param learnable_pos_emb: Set to True to learn positional embeddings instead | |
:param image_size: Default image size. Used to initialize size of positional embeddings. | |
:param dim_class_emb: Dimension of learned class embedding | |
:param interpolate_class_emb: Set to True to average pool class embeddings of each patch | |
:param emb_padding_idx: Padding index (e.g. image border), default is None | |
""" | |
def __init__(self, | |
num_classes: int, | |
stride_level: int, | |
patch_size_full: Union[int, Tuple[int, int]], | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: int = True, | |
learnable_pos_emb: int = False, | |
image_size: Union[int, Tuple[int]] = 224, | |
dim_class_emb: int = 64, | |
interpolate_class_emb: bool = False, | |
emb_padding_idx: int = None | |
): | |
super().__init__() | |
self.num_classes = num_classes | |
self.stride_level = stride_level | |
self.patch_size_full = pair(patch_size_full) | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.learnable_pos_emb = learnable_pos_emb | |
self.image_size = pair(image_size) | |
self.dim_class_emb = dim_class_emb | |
self.interpolate_class_emb = interpolate_class_emb | |
self.emb_padding_idx = emb_padding_idx | |
if self.emb_padding_idx is not None: | |
self.num_classes += 1 | |
# Actual patch height and width, taking into account stride of input | |
self.P_H = max(1, self.patch_size_full[0] // stride_level) | |
self.P_W = max(1, self.patch_size_full[1] // stride_level) | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768): | |
''' | |
Initialize parts of encoder that are dependent on dimension of tokens. | |
Should be called when setting up MultiMAE. | |
:param dim_tokens: Dimension of tokens | |
''' | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
h_posemb = self.image_size[0] // (self.stride_level * self.P_H) | |
w_posemb = self.image_size[1] // (self.stride_level * self.P_W) | |
if self.sincos_pos_emb: | |
self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) | |
self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=self.learnable_pos_emb) | |
else: | |
self.pos_emb = nn.Parameter(torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)) | |
trunc_normal_(self.pos_emb, std=0.02) | |
# Image -> tokens projection | |
self.class_emb = nn.Embedding(num_embeddings=self.num_classes, embedding_dim=self.dim_class_emb, padding_idx=self.emb_padding_idx) | |
trunc_normal_(self.class_emb.weight, std=0.02) | |
if self.interpolate_class_emb: | |
self.proj = nn.Sequential( | |
nn.Upsample(scale_factor=(1 / self.P_H, 1 / self.P_W), | |
mode='bilinear'), # Actually a downsample operation | |
nn.Conv2d(in_channels=self.dim_class_emb, out_channels=self.dim_tokens, | |
kernel_size=1, stride=1), | |
) | |
else: | |
self.proj = nn.Conv2d( | |
in_channels=self.dim_class_emb, out_channels=self.dim_tokens, | |
kernel_size=(self.P_H, self.P_W), stride=(self.P_H, self.P_W) | |
) | |
def no_weight_decay(self): | |
return {'pos_emb', 'class_emb'} | |
def forward(self, x): | |
''' | |
Forward pass through input adapter, transforming image to sequence of tokens. | |
Adds task and positional encodings. | |
:param x: Input image tensor | |
''' | |
B, H, W = x.shape | |
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' | |
assert (H % self.P_H == 0) and ( | |
W % self.P_W == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}' | |
N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width | |
# Map to embedding | |
x = rearrange(self.class_emb(x), 'b nh nw c -> b c nh nw') | |
# Create patches [B, C, H, W] -> [B, (H*W), C] | |
x_patch = rearrange(self.proj(x), 'b d nh nw -> b (nh nw) d') | |
# Create positional embedding | |
x_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bilinear') | |
x_pos_emb = rearrange(x_pos_emb, 'b d nh nw -> b (nh nw) d') | |
# Add patches and positional embeddings | |
x = x_patch + x_pos_emb | |
return x | |