hjc-owo
init repo
966ae59
raw
history blame contribute delete
No virus
7.73 kB
# -*- coding: utf-8 -*-
# Author: ximing
# Description: LIVE pipeline
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
import shutil
from pathlib import Path
from typing import AnyStr
from PIL import Image
from tqdm.auto import tqdm
import torch
from torchvision import transforms
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.live import Painter, PainterOptimizer, xing_loss_fn
from pytorch_svgrender.plt import plot_img, plot_couple
class LIVEPipeline(ModelState):
def __init__(self, args):
logdir_ = f"sd{args.seed}" \
f"-im{args.x.image_size}" \
f"-P{args.x.num_paths}"
super().__init__(args, log_path_suffix=logdir_)
# create log dir
self.png_logs_dir = self.result_path / "png_logs"
self.svg_logs_dir = self.result_path / "svg_logs"
if self.accelerator.is_main_process:
self.png_logs_dir.mkdir(parents=True, exist_ok=True)
self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = self.result_path / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
def get_path_schedule(self, schedule_each):
if self.x_cfg.path_schedule == 'repeat':
return int(self.x_cfg.num_paths / schedule_each) * [schedule_each]
elif self.x_cfg.path_schedule == 'list':
assert isinstance(self.x_cfg.schedule_each, list)
return schedule_each
else:
raise NotImplementedError
def target_file_preprocess(self, tar_path):
process_comp = transforms.Compose([
transforms.Resize(size=(self.x_cfg.image_size, self.x_cfg.image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda t: t.unsqueeze(0)),
])
tar_pil = Image.open(tar_path).convert("RGB") # open file
target_img = process_comp(tar_pil) # preprocess
target_img = target_img.to(self.device)
return target_img
def painterly_rendering(self, img_path: AnyStr):
# load target file
target_file = Path(img_path)
assert target_file.exists(), f"{target_file} is not exist!"
shutil.copy(target_file, self.result_path) # copy target file
target_img = self.target_file_preprocess(target_file.as_posix())
self.print(f"load image file from: '{target_file.as_posix()}'")
# log path_schedule
path_schedule = self.get_path_schedule(self.x_cfg.schedule_each)
self.print(f"path_schedule: {path_schedule}")
renderer = Painter(target_img,
self.args.diffvg,
self.x_cfg.num_segments,
self.x_cfg.segment_init,
self.x_cfg.radius,
canvas_size=self.x_cfg.image_size,
trainable_bg=self.x_cfg.trainable_bg,
stroke=self.x_cfg.train_stroke,
stroke_width=self.x_cfg.width,
device=self.device)
# first init center
renderer.component_wise_path_init(pred=None, init_type=self.x_cfg.coord_init)
num_iter = self.x_cfg.num_iter
optimizer_list = [
PainterOptimizer(renderer, num_iter, self.x_cfg.lr_base,
self.x_cfg.train_stroke, self.x_cfg.trainable_bg)
for _ in range(len(path_schedule))
]
pathn_record = []
loss_weight_keep = 0
loss_weight = 1
total_step = len(path_schedule) * num_iter
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
for path_idx, pathn in enumerate(path_schedule):
# record path
pathn_record.append(pathn)
# init graphic
img = renderer.init_image(num_paths=pathn)
plot_img(img, self.result_path, fname=f"init_img_{path_idx}")
# rebuild optimizer
optimizer_list[path_idx].init_optimizers()
pbar.write(f"=> adding {pathn} paths, n_path: {sum(pathn_record)}, "
f"path_schedule: {self.x_cfg.path_schedule}")
for t in range(num_iter):
raster_img = renderer.get_image(step=t).to(self.device)
if self.make_video and (t % self.args.framefreq == 0 or t == num_iter - 1):
plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.frame_idx}")
self.frame_idx += 1
if self.x_cfg.use_distance_weighted_loss:
loss_weight = renderer.calc_distance_weight(loss_weight_keep)
# UDF Loss for Reconstruction
if self.x_cfg.use_l1_loss:
loss_recon = torch.nn.functional.l1_loss(raster_img, target_img)
else: # default: MSE loss
loss_mse = ((raster_img - target_img) ** 2)
loss_recon = (loss_mse.sum(1) * loss_weight).mean()
# Xing Loss for Self-Interaction Problem
loss_xing = xing_loss_fn(renderer.get_point_parameters()) * self.x_cfg.xing_loss_weight
# total loss
loss = loss_recon + loss_xing
pbar.set_description(
f"lr: {optimizer_list[path_idx].get_lr():.4f}, "
f"L_total: {loss.item():.4f}, "
f"L_recon: {loss_recon.item():.4f}, "
f"L_xing: {loss_xing.item()}"
)
# optimization
for i in range(path_idx + 1):
optimizer_list[i].zero_grad_()
loss.backward()
for i in range(path_idx + 1):
optimizer_list[i].step_()
renderer.clip_curve_shape()
if self.x_cfg.lr_schedule:
for i in range(path_idx + 1):
optimizer_list[i].update_lr()
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
plot_couple(target_img,
raster_img,
self.step,
output_dir=self.png_logs_dir.as_posix(),
fname=f"iter{self.step}")
renderer.save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg")
self.step += 1
pbar.update(1)
# end a set of path optimization
if self.x_cfg.use_distance_weighted_loss:
loss_weight_keep = loss_weight.detach().cpu().numpy() * 1
# recalculate the coordinates for the new join path
renderer.component_wise_path_init(pred=raster_img, init_type=self.x_cfg.coord_init)
renderer.save_svg(self.result_path / "final_svg.svg")
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(self.result_path / "live_rendering.mp4").as_posix()
])
self.close(msg="painterly rendering complete.")