3DOI / monoarti /transformer.py
shengyi-qian's picture
init
9afcee2
raw
history blame
No virus
18.6 kB
from typing import List, Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from . import axis_ops, ilnr_loss
from .vnl_loss import VNL_Loss
from .midas_loss import MidasLoss
from .detr.detr import MLP
from .detr.transformer import Transformer
from .detr.backbone import Backbone, Joiner
from .detr.position_encoding import PositionEmbeddingSine
from .detr.misc import nested_tensor_from_tensor_list, interpolate
from .detr import box_ops
from .detr.segmentation import (
MHAttentionMap, MaskHeadSmallConv, dice_loss, sigmoid_focal_loss
)
class INTR(torch.nn.Module):
"""
Implement Interaction 3D Transformer.
"""
def __init__(
self,
backbone_name = 'resnet50',
image_size = [192, 256],
ignore_index = -100,
num_classes = 1,
num_queries = 15,
freeze_backbone = False,
transformer_hidden_dim = 256,
transformer_dropout = 0.1,
transformer_nhead = 8,
transformer_dim_feedforward = 2048,
transformer_num_encoder_layers = 6,
transformer_num_decoder_layers = 6,
transformer_normalize_before = False,
transformer_return_intermediate_dec = True,
layers_movable = 3,
layers_rigid = 3,
layers_kinematic = 3,
layers_action = 3,
layers_axis = 2,
layers_affordance = 3,
affordance_focal_alpha = 0.95,
axis_bins = 30,
depth_on = True,
):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self._ignore_index = ignore_index
self._image_size = image_size
self._axis_bins = axis_bins
self._affordance_focal_alpha = affordance_focal_alpha
# backbone
backbone_base = Backbone(backbone_name, not freeze_backbone, True, False)
N_steps = transformer_hidden_dim // 2
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
backbone = Joiner(backbone_base, position_embedding)
backbone.num_channels = backbone_base.num_channels
self.backbone = backbone
self.transformer = Transformer(
d_model=transformer_hidden_dim,
dropout=transformer_dropout,
nhead=transformer_nhead,
dim_feedforward=transformer_dim_feedforward,
num_encoder_layers=transformer_num_encoder_layers,
num_decoder_layers=transformer_num_decoder_layers,
normalize_before=transformer_normalize_before,
return_intermediate_dec=transformer_return_intermediate_dec,
)
hidden_dim = self.transformer.d_model
self.hidden_dim = hidden_dim
nheads = self.transformer.nhead
self.num_queries = num_queries
# before transformer, input_proj maps 2048 channel resnet50 output to 512-channel
# transformer input
self.input_proj = nn.Conv2d(self.backbone.num_channels, hidden_dim, kernel_size=1)
# query mlp maps 2d keypoint coordinates to 256-dim positional encoding
self.query_mlp = MLP(2, hidden_dim, hidden_dim, 2)
# bbox MLP
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
if layers_movable > 1:
self.movable_embed = MLP(hidden_dim, hidden_dim, 3, layers_movable)
elif layers_movable == 1:
self.movable_embed = nn.Linear(hidden_dim, 3)
else:
raise ValueError("not supported")
if layers_rigid > 1:
self.rigid_embed = MLP(hidden_dim, hidden_dim, 2, layers_rigid)
elif layers_rigid == 1:
#self.rigid_embed = nn.Linear(hidden_dim, 2)
self.rigid_embed = nn.Linear(hidden_dim, 3)
else:
raise ValueError("not supported")
if layers_kinematic > 1:
self.kinematic_embed = MLP(hidden_dim, hidden_dim, 3, layers_kinematic)
elif layers_kinematic == 1:
self.kinematic_embed = nn.Linear(hidden_dim, 3)
else:
raise ValueError("not supported")
if layers_action > 1:
self.action_embed = MLP(hidden_dim, hidden_dim, 3, layers_action)
elif layers_action == 1:
self.action_embed = nn.Linear(hidden_dim, 3)
else:
raise ValueError("not supported")
if layers_axis > 1:
#self.axis_embed = MLP(hidden_dim, hidden_dim, 4, layers_axis)
self.axis_embed = MLP(hidden_dim, hidden_dim, 3, layers_axis)
# classification
# self.axis_embed = MLP(hidden_dim, hidden_dim, self._axis_bins * 2, layers_axis)
elif layers_axis == 1:
self.axis_embed = nn.Linear(hidden_dim, 3)
else:
raise ValueError("not supported")
# affordance
if layers_affordance > 1:
self.aff_embed = MLP(hidden_dim, hidden_dim, 2, layers_affordance)
elif layers_affordance == 1:
self.aff_embed = nn.Linear(hidden_dim, 2)
else:
raise ValueError("not supported")
# affordance head
self.aff_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
self.aff_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads)
# mask head
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads)
# depth head
self._depth_on = depth_on
if self._depth_on:
self.depth_query = nn.Embedding(1, hidden_dim)
self.depth_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
self.depth_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim, nheads)
self.depth_loss = ilnr_loss.MEADSTD_TANH_NORM_Loss()
fov = torch.tensor(1.0)
focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item()
self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size)
self.midas_loss = MidasLoss(alpha=0.1)
def freeze_layers(self, names):
"""
Freeze layers in 'names'.
"""
for name, param in self.named_parameters():
for freeze_name in names:
if freeze_name in name:
#print(name + ' ' + freeze_name)
param.requires_grad = False
def forward(
self,
image: torch.Tensor,
valid: torch.Tensor,
keypoints: torch.Tensor,
bbox: torch.Tensor,
masks: torch.Tensor,
movable: torch.Tensor,
rigid: torch.Tensor,
kinematic: torch.Tensor,
action: torch.Tensor,
affordance: torch.Tensor,
affordance_map: torch.FloatTensor,
depth: torch.Tensor,
axis: torch.Tensor,
fov: torch.Tensor,
backward: bool = True,
**kwargs,
):
"""
Model forward. Set backward = False if the model is inference only.
"""
device = image.device
# number of queries can be different in runtime
num_queries = keypoints.shape[1]
# DETR forward
samples = image
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
bs = features[-1].tensors.shape[0]
src, mask = features[-1].decompose()
assert mask is not None
# sample keypoint queries from the positional embedding
use_sine = False
if use_sine:
anchors = keypoints.float()
anchors_float = anchors.clone()
anchors_float = anchors_float.reshape(-1, 2)
anchors_float[:, 0] = ((anchors_float[:, 0] / self._image_size[1]) - 0.5) * 2
anchors_float[:, 1] = ((anchors_float[:, 1] / self._image_size[0]) - 0.5) * 2
anchors_float = anchors_float.unsqueeze(1).unsqueeze(1)
# 4x256x1x1
keypoint_queries = F.grid_sample(
#pos[0].repeat(self.num_queries, 1, 1, 1),
pos[-1].repeat(self.num_queries, 1, 1, 1),
anchors_float,
mode='nearest',
align_corners=True
)
# 4 x 10 (number of object queires) x 256
keypoint_queries = keypoint_queries.squeeze().reshape(-1, self.num_queries, self.hidden_dim)
else:
# use learned MLP to map postional encoding
anchors = keypoints.float()
anchors_float = anchors.clone()
anchors_float[:, :, 0] = ((anchors_float[:, :, 0] / self._image_size[1]) - 0.5) * 2
anchors_float[:, :, 1] = ((anchors_float[:, :, 1] / self._image_size[0]) - 0.5) * 2
keypoint_queries = self.query_mlp(anchors_float)
# append depth_query if the model is learning depth.
if self._depth_on:
bs = keypoint_queries.shape[0]
depth_query = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1)
keypoint_queries = torch.cat((keypoint_queries, depth_query), dim=1)
# transformer forward
src_proj = self.input_proj(src)
hs, memory = self.transformer(src_proj, mask, keypoint_queries, pos[-1])
if self._depth_on:
depth_hs = hs[-1][:, -1:]
ord_hs = hs[-1][:, :-1]
else:
ord_hs = hs[-1]
outputs_coord = self.bbox_embed(ord_hs).sigmoid()
outputs_movable = self.movable_embed(ord_hs)
outputs_rigid = self.rigid_embed(ord_hs)
outputs_kinematic = self.kinematic_embed(ord_hs)
outputs_action = self.action_embed(ord_hs)
# axis forward
outputs_axis = self.axis_embed(ord_hs).sigmoid()
# sigmoid range is 0 to 1, we want it to be -1 to 1
outputs_axis = (outputs_axis - 0.5) * 2
# affordance forward
bbox_aff = self.aff_attention(ord_hs, memory, mask=mask)
aff_masks = self.aff_head(src_proj, bbox_aff, [features[2].tensors, features[1].tensors, features[0].tensors])
outputs_aff_masks = aff_masks.view(bs, num_queries, aff_masks.shape[-2], aff_masks.shape[-1])
# mask forward
bbox_mask = self.bbox_attention(ord_hs, memory, mask=mask)
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
outputs_seg_masks = seg_masks.view(bs, num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
# depth forward
outputs_depth = None
if self._depth_on:
depth_att = self.depth_attention(depth_hs, memory, mask=mask)
depth_masks = self.depth_head(
src_proj,
depth_att,
[features[2].tensors, features[1].tensors, features[0].tensors]
)
outputs_depth = depth_masks.view(bs, 1, depth_masks.shape[-2], depth_masks.shape[-1])
out = {
'pred_boxes': box_ops.box_cxcywh_to_xyxy(outputs_coord),
'pred_movable': outputs_movable,
'pred_rigid': outputs_rigid,
'pred_kinematic': outputs_kinematic,
'pred_action': outputs_action,
'pred_masks': outputs_seg_masks,
'pred_axis': outputs_axis,
'pred_depth': outputs_depth,
'pred_affordance': outputs_aff_masks,
}
if not backward:
return out
# backward
src_boxes = outputs_coord
target_boxes = bbox
target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes)
bbox_valid = bbox[:, :, 0] > -0.5
num_boxes = bbox_valid.sum()
if num_boxes == 0:
out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
loss_bbox = loss_bbox * bbox_valid.unsqueeze(2) # remove invalid
out['loss_bbox'] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes).reshape(-1, 4),
box_ops.box_cxcywh_to_xyxy(target_boxes).reshape(-1, 4),
)).reshape(-1, self.num_queries)
loss_giou = loss_giou * bbox_valid # remove invalid
out['loss_giou'] = loss_giou.sum() / num_boxes
# affordance
affordance_valid = affordance[:, :, 0] > -0.5
if affordance_valid.sum() == 0:
out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
src_aff_masks = outputs_aff_masks[affordance_valid]
tgt_aff_masks = affordance_map[affordance_valid]
src_aff_masks = src_aff_masks.flatten(1)
tgt_aff_masks = tgt_aff_masks.flatten(1)
loss_aff = sigmoid_focal_loss(
src_aff_masks,
tgt_aff_masks,
affordance_valid.sum(),
alpha=self._affordance_focal_alpha,
)
out['loss_affordance'] = loss_aff
# axis
axis_valid = axis[:, :, 0] > 0.0
num_axis = axis_valid.sum()
if num_axis == 0:
out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
# regress angle
src_axis_angle = outputs_axis[axis_valid]
src_axis_angle_norm = F.normalize(src_axis_angle[:, :2])
src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1)
target_axis_xyxy = axis[axis_valid]
axis_center = target_boxes[axis_valid].clone()
axis_center[:, 2:] = axis_center[:, :2]
target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center)
loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis
loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis
out['loss_axis_angle'] = loss_axis_angle
out['loss_axis_offset'] = loss_axis_offset
src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center)
target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center)
axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy)
loss_eascore = 1 - axis_eascore
out['loss_eascore'] = loss_eascore.mean()
loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index)
if torch.isnan(loss_movable):
loss_movable = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_movable'] = loss_movable
loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index)
if torch.isnan(loss_rigid):
loss_rigid = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_rigid'] = loss_rigid
loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index)
if torch.isnan(loss_kinematic):
loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_kinematic'] = loss_kinematic
loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index)
if torch.isnan(loss_action):
loss_action = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_action'] = loss_action
# depth backward
if self._depth_on:
# (bs, 1, H, W)
src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False)
src_depths = src_depths.clamp(min=0.0, max=1.0)
tgt_depths = depth.unsqueeze(1) # (bs, H, W)
valid_depth = depth[:, 0, 0] > 0
if valid_depth.any():
src_depths = src_depths[valid_depth]
tgt_depths = tgt_depths[valid_depth]
depth_mask = tgt_depths > 1e-8
midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask)
loss_vnl = self.vnl_loss(tgt_depths, src_depths)
out['loss_depth'] = midas_loss
out['loss_vnl'] = loss_vnl
else:
out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device)
# mask backward
tgt_masks = masks
src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False)
valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10
if valid_mask.sum() == 0:
out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
num_masks = valid_mask.sum()
src_masks = src_masks[valid_mask]
tgt_masks = tgt_masks[valid_mask]
src_masks = src_masks.flatten(1)
tgt_masks = tgt_masks.flatten(1)
tgt_masks = tgt_masks.view(src_masks.shape)
out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks)
out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks)
return out