ma-images / pytorch_grad_cam /utils /svd_on_activations.py
broadwell's picture
Files from pytorch_grad_cam for legacy ResNet activations viz
d4b83d9 verified
raw
history blame
No virus
810 Bytes
import numpy as np
def get_2d_projection(activation_batch):
# TBD: use pytorch batch svd implementation
activation_batch[np.isnan(activation_batch)] = 0
projections = []
for activations in activation_batch:
reshaped_activations = (activations).reshape(
activations.shape[0], -1).transpose()
# Centering before the SVD seems to be important here,
# Otherwise the image returned is negative
reshaped_activations = reshaped_activations - \
reshaped_activations.mean(axis=0)
U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
projection = reshaped_activations @ VT[0, :]
projection = projection.reshape(activations.shape[1:])
projections.append(projection)
return np.float32(projections)