Skip to content

Commit 716755d

Browse files
authored
tinypose3d && modelzoo (#7844)
* metro con reverse tinypose3d fix readme modelzoo * fix tinypose3d
1 parent 5984726 commit 716755d

File tree

6 files changed

+45
-49
lines changed

6 files changed

+45
-49
lines changed

configs/pose3d/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424

2525
PaddleDetection 中提供了两种3D Pose算法(稀疏关键点),分别是适用于服务器端的大模型Metro3D和移动端的TinyPose3D。其中Metro3D基于[End-to-End Human Pose and Mesh Reconstruction with Transformers](https://arxiv.org/abs/2012.09760)进行了稀疏化改造,TinyPose3D是在TinyPose基础上修改输出3D关键点。
2626

27-
## 模型推荐(待补充)
27+
## 模型推荐
2828

29-
|模型|适用场景|human3.6m精度|模型下载|
30-
|:--:|:--:|:--:|:--:|
31-
|Metro3D|服务器端|-|-|
32-
|TinyPose3D|移动端|-|-|
29+
|模型|适用场景|human3.6m精度(14关键点)|human3.6m精度(17关键点)|模型下载|
30+
|:--:|:--:|:--:|:--:|:--:|
31+
|Metro3D|服务器端|56.014|46.619|[metro3d_24kpts.pdparams](https://bj.bcebos.com/v1/paddledet/models/pose3d/metro3d_24kpts.pdparams)|
32+
|TinyPose3D|移动端|86.381|71.223|[tinypose3d_human36m.pdparams](https://bj.bcebos.com/v1/paddledet/models/pose3d/tinypose3d_human36M.pdparams)|
3333

3434
注:
3535
1. 训练数据基于 [MeshTransfomer](https://github.com/microsoft/MeshTransformer) 中的训练数据。
@@ -137,13 +137,14 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer.py -c configs/pose3d/metro3d_24kpts.y
137137

138138
我们的训练数据提供了大量的低精度自动生成式的数据,用户可以在此数据训练的基础上,标注自己高精度的目标动作数据进行finetune,即可得到相对稳定较好的模型。
139139

140-
我们在医疗康复高精度数据上的训练效果展示如下
140+
我们在医疗康复高精度数据上的训练效果展示如下 [高清视频](https://user-images.githubusercontent.com/31800336/218949226-22e6ab25-facb-4cc6-8eca-38d4bfd973e5.mp4)
141141

142142
<div align="center">
143-
<img src="https://user-images.githubusercontent.com/31800336/218949226-22e6ab25-facb-4cc6-8eca-38d4bfd973e5.mp4" width='600'/>
143+
<img src="https://user-images.githubusercontent.com/31800336/221747019-ceacfd64-e218-476b-a369-c6dc259816b2.gif" width='600'/>
144144
</div>
145145

146146

147+
147148
## 引用
148149

149150
```

configs/pose3d/tinypose3d_human36M.yml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@ train_width: &train_width 128
1313
trainsize: &trainsize [*train_width, *train_height]
1414

1515
#####model
16-
architecture: TinyPose3DHRNet
16+
architecture: TinyPose3DHRHeatmapNet
1717
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/keypoint/tinypose_128x96.pdparams
1818

19-
TinyPose3DHRNet:
19+
TinyPose3DHRHeatmapNet:
2020
backbone: LiteHRNet
2121
post_process: HR3DNetPostProcess
22-
fc_channel: 1024
2322
num_joints: *num_joints
2423
width: &width 40
2524
loss: Pose3DLoss
@@ -56,17 +55,17 @@ OptimizerBuilder:
5655
#####data
5756
TrainDataset:
5857
!Pose3DDataset
59-
dataset_dir: Human3.6M
60-
image_dirs: ["Images"]
61-
anno_list: ['Human3.6m_train.json']
58+
dataset_dir: dataset/traindata/
59+
image_dirs: ["human3.6m"]
60+
anno_list: ['pose3d/Human3.6m_train.json']
6261
num_joints: *num_joints
6362
test_mode: False
6463

6564
EvalDataset:
6665
!Pose3DDataset
67-
dataset_dir: Human3.6M
68-
image_dirs: ["Images"]
69-
anno_list: ['Human3.6m_valid.json']
66+
dataset_dir: dataset/traindata/
67+
image_dirs: ["human3.6m"]
68+
anno_list: ['pose3d/Human3.6m_valid.json']
7069
num_joints: *num_joints
7170
test_mode: True
7271

ppdet/data/source/pose3d_cmb.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""
15-
this code is base on https://github.com/open-mmlab/mmpose
16-
"""
14+
1715
import os
1816
import cv2
1917
import numpy as np
@@ -80,7 +78,7 @@ def get_mask(self, mvm_percent=0.3):
8078
mjm_mask[indices, :] = 0.0
8179
# return mjm_mask
8280

83-
num_joints = 1
81+
num_joints = 10
8482
mvm_mask = np.ones((num_joints, 1)).astype(np.float)
8583
if self.test_mode == False:
8684
num_vertices = num_joints

ppdet/metrics/pose3d_metrics.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,6 @@ def all_gather(data):
137137

138138

139139
class Pose3DEval(object):
140-
"""refer to
141-
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
142-
Copyright (c) Microsoft, under the MIT License.
143-
"""
144-
145140
def __init__(self, output_eval, save_prediction_only=False):
146141
super(Pose3DEval, self).__init__()
147142
self.output_eval = output_eval

ppdet/modeling/architectures/keypoint_hrnet.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self,
4646
use_dark=True):
4747
"""
4848
HRNet network, see https://arxiv.org/abs/1902.09212
49-
49+
5050
Args:
5151
backbone (nn.Layer): backbone instance
5252
post_process (object): `HRNetPostProcess` instance
@@ -132,10 +132,10 @@ def __init__(self, use_dark=True):
132132

133133
def get_max_preds(self, heatmaps):
134134
'''get predictions from score maps
135-
135+
136136
Args:
137137
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
138-
138+
139139
Returns:
140140
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
141141
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
@@ -220,12 +220,12 @@ def dark_postprocess(self, hm, coords, kernelsize):
220220
def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
221221
"""the highest heatvalue location with a quarter offset in the
222222
direction from the highest response to the second highest response.
223-
223+
224224
Args:
225225
heatmaps (numpy.ndarray): The predicted heatmaps
226226
center (numpy.ndarray): The boxes center
227227
scale (numpy.ndarray): The scale factor
228-
228+
229229
Returns:
230230
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
231231
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
@@ -341,10 +341,7 @@ def __init__(
341341
self.deploy = False
342342
self.num_joints = num_joints
343343

344-
self.final_conv = L.Conv2d(width, num_joints, 1, 1, 0, bias=True)
345-
# for heatmap output
346-
self.final_conv_new = L.Conv2d(
347-
width, num_joints * 32, 1, 1, 0, bias=True)
344+
self.final_conv = L.Conv2d(width, num_joints * 32, 1, 1, 0, bias=True)
348345

349346
@classmethod
350347
def from_config(cls, cfg, *args, **kwargs):
@@ -356,20 +353,19 @@ def from_config(cls, cfg, *args, **kwargs):
356353
def _forward(self):
357354
feats = self.backbone(self.inputs) # feats:[[batch_size, 40, 32, 24]]
358355

359-
hrnet_outputs = self.final_conv_new(feats[0])
356+
hrnet_outputs = self.final_conv(feats[0])
360357
res = soft_argmax(hrnet_outputs, self.num_joints)
361-
362-
if self.training:
363-
return self.loss(res, self.inputs)
364-
else: # export model need
365-
return res
358+
return res
366359

367360
def get_loss(self):
368-
return self._forward()
361+
pose3d = self._forward()
362+
loss = self.loss(pose3d, None, self.inputs)
363+
outputs = {'loss': loss}
364+
return outputs
369365

370366
def get_pred(self):
371367
res_lst = self._forward()
372-
outputs = {'keypoint': res_lst}
368+
outputs = {'pose3d': res_lst}
373369
return outputs
374370

375371
def flip_back(self, output_flipped, matched_parts):
@@ -427,16 +423,23 @@ def from_config(cls, cfg, *args, **kwargs):
427423
return {'backbone': backbone, }
428424

429425
def _forward(self):
430-
feats = self.backbone(self.inputs) # feats:[[batch_size, 40, 32, 24]]
426+
'''
427+
self.inputs is a dict
428+
'''
429+
feats = self.backbone(
430+
self.inputs) # feats:[[batch_size, 40, width/4, height/4]]
431+
432+
hrnet_outputs = self.final_conv(
433+
feats[0]) # hrnet_outputs: [batch_size, num_joints*32,32,32]
431434

432-
hrnet_outputs = self.final_conv(feats[0])
433435
flatten_res = self.flatten(
434-
hrnet_outputs) # [batch_size, 24, (height/4)*(width/4)]
436+
hrnet_outputs) # [batch_size,num_joints*32,32*32]
437+
435438
res = self.fc1(flatten_res)
436439
res = self.act1(res)
437440
res = self.fc2(res)
438441
res = self.act2(res)
439-
res = self.fc3(res) # [batch_size, 24, 3]
442+
res = self.fc3(res)
440443

441444
if self.training:
442445
return self.loss(res, self.inputs)
@@ -448,7 +451,7 @@ def get_loss(self):
448451

449452
def get_pred(self):
450453
res_lst = self._forward()
451-
outputs = {'keypoint': res_lst}
454+
outputs = {'pose3d': res_lst}
452455
return outputs
453456

454457
def flip_back(self, output_flipped, matched_parts):

ppdet/modeling/architectures/pose3d_metro.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
trans_encoder='',
5454
loss='Pose3DLoss', ):
5555
"""
56-
METRO network, see https://arxiv.org/abs/
56+
Modified from METRO network, see https://arxiv.org/abs/2012.09760
5757
5858
Args:
5959
backbone (nn.Layer): backbone instance
@@ -65,7 +65,7 @@ def __init__(
6565
self.deploy = False
6666

6767
self.trans_encoder = trans_encoder
68-
self.conv_learn_tokens = paddle.nn.Conv1D(49, num_joints + 1, 1)
68+
self.conv_learn_tokens = paddle.nn.Conv1D(49, num_joints + 10, 1)
6969
self.cam_param_fc = paddle.nn.Linear(3, 2)
7070

7171
@classmethod

0 commit comments

Comments
 (0)