hjc-owo
init repo
966ae59
raw
history blame contribute delete
No virus
3.12 kB
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.utils import make_grid
def plt_triplet(
photos: torch.Tensor,
sketch: torch.Tensor,
style: torch.Tensor,
step: int,
prompt: str,
output_dir: str,
fname: str, # file name
dpi: int = 300
):
if photos.shape != sketch.shape:
raise ValueError("photos and sketch must have the same dimensions")
plt.figure()
plt.subplot(1, 3, 1) # nrows=1, ncols=3, index=1
grid = make_grid(photos, normalize=True, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.title("Generated sample")
plt.subplot(1, 3, 2) # nrows=1, ncols=3, index=2
# style = (style + 1) / 2
grid = make_grid(style, normalize=False, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.title(f"Style")
plt.subplot(1, 3, 3) # nrows=1, ncols=3, index=2
# sketch = (sketch + 1) / 2
grid = make_grid(sketch, normalize=False, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.title(f"Rendering result - {step} steps")
def insert_newline(string, point=9):
# split by blank
words = string.split()
if len(words) <= point:
return string
word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
return new_string
plt.suptitle(insert_newline(prompt), fontsize=10)
plt.tight_layout()
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi)
plt.close()
def plt_attn(attn: np.array,
threshold_map: np.array,
inputs: torch.Tensor,
inds: np.array,
output_path: str):
# currently supports one image (and not a batch)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
main_im = make_grid(inputs, normalize=True, pad_value=2)
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
plt.imshow(main_im, interpolation='nearest')
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.title("input img")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
plt.title("attn map")
plt.axis("off")
plt.subplot(1, 3, 3)
threshold_map_ = (threshold_map - threshold_map.min()) / \
(threshold_map.max() - threshold_map.min())
plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1)
plt.title("prob softmax")
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path)
plt.close()