Spaces:
Running
Running
File size: 5,208 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description: SVG Painter and ist optimizer
from typing import Tuple
import omegaconf
import pydiffvg
import torch
import numpy as np
from pytorch_svgrender.diffvg_warp import DiffVGState
from pytorch_svgrender.utils import get_rgb_from_color
class Painter(DiffVGState):
def __init__(self, device=None):
super().__init__(device)
self.device = device
self.strokes_counter = 0 # num of paths
def init_shapes(self, path_svg, reinit_cfg: omegaconf.DictConfig = None):
print(f"-> init svg from `{path_svg}` ...")
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(path_svg)
self.strokes_counter = len(self.shapes)
"""re-init font color"""
if reinit_cfg is not None:
self.color_init(reinit_cfg)
img = self.render_warp()
img = img[:, :, 3:4] * img[:, :, :3] + \
torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - img[:, :, 3:4])
img = img[:, :, :3]
# Convert img from HWC to NCHW
img = img.unsqueeze(0)
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
return img
def color_init(self, reinit_cfg: omegaconf.DictConfig):
if not reinit_cfg.reinit:
return
if reinit_cfg.reinit_color == 'randn':
for i, group in enumerate(self.shape_groups):
color_val = np.random.random(size=3).tolist() + [1.0]
group.fill_color = torch.FloatTensor(color_val)
elif reinit_cfg.reinit_color == 'randn_all':
color_val = np.random.random(size=3).tolist() + [1.0]
for i, group in enumerate(self.shape_groups):
group.fill_color = torch.FloatTensor(color_val)
else:
rgb = get_rgb_from_color(str(reinit_cfg.reinit_color))
color_val = list(rgb) + [1.0]
for i, group in enumerate(self.shape_groups):
group.fill_color = torch.FloatTensor(color_val)
def clip_curve_shape(self):
for group in self.shape_groups:
group.fill_color.data.clamp_(0.0, 1.0)
# force opacity
group.fill_color.data[-1] = 1.0
def get_image(self):
img = self.render_warp()
opacity = img[:, :, 3:4]
img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity)
img = img[:, :, :3]
# Convert img from HWC to NCHW
img = img.unsqueeze(0)
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
return img
def set_parameters(self):
self.point_vars = []
# the strokes point optimization
for i, path in enumerate(self.shapes):
path.points.requires_grad = True
self.point_vars.append(path.points)
# the strokes color optimization
self.color_vars = []
for i, group in enumerate(self.shape_groups):
if group.fill_color is not None:
group.fill_color.requires_grad = True
self.color_vars.append(group.fill_color)
if group.stroke_color is not None:
group.stroke_color.requires_grad = True
self.color_vars.append(group.stroke_color)
def get_point_parameters(self):
return self.point_vars
def get_color_parameters(self):
return self.color_vars
def pretty_save_svg(self, filename, width=None, height=None, shapes=None, shape_groups=None):
width = self.canvas_width if width is None else width
height = self.canvas_height if height is None else height
shapes = self.shapes if shapes is None else shapes
shape_groups = self.shape_groups if shape_groups is None else shape_groups
self.save_svg(filename, width, height, shapes, shape_groups, use_gamma=False, background=None)
def load_svg(self, path_svg):
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
return canvas_width, canvas_height, shapes, shape_groups
class PainterOptimizer:
def __init__(self, renderer: Painter, lr_cfg: omegaconf.DictConfig):
self.renderer = renderer
self.point_lr = lr_cfg.point
self.color_lr = lr_cfg.color
self.point_optimizer, self.color_optimizer = None, None
def init_optimizers(self):
self.renderer.set_parameters()
self.point_optimizer = torch.optim.Adam([
{'params': self.renderer.get_point_parameters(), 'lr': self.point_lr}])
self.color_optimizer = torch.optim.Adam([
{'params': self.renderer.get_color_parameters(), 'lr': self.color_lr}])
def update_lr(self, step):
pass
def zero_grad_(self):
self.point_optimizer.zero_grad()
self.color_optimizer.zero_grad()
def step_(self):
self.point_optimizer.step()
self.color_optimizer.step()
def get_lr(self) -> Tuple[float, float]:
return self.point_optimizer.param_groups[0]['lr'], self.color_optimizer.param_groups[0]['lr']
|