yijiu commited on
Commit
93a6fff
1 Parent(s): c5890a7

feat: upload project

Browse files
app.py CHANGED
@@ -1,7 +1,46 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import time
3
+ import numpy
4
+ import os
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import skimage
9
+ from models.hr_net import hr_w32
10
+ from tool_utils import heatmaps_to_coords,draw_joints
11
 
12
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
13
+ #Create example list from 'examples/'directory
14
+ example_list=[["./examples/"+example] for example in os.listdir("examples")]
15
 
16
+ def predict(numpy_img):
17
+ #resize the numpy_image size to (256,256)
18
+ img_np=skimage.transform.resize(numpy_img,[256,256])
19
+ #convert numpy_image to tensor
20
+ img=torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float().to(device)
21
+ #choose model class hr_w32
22
+ model=hr_w32().to(device)
23
+ #load weights of model
24
+ model.load_state_dict(torch.load('./weights/HRNet_epoch20_loss0.000474.pth')['model'])
25
+ # #set model to pred state
26
+ model.eval()
27
+ # #predict the heatmaps of joints
28
+ start_time=time.time()
29
+ heatmaps_pred=model(img)
30
+ heatmaps_pred=heatmaps_pred.double()
31
+ # #convert output to numpy
32
+ heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
33
+ # #heatmaps to joints location
34
+ coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[256,256],prob_threshold=0.1)
35
+ inference_time=time.time()-start_time
36
+ inference_time_text="model inference time:{:.4f}s".format(inference_time)
37
+ # #draw coords on image_np
38
+ img_rgb=draw_joints(img_np,coord_joints)
39
+ return img_rgb,inference_time_text
40
+
41
+
42
+
43
+ demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy',width=256,height=256),"text"],examples=example_list)
44
+
45
+ if __name__=="__main__":
46
+ demo.launch(show_api=False)
examples/000000000016.jpg ADDED
examples/000000000552.jpg ADDED
models/hr_net.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from models.modules.blocks.bottleneck import Bottleneck
5
+ from models.modules.stage_module import StageModule
6
+
7
+
8
+ def weights_init(m):
9
+ if isinstance(m, nn.Conv2d):
10
+ nn.init.normal_(m.weight, std=.01)
11
+ if m.bias is not None:
12
+ nn.init.constant_(m.bias, 0)
13
+ elif isinstance(m, nn.BatchNorm2d):
14
+ nn.init.constant_(m.weight, 1)
15
+ if m.bias is not None:
16
+ nn.init.constant_(m.bias, 0)
17
+
18
+
19
+ class HRNet(nn.Module):
20
+
21
+ def __init__(self, c=48, nof_joints=16, bn_momentum=.1):
22
+ super(HRNet, self).__init__()
23
+
24
+ # (b,3,y,x) -> (b,64,y,x)
25
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
26
+ stride=2, padding=1, bias=False)
27
+ self.bn1 = nn.BatchNorm2d(64, momentum=bn_momentum)
28
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
29
+ stride=2, padding=1, bias=False)
30
+ self.bn2 = nn.BatchNorm2d(64, momentum=bn_momentum)
31
+ self.relu = nn.ReLU(inplace=True)
32
+
33
+ # (b,64,y,x) -> (b,256,y,x)
34
+ downsample = nn.Sequential(
35
+ nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
36
+ nn.BatchNorm2d(256),
37
+ )
38
+ self.layer1 = nn.Sequential(
39
+ Bottleneck(64, 64, downsample=downsample),
40
+ Bottleneck(256, 64),
41
+ Bottleneck(256, 64),
42
+ Bottleneck(256, 64),
43
+ )
44
+
45
+ # (b,256,y,x) ---+---> (b,c,y,x)
46
+ # +---> (b,c*2,y/2,x/2)
47
+ self.transition1 = nn.ModuleList([
48
+ nn.Sequential(
49
+ nn.Conv2d(256, c, kernel_size=3,
50
+ stride=1, padding=1, bias=False),
51
+ nn.BatchNorm2d(c),
52
+ nn.ReLU(inplace=True),
53
+ ),
54
+ nn.Sequential(nn.Sequential(
55
+ nn.Conv2d(256, c * 2, kernel_size=3,
56
+ stride=2, padding=1, bias=False),
57
+ nn.BatchNorm2d(c * 2),
58
+ nn.ReLU(inplace=True),
59
+ ))
60
+ ])
61
+
62
+ # StageModule中每个分枝发生了融合
63
+ # (b,c,y,x) ------+---> (b,c,y,x)
64
+ # (b,c*2,y/2,x/2) +---> (b,c*2,y/2,x/2)
65
+ self.stage2 = nn.Sequential(
66
+ StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum)
67
+ )
68
+
69
+ # (b,c,y,x) ----------> (b,c,y,x)
70
+ # (b,c*2,y/2,x/2) +---> (b,c*2,y/2,x/2)
71
+ # +---> (b,c*4,y/4,x/4)
72
+ self.transition2 = nn.ModuleList([
73
+ nn.Sequential(),
74
+ nn.Sequential(),
75
+ nn.Sequential(nn.Sequential(
76
+ nn.Conv2d(c * 2, c * 4, kernel_size=3,
77
+ stride=2, padding=1, bias=False),
78
+ nn.BatchNorm2d(c * 4),
79
+ nn.ReLU(inplace=True),
80
+ ))
81
+ ])
82
+
83
+ # (b,c,y,x) ------++++---> (b,c,y,x)
84
+ # (b,c*2,y/2,x/2) ++++---> (b,c*2,y/2,x/2)
85
+ # (b,c*4,y/4,x/4) ++++---> (b,c*4,y/4,x/4)
86
+ self.stage3 = nn.Sequential(
87
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
88
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
89
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
90
+ StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
91
+ )
92
+
93
+ # (b,c,y,x) ----------> (b,c,y,x)
94
+ # (b,c*2,y/2,x/2) ----> (b,c*2,y/2,x/2)
95
+ # (b,c*4,y/4,x/4) +---> (b,c*4,y/4,x/4)
96
+ # +---> (b,c*8,y/8,x/8)
97
+ self.transition3 = nn.ModuleList([
98
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
99
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
100
+ nn.Sequential(), # None, - Used in place of "None" because it is callable
101
+ nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
102
+ nn.Conv2d(c * 4, c * 8, kernel_size=3,
103
+ stride=2, padding=1, bias=False),
104
+ nn.BatchNorm2d(c * 8),
105
+ nn.ReLU(inplace=True),
106
+ )),
107
+ ])
108
+
109
+ # (b,c,y,x) ------+++---> (b,c,y,x)
110
+ # (b,c*2,y/2,x/2) +++---> (b,c*2,y/2,x/2)
111
+ # (b,c*4,y/4,x/4) +++---> (b,c*4,y/4,x/4)
112
+ # (b,c*8,y/8,x/8) +++---> (b,c*8,y/8,x/8)
113
+ self.stage4 = nn.Sequential(
114
+ StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
115
+ StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
116
+ StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
117
+ )
118
+
119
+ # 取最高分辨率的结果
120
+ # (b,c,y,x) -> (b,nof_joints*2,y,x)
121
+ self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=1, stride=1)
122
+
123
+ self.apply(weights_init)
124
+
125
+ def forward(self, x):
126
+ x = self.relu(self.bn1(self.conv1(x)))
127
+ x = self.relu(self.bn2(self.conv2(x)))
128
+
129
+ x = self.layer1(x)
130
+ x = [trans(x) for trans in self.transition1]
131
+
132
+ x = self.stage2(x)
133
+ x = [
134
+ self.transition2[0](x[0]),
135
+ self.transition2[1](x[1]),
136
+ self.transition2[2](x[1]),
137
+ ]
138
+
139
+ x = self.stage3(x)
140
+ x = [
141
+ self.transition3[0](x[0]),
142
+ self.transition3[1](x[1]),
143
+ self.transition3[2](x[2]),
144
+ self.transition3[3](x[2]),
145
+ ]
146
+
147
+ x = self.stage4(x)
148
+
149
+ x = x[0]
150
+ out = self.final_layer(x)
151
+
152
+ return out
153
+
154
+ def hr_w32():
155
+ return HRNet(32)
156
+
157
+ if __name__ == '__main__':
158
+ import torch
159
+
160
+ model = hr_w32()
161
+ x = torch.randn(1,3,256,256)
162
+ output = model(x)
163
+ print(output.size())
models/modules/__init__.py ADDED
File without changes
models/modules/blocks/__init__.py ADDED
File without changes
models/modules/blocks/basic_block.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class BasicBlock(nn.Module):
5
+ """
6
+ (b,c,y,x) -> (b,c,y,x)
7
+ """
8
+ expansion = 1
9
+
10
+ def __init__(self, planes, bn_momentum=.1):
11
+ super(BasicBlock, self).__init__()
12
+
13
+ self.conv1 = nn.Conv2d(planes, planes, kernel_size=3,
14
+ stride=1, padding=1, bias=False)
15
+ self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
16
+ self.relu = nn.ReLU(inplace=True)
17
+
18
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
19
+ stride=1, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+
25
+ out = self.relu(self.bn1(self.conv1(x)))
26
+ out = self.bn2(self.conv2(out))
27
+
28
+ out += residual
29
+ return self.relu(out)
30
+
31
+
32
+ if __name__ == '__main__':
33
+ import torch
34
+
35
+ model = BasicBlock(256)
36
+ x = torch.randn(1, 256, 128, 128)
37
+ print(model(x).size()) # torch.Size([1,256,128,128])
models/modules/blocks/bottleneck.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class Bottleneck(nn.Module):
5
+ """
6
+ (b,c_in,y,x) -> (b,4*c_out,y,x)
7
+ """
8
+
9
+ expansion = 4
10
+
11
+ def __init__(self, inplanes, planes, downsample=None, bn_momentum=.1):
12
+ super(Bottleneck, self).__init__()
13
+
14
+ self.conv1 = nn.Conv2d(inplanes, planes,
15
+ kernel_size=1, bias=False)
16
+ self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
17
+
18
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
19
+ stride=1, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
21
+
22
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion,
23
+ kernel_size=1, bias=False)
24
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
25
+ momentum=bn_momentum)
26
+
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = downsample
29
+
30
+ def forward(self, x):
31
+ residual = x
32
+
33
+ out = self.relu(self.bn1(self.conv1(x)))
34
+ out = self.relu(self.bn2(self.conv2(out)))
35
+ out = self.bn3(self.conv3(out))
36
+
37
+ if self.downsample is not None:
38
+ residual = self.downsample(x)
39
+
40
+ out += residual
41
+ return self.relu(out)
42
+
43
+
44
+ if __name__ == '__main__':
45
+ import torch
46
+
47
+ downsample = nn.Sequential(
48
+ nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
49
+ nn.BatchNorm2d(256),
50
+ )
51
+ model = Bottleneck(64, 64, downsample=downsample)
52
+ x = torch.randn(1, 64, 128, 128)
53
+ print(model(x).size()) # torch.Size([1,256,128,128])
54
+
55
+ model = Bottleneck(256,64)
56
+ x = torch.randn(1,256,128,128)
57
+ print(model(x).size()) # torch.Size([2,256,128,128])
models/modules/stage_module.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from models.modules.blocks.basic_block import BasicBlock
4
+
5
+
6
+ class StageModule(nn.Module):
7
+
8
+ def __init__(self, stage, output_branches, c, bn_momentum):
9
+ super(StageModule, self).__init__()
10
+
11
+ self.stage = stage
12
+ self.output_branches = output_branches
13
+
14
+ # 得到stage对应数量的分枝
15
+ # 例如stage=3,c=32时
16
+ # i = 0,1,2
17
+ # i = 0 -> 4*BasicBlock(32)
18
+ # i = 1 -> 4*BasicBlock(64)
19
+ # i = 2 -> 4*BasicBlock(128)
20
+ #
21
+ # -+--- 4*BasicBlock(32) ---->
22
+ # +--- 4*BasicBlock(64) ---->
23
+ # +--- 4*BasicBlock(128) --->
24
+ self.branches = nn.ModuleList()
25
+ for i in range(self.stage):
26
+ w = c * (2**i)
27
+ branch = nn.Sequential(
28
+ BasicBlock(w, bn_momentum=bn_momentum),
29
+ BasicBlock(w, bn_momentum=bn_momentum),
30
+ BasicBlock(w, bn_momentum=bn_momentum),
31
+ BasicBlock(w, bn_momentum=bn_momentum),
32
+ )
33
+ self.branches.append(branch)
34
+
35
+ self.fuse_layers = nn.ModuleList()
36
+
37
+ # 得到i*j个输出分枝,其中第(i,j)个输出分枝代表第j个分枝向第i个输出变换的输出分枝
38
+ # i<j,则输出分枝的通道数小于分枝i的通道数,作上采样
39
+ # i>j,则输出分枝的通道数大于分枝i的通道数,作下采样
40
+ # +---output branch 0(c=32)---->
41
+ # +(upsample)
42
+ # ---branch 1(c=64)---+---output branch 1(c=64)---->
43
+ # +(downsample)
44
+ # +---output branch 2(c=128)--->
45
+ # 对于每一个输出分枝i
46
+ for i in range(self.output_branches):
47
+ self.fuse_layers.append(nn.ModuleList())
48
+
49
+ # 对于每一个分枝j
50
+ for j in range(self.stage):
51
+
52
+ # 如果分枝与输出分枝相对应,直接输出
53
+ if i == j:
54
+ self.fuse_layers[-1].append(nn.Sequential())
55
+
56
+ # 如果输出分枝编号小于分枝编号,则上采样后输出
57
+ elif i < j:
58
+ self.fuse_layers[-1].append(nn.Sequential(
59
+ nn.Conv2d(c * (2**j), c * (2**i), kernel_size=1,
60
+ stride=1, bias=False),
61
+ nn.BatchNorm2d(c * (2**i)),
62
+ nn.Upsample(scale_factor=(2.**(j-i))),
63
+ ))
64
+
65
+ # 如果输出分枝编号大于分枝编号,则下采样后输出
66
+ elif i > j:
67
+ ops = []
68
+ for _ in range(i - j - 1):
69
+ ops.append(nn.Sequential(
70
+ nn.Conv2d(c * (2**j), c * (2**j), kernel_size=3,
71
+ stride=2, padding=1, bias=False),
72
+ nn.BatchNorm2d(c * (2**j)),
73
+ nn.ReLU(inplace=True),
74
+ ))
75
+ ops.append(nn.Sequential(
76
+ nn.Conv2d(c * (2**j), c * (2**i), kernel_size=3,
77
+ stride=2, padding=1, bias=False),
78
+ nn.BatchNorm2d(c * (2**i)),
79
+ ))
80
+ self.fuse_layers[-1].append(nn.Sequential(*ops))
81
+
82
+ self.relu = nn.ReLU(inplace=True)
83
+
84
+ def forward(self, x):
85
+ # 将x经过每个分枝
86
+ x = [branch(b) for branch, b in zip(self.branches, x)]
87
+
88
+ x_fused = []
89
+ # 对于每个输出分枝
90
+ for i in range(len(self.fuse_layers)):
91
+ # 对于每个分枝
92
+ for j in range(len(self.branches)):
93
+ # 如果是第0个分枝,则将经过第0个分枝的x经过第i个输出分枝
94
+ if j == 0:
95
+ x_fused.append(self.fuse_layers[i][0](x[0]))
96
+ # 否则,将经过第j个分枝的x经过第i个输出分枝,与之前第i个输出分枝的结果相加
97
+ else:
98
+ x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
99
+
100
+ # 每个输出分枝的结果经过ReLU
101
+ for i in range(len(x_fused)):
102
+ x_fused[i] = self.relu(x_fused[i])
103
+
104
+ return x_fused
models/modules/stem.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class Stem(nn.Module):
5
+ """
6
+ Stem模块进行1/4的下采样,并将通道数变为64
7
+ (b,3,y,x) -> (b,64,y/4,x/4)
8
+ """
9
+ def __init__(self, bn_momentum=.1):
10
+ super(Stem, self).__init__()
11
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
12
+ stride=2, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(64, momentum=bn_momentum)
14
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
15
+ stride=2, padding=1, bias=False)
16
+ self.bn2 = nn.BatchNorm2d(64, momentum=bn_momentum)
17
+ self.relu = nn.ReLU(inplace=True)
18
+
19
+ def forward(self, x):
20
+ out = self.bn1(self.conv1(x))
21
+ out = self.bn2(self.conv2(out))
22
+ return self.relu(out)
23
+
24
+ if __name__ == '__main__':
25
+ import torch
26
+
27
+ model = Stem()
28
+ x = torch.randn(1,3,128,64)
29
+ print(model(x).size()) # torch.Size([1,64,32,16])
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython 3.0.5
2
+ gradio 4.8.0
3
+ gradio_client 0.7.1
4
+ huggingface-hub 0.19.4
5
+ imageio 2.33.0
6
+ numpy 1.24.3
7
+ opencv-python 4.8.1.78
8
+ opendatalab 0.0.10
9
+ Pillow 10.0.1
10
+ pip 23.3
11
+ pycocotools 2.0.7
12
+ scikit-image 0.21.0
13
+ scipy 1.10.1
14
+ torch 1.12.0+cu113
15
+ torchaudio 0.12.0+cu113
16
+ torchvision 0.13.0+cu113
17
+ tqdm 4.65.2
tool_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.image as mpimg
4
+ import cv2
5
+ import skimage
6
+ import torch
7
+ from PIL import Image
8
+
9
+ joints = [
10
+ 'left ankle',
11
+ 'left knee',
12
+ 'left hip',
13
+ 'right hip',
14
+ 'right knee',
15
+ 'right ankle',
16
+ 'belly',
17
+ 'chest',
18
+ 'neck',
19
+ 'head',
20
+ 'left wrist',
21
+ 'left elbow',
22
+ 'left shoulder',
23
+ 'right shoulder',
24
+ 'right elbow',
25
+ 'right wrist'
26
+ ]
27
+
28
+
29
+ def generate_heatmap(heatmap, pt, sigma=(33, 33), sigma_valu=7):
30
+ '''
31
+ :param heatmap: should be a np zeros array with shape (H,W) (only i channel), not (H,W,1)
32
+ :param pt: point coords, np array
33
+ :param sigma: should be a tuple with odd values (obsolete)
34
+ :param sigma_valu: vaalue for gaussian blur
35
+ :return: a np array of one joint heatmap with shape (H,W)
36
+
37
+ This function is obsolete, use 'generate_heatmaps()' instead.
38
+ '''
39
+ heatmap[int(pt[1])][int(pt[0])] = 1
40
+ # heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W)
41
+ heatmap = skimage.filters.gaussian(
42
+ heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W)
43
+ am = np.amax(heatmap)
44
+ heatmap = heatmap/am
45
+ return heatmap
46
+
47
+
48
+ def generate_heatmaps(img, pts, sigma=(33, 33), sigma_valu=7):
49
+ '''
50
+ :param img: np arrray img, (H,W,C)
51
+ :param pts: joint points coords, np array, same resolu as img
52
+ :param sigma: should be a tuple with odd values (obsolete)
53
+ :param sigma_valu: vaalue for gaussian blur
54
+ :return: np array heatmaps, (H,W,num_pts)
55
+ '''
56
+ H, W = img.shape[0], img.shape[1]
57
+ num_pts = pts.shape[0]
58
+ heatmaps = np.zeros((H, W, num_pts))
59
+ for i, pt in enumerate(pts):
60
+ # Filter unavailable heatmaps
61
+ if pt[0] == 0 and pt[1] == 0:
62
+ continue
63
+ # Filter some points out of the image
64
+ if pt[0] >= W:
65
+ pt[0] = W-1
66
+ if pt[1] >= H:
67
+ pt[1] = H-1
68
+ heatmap = heatmaps[:, :, i]
69
+ heatmap[int(pt[1])][int(pt[0])] = 1
70
+ # heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W)
71
+ heatmap = skimage.filters.gaussian(
72
+ heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W)
73
+ am = np.amax(heatmap)
74
+ heatmap = heatmap / am
75
+ heatmaps[:, :, i] = heatmap
76
+ return heatmaps
77
+
78
+
79
+ def load_image(path_image):
80
+ img = mpimg.imread(path_image)
81
+ # Return a np array (H,W,C)
82
+ return img
83
+
84
+
85
+ def crop(img, ele_anno, use_randscale=True, use_randflipLR=False, use_randcolor=False):
86
+ '''
87
+ :param img: np array of the origin image, (H,W,C)
88
+ :param ele_anno: one element of json annotation
89
+ :return: img_crop, ary_pts_crop, c_crop after cropping
90
+ '''
91
+
92
+ H, W = img.shape[0], img.shape[1]
93
+ s = ele_anno['scale_provided']
94
+ c = ele_anno['objpos']
95
+
96
+ # Adjust center and scale
97
+ if c[0] != -1:
98
+ c[1] = c[1] + 15 * s
99
+ s = s * 1.25
100
+ ary_pts = np.array(ele_anno['joint_self']) # (16, 3)
101
+ ary_pts_temp = ary_pts[np.any(ary_pts != [0, 0, 0], axis=1)]
102
+
103
+ if use_randscale:
104
+ scale_rand = np.random.uniform(low=1.0, high=3.0)
105
+ else:
106
+ scale_rand = 1
107
+
108
+ W_min = max(np.amin(ary_pts_temp, axis=0)[0] - s * 15 * scale_rand, 0)
109
+ H_min = max(np.amin(ary_pts_temp, axis=0)[1] - s * 15 * scale_rand, 0)
110
+ W_max = min(np.amax(ary_pts_temp, axis=0)[0] + s * 15 * scale_rand, W)
111
+ H_max = min(np.amax(ary_pts_temp, axis=0)[1] + s * 15 * scale_rand, H)
112
+ W_len = W_max - W_min
113
+ H_len = H_max - H_min
114
+ window_len = max(H_len, W_len)
115
+ pad_updown = (window_len - H_len)/2
116
+ pad_leftright = (window_len - W_len)/2
117
+
118
+ # Calculate 4 corner position
119
+ W_low = max((W_min - pad_leftright), 0)
120
+ W_high = min((W_max + pad_leftright), W)
121
+ H_low = max((H_min - pad_updown), 0)
122
+ H_high = min((H_max + pad_updown), H)
123
+
124
+ # Update joint points and center
125
+ ary_pts_crop = np.where(
126
+ ary_pts == [0, 0, 0], ary_pts, ary_pts - np.array([W_low, H_low, 0]))
127
+ c_crop = c - np.array([W_low, H_low])
128
+
129
+ img_crop = img[int(H_low):int(H_high), int(W_low):int(W_high), :]
130
+
131
+ # Pad when H, W different
132
+ H_new, W_new = img_crop.shape[0], img_crop.shape[1]
133
+ window_len_new = max(H_new, W_new)
134
+ pad_updown_new = int((window_len_new - H_new)/2)
135
+ pad_leftright_new = int((window_len_new - W_new)/2)
136
+
137
+ # ReUpdate joint points and center (because of the padding)
138
+ ary_pts_crop = np.where(ary_pts_crop == [
139
+ 0, 0, 0], ary_pts_crop, ary_pts_crop + np.array([pad_leftright_new, pad_updown_new, 0]))
140
+ c_crop = c_crop + np.array([pad_leftright_new, pad_updown_new])
141
+
142
+ img_crop = cv2.copyMakeBorder(img_crop, pad_updown_new, pad_updown_new,
143
+ pad_leftright_new, pad_leftright_new, cv2.BORDER_CONSTANT, value=0)
144
+
145
+ # change dtype and num scale
146
+ img_crop = img_crop / 255.
147
+ img_crop = img_crop.astype(np.float64)
148
+
149
+ if use_randflipLR:
150
+ flip = np.random.random() > 0.5
151
+ # print('rand_flipLR', flip)
152
+ if flip:
153
+ # (H,W,C)
154
+ img_crop = np.flip(img_crop, 1)
155
+ # Calculate flip pts, remember to filter [0,0] which is no available heatmap
156
+ ary_pts_crop = np.where(ary_pts_crop == [0, 0, 0], ary_pts_crop,
157
+ [window_len_new, 0, 0] + ary_pts_crop * [-1, 1, 0])
158
+ c_crop = [window_len_new, 0] + c_crop * [-1, 1]
159
+ # Rearrange pts
160
+ ary_pts_crop = np.concatenate(
161
+ (ary_pts_crop[5::-1], ary_pts_crop[6:10], ary_pts_crop[15:9:-1]))
162
+
163
+ if use_randcolor:
164
+ randcolor = np.random.random() > 0.5
165
+ # print('rand_color', randcolor)
166
+ if randcolor:
167
+ img_crop[...,
168
+ 0] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.)
169
+ img_crop[...,
170
+ 1] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.)
171
+ img_crop[...,
172
+ 2] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.)
173
+
174
+ return img_crop, ary_pts_crop, c_crop
175
+
176
+
177
+ def change_resolu(img, pts, c, resolu_out=(256, 256)):
178
+ '''
179
+ :param img: np array of the origin image
180
+ :param pts: joint points np array corresponding to the image, same resolu as img
181
+ :param c: center
182
+ :param resolu_out: a list or tuple
183
+ :return: img_out, pts_out, c_out under resolu_out
184
+ '''
185
+ H_in = img.shape[0]
186
+ W_in = img.shape[1]
187
+ H_out = resolu_out[0]
188
+ W_out = resolu_out[1]
189
+ H_scale = H_in/H_out
190
+ W_scale = W_in/W_out
191
+
192
+ pts_out = pts/np.array([W_scale, H_scale, 1])
193
+ c_out = c/np.array([W_scale, H_scale])
194
+ img_out = skimage.transform.resize(img, tuple(resolu_out))
195
+
196
+ return img_out, pts_out, c_out
197
+
198
+
199
+ def heatmaps_to_coords(heatmaps, resolu_out=[64, 64], prob_threshold=0.2):
200
+ '''
201
+ :param heatmaps: tensor with shape (64,64,16)
202
+ :param resolu_out: output resolution list
203
+ :return coord_joints: np array, shape (16,2)
204
+ '''
205
+
206
+ num_joints = heatmaps.shape[2]
207
+ # Resize
208
+ heatmaps = skimage.transform.resize(heatmaps, tuple(resolu_out))
209
+
210
+ coord_joints = np.zeros((num_joints, 3))
211
+ for i in range(num_joints):
212
+ heatmap = heatmaps[..., i]
213
+ max = np.max(heatmap)
214
+ # Only keep points larger than a threshold
215
+ if max >= prob_threshold:
216
+ idx = np.where(heatmap == max)
217
+ H = idx[0][0]
218
+ W = idx[1][0]
219
+ else:
220
+ H = 0
221
+ W = 0
222
+ coord_joints[i] = [W, H, max]
223
+ return coord_joints
224
+
225
+
226
+ def show_heatmaps(img, heatmaps, c=np.zeros((2)), num_fig=1):
227
+ '''
228
+ :param img: np array (H,W,3)
229
+ :param heatmaps: np array (H,W,num_pts)
230
+ :param c: center, np array (2,)
231
+ '''
232
+ H, W = img.shape[0], img.shape[1]
233
+
234
+ if heatmaps.shape[0] != H:
235
+ heatmaps = skimage.transform.resize(heatmaps, (H, W))
236
+
237
+ plt.figure(num_fig)
238
+ for i in range(heatmaps.shape[2] + 1):
239
+ plt.subplot(4, 5, i + 1)
240
+ if i == 0:
241
+ plt.title('Origin')
242
+ else:
243
+ plt.title(joints[i-1])
244
+
245
+ if i == 0:
246
+ plt.imshow(img)
247
+ else:
248
+ plt.imshow(heatmaps[:, :, i - 1])
249
+
250
+ plt.axis('off')
251
+ plt.subplot(4, 5, 20)
252
+ plt.axis('off')
253
+ plt.show()
254
+
255
+
256
+ def heatmap2rgb(heatmap):
257
+ """
258
+ : heatmap: (h,w)
259
+ """
260
+
261
+ heatmap = heatmap.detach().cpu().numpy()
262
+
263
+ # plt.figure(figsize=(1,1))
264
+ # plt.axis('off')
265
+ # plt.imshow(heatmap)
266
+ # plt.savefig('tmp/tmp.jpg', bbox_inches='tight', pad_inches=0, dpi=70)
267
+ # plt.close()
268
+ # plt.clf()
269
+
270
+ # img = Image.open('tmp/tmp.jpg')
271
+ cm = plt.get_cmap('jet')
272
+ normed_data = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap + 1e-8))
273
+ mapped_data = cm(normed_data)
274
+
275
+ # (h,w,c)
276
+ # img = np.array(img)
277
+ img = np.array(mapped_data)
278
+ img = img[:,:,:3]
279
+ img = torch.tensor(img).permute(2, 0, 1)
280
+
281
+ return img
282
+
283
+
284
+ def heatmaps2rgb(heatmaps):
285
+ """
286
+ : heatmaps: (b,h,w)
287
+ """
288
+
289
+ out_imgs = []
290
+ for heatmap in heatmaps:
291
+ out_imgs.append(heatmap2rgb(heatmap))
292
+
293
+ return torch.stack(out_imgs)
294
+
295
+
296
+ # def draw_joints(img, pts):
297
+ # scores = pts[:,2]
298
+ # pts = np.array(pts).astype(int)
299
+
300
+ # for i in range(pts.shape[0]):
301
+ # if pts[i, 0] != 0 and pts[i, 1] != 0:
302
+ # img = cv2.circle(img, (pts[i, 0], pts[i, 1]), radius=3,
303
+ # color=(255, 0, 0), thickness=-1)
304
+ # print('img',img.max(),img.min())
305
+ # # img = cv2.putText(img, f'{joints[i]}: {scores[i]:.2f}', (
306
+ # # pts[i, 0]+5, pts[i, 1]-5), cv2.FONT_HERSHEY_SIMPLEX, .25, (255, 0, 0))
307
+
308
+ # # Left arm
309
+ # for i in range(10, 13-1):
310
+ # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
311
+ # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
312
+ # pts[i+1, 1]), color=(255, 0, 0), thickness=1)
313
+
314
+ # # Right arm
315
+ # for i in range(13, 16-1):
316
+ # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
317
+ # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
318
+ # pts[i+1, 1]), color=(255, 0, 0), thickness=1)
319
+
320
+ # # Left leg
321
+ # for i in range(0, 3-1):
322
+ # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
323
+ # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
324
+ # pts[i+1, 1]), color=(255, 0, 0), thickness=1)
325
+ # # Right leg
326
+ # for i in range(3, 6-1):
327
+ # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
328
+ # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
329
+ # pts[i+1, 1]), color=(255, 0, 0), thickness=1)
330
+
331
+ # # Body
332
+ # for i in range(6, 10-1):
333
+ # if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
334
+ # img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
335
+ # pts[i+1, 1]), color=(255, 0, 0), thickness=1)
336
+
337
+ # if pts[2, 0] != 0 and pts[2, 1] != 0 and pts[3, 0] != 0 and pts[3, 1] != 0:
338
+ # img = cv2.line(img, (pts[2, 0], pts[2, 1]), (pts[2+1, 0],
339
+ # pts[2+1, 1]), color=(255, 0, 0), thickness=1)
340
+ # if pts[12, 0] != 0 and pts[12, 1] != 0 and pts[13, 0] != 0 and pts[13, 1] != 0:
341
+ # img = cv2.line(img, (pts[12, 0], pts[12, 1]), (pts[12+1, 0],
342
+ # pts[12+1, 1]), color=(255, 0, 0), thickness=1)
343
+
344
+ # return img
345
+ def draw_joints(img, pts):
346
+ # Convert the image to the range [0, 255] for visualization
347
+ img_visualization = (img * 255).astype(np.uint8)
348
+
349
+ # Draw lines for the body parts
350
+ for i in range(10, 13 - 1):
351
+ draw_line(img_visualization, pts[i], pts[i + 1])
352
+
353
+ for i in range(13, 16 - 1):
354
+ draw_line(img_visualization, pts[i], pts[i + 1])
355
+
356
+ for i in range(0, 3 - 1):
357
+ draw_line(img_visualization, pts[i], pts[i + 1])
358
+
359
+ for i in range(3, 6 - 1):
360
+ draw_line(img_visualization, pts[i], pts[i + 1])
361
+
362
+ for i in range(6, 10 - 1):
363
+ draw_line(img_visualization, pts[i], pts[i + 1])
364
+
365
+ draw_line(img_visualization, pts[2], pts[3])
366
+ draw_line(img_visualization, pts[12], pts[13])
367
+
368
+ return img_visualization / 255.0
369
+
370
+ def draw_line(img, pt1, pt2):
371
+ if pt1[0] != 0 and pt1[1] != 0 and pt2[0] != 0 and pt2[1] != 0:
372
+ cv2.line(img, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), color=(255, 0, 0), thickness=1)