broadwell's picture
Updates to ViT CAM viz, add ResNet CAM viz
8aa2d09 verified
raw
history blame
No virus
5.52 kB
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import re
from .image_utils import show_cam_on_image, show_overlapped_cam
def rn_relevance(
image,
target_features,
img_encoder,
method,
device,
neg_saliency=False,
img_dim=224,
):
target_layers = [img_encoder.layer4[-1]]
cam = method(
model=img_encoder,
target_layers=target_layers,
use_cuda=torch.cuda.is_available(),
)
if neg_saliency:
target_encoding = -target_features
else:
target_encoding = target_features
image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
0
].squeeze()
image_relevance = torch.FloatTensor(image_relevance)
resize_dim = int(list(image_relevance.shape)[0])
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
# image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = torch.nn.functional.interpolate(
image_relevance, size=img_dim, mode="bilinear"
)
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (
1e-7 + image_relevance.max() - image_relevance.min()
)
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
return image_relevance, image
def interpret_rn(
image,
target_features,
img_encoder,
method,
device,
neg_saliency=False,
img_dim=224,
):
image_relevance, image = rn_relevance(
image,
target_features,
img_encoder,
method,
device,
neg_saliency=neg_saliency,
img_dim=img_dim,
)
vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
# plt.imshow(vis)
def interpret_rn_overlapped(
image, target_features, img_encoder, method, device, img_dim=224
):
pos_image_relevance, _ = rn_relevance(
image,
target_features,
img_encoder,
method,
device,
neg_saliency=False,
img_dim=img_dim,
)
neg_image_relevance, image = rn_relevance(
image,
target_features,
img_encoder,
method,
device,
neg_saliency=True,
img_dim=img_dim,
)
vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
# plt.imshow(vis)
def rn_perword_relevance(
image,
text,
clip_model,
clip_tokenizer,
method,
device,
masked_word="",
data_only=False,
img_dim=224,
):
clip_model.eval()
main_text = clip_tokenizer(text).to(device)
# remove the word for which you want to visualize the saliency
masked_text = re.sub(masked_word, "", text)
masked_text = clip_tokenizer(masked_text).to(device)
image_features = clip_model.encode_image(image)
main_text_features = clip_model.encode_text(main_text)
masked_text_features = clip_model.encode_text(masked_text)
image_features_norm = image_features.norm(dim=-1, keepdim=True)
image_features_new = image_features / image_features_norm
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
main_text_features_new = main_text_features / main_text_features_norm
masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True)
masked_text_features_new = masked_text_features / masked_text_features_norm
target_encoding = main_text_features_new - masked_text_features_new
target_layers = [clip_model.visual.layer4[-1]]
cam = method(
model=clip_model.visual,
target_layers=target_layers,
use_cuda=torch.cuda.is_available(),
)
image_features = clip_model.visual(image)
image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
0
].squeeze()
image_relevance = torch.FloatTensor(image_relevance)
resize_dim = int(list(image_relevance.shape)[0])
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
# image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = torch.nn.functional.interpolate(
image_relevance, size=img_dim, mode="bilinear"
)
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (
1e-7 + image_relevance.max() - image_relevance.min()
)
if data_only:
return image_relevance
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
return image_relevance, image
def interpret_perword_rn(
image,
text,
clip_model,
clip_tokenizer,
method,
device,
masked_word="",
data_only=False,
img_dim=224,
):
image_relevance, image = rn_perword_relevance(
image,
text,
clip_model,
clip_tokenizer,
method,
device,
masked_word,
data_only=data_only,
img_dim=img_dim,
)
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
# plt.imshow(vis)