import matplotlib.pyplot as plt import numpy as np import pydiffvg import torch from PIL import Image from pytorch_svgrender.painter.clipascene import u2net_utils from pytorch_svgrender.painter.clipasso.u2net import U2NET from scipy import ndimage from skimage import morphology from skimage.measure import label from skimage.transform import resize from import DataLoader from torchvision import transforms from torchvision.utils import make_grid def plot_attn_dino(attn, threshold_map, inputs, inds, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(2, attn.shape[0] + 2, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, 2) plt.imshow(attn.sum(0).numpy(), interpolation='nearest') plt.title("atn map sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') plt.title("prob sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') plt.title("thresh sum") plt.axis("off") for i in range(attn.shape[0]): plt.subplot(2, attn.shape[0] + 2, i + 3) plt.imshow(attn[i].numpy()) plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) plt.imshow(threshold_map[i].numpy()) plt.axis("off") plt.tight_layout() plt.savefig(output_path) plt.close() def plot_attn_clip(attn, threshold_map, inputs, inds, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(1, 3, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(1, 3, 2) plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) plt.title("attn map") plt.axis("off") plt.subplot(1, 3, 3) threshold_map_ = (threshold_map - threshold_map.min()) / \ (threshold_map.max() - threshold_map.min()) plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) plt.title("prob softmax") plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.axis("off") plt.tight_layout() plt.savefig(output_path) plt.close() def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model): if saliency_model == "dino": plot_attn_dino(attn, threshold_map, inputs, inds, output_path) elif saliency_model == "clip": plot_attn_clip(attn, threshold_map, inputs, inds, output_path) def fix_image_scale(im): im_np = np.array(im) / 255 height, width = im_np.shape[0], im_np.shape[1] max_len = max(height, width) + 20 new_background = np.ones((max_len, max_len, 3)) y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 new_background[y: y + height, x: x + width] = im_np new_background = (new_background / new_background.max() * 255).astype(np.uint8) new_im = Image.fromarray(new_background) return new_im def get_size_of_largest_cc(binary_im): labels, num = label(binary_im, background=0, return_num=True) (unique, counts) = np.unique(labels, return_counts=True) args = np.argsort(counts)[::-1] largest_cc_label = unique[args][1] # without background return counts[args][1] def get_num_cc(binary_im): labels, num = label(binary_im, background=0, return_num=True) return num def get_obj_bb(binary_im): y = np.where(binary_im != 0)[0] x = np.where(binary_im != 0)[1] x0, x1, y0, y1 = x.min(), x.max(), y.min(), y.max() return x0, x1, y0, y1 def cut_and_resize(im, x0, x1, y0, y1, new_height, new_width): cut_obj = im[y0: y1, x0: x1] resized_obj = resize(cut_obj, (new_height, new_width)) new_mask = np.zeros(im.shape) center_y_new = int(new_height / 2) center_x_new = int(new_width / 2) center_targ_y = int(new_mask.shape[0] / 2) center_targ_x = int(new_mask.shape[1] / 2) startx, starty = center_targ_x - center_x_new, center_targ_y - center_y_new new_mask[starty: starty + resized_obj.shape[0], startx: startx + resized_obj.shape[1]] = resized_obj return new_mask def get_mask_u2net(pil_im, output_dir, u2net_path, resize_obj=0, preprocess=False, device="cpu"): w, h = pil_im.size[0], pil_im.size[1] test_salobj_dataset = u2net_utils.SalObjDataset(imgs_list=[pil_im], lbl_name_list=[], transform=transforms.Compose([u2net_utils.RescaleT(320), u2net_utils.ToTensorLab(flag=0)])) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) input_im_trans = next(iter(test_salobj_dataloader)) net = U2NET(3, 1) net.load_state_dict(torch.load(u2net_path)) net.eval() with torch.no_grad(): input_im_trans = input_im_trans.type(torch.FloatTensor) d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.cuda()) pred = d1[:, 0, :, :] pred = (pred - pred.min()) / (pred.max() - pred.min()) predict = pred predict[predict < 0.5] = 0 predict[predict >= 0.5] = 1 if preprocess: predict = torch.tensor( ndimage.binary_dilation(predict[0].cpu().numpy(), structure=np.ones((11, 11))).astype(int)).unsqueeze(0) mask =[predict, predict, predict], axis=0).permute(1, 2, 0) mask = mask.cpu().numpy() max_val = mask.max() mask[mask > max_val / 2] = 255 mask = mask.astype(np.uint8) mask = resize(mask, (h, w), anti_aliasing=False, order=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask mask =[predict, predict, predict], axis=0).permute(1, 2, 0) mask = mask.cpu().numpy() mask = resize(mask, (h, w), anti_aliasing=False) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB') / "mask.png") im_np = np.array(pil_im) im_np = im_np / im_np.max() if resize_obj: params = {} mask_np = mask[:, :, 0].astype(int) target_np = im_np min_size = int(get_size_of_largest_cc(mask_np) / 3) mask_np2 = morphology.remove_small_objects((mask_np > 0), min_size=min_size).astype(int) num_cc = get_num_cc(mask_np2) mask_np3 = np.ones((h, w, 3)) mask_np3[:, :, 0] = mask_np2 mask_np3[:, :, 1] = mask_np2 mask_np3[:, :, 2] = mask_np2 x0, x1, y0, y1 = get_obj_bb(mask_np2) im_width, im_height = x1 - x0, y1 - y0 max_size = max(im_width, im_height) target_size = int(min(h, w) * 0.7) if max_size < target_size and num_cc == 1: if im_width > im_height: new_width, new_height = target_size, int((target_size / im_width) * im_height) else: new_width, new_height = int((target_size / im_height) * im_width), target_size mask = cut_and_resize(mask_np3, x0, x1, y0, y1, new_height, new_width) target_np = target_np / target_np.max() im_np = cut_and_resize(target_np, x0, x1, y0, y1, new_height, new_width) params["original_center_y"] = (y0 + (y1 - y0) / 2) / h params["original_center_x"] = (x0 + (x1 - x0) / 2) / w params["scale_w"] = new_width / im_width params["scale_h"] = new_height / im_height / "resize_params.npy", params) im_np = mask * im_np im_np[mask == 0] = 1 im_final = (im_np / im_np.max() * 255).astype(np.uint8) im_final = Image.fromarray(im_final) return im_final, mask def is_in_canvas(canvas_width, canvas_height, path, device): shapes, shape_groups = [], [] stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0]) shapes.append(path) path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), fill_color=None, stroke_color=stroke_color) shape_groups.append(path_group) _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) img = _render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, *scene_args) img = img[:, :, 3:4] * img[:, :, :3] + \ torch.ones(img.shape[0], img.shape[1], 3, device=device) * (1 - img[:, :, 3:4]) img = img[:, :, :3].detach().cpu().numpy() return (1 - img).sum()