broadwell's picture
Upload 3 files
254b186 verified
raw
history blame contribute delete
No virus
9.12 kB
import torch
import numpy as np
# from PIL import Image
# import matplotlib.pyplot as plt
import cv2
import regex as re
from .image_utils import show_cam_on_image, show_overlapped_cam
def vit_block_vis(
image,
target_features,
img_encoder,
block,
device,
grad=False,
neg_saliency=False,
img_dim=224,
):
img_encoder.eval()
image_features = img_encoder(image)
image_features_norm = image_features.norm(dim=-1, keepdim=True)
image_features_new = image_features / image_features_norm
target_features_norm = target_features.norm(dim=-1, keepdim=True)
target_features_new = target_features / target_features_norm
similarity = image_features_new[0].dot(target_features_new[0])
image = (image - image.min()) / (image.max() - image.min())
img_encoder.zero_grad()
similarity.backward(retain_graph=True)
image_attn_blocks = list(
dict(img_encoder.transformer.resblocks.named_children()).values()
)
if grad:
cam = image_attn_blocks[block].attn_grad.detach()
else:
cam = image_attn_blocks[block].attn_probs.detach()
cam = cam.mean(dim=0)
image_relevance = cam[0, 1:]
resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
# image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
image_relevance = torch.nn.functional.interpolate(
image_relevance, size=img_dim, mode="bilinear"
)
image_relevance = image_relevance.reshape(img_dim, img_dim)
image_relevance = (image_relevance - image_relevance.min()) / (
image_relevance.max() - image_relevance.min()
)
cam = image_relevance * image
cam = cam / torch.max(cam)
# TODO: maybe we can ignore this...
####
masked_image_features = img_encoder(cam)
masked_image_features_norm = masked_image_features.norm(dim=-1, keepdim=True)
masked_image_features_new = masked_image_features / masked_image_features_norm
new_score = masked_image_features_new[0].dot(target_features_new[0])
####
cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
cam = np.float32(cam)
# plt.imshow(cam)
return new_score
def vit_relevance(
image,
target_features,
img_encoder,
device,
method="last grad",
neg_saliency=False,
img_dim=224,
):
img_encoder.eval()
image_features = img_encoder(image)
image_features_norm = image_features.norm(dim=-1, keepdim=True)
image_features_new = image_features / image_features_norm
target_features_norm = target_features.norm(dim=-1, keepdim=True)
target_features_new = target_features / target_features_norm
similarity = image_features_new[0].dot(target_features_new[0])
if neg_saliency:
objective = 1 - similarity
else:
objective = similarity
img_encoder.zero_grad()
objective.backward(retain_graph=True)
image_attn_blocks = list(
dict(img_encoder.transformer.resblocks.named_children()).values()
)
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
last_attn = image_attn_blocks[-1].attn_probs.detach()
last_attn = last_attn.reshape(-1, last_attn.shape[-1], last_attn.shape[-1])
last_grad = image_attn_blocks[-1].attn_grad.detach()
last_grad = last_grad.reshape(-1, last_grad.shape[-1], last_grad.shape[-1])
if method == "gradcam":
cam = last_grad * last_attn
cam = cam.clamp(min=0).mean(dim=0)
image_relevance = cam[0, 1:]
else:
R = torch.eye(
num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
).to(device)
for blk in image_attn_blocks:
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
if method == "last grad":
grad = last_grad
elif method == "all grads":
grad = blk.attn_grad.detach()
else:
print(
"The available visualization methods are: 'gradcam', 'last grad', 'all grads'."
)
return
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
R += torch.matmul(cam, R)
image_relevance = R[0, 1:]
resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
# image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
image_relevance = torch.nn.functional.interpolate(
image_relevance, size=img_dim, mode="bilinear"
)
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (
image_relevance.max() - image_relevance.min()
)
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
return image_relevance, image, image_features, similarity
def interpret_vit(
image,
target_features,
img_encoder,
device,
method="last grad",
neg_saliency=False,
img_dim=224,
):
image_relevance, image, image_features, similarity = vit_relevance(
image,
target_features,
img_encoder,
device,
method=method,
neg_saliency=neg_saliency,
img_dim=img_dim,
)
vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis, image_features, similarity
# plt.imshow(vis)
def interpret_vit_overlapped(
image, target_features, img_encoder, device, method="last grad", img_dim=224
):
pos_image_relevance, _, image_features, similarity = vit_relevance(
image,
target_features,
img_encoder,
device,
method=method,
neg_saliency=False,
img_dim=img_dim,
)
neg_image_relevance, image, _, _ = vit_relevance(
image,
target_features,
img_encoder,
device,
method=method,
neg_saliency=True,
img_dim=img_dim,
)
vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis, image_features, similarity
# plt.imshow(vis)
def vit_perword_relevance(
image_features,
text,
clip_model,
clip_tokenizer,
device,
masked_word="",
use_last_grad=True,
img_dim=224,
):
clip_model.eval()
main_text = clip_tokenizer(text).to(device)
# remove the word for which you want to visualize the saliency
masked_text = re.sub(masked_word, "", text)
masked_text = clip_tokenizer(masked_text).to(device)
main_text_features = clip_model.encode_text(main_text)
masked_text_features = clip_model.encode_text(masked_text)
# image_features = clip_model.encode_image(image)
image_features_norm = image_features.norm(dim=-1, keepdim=True)
image_features_new = image_features / image_features_norm
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
main_text_features_new = main_text_features / main_text_features_norm
masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True)
masked_text_features_new = masked_text_features / masked_text_features_norm
objective = image_features_new[0].dot(
main_text_features_new[0] - masked_text_features_new[0]
)
clip_model.visual.zero_grad()
objective.backward(retain_graph=True)
image_attn_blocks = list(
dict(clip_model.visual.transformer.resblocks.named_children()).values()
)
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(
num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
).to(device)
last_grad = image_attn_blocks[-1].attn_grad.detach()
last_grad = last_grad.reshape(-1, last_grad.shape[-1], last_grad.shape[-1])
for blk in image_attn_blocks:
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
if use_last_grad:
grad = last_grad
else:
grad = blk.attn_grad.detach()
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=0)
R += torch.matmul(cam, R)
image_relevance = R[0, 1:]
resize_dim = int(np.sqrt(list(image_relevance.shape)[0]))
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim)
image_relevance = torch.nn.functional.interpolate(
image_relevance, size=img_dim, mode="bilinear"
)
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (
image_relevance.max() - image_relevance.min()
)
# image = image[0].permute(1, 2, 0).data.cpu().numpy()
# image = (image - image.min()) / (image.max() - image.min())
# return image_relevance, image
return image_relevance