# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from torch.nn import functional as F import torchvision from typing import Any, Dict, List, Tuple from .sam.image_encoder import ImageEncoderViT from .sam.mask_decoder import MaskDecoder from .sam.prompt_encoder import PromptEncoder from .detr import box_ops from .detr.segmentation import dice_loss, sigmoid_focal_loss from .detr.misc import nested_tensor_from_tensor_list, interpolate from . import axis_ops, ilnr_loss #, pwnp_loss from .vnl_loss import VNL_Loss from .midas_loss import MidasLoss class SamTransformer(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" def __init__( self, image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, affordance_decoder: MaskDecoder, depth_decoder: MaskDecoder, transformer_hidden_dim: int, backbone_name: str, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], ) -> None: """ SAM predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.affordance_decoder = affordance_decoder # depth head self.depth_decoder = depth_decoder self.depth_query = nn.Embedding(2, transformer_hidden_dim) fov = torch.tensor(1.0) image_size = (768, 1024) focal_length = (image_size[1] / 2 / torch.tan(fov / 2)).item() self.vnl_loss = VNL_Loss(focal_length, focal_length, image_size) self.midas_loss = MidasLoss(alpha=0.1) self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) # if backbone_name == 'vit_h': # checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_h_4b8939.pth') # elif backbone_name == 'vit_l': # checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_l_0b3195.pth') # elif backbone_name == 'vit_b': # checkpoint_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'checkpoints', 'sam_vit_b_01ec64.pth') # else: # raise ValueError # with open(checkpoint_path, "rb") as f: # state_dict = torch.load(f) # self.load_state_dict(state_dict, strict=False) # self.affordance_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False) # self.depth_decoder.load_state_dict(self.mask_decoder.state_dict(), strict=False) self.num_queries = 15 self._affordance_focal_alpha = 0.95 self._ignore_index = -100 @property def device(self) -> Any: return self.pixel_mean.device def freeze_layers(self, names): """ Freeze layers in 'names'. """ for name, param in self.named_parameters(): for freeze_name in names: if freeze_name in name: param.requires_grad = False def forward( self, image: torch.Tensor, valid: torch.Tensor, keypoints: torch.Tensor, bbox: torch.Tensor, masks: torch.Tensor, movable: torch.Tensor, rigid: torch.Tensor, kinematic: torch.Tensor, action: torch.Tensor, affordance: torch.Tensor, affordance_map: torch.FloatTensor, depth: torch.Tensor, axis: torch.Tensor, fov: torch.Tensor, backward: bool = True, **kwargs, ): device = image.device multimask_output = False # image encoder # pad image to square h, w = image.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(image, (0, padw, 0, padh)) image_embeddings = self.image_encoder(x) outputs_seg_masks = [] outputs_movable = [] outputs_rigid = [] outputs_kinematic = [] outputs_action = [] outputs_axis = [] outputs_boxes = [] outputs_aff_masks = [] outputs_depth = [] for idx, curr_embedding in enumerate(image_embeddings): point_coords = keypoints[idx].unsqueeze(1) point_labels = torch.ones_like(point_coords[:, :, 0]) points = (point_coords, point_labels) sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=None, masks=None, ) # mask decoder low_res_masks, iou_predictions, output_movable, output_rigid, output_kinematic, output_action, output_axis = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) output_mask = self.postprocess_masks( low_res_masks, input_size=image.shape[-2:], original_size=(768, 1024), ) outputs_seg_masks.append(output_mask[:, 0]) outputs_movable.append(output_movable[:, 0]) outputs_rigid.append(output_rigid[:, 0]) outputs_kinematic.append(output_kinematic[:, 0]) outputs_action.append(output_action[:, 0]) outputs_axis.append(output_axis[:, 0]) # convert masks to boxes for evaluation pred_mask_bbox = (output_mask[:, 0].clone() > 0.0).long() empty_mask = pred_mask_bbox.sum(dim=-1).sum(dim=-1) pred_mask_bbox[empty_mask == 0] += 1 pred_boxes = torchvision.ops.masks_to_boxes(pred_mask_bbox) #pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / self._image_size[1], 1 / self._image_size[0]]) pred_boxes = box_ops.rescale_bboxes(pred_boxes, [1 / 768, 1 / 1024]) outputs_boxes.append(pred_boxes) # affordance decoder low_res_masks, iou_predictions = self.affordance_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) output_aff_masks = self.postprocess_masks( low_res_masks, input_size=image.shape[-2:], original_size=(192, 256), ) outputs_aff_masks.append(output_aff_masks[:, 0]) # depth decoder bs = keypoints.shape[0] #depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0).repeat(bs, 1, 1) depth_sparse_embeddings = self.depth_query.weight.unsqueeze(0) #depth_dense_embeddings = torch.zeros((bs, 256, 64, 64)).to(dense_embeddings.device) depth_dense_embeddings = torch.zeros((1, 256, 64, 64)).to(dense_embeddings.device) low_res_masks, iou_predictions = self.depth_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=depth_sparse_embeddings, dense_prompt_embeddings=depth_dense_embeddings, multimask_output=multimask_output, ) output_depth = self.postprocess_masks( low_res_masks, input_size=image.shape[-2:], original_size=(768, 1024), ) outputs_depth.append(output_depth[:, 0]) outputs_seg_masks = torch.stack(outputs_seg_masks) outputs_movable = torch.stack(outputs_movable) outputs_rigid = torch.stack(outputs_rigid) outputs_kinematic = torch.stack(outputs_kinematic) outputs_action = torch.stack(outputs_action) outputs_axis = torch.stack(outputs_axis) outputs_boxes = torch.stack(outputs_boxes) outputs_aff_masks = torch.stack(outputs_aff_masks) outputs_depth = torch.stack(outputs_depth) out = { 'pred_boxes': outputs_boxes, 'pred_movable': outputs_movable, 'pred_rigid': outputs_rigid, 'pred_kinematic': outputs_kinematic, 'pred_action': outputs_action, 'pred_masks': outputs_seg_masks, 'pred_axis': outputs_axis, 'pred_depth': outputs_depth, # 'pred_depth': outputs_seg_masks[:, :1].sigmoid(), 'pred_affordance': outputs_aff_masks, } if not backward: return out # backward src_boxes = outputs_boxes target_boxes = bbox target_boxes = box_ops.box_xyxy_to_cxcywh(target_boxes) bbox_valid = bbox[:, :, 0] > -0.5 num_boxes = bbox_valid.sum() out['loss_bbox'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_giou'] = torch.tensor(0.0, requires_grad=True).to(device) # affordance # out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device) affordance_valid = affordance[:, :, 0] > -0.5 if affordance_valid.sum() == 0: out['loss_affordance'] = torch.tensor(0.0, requires_grad=True).to(device) else: src_aff_masks = outputs_aff_masks[affordance_valid] tgt_aff_masks = affordance_map[affordance_valid] src_aff_masks = src_aff_masks.flatten(1) tgt_aff_masks = tgt_aff_masks.flatten(1) loss_aff = sigmoid_focal_loss( src_aff_masks, tgt_aff_masks, affordance_valid.sum(), alpha=self._affordance_focal_alpha, ) out['loss_affordance'] = loss_aff # axis axis_valid = axis[:, :, 0] > 0.0 num_axis = axis_valid.sum() if num_axis == 0: out['loss_axis_angle'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_axis_offset'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_eascore'] = torch.tensor(0.0, requires_grad=True).to(device) else: # regress angle src_axis_angle = outputs_axis[axis_valid] src_axis_angle_norm = F.normalize(src_axis_angle[:, :2]) src_axis_angle = torch.cat((src_axis_angle_norm, src_axis_angle[:, 2:]), dim=-1) target_axis_xyxy = axis[axis_valid] axis_center = target_boxes[axis_valid].clone() axis_center[:, 2:] = axis_center[:, :2] target_axis_angle = axis_ops.line_xyxy_to_angle(target_axis_xyxy, center=axis_center) loss_axis_angle = F.l1_loss(src_axis_angle[:, :2], target_axis_angle[:, :2], reduction='sum') / num_axis loss_axis_offset = F.l1_loss(src_axis_angle[:, 2:], target_axis_angle[:, 2:], reduction='sum') / num_axis out['loss_axis_angle'] = loss_axis_angle out['loss_axis_offset'] = loss_axis_offset src_axis_xyxy = axis_ops.line_angle_to_xyxy(src_axis_angle, center=axis_center) target_axis_xyxy = axis_ops.line_angle_to_xyxy(target_axis_angle, center=axis_center) axis_eascore, _, _ = axis_ops.ea_score(src_axis_xyxy, target_axis_xyxy) loss_eascore = 1 - axis_eascore out['loss_eascore'] = loss_eascore.mean() loss_movable = F.cross_entropy(outputs_movable.permute(0, 2, 1), movable, ignore_index=self._ignore_index) if torch.isnan(loss_movable): loss_movable = torch.tensor(0.0, requires_grad=True).to(device) out['loss_movable'] = loss_movable loss_rigid = F.cross_entropy(outputs_rigid.permute(0, 2, 1), rigid, ignore_index=self._ignore_index) if torch.isnan(loss_rigid): loss_rigid = torch.tensor(0.0, requires_grad=True).to(device) out['loss_rigid'] = loss_rigid loss_kinematic = F.cross_entropy(outputs_kinematic.permute(0, 2, 1), kinematic, ignore_index=self._ignore_index) if torch.isnan(loss_kinematic): loss_kinematic = torch.tensor(0.0, requires_grad=True).to(device) out['loss_kinematic'] = loss_kinematic loss_action = F.cross_entropy(outputs_action.permute(0, 2, 1), action, ignore_index=self._ignore_index) if torch.isnan(loss_action): loss_action = torch.tensor(0.0, requires_grad=True).to(device) out['loss_action'] = loss_action # depth backward out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) # (bs, 1, H, W) src_depths = interpolate(outputs_depth, size=depth.shape[-2:], mode='bilinear', align_corners=False) src_depths = src_depths.clamp(min=0.0, max=1.0) tgt_depths = depth.unsqueeze(1) # (bs, H, W) valid_depth = depth[:, 0, 0] > 0 if valid_depth.any(): src_depths = src_depths[valid_depth] tgt_depths = tgt_depths[valid_depth] depth_mask = tgt_depths > 1e-8 midas_loss, ssi_loss, reg_loss = self.midas_loss(src_depths, tgt_depths, depth_mask) loss_vnl = self.vnl_loss(tgt_depths, src_depths) out['loss_depth'] = midas_loss out['loss_vnl'] = loss_vnl else: out['loss_depth'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_vnl'] = torch.tensor(0.0, requires_grad=True).to(device) # mask backward tgt_masks = masks src_masks = interpolate(outputs_seg_masks, size=tgt_masks.shape[-2:], mode='bilinear', align_corners=False) valid_mask = tgt_masks.sum(dim=-1).sum(dim=-1) > 10 if valid_mask.sum() == 0: out['loss_mask'] = torch.tensor(0.0, requires_grad=True).to(device) out['loss_dice'] = torch.tensor(0.0, requires_grad=True).to(device) else: num_masks = valid_mask.sum() src_masks = src_masks[valid_mask] tgt_masks = tgt_masks[valid_mask] src_masks = src_masks.flatten(1) tgt_masks = tgt_masks.flatten(1) tgt_masks = tgt_masks.view(src_masks.shape) out['loss_mask'] = sigmoid_focal_loss(src_masks, tgt_masks.float(), num_masks) out['loss_dice'] = dice_loss(src_masks, tgt_masks, num_masks) return out def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. Arguments: masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. original_size (tuple(int, int)): The original size of the image before resizing for input to the model, in (H, W) format. Returns: (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) masks = masks[..., : input_size[0], : input_size[1]] masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - self.pixel_mean) / self.pixel_std # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x