import numpy as np import cv2 from matplotlib import pyplot as plt def get_mpl_colormap(cmap_name): cmap = plt.get_cmap(cmap_name) # Initialize the matplotlib color map sm = plt.cm.ScalarMappable(cmap=cmap) # Obtain linear color range color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1] return color_range.reshape(256, 1, 3) def show_cam_on_image(img, mask, neg_saliency=False): heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam = heatmap + np.float32(img) cam = cam / np.max(cam) return cam def show_overlapped_cam(img, neg_mask, pos_mask): # neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW) # pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET) neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), get_mpl_colormap("Blues")) pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), get_mpl_colormap("Reds")) neg_heatmap = np.float32(neg_heatmap) / 255 pos_heatmap = np.float32(pos_heatmap) / 255 # try different options: sum, average, ... heatmap = neg_heatmap + pos_heatmap cam = heatmap + np.float32(img) cam = cam / np.max(cam) return cam