Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from torchvision.utils import make_grid | |
def plt_triplet( | |
photos: torch.Tensor, | |
sketch: torch.Tensor, | |
style: torch.Tensor, | |
step: int, | |
prompt: str, | |
output_dir: str, | |
fname: str, # file name | |
dpi: int = 300 | |
): | |
if photos.shape != sketch.shape: | |
raise ValueError("photos and sketch must have the same dimensions") | |
plt.figure() | |
plt.subplot(1, 3, 1) # nrows=1, ncols=3, index=1 | |
grid = make_grid(photos, normalize=True, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.title("Generated sample") | |
plt.subplot(1, 3, 2) # nrows=1, ncols=3, index=2 | |
# style = (style + 1) / 2 | |
grid = make_grid(style, normalize=False, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.title(f"Style") | |
plt.subplot(1, 3, 3) # nrows=1, ncols=3, index=2 | |
# sketch = (sketch + 1) / 2 | |
grid = make_grid(sketch, normalize=False, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.title(f"Rendering result - {step} steps") | |
def insert_newline(string, point=9): | |
# split by blank | |
words = string.split() | |
if len(words) <= point: | |
return string | |
word_chunks = [words[i:i + point] for i in range(0, len(words), point)] | |
new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) | |
return new_string | |
plt.suptitle(insert_newline(prompt), fontsize=10) | |
plt.tight_layout() | |
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) | |
plt.close() | |
def plt_attn(attn: np.array, | |
threshold_map: np.array, | |
inputs: torch.Tensor, | |
inds: np.array, | |
output_path: str): | |
# currently supports one image (and not a batch) | |
plt.figure(figsize=(10, 5)) | |
plt.subplot(1, 3, 1) | |
main_im = make_grid(inputs, normalize=True, pad_value=2) | |
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
plt.imshow(main_im, interpolation='nearest') | |
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
plt.title("input img") | |
plt.axis("off") | |
plt.subplot(1, 3, 2) | |
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) | |
plt.title("attn map") | |
plt.axis("off") | |
plt.subplot(1, 3, 3) | |
threshold_map_ = (threshold_map - threshold_map.min()) / \ | |
(threshold_map.max() - threshold_map.min()) | |
plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1) | |
plt.title("prob softmax") | |
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
plt.axis("off") | |
plt.tight_layout() | |
plt.savefig(output_path) | |
plt.close() | |