# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import pathlib from PIL import Image from typing import AnyStr import numpy as np from tqdm.auto import tqdm import torch from torch.optim.lr_scheduler import LambdaLR import torchvision from torchvision import transforms from pytorch_svgrender.libs.engine import ModelState from pytorch_svgrender.libs.solver.optim import get_optimizer from pytorch_svgrender.painter.svgdreamer import Painter, PainterOptimizer from pytorch_svgrender.painter.svgdreamer.painter_params import CosineWithWarmupLRLambda from pytorch_svgrender.painter.live import xing_loss_fn from pytorch_svgrender.painter.svgdreamer import VectorizedParticleSDSPipeline from pytorch_svgrender.plt import plot_img from pytorch_svgrender.utils.color_attrs import init_tensor_with_color from pytorch_svgrender.token2attn.ptp_utils import view_images from pytorch_svgrender.diffusers_warp import model2res import ImageReward as RM class SVGDreamerPipeline(ModelState): def __init__(self, args): assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"] assert args.x.guidance.n_particle >= args.x.guidance.vsd_n_particle assert args.x.guidance.n_particle >= args.x.guidance.phi_n_particle assert args.x.guidance.n_phi_sample >= 1 logdir_ = f"sd{args.seed}" \ f"-{'vpsd' if args.x.skip_sive else 'sive'}" \ f"-{args.x.model_id}" \ f"-{args.x.style}" \ f"-P{args.x.num_paths}" \ f"{'-RePath' if args.x.path_reinit.use else ''}" super().__init__(args, log_path_suffix=logdir_) # create log dir self.png_logs_dir = self.result_path / "png_logs" self.svg_logs_dir = self.result_path / "svg_logs" self.ft_png_logs_dir = self.result_path / "ft_png_logs" self.ft_svg_logs_dir = self.result_path / "ft_svg_logs" self.sd_sample_dir = self.result_path / 'sd_samples' self.reinit_dir = self.result_path / "reinit_logs" self.init_stage_two_dir = self.result_path / "stage_two_init_logs" self.phi_samples_dir = self.result_path / "phi_sampling_logs" if self.accelerator.is_main_process: self.png_logs_dir.mkdir(parents=True, exist_ok=True) self.svg_logs_dir.mkdir(parents=True, exist_ok=True) self.ft_png_logs_dir.mkdir(parents=True, exist_ok=True) self.ft_svg_logs_dir.mkdir(parents=True, exist_ok=True) self.sd_sample_dir.mkdir(parents=True, exist_ok=True) self.reinit_dir.mkdir(parents=True, exist_ok=True) self.init_stage_two_dir.mkdir(parents=True, exist_ok=True) self.phi_samples_dir.mkdir(parents=True, exist_ok=True) self.select_fpth = self.result_path / 'select_sample.png' # make video log self.make_video = self.args.mv if self.make_video: self.frame_idx = 0 self.frame_log_dir = self.result_path / "frame_logs" self.frame_log_dir.mkdir(parents=True, exist_ok=True) self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) self.pipeline = VectorizedParticleSDSPipeline(args, args.diffuser, self.x_cfg.guidance, self.device) # load reward model self.reward_model = None if self.x_cfg.guidance.phi_ReFL: self.reward_model = RM.load("ImageReward-v1.0", device=self.device, download_root=self.x_cfg.reward_path) self.style = self.x_cfg.style if self.style == "pixelart": self.x_cfg.lr_stage_one.lr_schedule = False self.x_cfg.lr_stage_two.lr_schedule = False def target_file_preprocess(self, tar_path: AnyStr): process_comp = transforms.Compose([ transforms.Resize(size=(self.x_cfg.image_size, self.x_cfg.image_size)), transforms.ToTensor(), transforms.Lambda(lambda t: t.unsqueeze(0)), ]) tar_pil = Image.open(tar_path).convert("RGB") # open file target_img = process_comp(tar_pil) # preprocess target_img = target_img.to(self.device) return target_img def SIVE_stage(self, text_prompt: str): # TODO: SIVE implementation pass def painterly_rendering(self, text_prompt: str, target_file: AnyStr = None): # log prompts self.print(f"prompt: {text_prompt}") self.print(f"neg_prompt: {self.args.neg_prompt}\n") # for convenience im_size = self.x_cfg.image_size guidance_cfg = self.x_cfg.guidance n_particle = self.x_cfg.guidance.n_particle total_step = self.x_cfg.guidance.num_iter path_reinit = self.x_cfg.path_reinit init_from_target = True if (target_file and pathlib.Path(target_file).exists()) else False # switch mode if self.x_cfg.skip_sive and not init_from_target: # mode 1: optimization with VPSD from scratch # randomly init self.print("optimization with VPSD from scratch...") if self.x_cfg.color_init == 'rand': target_img = torch.randn(1, 3, im_size, im_size) self.print("color: randomly init") else: target_img = init_tensor_with_color(self.x_cfg.color_init, 1, im_size, im_size) self.print(f"color: {self.x_cfg.color_init}") # log init target_img plot_img(target_img, self.result_path, fname='init_target_img') final_svg_path = None elif init_from_target: # mode 2: load the SVG file and finetune it self.print(f"load svg from {target_file} ...") self.print(f"SVG fine-tuning via VPSD...") final_svg_path = target_file if self.x_cfg.color_init == 'target_randn': # special order: init newly paths color use random color target_img = torch.randn(1, 3, im_size, im_size) self.print("color: randomly init") else: # load the SVG and init newly paths color use target_img # note: the target will be converted to png via pydiffvg when load_renderer called target_img = None else: # mode 3: text-to-img-to-svg (two stage) target_img, final_svg_path = self.SIVE_stage(text_prompt) self.x_cfg.path_svg = final_svg_path self.print("\n SVG fine-tuning via VPSD...") plot_img(target_img, self.result_path, fname='init_target_img') # create svg renderer renderers = [self.load_renderer(final_svg_path) for _ in range(n_particle)] # randomly initialize the particles if self.x_cfg.skip_sive or init_from_target: if target_img is None: target_img = self.target_file_preprocess(self.result_path / 'target_img.png') for render in renderers: render.component_wise_path_init(gt=target_img, pred=None, init_type='random') # log init images for i, r in enumerate(renderers): init_imgs = r.init_image(stage=0, num_paths=self.x_cfg.num_paths) plot_img(init_imgs, self.init_stage_two_dir, fname=f"init_img_stage_two_{i}") # init renderer optimizer optimizers = [] for renderer in renderers: optim_ = PainterOptimizer(renderer, self.style, guidance_cfg.num_iter, self.x_cfg.lr_stage_two, self.x_cfg.trainable_bg) optim_.init_optimizers() optimizers.append(optim_) # init phi_model optimizer phi_optimizer = get_optimizer('adamW', self.pipeline.phi_params, guidance_cfg.phi_lr, guidance_cfg.phi_optim) # init phi_model lr scheduler phi_scheduler = None schedule_cfg = guidance_cfg.phi_schedule if schedule_cfg.use: phi_lr_lambda = CosineWithWarmupLRLambda(num_steps=schedule_cfg.total_step, warmup_steps=schedule_cfg.warmup_steps, warmup_start_lr=schedule_cfg.warmup_start_lr, warmup_end_lr=schedule_cfg.warmup_end_lr, cosine_end_lr=schedule_cfg.cosine_end_lr) phi_scheduler = LambdaLR(phi_optimizer, lr_lambda=phi_lr_lambda, last_epoch=-1) self.print(f"-> Painter point Params: {len(renderers[0].get_point_parameters())}") self.print(f"-> Painter color Params: {len(renderers[0].get_color_parameters())}") self.print(f"-> Painter width Params: {len(renderers[0].get_width_parameters())}") L_reward = torch.tensor(0.) self.step = 0 # reset global step self.print(f"\ntotal VPSD optimization steps: {total_step}") with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: while self.step < total_step: # set particles particles = [renderer.get_image() for renderer in renderers] raster_imgs = torch.cat(particles, dim=0) if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): plot_img(raster_imgs, self.frame_log_dir, fname=f"iter{self.frame_idx}") self.frame_idx += 1 L_guide, grad, latents, t_step = self.pipeline.variational_score_distillation( raster_imgs, self.step, prompt=[text_prompt], negative_prompt=self.args.neg_prompt, grad_scale=guidance_cfg.grad_scale, enhance_particle=guidance_cfg.particle_aug, im_size=model2res(self.x_cfg.model_id) ) # Xing Loss for Self-Interaction Problem L_add = torch.tensor(0.) if self.style == "iconography" or self.x_cfg.xing_loss.use: for r in renderers: L_add += xing_loss_fn(r.get_point_parameters()) * self.x_cfg.xing_loss.weight loss = L_guide + L_add # optimization for opt_ in optimizers: opt_.zero_grad_() loss.backward() for opt_ in optimizers: opt_.step_() # phi_model optimization for _ in range(guidance_cfg.phi_update_step): L_lora = self.pipeline.train_phi_model(latents, guidance_cfg.phi_t, as_latent=True) phi_optimizer.zero_grad() L_lora.backward() phi_optimizer.step() # reward learning if guidance_cfg.phi_ReFL and self.step % guidance_cfg.phi_sample_step == 0: with torch.no_grad(): phi_outputs = [] phi_sample_paths = [] for idx in range(guidance_cfg.n_phi_sample): phi_output = self.pipeline.sample(text_prompt, num_inference_steps=guidance_cfg.phi_infer_step, generator=self.g_device) sample_path = (self.phi_samples_dir / f'iter{idx}.png').as_posix() phi_output.images[0].save(sample_path) phi_sample_paths.append(sample_path) phi_output_np = np.array(phi_output.images[0]) phi_outputs.append(phi_output_np) # save all samples view_images(phi_outputs, save_image=True, num_rows=max(len(phi_outputs) // 6, 1), fp=self.phi_samples_dir / f'samples_iter{self.step}.png') ranking, rewards = self.reward_model.inference_rank(text_prompt, phi_sample_paths) self.print(f"ranking: {ranking}, reward score: {rewards}") for k in range(guidance_cfg.n_phi_sample): phi = self.target_file_preprocess(phi_sample_paths[ranking[k] - 1]) L_reward = self.pipeline.train_phi_model_refl(phi, weight=rewards[k]) phi_optimizer.zero_grad() L_reward.backward() phi_optimizer.step() # update the learning rate of the phi_model if phi_scheduler is not None: phi_scheduler.step() # curve regularization for r in renderers: r.clip_curve_shape() # re-init paths if self.step % path_reinit.freq == 0 and self.step < path_reinit.stop_step and self.step != 0: for i, r in enumerate(renderers): r.reinitialize_paths(path_reinit.use, # on-off path_reinit.opacity_threshold, path_reinit.area_threshold, fpath=self.reinit_dir / f"reinit-{self.step}_p{i}.svg") # update lr if self.x_cfg.lr_stage_two.lr_schedule: for opt_ in optimizers: opt_.update_lr() # log pretrained model lr lr_str = "" for k, lr in optimizers[0].get_lr().items(): lr_str += f"{k}_lr: {lr:.4f}, " # log phi model lr cur_phi_lr = phi_optimizer.param_groups[0]['lr'] lr_str += f"phi_lr: {cur_phi_lr:.3e}, " pbar.set_description( lr_str + f"t: {t_step.item():.2f}, " f"L_total: {loss.item():.4f}, " f"L_add: {L_add.item():.4e}, " f"L_lora: {L_lora.item():.4f}, " f"L_reward: {L_reward.item():.4f}, " f"vpsd: {grad.item():.4e}" ) if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: # save png torchvision.utils.save_image(raster_imgs, fp=self.ft_png_logs_dir / f'iter{self.step}.png') # save svg for i, r in enumerate(renderers): r.pretty_save_svg(self.ft_svg_logs_dir / f"svg_iter{self.step}_p{i}.svg") self.step += 1 pbar.update(1) # save final for i, r in enumerate(renderers): final_svg_path = self.result_path / f"finetune_final_p_{i}.svg" r.pretty_save_svg(final_svg_path) # save SVGs torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png') if self.make_video: from subprocess import call call([ "ffmpeg", "-framerate", f"{self.args.framerate}", "-i", (self.frame_log_dir / "iter%d.png").as_posix(), "-vb", "20M", (self.result_path / "svgdreamer_rendering.mp4").as_posix() ]) self.close(msg="painterly rendering complete.") def load_renderer(self, path_svg=None): renderer = Painter(self.args.diffvg, self.style, self.x_cfg.num_segments, self.x_cfg.segment_init, self.x_cfg.radius, self.x_cfg.image_size, self.x_cfg.grid, self.x_cfg.trainable_bg, self.x_cfg.width, path_svg=path_svg, device=self.device) # if load a svg file, then rasterize it save_path = self.result_path / 'target_img.png' if path_svg is not None and (not save_path.exists()): canvas_width, canvas_height, shapes, shape_groups = renderer.load_svg(path_svg) render_img = renderer.render_image(canvas_width, canvas_height, shapes, shape_groups) torchvision.utils.save_image(render_img, fp=save_path) return renderer