import torch import torch.nn as nn import torch.nn.functional as F class TetTexNet(nn.Module): def __init__(self, plane_reso=64, padding=0.1, fea_concat=True): super().__init__() # self.c_dim = c_dim self.plane_reso = plane_reso self.padding = padding self.fea_concat = fea_concat def forward(self, rolled_out_feature, query): # rolled_out_feature: rolled-out triplane feature # query: queried xyz coordinates (should be scaled consistently to ptr cloud) plane_reso = self.plane_reso triplane_feature = dict() triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso] triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso] triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:] query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy') query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz') query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx') if self.fea_concat: query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1) else: query_feature = query_feature_xy + query_feature_yz + query_feature_zx output = query_feature.permute(0, 2, 1) return output # uses values from plane_feature and pixel locations from vgrid to interpolate feature def sample_plane_feature(self, query, plane_feature, plane): # CYF note: # for pretraining, query are uniformly sampled positions w.i. [-scale, scale] # for training, query are essentially tetrahedra grid vertices, which are # also within [-scale, scale] in the current version! # xy range [-scale, scale] if plane == 'xy': xy = query[:, :, [0, 1]] elif plane == 'yz': xy = query[:, :, [1, 2]] elif plane == 'zx': xy = query[:, :, [2, 0]] else: raise ValueError("Error! Invalid plane type!") xy = xy[:, :, None].float() # not seem necessary to rescale the grid, because from # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html, # it specifies sampling locations normalized by plane_feature's spatial dimension, # which is within [-scale, scale] as specified by encoder's calling of coordinate2index() vgrid = 1.0 * xy sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1) return sampled_feat