import numpy as np def get_2d_projection(activation_batch): # TBD: use pytorch batch svd implementation activation_batch[np.isnan(activation_batch)] = 0 projections = [] for activations in activation_batch: reshaped_activations = (activations).reshape( activations.shape[0], -1).transpose() # Centering before the SVD seems to be important here, # Otherwise the image returned is negative reshaped_activations = reshaped_activations - \ reshaped_activations.mean(axis=0) U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) projection = reshaped_activations @ VT[0, :] projection = projection.reshape(activations.shape[1:]) projections.append(projection) return np.float32(projections)