import torch import torch.nn import numpy as np import pdb class VNL_Loss(torch.nn.Module): """ Virtual Normal Loss Function. """ def __init__(self, focal_x, focal_y, input_size, delta_cos=0.867, delta_diff_x=0.01, delta_diff_y=0.01, delta_diff_z=0.01, delta_z=0.0001, sample_ratio=0.15): super(VNL_Loss, self).__init__() self.fx = torch.tensor([focal_x], dtype=torch.float32) #.to(cuda0) self.fy = torch.tensor([focal_y], dtype=torch.float32) #.to(cuda0) self.input_size = input_size self.u0 = torch.tensor(input_size[1] // 2, dtype=torch.float32) #.to(cuda0) self.v0 = torch.tensor(input_size[0] // 2, dtype=torch.float32) #.to(cuda0) self.init_image_coor() self.delta_cos = delta_cos self.delta_diff_x = delta_diff_x self.delta_diff_y = delta_diff_y self.delta_diff_z = delta_diff_z self.delta_z = delta_z self.sample_ratio = sample_ratio def init_image_coor(self): x_row = np.arange(0, self.input_size[1]) x = np.tile(x_row, (self.input_size[0], 1)) x = x[np.newaxis, :, :] x = x.astype(np.float32) x = torch.from_numpy(x.copy()) #.to(cuda0) self.u_u0 = x - self.u0 y_col = np.arange(0, self.input_size[0]) # y_col = np.arange(0, height) y = np.tile(y_col, (self.input_size[1], 1)).T y = y[np.newaxis, :, :] y = y.astype(np.float32) y = torch.from_numpy(y.copy()) #.to(cuda0) self.v_v0 = y - self.v0 def transfer_xyz(self, depth): # print('!!!!!!!!!!!!!!!111111 ', self.u_u0.device, torch.abs(depth).device, self.fx.device) x = self.u_u0 * torch.abs(depth) / self.fx y = self.v_v0 * torch.abs(depth) / self.fy z = depth pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c] return pw def select_index(self): valid_width = self.input_size[1] valid_height = self.input_size[0] num = valid_width * valid_height p1 = np.random.choice(num, int(num * self.sample_ratio), replace=True) np.random.shuffle(p1) p2 = np.random.choice(num, int(num * self.sample_ratio), replace=True) np.random.shuffle(p2) p3 = np.random.choice(num, int(num * self.sample_ratio), replace=True) np.random.shuffle(p3) p1_x = p1 % self.input_size[1] p1_y = (p1 / self.input_size[1]).astype(np.int) p2_x = p2 % self.input_size[1] p2_y = (p2 / self.input_size[1]).astype(np.int) p3_x = p3 % self.input_size[1] p3_y = (p3 / self.input_size[1]).astype(np.int) p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y} return p123 def form_pw_groups(self, p123, pw): """ Form 3D points groups, with 3 points in each grouup. :param p123: points index :param pw: 3D points :return: """ p1_x = p123['p1_x'] p1_y = p123['p1_y'] p2_x = p123['p2_x'] p2_y = p123['p2_y'] p3_x = p123['p3_x'] p3_y = p123['p3_y'] pw1 = pw[:, p1_y, p1_x, :] pw2 = pw[:, p2_y, p2_x, :] pw3 = pw[:, p3_y, p3_x, :] # [B, N, 3(x,y,z), 3(p1,p2,p3)] pw_groups = torch.cat([pw1[:, :, :, np.newaxis], pw2[:, :, :, np.newaxis], pw3[:, :, :, np.newaxis]], 3) return pw_groups def filter_mask(self, p123, gt_xyz, delta_cos=0.867, delta_diff_x=0.005, delta_diff_y=0.005, delta_diff_z=0.005): pw = self.form_pw_groups(p123, gt_xyz) pw12 = pw[:, :, :, 1] - pw[:, :, :, 0] pw13 = pw[:, :, :, 2] - pw[:, :, :, 0] pw23 = pw[:, :, :, 2] - pw[:, :, :, 1] ###ignore linear pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]], 3) # [b, n, 3, 3] m_batchsize, groups, coords, index = pw_diff.shape proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1) # (B* X CX(3)) [bn, 3(p123), 3(xyz)] proj_key = pw_diff.view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)] q_norm = proj_query.norm(2, dim=2) nm = torch.bmm(q_norm.view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[] energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)] norm_energy = energy / (nm + 1e-8) norm_energy = norm_energy.view(m_batchsize * groups, -1) mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre mask_cos = mask_cos.view(m_batchsize, groups) ##ignore padding and invilid depth mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3 ###ignore near mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0 mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0 mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0 mask_ignore = (mask_x & mask_y & mask_z) | mask_cos mask_near = ~mask_ignore mask = mask_pad & mask_near return mask, pw def select_points_groups(self, gt_depth, pred_depth): pw_gt = self.transfer_xyz(gt_depth) pw_pred = self.transfer_xyz(pred_depth) #pdb.set_trace() B, C, H, W = gt_depth.shape p123 = self.select_index() # mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)] mask, pw_groups_gt = self.filter_mask(p123, pw_gt, delta_cos=0.867, delta_diff_x=0.005, delta_diff_y=0.005, delta_diff_z=0.005) # [b, n, 3, 3] pw_groups_pred = self.form_pw_groups(p123, pw_pred) pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001 mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2) pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3) pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3) return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore def forward(self, gt_depth, pred_depth, select=True): """ Virtual normal loss. :param pred_depth: predicted depth map, [B,W,H,C] :param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down] :return: """ device = gt_depth.device self.fx = self.fx.to(device) self.fy = self.fy.to(device) self.u0 = self.u0.to(device) self.v0 = self.v0.to(device) self.u_u0 = self.u_u0.to(device) self.v_v0 = self.v_v0.to(device) # print("************ ", self.fx.device, self.u_u0.device) gt_points, dt_points = self.select_points_groups(gt_depth, pred_depth) gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0] gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0] dt_p12 = dt_points[:, :, :, 1] - dt_points[:, :, :, 0] dt_p13 = dt_points[:, :, :, 2] - dt_points[:, :, :, 0] gt_normal = torch.cross(gt_p12, gt_p13, dim=2) dt_normal = torch.cross(dt_p12, dt_p13, dim=2) dt_norm = torch.norm(dt_normal, 2, dim=2, keepdim=True) gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True) dt_mask = dt_norm == 0.0 gt_mask = gt_norm == 0.0 dt_mask = dt_mask.to(torch.float32) gt_mask = gt_mask.to(torch.float32) dt_mask *= 0.01 gt_mask *= 0.01 gt_norm = gt_norm + gt_mask dt_norm = dt_norm + dt_mask gt_normal = gt_normal / gt_norm dt_normal = dt_normal / dt_norm #pdb.set_trace() loss = torch.abs(gt_normal - dt_normal) loss = torch.sum(torch.sum(loss, dim=2), dim=0) if select: loss, indices = torch.sort(loss, dim=0, descending=False) loss = loss[int(loss.size(0) * 0.25):] loss = torch.mean(loss) return loss