import matplotlib.pyplot as plt import numpy as np import pydiffvg import torch from PIL import Image from pytorch_svgrender.painter.clipascene import u2net_utils from pytorch_svgrender.painter.clipasso.u2net import U2NET from scipy import ndimage from skimage import morphology from skimage.measure import label from skimage.transform import resize from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import make_grid def plot_attn_dino(attn, threshold_map, inputs, inds, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(2, attn.shape[0] + 2, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, 2) plt.imshow(attn.sum(0).numpy(), interpolation='nearest') plt.title("atn map sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') plt.title("prob sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') plt.title("thresh sum") plt.axis("off") for i in range(attn.shape[0]): plt.subplot(2, attn.shape[0] + 2, i + 3) plt.imshow(attn[i].numpy()) plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) plt.imshow(threshold_map[i].numpy()) plt.axis("off") plt.tight_layout() plt.savefig(output_path) plt.close() def plot_attn_clip(attn, threshold_map, inputs, inds, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(1, 3, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(1, 3, 2) plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) plt.title("attn map") plt.axis("off") plt.subplot(1, 3, 3) threshold_map_ = (threshold_map - threshold_map.min()) / \ (threshold_map.max() - threshold_map.min()) plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) plt.title("prob softmax") plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.axis("off") plt.tight_layout() plt.savefig(output_path) plt.close() def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model): if saliency_model == "dino": plot_attn_dino(attn, threshold_map, inputs, inds, output_path) elif saliency_model == "clip": plot_attn_clip(attn, threshold_map, inputs, inds, output_path) def fix_image_scale(im): im_np = np.array(im) / 255 height, width = im_np.shape[0], im_np.shape[1] max_len = max(height, width) + 20 new_background = np.ones((max_len, max_len, 3)) y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 new_background[y: y + height, x: x + width] = im_np new_background = (new_background / new_background.max() * 255).astype(np.uint8) new_im = Image.fromarray(new_background) return new_im def get_size_of_largest_cc(binary_im): labels, num = label(binary_im, background=0, return_num=True) (unique, counts) = np.unique(labels, return_counts=True) args = np.argsort(counts)[::-1] largest_cc_label = unique[args][1] # without background return counts[args][1] def get_num_cc(binary_im): labels, num = label(binary_im, background=0, return_num=True) return num def get_obj_bb(binary_im): y = np.where(binary_im != 0)[0] x = np.where(binary_im != 0)[1] x0, x1, y0, y1 = x.min(), x.max(), y.min(), y.max() return x0, x1, y0, y1 def cut_and_resize(im, x0, x1, y0, y1, new_height, new_width): cut_obj = im[y0: y1, x0: x1] resized_obj = resize(cut_obj, (new_height, new_width)) new_mask = np.zeros(im.shape) center_y_new = int(new_height / 2) center_x_new = int(new_width / 2) center_targ_y = int(new_mask.shape[0] / 2) center_targ_x = int(new_mask.shape[1] / 2) startx, starty = center_targ_x - center_x_new, center_targ_y - center_y_new new_mask[starty: starty + resized_obj.shape[0], startx: startx + resized_obj.shape[1]] = resized_obj return new_mask def get_mask_u2net(pil_im, output_dir, u2net_path, resize_obj=0, preprocess=False, device="cpu"): w, h = pil_im.size[0], pil_im.size[1] test_salobj_dataset = u2net_utils.SalObjDataset(imgs_list=[pil_im], lbl_name_list=[], transform=transforms.Compose([u2net_utils.RescaleT(320), u2net_utils.ToTensorLab(flag=0)])) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) input_im_trans = next(iter(test_salobj_dataloader)) net = U2NET(3, 1) net.load_state_dict(torch.load(u2net_path)) net.to(device) net.eval() with torch.no_grad(): input_im_trans = input_im_trans.type(torch.FloatTensor) d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.cuda()) pred = d1[:, 0, :, :] pred = (pred - pred.min()) / (pred.max() - pred.min()) predict = pred predict[predict < 0.5] = 0 predict[predict >= 0.5] = 1 if preprocess: predict = torch.tensor( ndimage.binary_dilation(predict[0].cpu().numpy(), structure=np.ones((11, 11))).astype(int)).unsqueeze(0) mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0) mask = mask.cpu().numpy() max_val = mask.max() mask[mask > max_val / 2] = 255 mask = mask.astype(np.uint8) mask = resize(mask, (h, w), anti_aliasing=False, order=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask mask = torch.cat([predict, predict, predict], axis=0).permute(1, 2, 0) mask = mask.cpu().numpy() mask = resize(mask, (h, w), anti_aliasing=False) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB') im.save(output_dir / "mask.png") im_np = np.array(pil_im) im_np = im_np / im_np.max() if resize_obj: params = {} mask_np = mask[:, :, 0].astype(int) target_np = im_np min_size = int(get_size_of_largest_cc(mask_np) / 3) mask_np2 = morphology.remove_small_objects((mask_np > 0), min_size=min_size).astype(int) num_cc = get_num_cc(mask_np2) mask_np3 = np.ones((h, w, 3)) mask_np3[:, :, 0] = mask_np2 mask_np3[:, :, 1] = mask_np2 mask_np3[:, :, 2] = mask_np2 x0, x1, y0, y1 = get_obj_bb(mask_np2) im_width, im_height = x1 - x0, y1 - y0 max_size = max(im_width, im_height) target_size = int(min(h, w) * 0.7) if max_size < target_size and num_cc == 1: if im_width > im_height: new_width, new_height = target_size, int((target_size / im_width) * im_height) else: new_width, new_height = int((target_size / im_height) * im_width), target_size mask = cut_and_resize(mask_np3, x0, x1, y0, y1, new_height, new_width) target_np = target_np / target_np.max() im_np = cut_and_resize(target_np, x0, x1, y0, y1, new_height, new_width) params["original_center_y"] = (y0 + (y1 - y0) / 2) / h params["original_center_x"] = (x0 + (x1 - x0) / 2) / w params["scale_w"] = new_width / im_width params["scale_h"] = new_height / im_height np.save(output_dir / "resize_params.npy", params) im_np = mask * im_np im_np[mask == 0] = 1 im_final = (im_np / im_np.max() * 255).astype(np.uint8) im_final = Image.fromarray(im_final) return im_final, mask def is_in_canvas(canvas_width, canvas_height, path, device): shapes, shape_groups = [], [] stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0]) shapes.append(path) path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), fill_color=None, stroke_color=stroke_color) shape_groups.append(path_group) _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) img = _render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, *scene_args) img = img[:, :, 3:4] * img[:, :, :3] + \ torch.ones(img.shape[0], img.shape[1], 3, device=device) * (1 - img[:, :, 3:4]) img = img[:, :, :3].detach().cpu().numpy() return (1 - img).sum()