import numpy as np import matplotlib.pyplot as plt import matplotlib.image as mpimg import cv2 import skimage import torch from PIL import Image joints = [ 'left ankle', 'left knee', 'left hip', 'right hip', 'right knee', 'right ankle', 'belly', 'chest', 'neck', 'head', 'left wrist', 'left elbow', 'left shoulder', 'right shoulder', 'right elbow', 'right wrist' ] def generate_heatmap(heatmap, pt, sigma=(33, 33), sigma_valu=7): ''' :param heatmap: should be a np zeros array with shape (H,W) (only i channel), not (H,W,1) :param pt: point coords, np array :param sigma: should be a tuple with odd values (obsolete) :param sigma_valu: vaalue for gaussian blur :return: a np array of one joint heatmap with shape (H,W) This function is obsolete, use 'generate_heatmaps()' instead. ''' heatmap[int(pt[1])][int(pt[0])] = 1 # heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W) heatmap = skimage.filters.gaussian( heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W) am = np.amax(heatmap) heatmap = heatmap/am return heatmap def generate_heatmaps(img, pts, sigma=(33, 33), sigma_valu=7): ''' :param img: np arrray img, (H,W,C) :param pts: joint points coords, np array, same resolu as img :param sigma: should be a tuple with odd values (obsolete) :param sigma_valu: vaalue for gaussian blur :return: np array heatmaps, (H,W,num_pts) ''' H, W = img.shape[0], img.shape[1] num_pts = pts.shape[0] heatmaps = np.zeros((H, W, num_pts)) for i, pt in enumerate(pts): # Filter unavailable heatmaps if pt[0] == 0 and pt[1] == 0: continue # Filter some points out of the image if pt[0] >= W: pt[0] = W-1 if pt[1] >= H: pt[1] = H-1 heatmap = heatmaps[:, :, i] heatmap[int(pt[1])][int(pt[0])] = 1 # heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W) heatmap = skimage.filters.gaussian( heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W) am = np.amax(heatmap) heatmap = heatmap / am heatmaps[:, :, i] = heatmap return heatmaps def load_image(path_image): img = mpimg.imread(path_image) # Return a np array (H,W,C) return img def crop(img, ele_anno, use_randscale=True, use_randflipLR=False, use_randcolor=False): ''' :param img: np array of the origin image, (H,W,C) :param ele_anno: one element of json annotation :return: img_crop, ary_pts_crop, c_crop after cropping ''' H, W = img.shape[0], img.shape[1] s = ele_anno['scale_provided'] c = ele_anno['objpos'] # Adjust center and scale if c[0] != -1: c[1] = c[1] + 15 * s s = s * 1.25 ary_pts = np.array(ele_anno['joint_self']) # (16, 3) ary_pts_temp = ary_pts[np.any(ary_pts != [0, 0, 0], axis=1)] if use_randscale: scale_rand = np.random.uniform(low=1.0, high=3.0) else: scale_rand = 1 W_min = max(np.amin(ary_pts_temp, axis=0)[0] - s * 15 * scale_rand, 0) H_min = max(np.amin(ary_pts_temp, axis=0)[1] - s * 15 * scale_rand, 0) W_max = min(np.amax(ary_pts_temp, axis=0)[0] + s * 15 * scale_rand, W) H_max = min(np.amax(ary_pts_temp, axis=0)[1] + s * 15 * scale_rand, H) W_len = W_max - W_min H_len = H_max - H_min window_len = max(H_len, W_len) pad_updown = (window_len - H_len)/2 pad_leftright = (window_len - W_len)/2 # Calculate 4 corner position W_low = max((W_min - pad_leftright), 0) W_high = min((W_max + pad_leftright), W) H_low = max((H_min - pad_updown), 0) H_high = min((H_max + pad_updown), H) # Update joint points and center ary_pts_crop = np.where( ary_pts == [0, 0, 0], ary_pts, ary_pts - np.array([W_low, H_low, 0])) c_crop = c - np.array([W_low, H_low]) img_crop = img[int(H_low):int(H_high), int(W_low):int(W_high), :] # Pad when H, W different H_new, W_new = img_crop.shape[0], img_crop.shape[1] window_len_new = max(H_new, W_new) pad_updown_new = int((window_len_new - H_new)/2) pad_leftright_new = int((window_len_new - W_new)/2) # ReUpdate joint points and center (because of the padding) ary_pts_crop = np.where(ary_pts_crop == [ 0, 0, 0], ary_pts_crop, ary_pts_crop + np.array([pad_leftright_new, pad_updown_new, 0])) c_crop = c_crop + np.array([pad_leftright_new, pad_updown_new]) img_crop = cv2.copyMakeBorder(img_crop, pad_updown_new, pad_updown_new, pad_leftright_new, pad_leftright_new, cv2.BORDER_CONSTANT, value=0) # change dtype and num scale img_crop = img_crop / 255. img_crop = img_crop.astype(np.float64) if use_randflipLR: flip = np.random.random() > 0.5 # print('rand_flipLR', flip) if flip: # (H,W,C) img_crop = np.flip(img_crop, 1) # Calculate flip pts, remember to filter [0,0] which is no available heatmap ary_pts_crop = np.where(ary_pts_crop == [0, 0, 0], ary_pts_crop, [window_len_new, 0, 0] + ary_pts_crop * [-1, 1, 0]) c_crop = [window_len_new, 0] + c_crop * [-1, 1] # Rearrange pts ary_pts_crop = np.concatenate( (ary_pts_crop[5::-1], ary_pts_crop[6:10], ary_pts_crop[15:9:-1])) if use_randcolor: randcolor = np.random.random() > 0.5 # print('rand_color', randcolor) if randcolor: img_crop[..., 0] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.) img_crop[..., 1] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.) img_crop[..., 2] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.) return img_crop, ary_pts_crop, c_crop def change_resolu(img, pts, c, resolu_out=(256, 256)): ''' :param img: np array of the origin image :param pts: joint points np array corresponding to the image, same resolu as img :param c: center :param resolu_out: a list or tuple :return: img_out, pts_out, c_out under resolu_out ''' H_in = img.shape[0] W_in = img.shape[1] H_out = resolu_out[0] W_out = resolu_out[1] H_scale = H_in/H_out W_scale = W_in/W_out pts_out = pts/np.array([W_scale, H_scale, 1]) c_out = c/np.array([W_scale, H_scale]) img_out = skimage.transform.resize(img, tuple(resolu_out)) return img_out, pts_out, c_out def heatmaps_to_coords(heatmaps, resolu_out=[64, 64], prob_threshold=0.2): ''' :param heatmaps: tensor with shape (64,64,16) :param resolu_out: output resolution list :return coord_joints: np array, shape (16,2) ''' num_joints = heatmaps.shape[2] # Resize heatmaps = skimage.transform.resize(heatmaps, tuple(resolu_out)) coord_joints = np.zeros((num_joints, 3)) for i in range(num_joints): heatmap = heatmaps[..., i] max = np.max(heatmap) # Only keep points larger than a threshold if max >= prob_threshold: idx = np.where(heatmap == max) H = idx[0][0] W = idx[1][0] else: H = 0 W = 0 coord_joints[i] = [W, H, max] return coord_joints def show_heatmaps(img, heatmaps, c=np.zeros((2)), num_fig=1): ''' :param img: np array (H,W,3) :param heatmaps: np array (H,W,num_pts) :param c: center, np array (2,) ''' H, W = img.shape[0], img.shape[1] if heatmaps.shape[0] != H: heatmaps = skimage.transform.resize(heatmaps, (H, W)) plt.figure(num_fig) for i in range(heatmaps.shape[2] + 1): plt.subplot(4, 5, i + 1) if i == 0: plt.title('Origin') else: plt.title(joints[i-1]) if i == 0: plt.imshow(img) else: plt.imshow(heatmaps[:, :, i - 1]) plt.axis('off') plt.subplot(4, 5, 20) plt.axis('off') plt.show() def heatmap2rgb(heatmap): """ : heatmap: (h,w) """ heatmap = heatmap.detach().cpu().numpy() # plt.figure(figsize=(1,1)) # plt.axis('off') # plt.imshow(heatmap) # plt.savefig('tmp/tmp.jpg', bbox_inches='tight', pad_inches=0, dpi=70) # plt.close() # plt.clf() # img = Image.open('tmp/tmp.jpg') cm = plt.get_cmap('jet') normed_data = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap + 1e-8)) mapped_data = cm(normed_data) # (h,w,c) # img = np.array(img) img = np.array(mapped_data) img = img[:,:,:3] img = torch.tensor(img).permute(2, 0, 1) return img def heatmaps2rgb(heatmaps): """ : heatmaps: (b,h,w) """ out_imgs = [] for heatmap in heatmaps: out_imgs.append(heatmap2rgb(heatmap)) return torch.stack(out_imgs) # def draw_joints(img, pts): # scores = pts[:,2] # pts = np.array(pts).astype(int) # for i in range(pts.shape[0]): # if pts[i, 0] != 0 and pts[i, 1] != 0: # img = cv2.circle(img, (pts[i, 0], pts[i, 1]), radius=3, # color=(255, 0, 0), thickness=-1) # print('img',img.max(),img.min()) # # img = cv2.putText(img, f'{joints[i]}: {scores[i]:.2f}', ( # # pts[i, 0]+5, pts[i, 1]-5), cv2.FONT_HERSHEY_SIMPLEX, .25, (255, 0, 0)) # # Left arm # for i in range(10, 13-1): # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], # pts[i+1, 1]), color=(255, 0, 0), thickness=1) # # Right arm # for i in range(13, 16-1): # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], # pts[i+1, 1]), color=(255, 0, 0), thickness=1) # # Left leg # for i in range(0, 3-1): # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], # pts[i+1, 1]), color=(255, 0, 0), thickness=1) # # Right leg # for i in range(3, 6-1): # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], # pts[i+1, 1]), color=(255, 0, 0), thickness=1) # # Body # for i in range(6, 10-1): # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], # pts[i+1, 1]), color=(255, 0, 0), thickness=1) # if pts[2, 0] != 0 and pts[2, 1] != 0 and pts[3, 0] != 0 and pts[3, 1] != 0: # img = cv2.line(img, (pts[2, 0], pts[2, 1]), (pts[2+1, 0], # pts[2+1, 1]), color=(255, 0, 0), thickness=1) # if pts[12, 0] != 0 and pts[12, 1] != 0 and pts[13, 0] != 0 and pts[13, 1] != 0: # img = cv2.line(img, (pts[12, 0], pts[12, 1]), (pts[12+1, 0], # pts[12+1, 1]), color=(255, 0, 0), thickness=1) # return img def draw_joints(img, pts): # Convert the image to the range [0, 255] for visualization img_visualization = (img).astype(np.uint8) # Draw lines for the body parts for i in range(10, 13 - 1): draw_line(img_visualization, pts[i], pts[i + 1]) for i in range(13, 16 - 1): draw_line(img_visualization, pts[i], pts[i + 1]) for i in range(0, 3 - 1): draw_line(img_visualization, pts[i], pts[i + 1]) for i in range(3, 6 - 1): draw_line(img_visualization, pts[i], pts[i + 1]) for i in range(6, 10 - 1): draw_line(img_visualization, pts[i], pts[i + 1]) draw_line(img_visualization, pts[2], pts[3]) draw_line(img_visualization, pts[12], pts[13]) return img_visualization / 255.0 def draw_line(img, pt1, pt2): if pt1[0] != 0 and pt1[1] != 0 and pt2[0] != 0 and pt2[1] != 0: cv2.line(img, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), color=(57, 255, 20), thickness=2)