# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: SVG Painter and ist optimizer from typing import Tuple import omegaconf import pydiffvg import torch import numpy as np from pytorch_svgrender.diffvg_warp import DiffVGState from pytorch_svgrender.utils import get_rgb_from_color class Painter(DiffVGState): def __init__(self, device=None): super().__init__(device) self.device = device self.strokes_counter = 0 # num of paths def init_shapes(self, path_svg, reinit_cfg: omegaconf.DictConfig = None): print(f"-> init svg from `{path_svg}` ...") self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(path_svg) self.strokes_counter = len(self.shapes) """re-init font color""" if reinit_cfg is not None: self.color_init(reinit_cfg) img = self.render_warp() img = img[:, :, 3:4] * img[:, :, :3] + \ torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - img[:, :, 3:4]) img = img[:, :, :3] # Convert img from HWC to NCHW img = img.unsqueeze(0) img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW return img def color_init(self, reinit_cfg: omegaconf.DictConfig): if not reinit_cfg.reinit: return if reinit_cfg.reinit_color == 'randn': for i, group in enumerate(self.shape_groups): color_val = np.random.random(size=3).tolist() + [1.0] group.fill_color = torch.FloatTensor(color_val) elif reinit_cfg.reinit_color == 'randn_all': color_val = np.random.random(size=3).tolist() + [1.0] for i, group in enumerate(self.shape_groups): group.fill_color = torch.FloatTensor(color_val) else: rgb = get_rgb_from_color(str(reinit_cfg.reinit_color)) color_val = list(rgb) + [1.0] for i, group in enumerate(self.shape_groups): group.fill_color = torch.FloatTensor(color_val) def clip_curve_shape(self): for group in self.shape_groups: group.fill_color.data.clamp_(0.0, 1.0) # force opacity group.fill_color.data[-1] = 1.0 def get_image(self): img = self.render_warp() opacity = img[:, :, 3:4] img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity) img = img[:, :, :3] # Convert img from HWC to NCHW img = img.unsqueeze(0) img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW return img def set_parameters(self): self.point_vars = [] # the strokes point optimization for i, path in enumerate(self.shapes): path.points.requires_grad = True self.point_vars.append(path.points) # the strokes color optimization self.color_vars = [] for i, group in enumerate(self.shape_groups): if group.fill_color is not None: group.fill_color.requires_grad = True self.color_vars.append(group.fill_color) if group.stroke_color is not None: group.stroke_color.requires_grad = True self.color_vars.append(group.stroke_color) def get_point_parameters(self): return self.point_vars def get_color_parameters(self): return self.color_vars def pretty_save_svg(self, filename, width=None, height=None, shapes=None, shape_groups=None): width = self.canvas_width if width is None else width height = self.canvas_height if height is None else height shapes = self.shapes if shapes is None else shapes shape_groups = self.shape_groups if shape_groups is None else shape_groups self.save_svg(filename, width, height, shapes, shape_groups, use_gamma=False, background=None) def load_svg(self, path_svg): canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg) return canvas_width, canvas_height, shapes, shape_groups class PainterOptimizer: def __init__(self, renderer: Painter, lr_cfg: omegaconf.DictConfig): self.renderer = renderer self.point_lr = lr_cfg.point self.color_lr = lr_cfg.color self.point_optimizer, self.color_optimizer = None, None def init_optimizers(self): self.renderer.set_parameters() self.point_optimizer = torch.optim.Adam([ {'params': self.renderer.get_point_parameters(), 'lr': self.point_lr}]) self.color_optimizer = torch.optim.Adam([ {'params': self.renderer.get_color_parameters(), 'lr': self.color_lr}]) def update_lr(self, step): pass def zero_grad_(self): self.point_optimizer.zero_grad() self.color_optimizer.zero_grad() def step_(self): self.point_optimizer.step() self.color_optimizer.step() def get_lr(self) -> Tuple[float, float]: return self.point_optimizer.param_groups[0]['lr'], self.color_optimizer.param_groups[0]['lr']