Skip to content

Commit 8ebb0ee

Browse files
authored
Merge pull request #405 from zoooo0820/pphuman_stgcn
[PPhuman application] application of action model training in PPhuman
2 parents a6b7a22 + 2040efa commit 8ebb0ee

File tree

9 files changed

+362
-4
lines changed

9 files changed

+362
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ PaddleVideo是[飞桨官方](https://www.paddlepaddle.org.cn/?fr=paddleEdu_githu
211211
| [EIVideo](applications/EIVideo) | 视频交互式分割工具|
212212
| [Anti-UAV](applications/Anti-UAV) |无人机检测方案|
213213
| [AbnormalActionDetection](applications/AbnormalActionDetection) |异常行为检测方案|
214+
| [PP-Human](applications/PPHuman) | 行人分析场景动作识别方案 |
214215

215216

216217
## 文档教程

README_en.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ PaddleVideo is a toolset for video tasks prepared for the industry and academia.
208208
| [EIVideo](applications/EIVideo) | Interactive video segmentation tool|
209209
| [Anti-UAV](applications/Anti-UAV) |UAV detection solution|
210210
| [AbnormalActionDetection](applications/AbnormalActionDetection) |Abnormal action detection solution|
211+
| [PP-Human](applications/PPHuman) | Action recognition solution for pedestrian analysis scene |
211212

212213

213214
## Documentation tutorial

applications/PPHuman/README.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# PP-Human 行为识别模型
2+
3+
实时行人分析工具[PP-Human](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/pphuman)中集成了基于骨骼点的行为识别模块。本文档介绍如何基于[PaddleVideo](https://github.com/PaddlePaddle/PaddleVideo/),完成行为识别模型的训练流程。
4+
5+
## 行为识别模型训练
6+
目前行为识别模型使用的是[ST-GCN](https://arxiv.org/abs/1801.07455),并在[PaddleVideo训练流程](https://github.com/PaddlePaddle/PaddleVideo/blob/develop/docs/zh-CN/model_zoo/recognition/stgcn.md)的基础上修改适配,完成模型训练。
7+
8+
### 准备训练数据
9+
STGCN是一个基于骨骼点坐标序列进行预测的模型。在PaddleVideo中,训练数据为采用`.npy`格式存储的`Numpy`数据,标签则可以是`.npy``.pkl`格式存储的文件。对于序列数据的维度要求为`(N,C,T,V,M)`
10+
11+
以我们在PPhuman中的模型为例,其中具体说明如下:
12+
| 维度 | 大小 | 说明 |
13+
| ---- | ---- | ---------- |
14+
| N | 不定 | 数据集序列个数 |
15+
| C | 2 | 关键点坐标维度,即(x, y) |
16+
| T | 50 | 动作序列的时序维度(即持续帧数)|
17+
| V | 17 | 每个人物关键点的个数,这里我们使用了`COCO`数据集的定义,具体可见[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareKeypointDataSet_cn.md#COCO%E6%95%B0%E6%8D%AE%E9%9B%86) |
18+
| M | 1 | 人物个数,这里我们每个动作序列只针对单人预测 |
19+
20+
#### 1. 获取序列的骨骼点坐标
21+
对于一个待标注的序列(这里序列指一个动作片段,可以是视频或有顺序的图片集合)。可以通过模型预测或人工标注的方式获取骨骼点(也称为关键点)坐标。
22+
- 模型预测:可以直接选用[PaddleDetection KeyPoint模型系列](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/keypoint) 模型库中的模型,并根据`3、训练与测试 - 部署预测 - 检测+keypoint top-down模型联合部署`中的步骤获取目标序列的17个关键点坐标。
23+
- 人工标注:若对关键点的数量或是定义有其他需求,也可以直接人工标注各个关键点的坐标位置,注意对于被遮挡或较难标注的点,仍需要标注一个大致坐标,否则后续网络学习过程会受到影响。
24+
25+
在完成骨骼点坐标的获取后,建议根据各人物的检测框进行归一化处理,以消除人物位置、尺度的差异给网络带来的收敛难度,这一步可以参考[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/pphuman/pipe_utils.py#L352-L363)
26+
27+
#### 2. 统一序列的时序长度
28+
由于实际数据中每个动作的长度不一,首先需要根据您的数据和实际场景预定时序长度(在PP-Human中我们采用50帧为一个动作序列),并对数据做以下处理:
29+
- 实际长度超过预定长度的数据,随机截取一个50帧的片段
30+
- 实际长度不足预定长度的数据:补0,直到满足50帧
31+
- 恰好等于预定长度的数据: 无需处理
32+
33+
注意:在这一步完成后,请严格确认处理后的数据仍然包含了一个完整的行为动作,不会产生预测上的歧义,建议通过可视化数据的方式进行确认。
34+
35+
#### 3. 保存为PaddleVideo可用的文件格式
36+
在经过前两步处理后,我们得到了每个人物动作片段的标注,此时我们已有一个列表`all_kpts`,这个列表中包含多个关键点序列片段,其中每一个片段形状为(T, V, C) (在我们的例子中即(50, 17, 2)), 下面进一步将其转化为PaddleVideo可用的格式。
37+
- 调整维度顺序: 可通过`np.transpose``np.expand_dims`将每一个片段的维度转化为(C, T, V, M)的格式。
38+
- 将所有片段组合并保存为一个文件
39+
40+
注意:这里的`class_id``int`类型,与其他分类任务类似。例如`0:摔倒, 1:其他`
41+
42+
至此,我们得到了可用的训练数据(`.npy`)和对应的标注文件(`.pkl`)。
43+
44+
#### 示例:基于UR Fall Detection Dataset的摔倒数据处理
45+
[UR Fall Detection Dataset](http://fenix.univ.rzeszow.pl/~mkepski/ds/uf.html)是一个包含了不同摄像机视角及不同传感器下的摔倒检测数据集。数据集本身并不包含关键点坐标标注,在这里我们使用平视视角(camera 0)的RGB图像数据,介绍如何依照上面展示的步骤完成数据准备工作。
46+
47+
(1)使用[PaddleDetection关键点模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/keypoint)完成关键点坐标的检测
48+
```bash
49+
# current path is under root of PaddleDetection
50+
51+
# Step 1: download pretrained inference models.
52+
wget https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip
53+
wget https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip
54+
unzip -d output_inference/ mot_ppyoloe_l_36e_pipeline.zip
55+
unzip -d output_inference/ dark_hrnet_w32_256x192.zip
56+
57+
# Step 2: Get the keypoint coordinarys
58+
59+
# if your data is image sequence
60+
python deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inference/mot_ppyoloe_l_36e_pipeline/ --keypoint_model_dir=output_inference/dark_hrnet_w32_256x192 --image_dir={your image directory path} --device=GPU --save_res=True
61+
62+
# if your data is video
63+
python deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inference/mot_ppyoloe_l_36e_pipeline/ --keypoint_model_dir=output_inference/dark_hrnet_w32_256x192 --video_file={your video file path} --device=GPU --save_res=True
64+
```
65+
这样我们会得到一个`det_keypoint_unite_image_results.json`的检测结果文件。内容的具体含义请见[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/python/det_keypoint_unite_infer.py#L108)
66+
67+
这里我们需要对UR Fall中的每一段数据执行上面介绍的步骤,在每一段执行完成后及时将检测结果文件妥善保存到一个文件夹中。
68+
```bash
69+
70+
mkdir {root of PaddleVideo}/applications/PPHuman/datasets/annotations
71+
mv det_keypoint_unite_image_results.json {root of PaddleVideo}/applications/PPHuman/datasets/annotations/det_keypoint_unite_image_results_{video_id}_{camera_id}.json
72+
```
73+
74+
(2)将关键点坐标转化为训练数据
75+
76+
77+
在完成上述步骤后,我们得到的骨骼点数据形式如下:
78+
```
79+
annotations/
80+
├── det_keypoint_unite_image_results_fall-01-cam0-rgb.json
81+
├── det_keypoint_unite_image_results_fall-02-cam0-rgb.json
82+
├── det_keypoint_unite_image_results_fall-03-cam0-rgb.json
83+
├── det_keypoint_unite_image_results_fall-04-cam0-rgb.json
84+
...
85+
├── det_keypoint_unite_image_results_fall-28-cam0-rgb.json
86+
├── det_keypoint_unite_image_results_fall-29-cam0-rgb.json
87+
└── det_keypoint_unite_image_results_fall-30-cam0-rgb.json
88+
```
89+
这里使用我们提供的脚本直接将数据转化为训练数据, 得到数据文件`train_data.npy`, 标签文件`train_label.pkl`。该脚本执行的内容包括解析json文件内容、前述步骤中介绍的整理训练数据及保存数据文件。
90+
```bash
91+
# current path is {root of PaddleVideo}/applications/PPHuman/datasets/
92+
93+
python prepare_dataset.py
94+
```
95+
几点说明:
96+
- UR Fall的动作大多是100帧左右长度对应一个完整动作,个别视频包含一些无关动作,可以手工去除,也可以裁剪作为负样本
97+
- 统一将数据整理为100帧,再抽取为50帧,保证动作完整性
98+
- 上述包含摔倒的动作是正样本,在实际训练中也需要一些其他的动作或正常站立等作为负样本,步骤同上,但注意label的类型取1。
99+
100+
这里我们提供了我们处理好的更全面的[数据](https://bj.bcebos.com/v1/paddledet/data/PPhuman/fall_data.zip),包括其他场景中的摔倒及非摔倒的动作场景。
101+
102+
### 训练与测试
103+
在PaddleVideo中,使用以下命令即可开始训练:
104+
```bash
105+
# current path is under root of PaddleVideo
106+
python main.py -c applications/PPHuman/configs/stgcn_pphuman.yaml
107+
108+
# 由于整个任务可能过拟合,建议同时开启验证以保存最佳模型
109+
python main.py --validate -c applications/PPHuman/configs/stgcn_pphuman.yaml
110+
```
111+
112+
在训练完成后,采用以下命令进行预测:
113+
```bash
114+
python main.py --test -c applications/PPHuman/configs/stgcn_pphuman.yaml -w output/STGCN/STGCN_best.pdparams
115+
```
116+
117+
### 导出模型推理
118+
119+
- 在PaddleVideo中,通过以下命令实现模型的导出,得到模型结构文件`STGCN.pdmodel`和模型权重文件`STGCN.pdiparams`,并增加配置文件:
120+
```bash
121+
# current path is under root of PaddleVideo
122+
python tools/export_model.py -c applications/PPHuman/configs/stgcn_pphuman.yaml \
123+
-p output/STGCN/STGCN_best.pdparams \
124+
-o output_inference/STGCN
125+
126+
cp applications/PPHuman/configs/infer_cfg.yml output_inference/STGCN
127+
128+
# 重命名模型文件,适配PP-Human的调用
129+
cd output_inference/STGCN
130+
mv STGCN.pdiparams model.pdiparams
131+
mv STGCN.pdiparams.info model.pdiparams.info
132+
mv STGCN.pdmodel model.pdmodel
133+
```
134+
完成后的导出模型目录结构如下:
135+
```
136+
STGCN
137+
├── infer_cfg.yml
138+
├── model.pdiparams
139+
├── model.pdiparams.info
140+
├── model.pdmodel
141+
```
142+
143+
至此,就可以使用[PP-Human](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/pphuman)进行行为识别的推理了。
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mode: fluid
2+
use_dynamic_shape: false
3+
arch: STGCN
4+
min_subgraph_size: 3
5+
Preprocess:
6+
- window_size: 50
7+
type: AutoPadding
8+
label_list:
9+
- keypoint
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
MODEL: #MODEL field
2+
framework: "RecognizerGCN" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
3+
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
4+
name: "STGCN" #Mandatory, The name of backbone.
5+
in_channels: 2
6+
dropout: 0.5
7+
layout: 'coco_keypoint'
8+
data_bn: True
9+
head:
10+
name: "STGCNHead" #Mandatory, indicate the type of head, associate to the 'paddlevideo/modeling/heads'
11+
num_classes: 2 #Optional, the number of classes to be classified.
12+
if_top5: False
13+
14+
DATASET: #DATASET field
15+
batch_size: 64 #Mandatory, batch size
16+
num_workers: 4 #Mandatory, the number of subprocess on each GPU.
17+
test_batch_size: 1
18+
test_num_workers: 0
19+
train:
20+
format: "SkeletonDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevideo/loader/dateset'
21+
file_path: "./applications/PPHuman/datasets/train_data.npy" #mandatory, train data index file path
22+
label_path: "./applications/PPHuman/datasets/train_label.pkl"
23+
24+
valid:
25+
format: "SkeletonDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevideo/loader/dateset'
26+
file_path: "./applications/PPHuman/datasets/val_data.npy" #Mandatory, valid data index file path
27+
label_path: "./applications/PPHuman/datasets/val_label.pkl"
28+
29+
test_mode: True
30+
test:
31+
format: "SkeletonDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevideo/loader/dateset'
32+
file_path: "./applications/PPHuman/datasets/val_data.npy" #Mandatory, valid data index file path
33+
label_path: "./applications/PPHuman/datasets/val_label.pkl"
34+
35+
test_mode: True
36+
37+
PIPELINE: #PIPELINE field
38+
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
39+
transform: #Mandotary, image transfrom operator
40+
- Iden:
41+
valid: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
42+
transform: #Mandotary, image transfrom operator
43+
- Iden:
44+
test: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
45+
transform: #Mandotary, image transfrom operator
46+
- Iden:
47+
48+
OPTIMIZER: #OPTIMIZER field
49+
name: 'Momentum'
50+
momentum: 0.9
51+
learning_rate:
52+
name: 'CosineAnnealingDecay'
53+
learning_rate: 0.05
54+
T_max: 50
55+
weight_decay:
56+
name: 'L2'
57+
value: 1e-4
58+
59+
METRIC:
60+
name: 'SkeletonMetric'
61+
top_k: 2
62+
63+
INFERENCE:
64+
name: 'STGCN_Inference_helper'
65+
num_channels: 2
66+
window_size: 50
67+
vertex_nums: 17
68+
person_nums: 1
69+
70+
model_name: "STGCN"
71+
log_interval: 10 #Optional, the interal of logger, default:10
72+
epochs: 50 #Mandatory, total epoch
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import json
3+
import numpy as np
4+
import pickle
5+
"""
6+
This python script is used to convert keypoint results of UR FALL dataset
7+
for training by PaddleVideo
8+
"""
9+
10+
11+
def self_norm(kpt, bbox):
12+
# kpt: (2, T, 17, 1), bbox: (T, 4)
13+
tl = bbox[:, 0:2]
14+
wh = bbox[:, 2:]
15+
tl = np.expand_dims(np.transpose(tl, (1, 0)), (2, 3))
16+
wh = np.expand_dims(np.transpose(wh, (1, 0)), (2, 3))
17+
18+
res = (kpt - tl) / wh
19+
res *= np.expand_dims(np.array([[384.], [512.]]), (2, 3))
20+
return res
21+
22+
23+
def convert_to_ppvideo(all_kpts, all_scores, all_bbox):
24+
# shape of all_kpts is (T, 17, 2)
25+
keypoint = np.expand_dims(np.transpose(all_kpts, [2, 0, 1]),
26+
-1) #(2, T, 17, 1)
27+
keypoint = self_norm(keypoint, all_bbox)
28+
29+
scores = all_scores
30+
if keypoint.shape[1] > 100:
31+
frame_start = (keypoint.shape[1] - 100) // 2
32+
keypoint = keypoint[:, frame_start:frame_start + 100:2, :, :]
33+
scores = all_scores[frame_start:frame_start + 100:2, :, :]
34+
elif keypoint.shape[1] < 100:
35+
keypoint = np.concatenate([
36+
keypoint,
37+
np.zeros((2, 100 - keypoint.shape[1], 17, 1), dtype=keypoint.dtype)
38+
], 1)[:, ::2, :, :]
39+
scores = np.concatenate([
40+
all_scores,
41+
np.zeros((100 - all_scores.shape[0], 17, 1), dtype=keypoint.dtype)
42+
], 0)[::2, :, :]
43+
else:
44+
keypoint = keypoint[:, ::2, :, :]
45+
scores = scores[::2, :, :]
46+
return keypoint, scores
47+
48+
49+
def decode_json_path(json_path):
50+
content = json.load(open(json_path))
51+
content = sorted(content, key=lambda x: x[0])
52+
all_kpts = []
53+
all_score = []
54+
all_bbox = []
55+
for annos in content:
56+
bboxes = annos[1]
57+
kpts = annos[2][0]
58+
frame_id = annos[0]
59+
60+
if len(bboxes) != 1:
61+
continue
62+
kpt_res = []
63+
kpt_score = []
64+
for kpt in kpts[0]:
65+
x, y, score = kpt
66+
kpt_res.append([x, y])
67+
kpt_score.append([score])
68+
all_kpts.append(np.array(kpt_res))
69+
all_score.append(np.array(kpt_score))
70+
all_bbox.append([
71+
bboxes[0][0], bboxes[0][1], bboxes[0][2] - bboxes[0][0],
72+
bboxes[0][3] - bboxes[0][1]
73+
])
74+
all_kpts_np = np.array(all_kpts)
75+
all_score_np = np.array(all_score)
76+
all_bbox_np = np.array(all_bbox)
77+
video_anno, scores = convert_to_ppvideo(all_kpts_np, all_score_np,
78+
all_bbox_np)
79+
80+
return video_anno, scores
81+
82+
83+
if __name__ == '__main__':
84+
all_keypoints = []
85+
all_labels = [[], []]
86+
all_scores = []
87+
for i, path in enumerate(os.listdir("annotations")):
88+
video_anno, score = decode_json_path(os.path.join("annotations", path))
89+
90+
all_keypoints.append(video_anno)
91+
all_labels[0].append(str(i))
92+
all_labels[1].append(0) #label 0 means falling
93+
all_scores.append(score)
94+
all_data = np.stack(all_keypoints, 0)
95+
all_score_data = np.stack(all_scores, 0)
96+
np.save(f"train_data.npy", all_data)
97+
pickle.dump(all_labels, open(f"train_label.pkl", "wb"))
98+
np.save("kptscore_data.npy", all_score_data)

paddlevideo/metrics/skeleton_metric.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,29 @@ class SkeletonMetric(BaseMetric):
3131
Args:
3232
out_file: str, file to save test results.
3333
"""
34+
3435
def __init__(self,
3536
data_size,
3637
batch_size,
3738
out_file='submission.csv',
38-
log_interval=1):
39+
log_interval=1,
40+
top_k=5):
3941
"""prepare for metrics
4042
"""
4143
super().__init__(data_size, batch_size, log_interval)
4244
self.top1 = []
4345
self.top5 = []
4446
self.values = []
4547
self.out_file = out_file
48+
self.k = top_k
4649

4750
def update(self, batch_id, data, outputs):
4851
"""update metrics during each iter
4952
"""
5053
if len(data) == 2: # data with label
5154
labels = data[1]
5255
top1 = paddle.metric.accuracy(input=outputs, label=labels, k=1)
53-
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=5)
56+
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=self.k)
5457
if self.world_size > 1:
5558
top1 = paddle.distributed.all_reduce(
5659
top1, op=paddle.distributed.ReduceOp.SUM) / self.world_size

0 commit comments

Comments
 (0)