3DOI / monoarti /sam_transformer.py
shengyi-qian's picture
init
9afcee2
raw
history blame
No virus
17.3 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from typing import Any, Dict, List, Tuple
from .sam.image_encoder import ImageEncoderViT
from .sam.mask_decoder import MaskDecoder
from .sam.prompt_encoder import PromptEncoder
from .detr import box_ops
from .detr.segmentation import dice_loss, sigmoid_focal_loss
from .detr.misc import nested_tensor_from_tensor_list, interpolate
from . import axis_ops, ilnr_loss #, pwnp_loss
from .vnl_loss import VNL_Loss
from .midas_loss import MidasLoss
class SamTransformer(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
affordance_decoder: MaskDecoder,
depth_decoder: MaskDecoder,
transformer_hidden_dim: int,
backbone_name: str,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.affordance_decoder = affordance_decoder
# depth head
self.depth_decoder = depth_decoder
self.depth_query = nn.Embedding(2, transformer_hidden_dim)
fov = torch.tensor(1.0)
image_size = (768, 1024)
focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item()
self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size)
self.midas_loss = MidasLoss(alpha=0.1)
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
# if backbone_name == 'vit_h':
# checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_h_4b8939.pth')
# elif backbone_name == 'vit_l':
# checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_l_0b3195.pth')
# elif backbone_name == 'vit_b':
# checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_b_01ec64.pth')
# else:
# raise ValueError
# with open(checkpoint_path, "rb") as f:
# state_dict = torch.load(f)
# self.load_state_dict(state_dict, strict=False)
# self.affordance_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False)
# self.depth_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False)
self.num_queries = 15
self._affordance_focal_alpha = 0.95
self._ignore_index = -100
@property
def device(self) -> Any:
return self.pixel_mean.device
def freeze_layers(self, names):
"""
Freeze layers in 'names'.
"""
for name, param in self.named_parameters():
for freeze_name in names:
if freeze_name in name:
param.requires_grad = False
def forward(
self,
image: torch.Tensor,
valid: torch.Tensor,
keypoints: torch.Tensor,
bbox: torch.Tensor,
masks: torch.Tensor,
movable: torch.Tensor,
rigid: torch.Tensor,
kinematic: torch.Tensor,
action: torch.Tensor,
affordance: torch.Tensor,
affordance_map: torch.FloatTensor,
depth: torch.Tensor,
axis: torch.Tensor,
fov: torch.Tensor,
backward: bool = True,
**kwargs,
):
device = image.device
multimask_output = False
# image encoder
# pad image to square
h, w = image.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(image, (0, padw, 0, padh))
image_embeddings = self.image_encoder(x)
outputs_seg_masks = []
outputs_movable = []
outputs_rigid = []
outputs_kinematic = []
outputs_action = []
outputs_axis = []
outputs_boxes = []
outputs_aff_masks = []
outputs_depth = []
for idx, curr_embedding in enumerate(image_embeddings):
point_coords = keypoints[idx].unsqueeze(1)
point_labels = torch.ones_like(point_coords[:, :, 0])
points = (point_coords, point_labels)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=None,
masks=None,
)
# mask decoder
low_res_masks, iou_predictions, output_movable, output_rigid, output_kinematic, output_action, output_axis = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
output_mask = self.postprocess_masks(
low_res_masks,
input_size=image.shape[-2:],
original_size=(768, 1024),
)
outputs_seg_masks.append(output_mask[:, 0])
outputs_movable.append(output_movable[:, 0])
outputs_rigid.append(output_rigid[:, 0])
outputs_kinematic.append(output_kinematic[:, 0])
outputs_action.append(output_action[:, 0])
outputs_axis.append(output_axis[:, 0])
# convert masks to boxes for evaluation
pred_mask_bbox = (output_mask[:, 0].clone() > 0.0).long()
empty_mask = pred_mask_bbox.sum(dim=-1).sum(dim=-1)
pred_mask_bbox[empty_mask == 0] += 1
pred_boxes = torchvision.ops.masks_to_boxes(pred_mask_bbox)
#pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / self._image_size[1], 1 / self._image_size[0]])
pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / 768, 1 / 1024])
outputs_boxes.append(pred_boxes)
# affordance decoder
low_res_masks, iou_predictions = self.affordance_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
output_aff_masks = self.postprocess_masks(
low_res_masks,
input_size=image.shape[-2:],
original_size=(192, 256),
)
outputs_aff_masks.append(output_aff_masks[:, 0])
# depth decoder
bs = keypoints.shape[0]
#depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1)
depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0)
#depth_dense_embeddings = torch.zeros((bs, 256, 64, 64)).to(dense_embeddings.device)
depth_dense_embeddings = torch.zeros((1, 256, 64, 64)).to(dense_embeddings.device)
low_res_masks, iou_predictions = self.depth_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=depth_sparse_embeddings,
dense_prompt_embeddings=depth_dense_embeddings,
multimask_output=multimask_output,
)
output_depth = self.postprocess_masks(
low_res_masks,
input_size=image.shape[-2:],
original_size=(768, 1024),
)
outputs_depth.append(output_depth[:, 0])
outputs_seg_masks = torch.stack(outputs_seg_masks)
outputs_movable = torch.stack(outputs_movable)
outputs_rigid = torch.stack(outputs_rigid)
outputs_kinematic = torch.stack(outputs_kinematic)
outputs_action = torch.stack(outputs_action)
outputs_axis = torch.stack(outputs_axis)
outputs_boxes = torch.stack(outputs_boxes)
outputs_aff_masks = torch.stack(outputs_aff_masks)
outputs_depth = torch.stack(outputs_depth)
out = {
'pred_boxes': outputs_boxes,
'pred_movable': outputs_movable,
'pred_rigid': outputs_rigid,
'pred_kinematic': outputs_kinematic,
'pred_action': outputs_action,
'pred_masks': outputs_seg_masks,
'pred_axis': outputs_axis,
'pred_depth': outputs_depth,
# 'pred_depth': outputs_seg_masks[:, :1].sigmoid(),
'pred_affordance': outputs_aff_masks,
}
if not backward:
return out
# backward
src_boxes = outputs_boxes
target_boxes = bbox
target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes)
bbox_valid = bbox[:, :, 0] > -0.5
num_boxes = bbox_valid.sum()
out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device)
# affordance
# out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device)
affordance_valid = affordance[:, :, 0] > -0.5
if affordance_valid.sum() == 0:
out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
src_aff_masks = outputs_aff_masks[affordance_valid]
tgt_aff_masks = affordance_map[affordance_valid]
src_aff_masks = src_aff_masks.flatten(1)
tgt_aff_masks = tgt_aff_masks.flatten(1)
loss_aff = sigmoid_focal_loss(
src_aff_masks,
tgt_aff_masks,
affordance_valid.sum(),
alpha=self._affordance_focal_alpha,
)
out['loss_affordance'] = loss_aff
# axis
axis_valid = axis[:, :, 0] > 0.0
num_axis = axis_valid.sum()
if num_axis == 0:
out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
# regress angle
src_axis_angle = outputs_axis[axis_valid]
src_axis_angle_norm = F.normalize(src_axis_angle[:, :2])
src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1)
target_axis_xyxy = axis[axis_valid]
axis_center = target_boxes[axis_valid].clone()
axis_center[:, 2:] = axis_center[:, :2]
target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center)
loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis
loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis
out['loss_axis_angle'] = loss_axis_angle
out['loss_axis_offset'] = loss_axis_offset
src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center)
target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center)
axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy)
loss_eascore = 1 - axis_eascore
out['loss_eascore'] = loss_eascore.mean()
loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index)
if torch.isnan(loss_movable):
loss_movable = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_movable'] = loss_movable
loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index)
if torch.isnan(loss_rigid):
loss_rigid = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_rigid'] = loss_rigid
loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index)
if torch.isnan(loss_kinematic):
loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_kinematic'] = loss_kinematic
loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index)
if torch.isnan(loss_action):
loss_action = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_action'] = loss_action
# depth backward
out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device)
# (bs, 1, H, W)
src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False)
src_depths = src_depths.clamp(min=0.0, max=1.0)
tgt_depths = depth.unsqueeze(1) # (bs, H, W)
valid_depth = depth[:, 0, 0] > 0
if valid_depth.any():
src_depths = src_depths[valid_depth]
tgt_depths = tgt_depths[valid_depth]
depth_mask = tgt_depths > 1e-8
midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask)
loss_vnl = self.vnl_loss(tgt_depths, src_depths)
out['loss_depth'] = midas_loss
out['loss_vnl'] = loss_vnl
else:
out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device)
# mask backward
tgt_masks = masks
src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False)
valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10
if valid_mask.sum() == 0:
out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device)
out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device)
else:
num_masks = valid_mask.sum()
src_masks = src_masks[valid_mask]
tgt_masks = tgt_masks[valid_mask]
src_masks = src_masks.flatten(1)
tgt_masks = tgt_masks.flatten(1)
tgt_masks = tgt_masks.view(src_masks.shape)
out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks)
out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks)
return out
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x