Spaces:
Running
Running
from typing import Dict | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as nnf | |
import torchvision | |
import numpy as np | |
from scipy.spatial import Delaunay | |
from shapely.geometry import Point | |
from shapely.geometry.polygon import Polygon | |
class ToneLoss(nn.Module): | |
def __init__(self, cfg): | |
super(ToneLoss, self).__init__() | |
self.dist_loss_weight = cfg.dist_loss_weight | |
self.im_init = None | |
self.mse_loss = nn.MSELoss() | |
self.blur = torchvision.transforms.GaussianBlur( | |
kernel_size=(cfg.pixel_dist_kernel_blur, | |
cfg.pixel_dist_kernel_blur), | |
sigma=(cfg.pixel_dist_sigma, cfg.pixel_dist_sigma) | |
) | |
self.init_blurred = None | |
def set_image_init(self, im_init): | |
self.im_init = im_init | |
self.init_blurred = self.blur(self.im_init) | |
def get_scheduler(self, step=None): | |
if step is not None: | |
return self.dist_loss_weight * np.exp(-(1 / 5) * ((step - 300) / (20)) ** 2) | |
else: | |
return self.dist_loss_weight | |
def forward(self, cur_raster, step=None): | |
blurred_cur = self.blur(cur_raster) | |
return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step) | |
class ConformalLoss: | |
def __init__(self, parameters, shape_groups, target_letter: str, device: torch.device): | |
self.parameters = parameters | |
self.device = device | |
self.target_letter = target_letter | |
self.shape_groups = shape_groups | |
self.faces = self.init_faces(device) | |
self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))] | |
with torch.no_grad(): | |
self.angles = [] | |
self.reset(device) | |
def get_angles(self, points: torch.Tensor) -> torch.Tensor: | |
angles_ = [] | |
for i in range(len(self.faces)): | |
triangles = points[self.faces[i]] | |
triangles_roll_a = points[self.faces_roll_a[i]] | |
edges = triangles_roll_a - triangles | |
length = edges.norm(dim=-1) | |
edges = edges / (length + 1e-1)[:, :, None] | |
edges_roll = torch.roll(edges, 1, 1) | |
cosine = torch.einsum('ned,ned->ne', edges, edges_roll) | |
angles = torch.arccos(cosine) | |
angles_.append(angles) | |
return angles_ | |
def get_letter_inds(self, letter_to_insert): | |
for group, l in zip(self.shape_groups, self.target_letter): | |
if l == letter_to_insert: | |
letter_inds = group.shape_ids | |
return letter_inds[0], letter_inds[-1], len(letter_inds) | |
def reset(self, device): | |
points = torch.cat([point.to(device) for point in self.parameters]) | |
self.angles = self.get_angles(points) | |
def init_faces(self, device: torch.device) -> torch.tensor: | |
faces_ = [] | |
for j, c in enumerate(self.target_letter): | |
points_np = [ | |
self.parameters[i].clone().detach().cpu().numpy() | |
for i in range(len(self.parameters)) | |
] | |
start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c) | |
print(c, "start_ind: ", start_ind.item(), ", end_ind: ", end_ind.item()) | |
holes = [] | |
if shapes_per_letter > 1: | |
holes = points_np[start_ind + 1:end_ind] | |
poly = Polygon(points_np[start_ind], holes=holes) | |
poly = poly.buffer(0) | |
points_np = np.concatenate(points_np) | |
faces = Delaunay(points_np).simplices | |
is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=bool) | |
faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64)) | |
return faces_ | |
def __call__(self) -> torch.Tensor: | |
loss_angles = 0 | |
points = torch.cat(self.parameters).to(self.device) | |
angles = self.get_angles(points) | |
for i in range(len(self.faces)): | |
loss_angles += (nnf.mse_loss(angles[i], self.angles[i])) | |
return loss_angles | |