# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import shutil from PIL import Image from pathlib import Path import torch from torchvision import transforms import clip from tqdm.auto import tqdm import numpy as np from pytorch_svgrender.libs.engine import ModelState from pytorch_svgrender.painter.style_clipdraw import ( Painter, PainterOptimizer, VGG16Extractor, StyleLoss, sample_indices ) from pytorch_svgrender.plt import plot_img, plot_couple class StyleCLIPDrawPipeline(ModelState): def __init__(self, args): logdir_ = f"sd{args.seed}" \ f"-P{args.x.num_paths}" \ f"-style{args.x.style_strength}" \ f"-n{args.x.num_aug}" super().__init__(args, log_path_suffix=logdir_) # create log dir self.png_logs_dir = self.result_path / "png_logs" self.svg_logs_dir = self.result_path / "svg_logs" if self.accelerator.is_main_process: self.png_logs_dir.mkdir(parents=True, exist_ok=True) self.svg_logs_dir.mkdir(parents=True, exist_ok=True) # make video log self.make_video = self.args.mv if self.make_video: self.frame_idx = 0 self.frame_log_dir = self.result_path / "frame_logs" self.frame_log_dir.mkdir(parents=True, exist_ok=True) self.clip, self.tokenize_fn = self.init_clip() self.style_extractor = VGG16Extractor(space="normal").to(self.device) self.style_loss = StyleLoss() def init_clip(self): model, _ = clip.load('ViT-B/32', self.device, jit=False) return model, clip.tokenize def drawing_augment(self, image): augment_trans = transforms.Compose([ transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5), transforms.RandomResizedCrop(224, scale=(0.7, 0.9)), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) # image augmentation transformation img_augs = [] for n in range(self.x_cfg.num_aug): img_augs.append(augment_trans(image)) im_batch = torch.cat(img_augs) # clip visual encoding image_features = self.clip.encode_image(im_batch) return image_features def style_file_preprocess(self, style_file): process_comp = transforms.Compose([ transforms.Resize(size=(224, 224)), transforms.ToTensor(), transforms.Lambda(lambda t: t.unsqueeze(0)), transforms.Lambda(lambda t: (t + 1) / 2), ]) style_file = process_comp(style_file) style_file = style_file.to(self.device) return style_file def painterly_rendering(self, prompt, style_fpath): # load style file style_path = Path(style_fpath) assert style_path.exists(), f"{style_fpath} is not exist!" self.print(f"load style file from: {style_path.as_posix()}") style_pil = Image.open(style_path.as_posix()).convert("RGB") style_img = self.style_file_preprocess(style_pil) shutil.copy(style_fpath, self.result_path) # copy style file # extract style features from style image feat_style = None for i in range(5): with torch.no_grad(): # r is region of interest (mask) feat_e = self.style_extractor.forward_samples_hypercolumn(style_img, samps=1000) feat_style = feat_e if feat_style is None else torch.cat((feat_style, feat_e), dim=2) # text prompt encoding self.print(f"prompt: {prompt}") text_tokenize = self.tokenize_fn(prompt).to(self.device) with torch.no_grad(): text_features = self.clip.encode_text(text_tokenize) renderer = Painter(self.x_cfg, self.args.diffvg, num_strokes=self.x_cfg.num_paths, canvas_size=self.x_cfg.image_size, device=self.device) img = renderer.init_image(stage=0) self.print("init_image shape: ", img.shape) plot_img(img, self.result_path, fname="init_img") optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr) optimizer.init_optimizers() style_weight = 4 * (self.x_cfg.style_strength / 100) self.print(f'style_weight: {style_weight}') total_step = self.x_cfg.num_iter with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: while self.step < total_step: rendering = renderer.get_image(self.step).to(self.device) if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}") self.frame_idx += 1 rendering_aug = self.drawing_augment(rendering) loss = torch.tensor(0., device=self.device) # do clip optimization if self.step < 0.9 * total_step: for n in range(self.x_cfg.num_aug): loss -= torch.cosine_similarity(text_features, rendering_aug[n:n + 1], dim=1).mean() # do style optimization # extract style features based on the approach from STROTSS [Kolkin et al., 2019]. feat_content = self.style_extractor(rendering) xx, xy = sample_indices(feat_content[0], feat_style) np.random.shuffle(xx) np.random.shuffle(xy) L_style = self.style_loss.forward(feat_content, feat_content, feat_style, [xx, xy], 0) loss += L_style * style_weight pbar.set_description( f"lr: {optimizer.get_lr():.3f}, " f"L_train: {loss.item():.4f}, " f"L_style: {L_style.item():.4f}" ) # optimization optimizer.zero_grad_() loss.backward() optimizer.step_() renderer.clip_curve_shape() if self.x_cfg.lr_schedule: optimizer.update_lr(self.step) if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: plot_couple(style_img, rendering, self.step, prompt=prompt, output_dir=self.png_logs_dir.as_posix(), fname=f"iter{self.step}") renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") self.step += 1 pbar.update(1) plot_couple(style_img, rendering, self.step, prompt=prompt, output_dir=self.result_path.as_posix(), fname=f"final_iter") renderer.save_svg(self.result_path.as_posix(), "final_svg") if self.make_video: from subprocess import call call([ "ffmpeg", "-framerate", f"{self.args.framerate}", "-i", (self.frame_log_dir / "iter%d.png").as_posix(), "-vb", "20M", (self.result_path / "styleclipdraw_rendering.mp4").as_posix() ]) self.close(msg="painterly rendering complete.")