broadwell commited on
Commit
d4b83d9
1 Parent(s): 526dc26

Files from pytorch_grad_cam for legacy ResNet activations viz

Browse files
pytorch_grad_cam/activations_and_gradients.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ActivationsAndGradients:
2
+ """ Class for extracting activations and
3
+ registering gradients from targetted intermediate layers """
4
+
5
+ def __init__(self, model, target_layers, reshape_transform):
6
+ self.model = model
7
+ self.gradients = []
8
+ self.activations = []
9
+ self.reshape_transform = reshape_transform
10
+ self.handles = []
11
+ for target_layer in target_layers:
12
+ self.handles.append(
13
+ target_layer.register_forward_hook(
14
+ self.save_activation))
15
+ # Backward compitability with older pytorch versions:
16
+ if hasattr(target_layer, 'register_full_backward_hook'):
17
+ self.handles.append(
18
+ target_layer.register_full_backward_hook(
19
+ self.save_gradient))
20
+ else:
21
+ self.handles.append(
22
+ target_layer.register_backward_hook(
23
+ self.save_gradient))
24
+
25
+ def save_activation(self, module, input, output):
26
+ activation = output
27
+ if self.reshape_transform is not None:
28
+ activation = self.reshape_transform(activation)
29
+ self.activations.append(activation.cpu().detach())
30
+
31
+ def save_gradient(self, module, grad_input, grad_output):
32
+ # Gradients are computed in reverse order
33
+ grad = grad_output[0]
34
+ if self.reshape_transform is not None:
35
+ grad = self.reshape_transform(grad)
36
+ self.gradients = [grad.cpu().detach()] + self.gradients
37
+
38
+ def __call__(self, x):
39
+ self.gradients = []
40
+ self.activations = []
41
+ return self.model(x)
42
+
43
+ def release(self):
44
+ for handle in self.handles:
45
+ handle.remove()
pytorch_grad_cam/base_cam.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import ttach as tta
5
+ from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
6
+ from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
7
+
8
+
9
+ class BaseCAM:
10
+ def __init__(self,
11
+ model,
12
+ target_layers,
13
+ use_cuda=False,
14
+ reshape_transform=None,
15
+ compute_input_gradient=False,
16
+ uses_gradients=True):
17
+ self.model = model.eval()
18
+ self.target_layers = target_layers
19
+ self.cuda = use_cuda
20
+ if self.cuda:
21
+ self.model = model.cuda()
22
+ self.reshape_transform = reshape_transform
23
+ self.compute_input_gradient = compute_input_gradient
24
+ self.uses_gradients = uses_gradients
25
+ self.activations_and_grads = ActivationsAndGradients(
26
+ self.model, target_layers, reshape_transform)
27
+
28
+ """ Get a vector of weights for every channel in the target layer.
29
+ Methods that return weights channels,
30
+ will typically need to only implement this function. """
31
+
32
+ def get_cam_weights(self,
33
+ input_tensor,
34
+ target_layers,
35
+ target_category,
36
+ activations,
37
+ grads):
38
+ raise Exception("Not Implemented")
39
+
40
+ def get_objective(self, input_encoding, target_encoding):
41
+ # input and target encoding should be normalized!
42
+
43
+ input_encoding_norm = input_encoding.norm(dim=-1, keepdim=True)
44
+ input_encoding = input_encoding / input_encoding_norm
45
+
46
+ target_encoding_norm = target_encoding.norm(dim=-1, keepdim=True)
47
+ target_encoding = target_encoding / target_encoding_norm
48
+
49
+ return input_encoding[0].dot(target_encoding[0])
50
+
51
+ def get_cam_image(self,
52
+ input_tensor,
53
+ target_layer,
54
+ target_category,
55
+ activations,
56
+ grads,
57
+ eigen_smooth=False):
58
+ weights = self.get_cam_weights(input_tensor, target_layer,
59
+ target_category, activations, grads)
60
+ weighted_activations = weights[:, :, None, None] * activations
61
+ if eigen_smooth:
62
+ cam = get_2d_projection(weighted_activations)
63
+ else:
64
+ cam = weighted_activations.sum(axis=1)
65
+ return cam
66
+
67
+ def forward(self, input_tensor, target_encoding, target_category=None, eigen_smooth=False):
68
+ if self.cuda:
69
+ input_tensor = input_tensor.cuda()
70
+
71
+ if self.compute_input_gradient:
72
+ input_tensor = torch.autograd.Variable(input_tensor,
73
+ requires_grad=True)
74
+ # output will be the image encoding
75
+ output = self.activations_and_grads(input_tensor)
76
+ if isinstance(target_category, int):
77
+ target_category = [target_category] * input_tensor.size(0)
78
+
79
+ if target_category is None:
80
+ target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
81
+ else:
82
+ assert(len(target_category) == input_tensor.size(0))
83
+
84
+
85
+ if self.uses_gradients:
86
+ self.model.zero_grad()
87
+ #objective = self.get_objective(output, target_encoding)
88
+ output_norm = output.norm(dim=-1, keepdim=True)
89
+ output = output / output_norm
90
+
91
+ target_encoding_norm = target_encoding.norm(dim=-1, keepdim=True)
92
+ target_encoding = target_encoding / target_encoding_norm
93
+
94
+ objective = output[0].dot(target_encoding[0])
95
+ objective.backward(retain_graph=True)
96
+
97
+ # In most of the saliency attribution papers, the saliency is
98
+ # computed with a single target layer.
99
+ # Commonly it is the last convolutional layer.
100
+ # Here we support passing a list with multiple target layers.
101
+ # It will compute the saliency image for every image,
102
+ # and then aggregate them (with a default mean aggregation).
103
+ # This gives you more flexibility in case you just want to
104
+ # use all conv layers for example, all Batchnorm layers,
105
+ # or something else.
106
+ cam_per_layer = self.compute_cam_per_layer(input_tensor,
107
+ target_category,
108
+ eigen_smooth)
109
+
110
+ #return self.aggregate_multi_layers(cam_per_layer)
111
+ return cam_per_layer
112
+
113
+
114
+ def get_target_width_height(self, input_tensor):
115
+ width, height = input_tensor.size(-1), input_tensor.size(-2)
116
+ return width, height
117
+
118
+ def compute_cam_per_layer(
119
+ self,
120
+ input_tensor,
121
+ target_category,
122
+ eigen_smooth):
123
+ activations_list = [a.cpu().data.numpy()
124
+ for a in self.activations_and_grads.activations]
125
+ grads_list = [g.cpu().data.numpy()
126
+ for g in self.activations_and_grads.gradients]
127
+ target_size = self.get_target_width_height(input_tensor)
128
+
129
+ cam_per_target_layer = []
130
+ # Loop over the saliency image from every layer
131
+
132
+
133
+ for target_layer, layer_activations, layer_grads in \
134
+ zip(self.target_layers, activations_list, grads_list):
135
+ cam = self.get_cam_image(input_tensor,
136
+ target_layer,
137
+ target_category,
138
+ layer_activations,
139
+ layer_grads,
140
+ eigen_smooth)
141
+ cam = np.maximum(cam, 0) # works like mute the min-max scale in the function of scale_cam_image
142
+ scaled = cam#self.scale_cam_image(cam, target_size)
143
+ cam_per_target_layer.append(scaled[:, None, :])
144
+
145
+ return cam_per_target_layer
146
+
147
+ def aggregate_multi_layers(self, cam_per_target_layer):
148
+ cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
149
+ cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
150
+ result = np.mean(cam_per_target_layer, axis=1)
151
+ return self.scale_cam_image(result)
152
+
153
+ def scale_cam_image(self, cam, target_size=None):
154
+ result = []
155
+ for img in cam:
156
+ img = img - np.min(img)
157
+ img = img / (1e-7 + np.max(img))
158
+ img = np.float32(img)
159
+ if target_size is not None:
160
+ img = cv2.resize(img, target_size)
161
+ result.append(img)
162
+ result = np.float32(result)
163
+
164
+ return result
165
+
166
+ def forward_augmentation_smoothing(self,
167
+ input_tensor,
168
+ target_category=None,
169
+ eigen_smooth=False):
170
+ transforms = tta.Compose(
171
+ [
172
+ tta.HorizontalFlip(),
173
+ tta.Multiply(factors=[0.9, 1, 1.1]),
174
+ ]
175
+ )
176
+ cams = []
177
+ for transform in transforms:
178
+ augmented_tensor = transform.augment_image(input_tensor)
179
+ cam = self.forward(augmented_tensor,
180
+ target_category, eigen_smooth)
181
+
182
+ # The ttach library expects a tensor of size BxCxHxW
183
+ cam = cam[:, None, :, :]
184
+ cam = torch.from_numpy(cam)
185
+ cam = transform.deaugment_mask(cam)
186
+
187
+ # Back to numpy float32, HxW
188
+ cam = cam.numpy()
189
+ cam = cam[:, 0, :, :]
190
+ cams.append(cam)
191
+
192
+ cam = np.mean(np.float32(cams), axis=0)
193
+ return cam
194
+
195
+ def __call__(self,
196
+ input_tensor,
197
+ target_encoding,
198
+ target_category=None,
199
+ aug_smooth=False,
200
+ eigen_smooth=False):
201
+
202
+ # Smooth the CAM result with test time augmentation
203
+ if aug_smooth is True:
204
+ return self.forward_augmentation_smoothing(
205
+ input_tensor, target_category, eigen_smooth)
206
+
207
+ return self.forward(input_tensor, target_encoding,
208
+ target_category, eigen_smooth)
209
+
210
+ def __del__(self):
211
+ self.activations_and_grads.release()
212
+
213
+ def __enter__(self):
214
+ return self
215
+
216
+ def __exit__(self, exc_type, exc_value, exc_tb):
217
+ self.activations_and_grads.release()
218
+ if isinstance(exc_value, IndexError):
219
+ # Handle IndexError here...
220
+ print(
221
+ f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
222
+ return True
pytorch_grad_cam/grad_cam.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pytorch_grad_cam.base_cam import BaseCAM
3
+
4
+
5
+ class GradCAM(BaseCAM):
6
+ def __init__(self, model, target_layers, use_cuda=False,
7
+ reshape_transform=None):
8
+ super(
9
+ GradCAM,
10
+ self).__init__(
11
+ model,
12
+ target_layers,
13
+ use_cuda,
14
+ reshape_transform)
15
+
16
+ def get_cam_weights(self,
17
+ input_tensor,
18
+ target_layer,
19
+ target_category,
20
+ activations,
21
+ grads):
22
+
23
+
24
+ res = np.mean(grads, axis=(2, 3))
25
+ return res
pytorch_grad_cam/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from pytorch_grad_cam.utils.image import deprocess_image
2
+ from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
pytorch_grad_cam/utils/find_layers.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def replace_layer_recursive(model, old_layer, new_layer):
2
+ for name, layer in model._modules.items():
3
+ if layer == old_layer:
4
+ model._modules[name] = new_layer
5
+ return True
6
+ elif replace_layer_recursive(layer, old_layer, new_layer):
7
+ return True
8
+ return False
9
+
10
+
11
+ def replace_all_layer_type_recursive(model, old_layer_type, new_layer):
12
+ for name, layer in model._modules.items():
13
+ if isinstance(layer, old_layer_type):
14
+ model._modules[name] = new_layer
15
+ replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
16
+
17
+
18
+ def find_layer_types_recursive(model, layer_types):
19
+ def predicate(layer):
20
+ return type(layer) in layer_types
21
+ return find_layer_predicate_recursive(model, predicate)
22
+
23
+
24
+ def find_layer_predicate_recursive(model, predicate):
25
+ result = []
26
+ for name, layer in model._modules.items():
27
+ if predicate(layer):
28
+ result.append(layer)
29
+ result.extend(find_layer_predicate_recursive(layer, predicate))
30
+ return result
pytorch_grad_cam/utils/image.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.transforms import Compose, Normalize, ToTensor
5
+
6
+
7
+ def preprocess_image(img: np.ndarray, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> torch.Tensor:
8
+ preprocessing = Compose([
9
+ ToTensor(),
10
+ Normalize(mean=mean, std=std)
11
+ ])
12
+ return preprocessing(img.copy()).unsqueeze(0)
13
+
14
+
15
+ def deprocess_image(img):
16
+ """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
17
+ img = img - np.mean(img)
18
+ img = img / (np.std(img) + 1e-5)
19
+ img = img * 0.1
20
+ img = img + 0.5
21
+ img = np.clip(img, 0, 1)
22
+ return np.uint8(img * 255)
23
+
24
+
25
+ def show_cam_on_image(img: np.ndarray,
26
+ mask: np.ndarray,
27
+ use_rgb: bool = False,
28
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
29
+ """ This function overlays the cam mask on the image as an heatmap.
30
+ By default the heatmap is in BGR format.
31
+
32
+ :param img: The base image in RGB or BGR format.
33
+ :param mask: The cam mask.
34
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
35
+ :param colormap: The OpenCV colormap to be used.
36
+ :returns: The default image with the cam overlay.
37
+ """
38
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
39
+ if use_rgb:
40
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
41
+ heatmap = np.float32(heatmap) / 255
42
+
43
+ if np.max(img) > 1:
44
+ raise Exception(
45
+ "The input image should np.float32 in the range [0, 1]")
46
+
47
+ cam = heatmap + img
48
+ cam = cam / np.max(cam)
49
+ return np.uint8(255 * cam)
pytorch_grad_cam/utils/svd_on_activations.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_2d_projection(activation_batch):
5
+ # TBD: use pytorch batch svd implementation
6
+ activation_batch[np.isnan(activation_batch)] = 0
7
+ projections = []
8
+ for activations in activation_batch:
9
+ reshaped_activations = (activations).reshape(
10
+ activations.shape[0], -1).transpose()
11
+ # Centering before the SVD seems to be important here,
12
+ # Otherwise the image returned is negative
13
+ reshaped_activations = reshaped_activations - \
14
+ reshaped_activations.mean(axis=0)
15
+ U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
16
+ projection = reshaped_activations @ VT[0, :]
17
+ projection = projection.reshape(activations.shape[1:])
18
+ projections.append(projection)
19
+ return np.float32(projections)