ma-images / pytorch_grad_cam /activations_and_gradients.py
broadwell's picture
Files from pytorch_grad_cam for legacy ResNet activations viz
d4b83d9 verified
raw
history blame
No virus
1.73 kB
class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """
def __init__(self, model, target_layers, reshape_transform):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.handles = []
for target_layer in target_layers:
self.handles.append(
target_layer.register_forward_hook(
self.save_activation))
# Backward compitability with older pytorch versions:
if hasattr(target_layer, 'register_full_backward_hook'):
self.handles.append(
target_layer.register_full_backward_hook(
self.save_gradient))
else:
self.handles.append(
target_layer.register_backward_hook(
self.save_gradient))
def save_activation(self, module, input, output):
activation = output
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
def save_gradient(self, module, grad_input, grad_output):
# Gradients are computed in reverse order
grad = grad_output[0]
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
def __call__(self, x):
self.gradients = []
self.activations = []
return self.model(x)
def release(self):
for handle in self.handles:
handle.remove()