File size: 5,208 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description: SVG Painter and ist optimizer

from typing import Tuple

import omegaconf
import pydiffvg
import torch
import numpy as np

from pytorch_svgrender.diffvg_warp import DiffVGState
from pytorch_svgrender.utils import get_rgb_from_color


class Painter(DiffVGState):

    def __init__(self, device=None):
        super().__init__(device)
        self.device = device

        self.strokes_counter = 0  # num of paths

    def init_shapes(self, path_svg, reinit_cfg: omegaconf.DictConfig = None):
        print(f"-> init svg from `{path_svg}` ...")
        self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(path_svg)
        self.strokes_counter = len(self.shapes)

        """re-init font color"""
        if reinit_cfg is not None:
            self.color_init(reinit_cfg)

        img = self.render_warp()
        img = img[:, :, 3:4] * img[:, :, :3] + \
              torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - img[:, :, 3:4])
        img = img[:, :, :3]
        # Convert img from HWC to NCHW
        img = img.unsqueeze(0)
        img = img.permute(0, 3, 1, 2).to(self.device)  # NHWC -> NCHW
        return img

    def color_init(self, reinit_cfg: omegaconf.DictConfig):
        if not reinit_cfg.reinit:
            return

        if reinit_cfg.reinit_color == 'randn':
            for i, group in enumerate(self.shape_groups):
                color_val = np.random.random(size=3).tolist() + [1.0]
                group.fill_color = torch.FloatTensor(color_val)
        elif reinit_cfg.reinit_color == 'randn_all':
            color_val = np.random.random(size=3).tolist() + [1.0]
            for i, group in enumerate(self.shape_groups):
                group.fill_color = torch.FloatTensor(color_val)
        else:
            rgb = get_rgb_from_color(str(reinit_cfg.reinit_color))
            color_val = list(rgb) + [1.0]
            for i, group in enumerate(self.shape_groups):
                group.fill_color = torch.FloatTensor(color_val)

    def clip_curve_shape(self):
        for group in self.shape_groups:
            group.fill_color.data.clamp_(0.0, 1.0)
            # force opacity
            group.fill_color.data[-1] = 1.0

    def get_image(self):
        img = self.render_warp()
        opacity = img[:, :, 3:4]
        img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity)
        img = img[:, :, :3]
        # Convert img from HWC to NCHW
        img = img.unsqueeze(0)
        img = img.permute(0, 3, 1, 2).to(self.device)  # NHWC -> NCHW
        return img

    def set_parameters(self):
        self.point_vars = []
        # the strokes point optimization
        for i, path in enumerate(self.shapes):
            path.points.requires_grad = True
            self.point_vars.append(path.points)

        # the strokes color optimization
        self.color_vars = []
        for i, group in enumerate(self.shape_groups):
            if group.fill_color is not None:
                group.fill_color.requires_grad = True
                self.color_vars.append(group.fill_color)
            if group.stroke_color is not None:
                group.stroke_color.requires_grad = True
                self.color_vars.append(group.stroke_color)

    def get_point_parameters(self):
        return self.point_vars

    def get_color_parameters(self):
        return self.color_vars

    def pretty_save_svg(self, filename, width=None, height=None, shapes=None, shape_groups=None):
        width = self.canvas_width if width is None else width
        height = self.canvas_height if height is None else height
        shapes = self.shapes if shapes is None else shapes
        shape_groups = self.shape_groups if shape_groups is None else shape_groups

        self.save_svg(filename, width, height, shapes, shape_groups, use_gamma=False, background=None)

    def load_svg(self, path_svg):
        canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
        return canvas_width, canvas_height, shapes, shape_groups


class PainterOptimizer:

    def __init__(self, renderer: Painter, lr_cfg: omegaconf.DictConfig):
        self.renderer = renderer
        self.point_lr = lr_cfg.point
        self.color_lr = lr_cfg.color
        self.point_optimizer, self.color_optimizer = None, None

    def init_optimizers(self):
        self.renderer.set_parameters()
        self.point_optimizer = torch.optim.Adam([
            {'params': self.renderer.get_point_parameters(), 'lr': self.point_lr}])
        self.color_optimizer = torch.optim.Adam([
            {'params': self.renderer.get_color_parameters(), 'lr': self.color_lr}])

    def update_lr(self, step):
        pass

    def zero_grad_(self):
        self.point_optimizer.zero_grad()
        self.color_optimizer.zero_grad()

    def step_(self):
        self.point_optimizer.step()
        self.color_optimizer.step()

    def get_lr(self) -> Tuple[float, float]:
        return self.point_optimizer.param_groups[0]['lr'], self.color_optimizer.param_groups[0]['lr']