# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: import pathlib from PIL import Image from functools import partial import torch import torch.nn.functional as F from torchvision import transforms from torchvision.datasets.folder import is_image_file from tqdm.auto import tqdm import numpy as np from skimage.color import rgb2gray import diffusers from pytorch_svgrender.libs.engine import ModelState from pytorch_svgrender.libs.metric.lpips_origin import LPIPS from pytorch_svgrender.libs.metric.piq.perceptual import DISTS as DISTS_PIQ from pytorch_svgrender.libs.metric.clip_score import CLIPScoreWrapper from pytorch_svgrender.painter.diffsketcher import ( Painter, SketchPainterOptimizer, Token2AttnMixinASDSPipeline, Token2AttnMixinASDSSDXLPipeline) from pytorch_svgrender.plt import plot_img, plot_couple from pytorch_svgrender.painter.diffsketcher.sketch_utils import plt_attn from pytorch_svgrender.painter.clipasso.sketch_utils import get_mask_u2net, fix_image_scale from pytorch_svgrender.painter.diffsketcher.stroke_pruning import paths_pruning from pytorch_svgrender.token2attn.attn_control import AttentionStore, EmptyControl from pytorch_svgrender.token2attn.ptp_utils import view_images from pytorch_svgrender.diffusers_warp import init_StableDiffusion_pipeline, model2res class DiffSketcherPipeline(ModelState): def __init__(self, args): attn_log_ = "" if args.x.attention_init: attn_log_ = f"-tk{args.x.token_ind}" \ f"{'-XDoG' if args.x.xdog_intersec else ''}" \ f"-atc{args.x.attn_coeff}-tau{args.x.softmax_temp}" logdir_ = f"sd{args.seed}-im{args.x.image_size}" \ f"-P{args.x.num_paths}W{args.x.width}{'OP' if args.x.optim_opacity else 'BL'}" \ f"{attn_log_}" super().__init__(args, log_path_suffix=logdir_) # create log dir self.png_logs_dir = self.result_path / "png_logs" self.svg_logs_dir = self.result_path / "svg_logs" self.attn_logs_dir = self.result_path / "attn_logs" if self.accelerator.is_main_process: self.png_logs_dir.mkdir(parents=True, exist_ok=True) self.svg_logs_dir.mkdir(parents=True, exist_ok=True) self.attn_logs_dir.mkdir(parents=True, exist_ok=True) # make video log self.make_video = self.args.mv if self.make_video: self.frame_idx = 0 self.frame_log_dir = self.result_path / "frame_logs" self.frame_log_dir.mkdir(parents=True, exist_ok=True) if self.x_cfg.model_id == "sdxl": # default LSDSSDXLPipeline scheduler is EulerDiscreteScheduler # when LSDSSDXLPipeline calls, scheduler.timesteps will change in step 4 # which causes problem in sds add_noise() function # because the random t may not in scheduler.timesteps custom_pipeline = Token2AttnMixinASDSSDXLPipeline custom_scheduler = diffusers.DPMSolverMultistepScheduler self.x_cfg.cross_attn_res = self.x_cfg.cross_attn_res * 2 elif self.x_cfg.model_id == 'sd21': custom_pipeline = Token2AttnMixinASDSPipeline custom_scheduler = diffusers.DDIMScheduler else: # sd14, sd15 custom_pipeline = Token2AttnMixinASDSPipeline custom_scheduler = diffusers.DDIMScheduler self.diffusion = init_StableDiffusion_pipeline( self.x_cfg.model_id, custom_pipeline=custom_pipeline, custom_scheduler=custom_scheduler, device=self.device, local_files_only=not args.diffuser.download, force_download=args.diffuser.force_download, resume_download=args.diffuser.resume_download, ldm_speed_up=self.x_cfg.ldm_speed_up, enable_xformers=self.x_cfg.enable_xformers, gradient_checkpoint=self.x_cfg.gradient_checkpoint, ) self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) # init clip model and clip score wrapper self.cargs = self.x_cfg.clip self.clip_score_fn = CLIPScoreWrapper(self.cargs.model_name, device=self.device, visual_score=True, feats_loss_type=self.cargs.feats_loss_type, feats_loss_weights=self.cargs.feats_loss_weights, fc_loss_weight=self.cargs.fc_loss_weight) def load_render(self, target_img, attention_map, mask=None): renderer = Painter(self.x_cfg, self.args.diffvg, num_strokes=self.x_cfg.num_paths, num_segments=self.x_cfg.num_segments, canvas_size=self.x_cfg.image_size, device=self.device, target_im=target_img, attention_map=attention_map, mask=mask) return renderer def extract_ldm_attn(self, prompts): # init controller controller = AttentionStore() if self.x_cfg.attention_init else EmptyControl() height = width = model2res(self.x_cfg.model_id) outputs = self.diffusion(prompt=[prompts], negative_prompt=self.args.neg_prompt, height=height, width=width, controller=controller, num_inference_steps=self.x_cfg.num_inference_steps, guidance_scale=self.x_cfg.guidance_scale, generator=self.g_device) target_file = self.result_path / "ldm_generated_image.png" view_images([np.array(img) for img in outputs.images], save_image=True, fp=target_file) if self.x_cfg.attention_init: """ldm cross-attention map""" cross_attention_maps, tokens = \ self.diffusion.get_cross_attention([prompts], controller, res=self.x_cfg.cross_attn_res, from_where=("up", "down"), save_path=self.result_path / "cross_attn.png") self.print(f"the length of tokens is {len(tokens)}, select {self.x_cfg.token_ind}-th token") # [res, res, seq_len] self.print(f"origin cross_attn_map shape: {cross_attention_maps.shape}") # [res, res] cross_attn_map = cross_attention_maps[:, :, self.x_cfg.token_ind] self.print(f"select cross_attn_map shape: {cross_attn_map.shape}\n") cross_attn_map = 255 * cross_attn_map / cross_attn_map.max() # [res, res, 3] cross_attn_map = cross_attn_map.unsqueeze(-1).expand(*cross_attn_map.shape, 3) # [3, res, res] cross_attn_map = cross_attn_map.permute(2, 0, 1).unsqueeze(0) # [3, clip_size, clip_size] cross_attn_map = F.interpolate(cross_attn_map, size=self.x_cfg.image_size, mode='bicubic') cross_attn_map = torch.clamp(cross_attn_map, min=0, max=255) # rgb to gray cross_attn_map = rgb2gray(cross_attn_map.squeeze(0).permute(1, 2, 0)).astype(np.float32) # torch to numpy if cross_attn_map.shape[-1] != self.x_cfg.image_size and cross_attn_map.shape[-2] != self.x_cfg.image_size: cross_attn_map = cross_attn_map.reshape(self.x_cfg.image_size, self.x_cfg.image_size) # to [0, 1] cross_attn_map = (cross_attn_map - cross_attn_map.min()) / (cross_attn_map.max() - cross_attn_map.min()) """ldm self-attention map""" self_attention_maps, svd, vh_ = \ self.diffusion.get_self_attention_comp([prompts], controller, res=self.x_cfg.self_attn_res, from_where=("up", "down"), img_size=self.x_cfg.image_size, max_com=self.x_cfg.max_com, save_path=self.result_path) # comp self-attention map if self.x_cfg.mean_comp: self_attn = np.mean(vh_, axis=0) self.print(f"use the mean of {self.x_cfg.max_com} comps.") else: self_attn = vh_[self.x_cfg.comp_idx] self.print(f"select {self.x_cfg.comp_idx}-th comp.") # to [0, 1] self_attn = (self_attn - self_attn.min()) / (self_attn.max() - self_attn.min()) # visual final self-attention self_attn_vis = np.copy(self_attn) self_attn_vis = self_attn_vis * 255 self_attn_vis = np.repeat(np.expand_dims(self_attn_vis, axis=2), 3, axis=2).astype(np.uint8) view_images(self_attn_vis, save_image=True, fp=self.result_path / "self-attn-final.png") """attention map fusion""" attn_map = self.x_cfg.attn_coeff * cross_attn_map + (1 - self.x_cfg.attn_coeff) * self_attn # to [0, 1] attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) self.print(f"-> fusion attn_map: {attn_map.shape}") else: attn_map = None return target_file.as_posix(), attn_map @property def clip_norm_(self): return transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) def clip_pair_augment(self, x: torch.Tensor, y: torch.Tensor, im_res: int, augments: str = "affine_norm", num_aug: int = 4): # init augmentations augment_list = [] if "affine" in augments: augment_list.append( transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5) ) augment_list.append( transforms.RandomResizedCrop(im_res, scale=(0.8, 0.8), ratio=(1.0, 1.0)) ) augment_list.append(self.clip_norm_) # CLIP Normalize # compose augmentations augment_compose = transforms.Compose(augment_list) # make augmentation pairs x_augs, y_augs = [self.clip_score_fn.normalize(x)], [self.clip_score_fn.normalize(y)] # repeat N times for n in range(num_aug): augmented_pair = augment_compose(torch.cat([x, y])) x_augs.append(augmented_pair[0].unsqueeze(0)) y_augs.append(augmented_pair[1].unsqueeze(0)) xs = torch.cat(x_augs, dim=0) ys = torch.cat(y_augs, dim=0) return xs, ys def painterly_rendering(self, prompt): # log prompts self.print(f"prompt: {prompt}") self.print(f"negative_prompt: {self.args.neg_prompt}\n") # init attention target_file, attention_map = self.extract_ldm_attn(prompt) timesteps_ = self.diffusion.scheduler.timesteps.cpu().numpy().tolist() self.print(f"{len(timesteps_)} denoising steps, {timesteps_}") perceptual_loss_fn = None if self.x_cfg.perceptual.coeff > 0: if self.x_cfg.perceptual.name == "lpips": lpips_loss_fn = LPIPS(net=self.x_cfg.perceptual.lpips_net).to(self.device) perceptual_loss_fn = partial(lpips_loss_fn.forward, return_per_layer=False, normalize=False) elif self.x_cfg.perceptual.name == "dists": perceptual_loss_fn = DISTS_PIQ() inputs, mask = self.get_target(target_file, self.x_cfg.image_size, self.result_path, self.x_cfg.u2net_path, self.x_cfg.mask_object, self.x_cfg.fix_scale, self.device) inputs = inputs.detach() # inputs as GT self.print("inputs shape: ", inputs.shape) # load renderer renderer = self.load_render(inputs, attention_map, mask=mask) # init img img = renderer.init_image(stage=0) self.print("init_image shape: ", img.shape) plot_img(img, self.result_path, fname="init_sketch") # load optimizer optimizer = SketchPainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.optim_opacity, self.x_cfg.optim_rgba, self.x_cfg.color_lr, self.x_cfg.optim_width, self.x_cfg.width_lr) optimizer.init_optimizers() # log params self.print(f"-> Painter point Params: {len(renderer.get_points_params())}") self.print(f"-> Painter width Params: {len(renderer.get_width_parameters())}") self.print(f"-> Painter opacity Params: {len(renderer.get_color_parameters())}") total_iter = self.x_cfg.num_iter best_visual_loss, best_semantic_loss = 100, 100 min_delta = 1e-6 self.print(f"\ntotal optimization steps: {total_iter}") with tqdm(initial=self.step, total=total_iter, disable=not self.accelerator.is_main_process) as pbar: while self.step < total_iter: raster_sketch = renderer.get_image().to(self.device) if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_iter - 1): plot_img(raster_sketch, self.frame_log_dir, fname=f"iter{self.frame_idx}") self.frame_idx += 1 # ASDS loss sds_loss, grad = torch.tensor(0), torch.tensor(0) if self.step >= self.x_cfg.sds.warmup: grad_scale = self.x_cfg.sds.grad_scale if self.step > self.x_cfg.sds.warmup else 0 sds_loss, grad = self.diffusion.score_distillation_sampling( raster_sketch, crop_size=self.x_cfg.sds.crop_size, augments=self.x_cfg.sds.augmentations, prompt=[prompt], negative_prompt=self.args.neg_prompt, guidance_scale=self.x_cfg.sds.guidance_scale, grad_scale=grad_scale, t_range=list(self.x_cfg.sds.t_range), ) # CLIP data augmentation raster_sketch_aug, inputs_aug = self.clip_pair_augment( raster_sketch, inputs, im_res=224, augments=self.cargs.augmentations, num_aug=self.cargs.num_aug ) # clip visual loss total_visual_loss = torch.tensor(0) l_clip_fc, l_clip_conv, clip_conv_loss_sum = torch.tensor(0), [], torch.tensor(0) if self.x_cfg.clip.vis_loss > 0: l_clip_fc, l_clip_conv = self.clip_score_fn.compute_visual_distance( raster_sketch_aug, inputs_aug, clip_norm=False ) clip_conv_loss_sum = sum(l_clip_conv) total_visual_loss = self.x_cfg.clip.vis_loss * (clip_conv_loss_sum + l_clip_fc) # text-visual loss l_tvd = torch.tensor(0.) if self.cargs.text_visual_coeff > 0: l_tvd = self.clip_score_fn.compute_text_visual_distance( raster_sketch_aug, prompt ) * self.cargs.text_visual_coeff # perceptual loss l_percep = torch.tensor(0.) if perceptual_loss_fn is not None: l_perceptual = perceptual_loss_fn(raster_sketch, inputs).mean() l_percep = l_perceptual * self.x_cfg.perceptual.coeff # total loss loss = sds_loss + total_visual_loss + l_tvd + l_percep # optimization optimizer.zero_grad_() loss.backward() optimizer.step_() # update lr if self.x_cfg.lr_schedule: optimizer.update_lr(self.step, self.x_cfg.decay_steps) # records pbar.set_description( f"lr: {optimizer.get_lr():.2f}, " f"l_total: {loss.item():.4f}, " f"l_clip_fc: {l_clip_fc.item():.4f}, " f"l_clip_conv({len(l_clip_conv)}): {clip_conv_loss_sum.item():.4f}, " f"l_tvd: {l_tvd.item():.4f}, " f"l_percep: {l_percep.item():.4f}, " f"sds: {grad.item():.4e}" ) # log raster and svg if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: # log png plot_couple(inputs, raster_sketch, self.step, prompt=prompt, output_dir=self.png_logs_dir.as_posix(), fname=f"iter{self.step}") # log svg renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") # log cross attn if self.x_cfg.log_cross_attn: controller = AttentionStore() _, _ = self.diffusion.get_cross_attention([prompt], controller, res=self.x_cfg.cross_attn_res, from_where=("up", "down"), save_path=self.attn_logs_dir / f"iter{self.step}.png") # logging the best raster images and SVG if self.step % self.args.eval_step == 0 and self.accelerator.is_main_process: with torch.no_grad(): # visual metric l_clip_fc, l_clip_conv = self.clip_score_fn.compute_visual_distance( raster_sketch_aug, inputs_aug, clip_norm=False ) loss_eval = sum(l_clip_conv) + l_clip_fc cur_delta = loss_eval.item() - best_visual_loss if abs(cur_delta) > min_delta and cur_delta < 0: best_visual_loss = loss_eval.item() best_iter_v = self.step plot_couple(inputs, raster_sketch, best_iter_v, prompt=prompt, output_dir=self.result_path.as_posix(), fname="visual_best") renderer.save_svg(self.result_path.as_posix(), "visual_best") # semantic metric loss_eval = self.clip_score_fn.compute_text_visual_distance( raster_sketch_aug, prompt ) cur_delta = loss_eval.item() - best_semantic_loss if abs(cur_delta) > min_delta and cur_delta < 0: best_semantic_loss = loss_eval.item() best_iter_s = self.step plot_couple(inputs, raster_sketch, best_iter_s, prompt=prompt, output_dir=self.result_path.as_posix(), fname="semantic_best") renderer.save_svg(self.result_path.as_posix(), "semantic_best") # log attention, just once if self.step == 0 and self.x_cfg.attention_init and self.accelerator.is_main_process: plt_attn(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(), (self.result_path / "attention_map.png").as_posix()) self.step += 1 pbar.update(1) # saving final result renderer.save_svg(self.svg_logs_dir.as_posix(), "final_svg_tmp") # stroke pruning if self.args.opacity_delta != 0: paths_pruning(self.svg_logs_dir / "final_svg_tmp.svg", self.result_path / "final_result.svg", self.x_cfg.opacity_delta) final_raster_sketch = renderer.get_image().to(self.device) plot_img(final_raster_sketch, output_dir=self.result_path, fname='final_render') if self.make_video: from subprocess import call call([ "ffmpeg", "-framerate", f"{self.args.framerate}", "-i", (self.frame_log_dir / "iter%d.png").as_posix(), "-vb", "20M", (self.result_path / "diffsketcher_rendering.mp4").as_posix() ]) self.close(msg="painterly rendering complete.") def get_target(self, target_file, image_size, output_dir, u2net_path, mask_object, fix_scale, device): if not is_image_file(target_file): raise TypeError(f"{target_file} is not image file.") target = Image.open(target_file) if target.mode == "RGBA": # Create a white rgba background new_image = Image.new("RGBA", target.size, "WHITE") # Paste the image on the background. new_image.paste(target, (0, 0), target) target = new_image target = target.convert("RGB") # U2Net mask mask = target if mask_object: if pathlib.Path(u2net_path).exists(): masked_im, mask = get_mask_u2net(target, output_dir, u2net_path, device) target = masked_im else: self.print(f"'{u2net_path}' is not exist, disable mask target") if fix_scale: target = fix_image_scale(target) # define image transforms transforms_ = [] if target.size[0] != target.size[1]: transforms_.append(transforms.Resize((image_size, image_size))) else: transforms_.append(transforms.Resize(image_size)) transforms_.append(transforms.CenterCrop(image_size)) transforms_.append(transforms.ToTensor()) # preprocess data_transforms = transforms.Compose(transforms_) target_ = data_transforms(target).unsqueeze(0).to(self.device) return target_, mask