3DOI / monoarti /axis_ops.py
shengyi-qian's picture
init
9afcee2
raw
history blame contribute delete
No virus
9.75 kB
import torch
import numpy as np
import pdb
def ea_score(src_line, tgt_line):
"""
Implement a differentiable EAScore of two 2D lines.
Kai Zhao∗, Qi Han∗, Chang-Bin Zhang, Jun Xu, Ming-Ming Cheng.
Deep Hough Transform for Semantic Line Detection.
TPAMI 2021.
- src_line: tensor shape Nx4, XYXY format
- tgt_line: tensor shape Nx4, XYXY format
"""
# midpoint error
src_line_mid_x = (src_line[:, 0] + src_line[:, 2]) / 2
tgt_line_mid_x = (tgt_line[:, 0] + tgt_line[:, 2]) / 2
src_line_mid_y = (src_line[:, 1] + src_line[:, 3]) / 2
tgt_line_mid_y = (tgt_line[:, 1] + tgt_line[:, 3]) / 2
line_se = 1 - torch.sqrt((src_line_mid_x - tgt_line_mid_x)**2 + (src_line_mid_y - tgt_line_mid_y)**2)
line_se = line_se.clamp(min=0)
# angle error
src_line_angle = torch.atan((src_line[:, 1] - src_line[:, 3]) / (src_line[:, 0] - src_line[:, 2] + 1e-5))
tgt_line_angle = torch.atan((tgt_line[:, 1] - tgt_line[:, 3]) / (tgt_line[:, 0] - tgt_line[:, 2] + 1e-5))
d_angle = torch.abs(src_line_angle - tgt_line_angle)
d_angle = torch.min(d_angle, torch.pi - d_angle)
line_sa = 1 - d_angle / (torch.pi / 2)
line_eascore = (line_se * line_sa) ** 2
line_eascore[torch.isnan(line_eascore)] = 0.0
return line_eascore, line_se, line_sa
def sine_to_angle(sin, cos, r, eps=1e-5):
sin = sin + eps
cos = cos + eps
mag = torch.sqrt(sin ** 2 + cos ** 2)
sin = sin / mag
cos = cos / mag
r = r / mag
theta_cos = torch.acos(cos)
theta_sin = torch.asin(sin)
theta_cos[theta_sin < 0] = torch.pi * 2 - theta_cos[theta_sin < 0]
return theta_cos, r
def line_xyxy_to_angle(line_xyxy, center=[0.5, 0.5], debug=False):
"""
Convert [X1, Y1, X2, Y2] representation of a 2D line to
[sin(theta), cos(theta), offset] representation.
r = xcos(theta) + ysin(theta)
For two points (x1, y1) and (x2, y2) within image plane [0, 1],
- cos(theta) = y1 - y2
- sin(theta) = x2 - x1
- r = x2y1 - x1y2
Shengyi Qian, Linyi Jin, Chris Rockwell, Siyi Chen, David Fouhey.
Understanding 3D Object Articulation in Internet Videos.
CVPR 2022.
"""
eps = 1e-5
device = line_xyxy.device
if isinstance(center, list):
center_w, center_h = center
line_xyxy = line_xyxy - torch.as_tensor([center_w, center_h, center_w, center_h]).to(device)
elif isinstance(center, torch.Tensor):
line_xyxy = line_xyxy - center
else:
raise NotImplementedError
#line_xyxy = line_xyxy.clamp(min=-0.5, max=0.5)
line_xyxy = line_xyxy.clamp(min=-1.0, max=1.0)
x1, y1, x2, y2 = line_xyxy[:,:1], line_xyxy[:,1:2], line_xyxy[:,2:3], line_xyxy[:,3:4]
cos = y1 - y2
sin = x2 - x1
r = x2 * y1 - x1 * y2
theta, r = sine_to_angle(sin, cos, r)
if debug:
pdb.set_trace()
pass
# normalize, and ensure
# sin(theta) in [0, 1]
# cos(theta) in [-1, 1]
# r in (-sqrt(2) / 2, sqrt(2) / 2)
theta *= 2
sin = torch.sin(theta)
cos = torch.cos(theta)
# assert (r > (- np.sqrt(2) / 2)).all()
# assert (r < (np.sqrt(2) / 2)).all()
assert (r > (- np.sqrt(2))).all()
assert (r < (np.sqrt(2))).all()
assert (sin >= -1).all()
assert (sin <= 1).all()
assert (cos >= -1).all()
assert (cos <= 1).all()
return torch.cat((sin, cos, r), dim=1)
def line_angle_to_xyxy(line_angle, center=[0.5, 0.5], use_bins=False, debug=False):
"""
Convert [sin(theta), cos(theta), offset] representation of a 2D line to
[X1, Y1, X2, Y2] representation.
Shengyi Qian, Linyi Jin, Chris Rockwell, Siyi Chen, David Fouhey.
Understanding 3D Object Articulation in Internet Videos.
CVPR 2022.
"""
eps = 1e-5
device = line_angle.device
if isinstance(center, list):
center_w, center_h = center
elif isinstance(center, torch.Tensor):
center_w = center[:, 0:1]
center_h = center[:, 1:2]
else:
raise NotImplementedError
sin = line_angle[:, :1]
cos = line_angle[:, 1:2]
r = line_angle[:, 2:]
# normalize sin and cos
# make sure r is not out of boundary
theta, r = sine_to_angle(sin, cos, r)
theta /= 2
sin = torch.sin(theta)
cos = torch.cos(theta)
r = r.clamp(min=(- np.sqrt(2) / 2), max=(np.sqrt(2) / 2))
# intersect line with four boundaries
y1 = (r - cos * (0.0 - center_w)) / sin + center_h
y2 = (r - cos * (1.0 - center_w)) / sin + center_h
x3 = (r - sin * (0.0 - center_h)) / cos + center_w
x4 = (r - sin * (1.0 - center_h)) / cos + center_w
line_xyxy = []
for i in range(line_angle.shape[0]):
line = []
if y1[i] > - eps and y1[i] < (1.0 + eps):
line.append([0.0, y1[i]])
if y2[i] > - eps and y2[i] < (1.0 + eps):
line.append([1.0, y2[i]])
if len(line) < 2 and x3[i] > - eps and x3[i] < (1.0 + eps):
line.append([x3[i], 0.0])
if len(line) < 2 and x4[i] > - eps and x4[i] < (1.0 + eps):
line.append([x4[i], 1.0])
# Mathematically, we should only have two boundary points.
# However, in training time, it is not guaranteed the represented
# line is within the image plane. Even if r < sqrt(2)/2, it
# can be out of boundary. But it's rare and we want to ignore it.
if len(line) != 2:
line = [[0.0, y1[i]], [x3[i], 0.0]]
line = torch.as_tensor(line, device=device)
if torch.isnan(line.mean()):
pdb.set_trace()
pass
# make sure it is sorted, so that model does not get confused.
# flat [[x1, y1], [x2, y2]] to [x1, y1, x2, y2]
sort_idx = torch.sort(line[:, 0], dim=0, descending=False)[1]
line = line[sort_idx]
line = line.flatten()
line_xyxy.append(line)
line_xyxy = torch.stack(line_xyxy)
if debug:
pdb.set_trace()
pass
return line_xyxy
# def line_xyxy_to_angle(line_xyxy):
# """
# Convert [X1, Y1, X2, Y2] representation of a 2D line to
# [sin(theta), cos(theta), offset] representation.
# xcos(theta) + ysin(theta) = r
# For two points (x1, y1) and (x2, y2) within image plane [0, 1],
# - cos(theta) = y1 - y2
# - sin(theta) = x2 - x1
# - r = x2y1 - x1y2
# Shengyi Qian, Linyi Jin, Chris Rockwell, Siyi Chen, David Fouhey.
# Understanding 3D Object Articulation in Internet Videos.
# CVPR 2022.
# """
# eps = 1e-5
# device = line_xyxy.device
# line_xyxy = line_xyxy - torch.as_tensor([0.5, 0.5, 0.5, 0.5]).to(device)
# line_xyxy = line_xyxy.clamp(min=-0.5, max=0.5)
# x1, y1, x2, y2 = line_xyxy[:,:1], line_xyxy[:,1:2], line_xyxy[:,2:3], line_xyxy[:,3:4]
# cos = y1 - y2
# sin = x2 - x1
# r = x2 * y1 - x1 * y2
# # normalize, and ensure
# # sin(theta) in [0, 1]
# # cos(theta) in [-1, 1]
# # r in (-sqrt(2) / 2, sqrt(2) / 2)
# sign = torch.sign(r)
# mag = torch.sqrt(sin ** 2 + cos ** 2) + eps
# r = r / mag * sign
# sin = sin / mag * sign
# cos = cos / mag * sign
# #assert (r > (- torch.sqrt(2) / 2)).all()
# assert (r >= 0).all()
# assert (r < (np.sqrt(2) / 2)).all()
# assert (sin >= -1).all()
# assert (sin <= 1).all()
# assert (cos >= -1).all()
# assert (cos <= 1).all()
# return torch.cat((sin, cos, r), dim=1)
# def line_angle_to_xyxy(line_angle, use_bins=False):
# """
# Convert [sin(theta), cos(theta), offset] representation of a 2D line to
# [X1, Y1, X2, Y2] representation.
# Shengyi Qian, Linyi Jin, Chris Rockwell, Siyi Chen, David Fouhey.
# Understanding 3D Object Articulation in Internet Videos.
# CVPR 2022.
# """
# eps = 1e-5
# device = line_angle.device
# sin = line_angle[:, :1] + eps
# cos = line_angle[:, 1:2] + eps
# r = line_angle[:, 2:]
# # normalize sin and cos
# # make sure r is not out of boundary
# mag = torch.sqrt(sin ** 2 + cos ** 2)
# sin = sin / mag
# cos = cos / mag
# r = r.clamp(min=0, max=(np.sqrt(2) / 2))
# # intersect line with four boundaries
# y1 = (r - cos * (- 0.5)) / sin + 0.5
# y2 = (r - cos * 0.5) / sin + 0.5
# x3 = (r - sin * (- 0.5)) / cos + 0.5
# x4 = (r - sin * 0.5) / cos + 0.5
# line_xyxy = []
# for i in range(line_angle.shape[0]):
# line = []
# if y1[i] > - eps and y1[i] < (1.0 + eps):
# line.append([0.0, y1[i]])
# if y2[i] > - eps and y2[i] < (1.0 + eps):
# line.append([1.0, y2[i]])
# if len(line) < 2 and x3[i] > - eps and x3[i] < (1.0 + eps):
# line.append([x3[i], 0.0])
# if len(line) < 2 and x4[i] > - eps and x4[i] < (1.0 + eps):
# line.append([x4[i], 1.0])
# # Mathematically, we should only have two boundary points.
# # However, in training time, it is not guaranteed the represented
# # line is within the image plane. Even if r < sqrt(2)/2, it
# # can be out of boundary. But it's rare and we want to ignore it.
# if len(line) != 2:
# line = [[0.0, y1[i]], [x3[i], 0.0]]
# line = torch.as_tensor(line, device=device)
# if torch.isnan(line.mean()):
# pdb.set_trace()
# pass
# # make sure it is sorted, so that model does not get confused.
# # flat [[x1, y1], [x2, y2]] to [x1, y1, x2, y2]
# sort_idx = torch.sort(line[:, 0], dim=0, descending=False)[1]
# line = line[sort_idx]
# line = line.flatten()
# line_xyxy.append(line)
# line_xyxy = torch.stack(line_xyxy)
# return line_xyxy