File size: 1,280 Bytes
a6b26e3
 
254b186
 
 
 
 
 
 
 
 
 
 
 
 
 
a6b26e3
 
 
254b186
a6b26e3
 
 
 
 
254b186
a6b26e3
254b186
 
 
 
a6b26e3
 
 
 
 
 
254b186
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import cv2
from matplotlib import pyplot as plt


def get_mpl_colormap(cmap_name):
    cmap = plt.get_cmap(cmap_name)

    # Initialize the matplotlib color map
    sm = plt.cm.ScalarMappable(cmap=cmap)

    # Obtain linear color range
    color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]

    return color_range.reshape(256, 1, 3)


def show_cam_on_image(img, mask, neg_saliency=False):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)

    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam


def show_overlapped_cam(img, neg_mask, pos_mask):
    # neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
    # pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
    neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), get_mpl_colormap("Blues"))
    pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), get_mpl_colormap("Reds"))
    neg_heatmap = np.float32(neg_heatmap) / 255
    pos_heatmap = np.float32(pos_heatmap) / 255
    # try different options: sum, average, ...
    heatmap = neg_heatmap + pos_heatmap
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam