import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class MEADSTD_TANH_NORM_Loss(nn.Module): """ The implementation comes from https://github.com/aim-uofa/AdelaiDepth/blob/main/LeReS/Train/lib/models/ILNR_loss.py loss = MAE((d-u)/s - d') + MAE(tanh(0.01*(d-u)/s) - tanh(0.01*d')) """ def __init__(self, valid_threshold=-1e-8, max_threshold=1e8): super(MEADSTD_TANH_NORM_Loss, self).__init__() self.valid_threshold = valid_threshold self.max_threshold = max_threshold #self.thres1 = 0.9 def transform(self, gt): # Get mean and standard deviation data_mean = [] data_std_dev = [] for i in range(gt.shape[0]): gt_i = gt[i] mask = gt_i > 0 depth_valid = gt_i[mask] if depth_valid.shape[0] < 10: data_mean.append(torch.tensor(0).cuda()) data_std_dev.append(torch.tensor(1).cuda()) continue size = depth_valid.shape[0] depth_valid_sort, _ = torch.sort(depth_valid, 0) depth_valid_mask = depth_valid_sort[int(size*0.1): -int(size*0.1)] data_mean.append(depth_valid_mask.mean()) data_std_dev.append(depth_valid_mask.std()) data_mean = torch.stack(data_mean, dim=0).cuda() data_std_dev = torch.stack(data_std_dev, dim=0).cuda() return data_mean, data_std_dev def forward(self, pred, gt): """ Calculate loss. """ mask = (gt > self.valid_threshold) & (gt < self.max_threshold) # [b, c, h, w] mask_sum = torch.sum(mask, dim=(1, 2, 3)) # mask invalid batches mask_batch = mask_sum > 100 if True not in mask_batch: return torch.tensor(0.0, dtype=torch.float).cuda() mask_maskbatch = mask[mask_batch] pred_maskbatch = pred[mask_batch] gt_maskbatch = gt[mask_batch] gt_mean, gt_std = self.transform(gt_maskbatch) gt_trans = (gt_maskbatch - gt_mean[:, None, None, None]) / (gt_std[:, None, None, None] + 1e-8) B, C, H, W = gt_maskbatch.shape loss = 0 loss_tanh = 0 for i in range(B): mask_i = mask_maskbatch[i, ...] pred_depth_i = pred_maskbatch[i, ...][mask_i] gt_trans_i = gt_trans[i, ...][mask_i] depth_diff = torch.abs(gt_trans_i - pred_depth_i) loss += torch.mean(depth_diff) tanh_norm_gt = torch.tanh(0.01*gt_trans_i) tanh_norm_pred = torch.tanh(0.01*pred_depth_i) loss_tanh += torch.mean(torch.abs(tanh_norm_gt - tanh_norm_pred)) loss_out = loss/B + loss_tanh/B return loss_out.float()