# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import torch from tqdm.auto import tqdm from torchvision import transforms import clip from pytorch_svgrender.libs.engine import ModelState from pytorch_svgrender.painter.clipdraw import Painter, PainterOptimizer from pytorch_svgrender.plt import plot_img, plot_couple class CLIPDrawPipeline(ModelState): def __init__(self, args): logdir_ = f"sd{args.seed}" \ f"-im{args.x.image_size}" \ f"-P{args.x.num_paths}" super().__init__(args, log_path_suffix=logdir_) # create log dir self.png_logs_dir = self.result_path / "png_logs" self.svg_logs_dir = self.result_path / "svg_logs" if self.accelerator.is_main_process: self.png_logs_dir.mkdir(parents=True, exist_ok=True) self.svg_logs_dir.mkdir(parents=True, exist_ok=True) # make video log self.make_video = self.args.mv if self.make_video: self.frame_idx = 0 self.frame_log_dir = self.result_path / "frame_logs" self.frame_log_dir.mkdir(parents=True, exist_ok=True) self.clip, self.tokenize_fn = self.init_clip() def init_clip(self): model, _ = clip.load('ViT-B/32', self.device, jit=False) return model, clip.tokenize def drawing_augment(self, image): augment_trans = transforms.Compose([ transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5), transforms.RandomResizedCrop(224, scale=(0.7, 0.9)), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) # image augmentation transformation img_augs = [] for n in range(self.x_cfg.num_aug): img_augs.append(augment_trans(image)) im_batch = torch.cat(img_augs) # clip visual encoding image_features = self.clip.encode_image(im_batch) return image_features def painterly_rendering(self, prompt): self.print(f"prompt: {prompt}") # text prompt encoding text_tokenize = self.tokenize_fn(prompt).to(self.device) with torch.no_grad(): text_features = self.clip.encode_text(text_tokenize) # init SVG Painter renderer = Painter(self.x_cfg, self.args.diffvg, num_strokes=self.x_cfg.num_paths, canvas_size=self.x_cfg.image_size, device=self.device) img = renderer.init_image(stage=0) self.print("init_image shape: ", img.shape) plot_img(img, self.result_path, fname="init_img") # init painter optimizer optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr) optimizer.init_optimizers() total_step = self.x_cfg.num_iter with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: while self.step < total_step: rendering = renderer.get_image(self.step).to(self.device) if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}") self.frame_idx += 1 # data augmentation aug_svg_batch = self.drawing_augment(rendering) loss = torch.tensor(0., device=self.device) for n in range(self.x_cfg.num_aug): loss -= torch.cosine_similarity(text_features, aug_svg_batch[n:n + 1], dim=1).mean() pbar.set_description( f"lr: {optimizer.get_lr():.3f}, " f"L_train: {loss.item():.4f}" ) # optimization optimizer.zero_grad_() loss.backward() optimizer.step_() renderer.clip_curve_shape() if self.x_cfg.lr_schedule: optimizer.update_lr(self.step) if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: plot_couple(img, rendering, self.step, prompt=prompt, output_dir=self.png_logs_dir.as_posix(), fname=f"iter{self.step}") renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") self.step += 1 pbar.update(1) renderer.save_svg(self.result_path.as_posix(), "final_svg") if self.make_video: from subprocess import call call([ "ffmpeg", "-framerate", f"{self.args.framerate}", "-i", (self.frame_log_dir / "iter%d.png").as_posix(), "-vb", "20M", (self.result_path / "clipdraw_rendering.mp4").as_posix() ]) self.close(msg="painterly rendering complete.")