Skip to content

[PPhuman application] application of action model training in PPhuman #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ PaddleVideo是[飞桨官方](https://www.paddlepaddle.org.cn/?fr=paddleEdu_githu
| [EIVideo](applications/EIVideo) | 视频交互式分割工具|
| [Anti-UAV](applications/Anti-UAV) |无人机检测方案|
| [AbnormalActionDetection](applications/AbnormalActionDetection) |异常行为检测方案|
| [PP-Human](applications/PPHuman) | 行人分析场景动作识别方案 |


## 文档教程
Expand Down
1 change: 1 addition & 0 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ PaddleVideo is a toolset for video tasks prepared for the industry and academia.
| [EIVideo](applications/EIVideo) | Interactive video segmentation tool|
| [Anti-UAV](applications/Anti-UAV) |UAV detection solution|
| [AbnormalActionDetection](applications/AbnormalActionDetection) |Abnormal action detection solution|
| [PP-Human](applications/PPHuman) | Action recognition solution for pedestrian analysis scene |


## Documentation tutorial
Expand Down
143 changes: 143 additions & 0 deletions applications/PPHuman/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# PP-Human 行为识别模型

实时行人分析工具[PP-Human](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/pphuman)中集成了基于骨骼点的行为识别模块。本文档介绍如何基于[PaddleVideo](https://github.com/PaddlePaddle/PaddleVideo/),完成行为识别模型的训练流程。

## 行为识别模型训练
目前行为识别模型使用的是[ST-GCN](https://arxiv.org/abs/1801.07455),并在[PaddleVideo训练流程](https://github.com/PaddlePaddle/PaddleVideo/blob/develop/docs/zh-CN/model_zoo/recognition/stgcn.md)的基础上修改适配,完成模型训练。

### 准备训练数据
STGCN是一个基于骨骼点坐标序列进行预测的模型。在PaddleVideo中,训练数据为采用`.npy`格式存储的`Numpy`数据,标签则可以是`.npy`或`.pkl`格式存储的文件。对于序列数据的维度要求为`(N,C,T,V,M)`。

以我们在PPhuman中的模型为例,其中具体说明如下:
| 维度 | 大小 | 说明 |
| ---- | ---- | ---------- |
| N | 不定 | 数据集序列个数 |
| C | 2 | 关键点坐标维度,即(x, y) |
| T | 50 | 动作序列的时序维度(即持续帧数)|
| V | 17 | 每个人物关键点的个数,这里我们使用了`COCO`数据集的定义,具体可见[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareKeypointDataSet_cn.md#COCO%E6%95%B0%E6%8D%AE%E9%9B%86) |
| M | 1 | 人物个数,这里我们每个动作序列只针对单人预测 |

#### 1. 获取序列的骨骼点坐标
对于一个待标注的序列(这里序列指一个动作片段,可以是视频或有顺序的图片集合)。可以通过模型预测或人工标注的方式获取骨骼点(也称为关键点)坐标。
- 模型预测:可以直接选用[PaddleDetection KeyPoint模型系列](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/keypoint) 模型库中的模型,并根据`3、训练与测试 - 部署预测 - 检测+keypoint top-down模型联合部署`中的步骤获取目标序列的17个关键点坐标。
- 人工标注:若对关键点的数量或是定义有其他需求,也可以直接人工标注各个关键点的坐标位置,注意对于被遮挡或较难标注的点,仍需要标注一个大致坐标,否则后续网络学习过程会受到影响。

在完成骨骼点坐标的获取后,建议根据各人物的检测框进行归一化处理,以消除人物位置、尺度的差异给网络带来的收敛难度,这一步可以参考[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/pphuman/pipe_utils.py#L352-L363)。

#### 2. 统一序列的时序长度
由于实际数据中每个动作的长度不一,首先需要根据您的数据和实际场景预定时序长度(在PP-Human中我们采用50帧为一个动作序列),并对数据做以下处理:
- 实际长度超过预定长度的数据,随机截取一个50帧的片段
- 实际长度不足预定长度的数据:补0,直到满足50帧
- 恰好等于预定长度的数据: 无需处理

注意:在这一步完成后,请严格确认处理后的数据仍然包含了一个完整的行为动作,不会产生预测上的歧义,建议通过可视化数据的方式进行确认。

#### 3. 保存为PaddleVideo可用的文件格式
在经过前两步处理后,我们得到了每个人物动作片段的标注,此时我们已有一个列表`all_kpts`,这个列表中包含多个关键点序列片段,其中每一个片段形状为(T, V, C) (在我们的例子中即(50, 17, 2)), 下面进一步将其转化为PaddleVideo可用的格式。
- 调整维度顺序: 可通过`np.transpose`和`np.expand_dims`将每一个片段的维度转化为(C, T, V, M)的格式。
- 将所有片段组合并保存为一个文件

注意:这里的`class_id`是`int`类型,与其他分类任务类似。例如`0:摔倒, 1:其他`。

至此,我们得到了可用的训练数据(`.npy`)和对应的标注文件(`.pkl`)。

#### 示例:基于UR Fall Detection Dataset的摔倒数据处理
[UR Fall Detection Dataset](http://fenix.univ.rzeszow.pl/~mkepski/ds/uf.html)是一个包含了不同摄像机视角及不同传感器下的摔倒检测数据集。数据集本身并不包含关键点坐标标注,在这里我们使用平视视角(camera 0)的RGB图像数据,介绍如何依照上面展示的步骤完成数据准备工作。

(1)使用[PaddleDetection关键点模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/keypoint)完成关键点坐标的检测
```bash
# current path is under root of PaddleDetection

# Step 1: download pretrained inference models.
wget https://bj.bcebos.com/v1/paddledet/models/pipeline/mot_ppyoloe_l_36e_pipeline.zip
wget https://bj.bcebos.com/v1/paddledet/models/pipeline/dark_hrnet_w32_256x192.zip
unzip -d output_inference/ mot_ppyoloe_l_36e_pipeline.zip
unzip -d output_inference/ dark_hrnet_w32_256x192.zip

# Step 2: Get the keypoint coordinarys

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一下注释?

# if your data is image sequence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks

# if your data is image sequence
python deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inference/mot_ppyoloe_l_36e_pipeline/ --keypoint_model_dir=output_inference/dark_hrnet_w32_256x192 --image_dir={your image directory path} --device=GPU --save_res=True

# if your data is video
python deploy/python/det_keypoint_unite_infer.py --det_model_dir=output_inference/mot_ppyoloe_l_36e_pipeline/ --keypoint_model_dir=output_inference/dark_hrnet_w32_256x192 --video_file={your video file path} --device=GPU --save_res=True
```
这样我们会得到一个`det_keypoint_unite_image_results.json`的检测结果文件。内容的具体含义请见[这里](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/python/det_keypoint_unite_infer.py#L108)。

这里我们需要对UR Fall中的每一段数据执行上面介绍的步骤,在每一段执行完成后及时将检测结果文件妥善保存到一个文件夹中。
```bash

mkdir {root of PaddleVideo}/applications/PPHuman/datasets/annotations
mv det_keypoint_unite_image_results.json {root of PaddleVideo}/applications/PPHuman/datasets/annotations/det_keypoint_unite_image_results_{video_id}_{camera_id}.json
```

(2)将关键点坐标转化为训练数据


在完成上述步骤后,我们得到的骨骼点数据形式如下:
```
annotations/
├── det_keypoint_unite_image_results_fall-01-cam0-rgb.json
├── det_keypoint_unite_image_results_fall-02-cam0-rgb.json
├── det_keypoint_unite_image_results_fall-03-cam0-rgb.json
├── det_keypoint_unite_image_results_fall-04-cam0-rgb.json
...
├── det_keypoint_unite_image_results_fall-28-cam0-rgb.json
├── det_keypoint_unite_image_results_fall-29-cam0-rgb.json
└── det_keypoint_unite_image_results_fall-30-cam0-rgb.json
```
这里使用我们提供的脚本直接将数据转化为训练数据, 得到数据文件`train_data.npy`, 标签文件`train_label.pkl`。该脚本执行的内容包括解析json文件内容、前述步骤中介绍的整理训练数据及保存数据文件。
```bash
# current path is {root of PaddleVideo}/applications/PPHuman/datasets/

python prepare_dataset.py
```
几点说明:
- UR Fall的动作大多是100帧左右长度对应一个完整动作,个别视频包含一些无关动作,可以手工去除,也可以裁剪作为负样本
- 统一将数据整理为100帧,再抽取为50帧,保证动作完整性
- 上述包含摔倒的动作是正样本,在实际训练中也需要一些其他的动作或正常站立等作为负样本,步骤同上,但注意label的类型取1。

这里我们提供了我们处理好的更全面的[数据](https://bj.bcebos.com/v1/paddledet/data/PPhuman/fall_data.zip),包括其他场景中的摔倒及非摔倒的动作场景。

### 训练与测试
在PaddleVideo中,使用以下命令即可开始训练:
```bash
# current path is under root of PaddleVideo
python main.py -c applications/PPHuman/configs/stgcn_pphuman.yaml

# 由于整个任务可能过拟合,建议同时开启验证以保存最佳模型
python main.py --validate -c applications/PPHuman/configs/stgcn_pphuman.yaml
```

在训练完成后,采用以下命令进行预测:
```bash
python main.py --test -c applications/PPHuman/configs/stgcn_pphuman.yaml -w output/STGCN/STGCN_best.pdparams
```

### 导出模型推理

- 在PaddleVideo中,通过以下命令实现模型的导出,得到模型结构文件`STGCN.pdmodel`和模型权重文件`STGCN.pdiparams`,并增加配置文件:
```bash
# current path is under root of PaddleVideo
python tools/export_model.py -c applications/PPHuman/configs/stgcn_pphuman.yaml \
-p output/STGCN/STGCN_best.pdparams \
-o output_inference/STGCN

cp applications/PPHuman/configs/infer_cfg.yml output_inference/STGCN

# 重命名模型文件,适配PP-Human的调用
cd output_inference/STGCN
mv STGCN.pdiparams model.pdiparams
mv STGCN.pdiparams.info model.pdiparams.info
mv STGCN.pdmodel model.pdmodel
```
完成后的导出模型目录结构如下:
```
STGCN
├── infer_cfg.yml
├── model.pdiparams
├── model.pdiparams.info
├── model.pdmodel
```

至此,就可以使用[PP-Human](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/pphuman)进行行为识别的推理了。
9 changes: 9 additions & 0 deletions applications/PPHuman/configs/infer_cfg.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mode: fluid
use_dynamic_shape: false
arch: STGCN
min_subgraph_size: 3
Preprocess:
- window_size: 50
type: AutoPadding
label_list:
- keypoint
72 changes: 72 additions & 0 deletions applications/PPHuman/configs/stgcn_pphuman.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
MODEL: #MODEL field
framework: "RecognizerGCN" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
name: "STGCN" #Mandatory, The name of backbone.
in_channels: 2
dropout: 0.5
layout: 'coco_keypoint'
data_bn: True
head:
name: "STGCNHead" #Mandatory, indicate the type of head, associate to the 'paddlevideo/modeling/heads'
num_classes: 2 #Optional, the number of classes to be classified.
if_top5: False

DATASET: #DATASET field
batch_size: 64 #Mandatory, batch size
num_workers: 4 #Mandatory, the number of subprocess on each GPU.
test_batch_size: 1
test_num_workers: 0
train:
format: "SkeletonDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevideo/loader/dateset'
file_path: "./applications/PPHuman/datasets/train_data.npy" #mandatory, train data index file path
label_path: "./applications/PPHuman/datasets/train_label.pkl"

valid:
format: "SkeletonDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevideo/loader/dateset'
file_path: "./applications/PPHuman/datasets/val_data.npy" #Mandatory, valid data index file path
label_path: "./applications/PPHuman/datasets/val_label.pkl"

test_mode: True
test:
format: "SkeletonDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevideo/loader/dateset'
file_path: "./applications/PPHuman/datasets/val_data.npy" #Mandatory, valid data index file path
label_path: "./applications/PPHuman/datasets/val_label.pkl"

test_mode: True

PIPELINE: #PIPELINE field
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
transform: #Mandotary, image transfrom operator
- Iden:
valid: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
transform: #Mandotary, image transfrom operator
- Iden:
test: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
transform: #Mandotary, image transfrom operator
- Iden:

OPTIMIZER: #OPTIMIZER field
name: 'Momentum'
momentum: 0.9
learning_rate:
name: 'CosineAnnealingDecay'
learning_rate: 0.05
T_max: 50
weight_decay:
name: 'L2'
value: 1e-4

METRIC:
name: 'SkeletonMetric'
top_k: 2

INFERENCE:
name: 'STGCN_Inference_helper'
num_channels: 2
window_size: 50
vertex_nums: 17
person_nums: 1

model_name: "STGCN"
log_interval: 10 #Optional, the interal of logger, default:10
epochs: 50 #Mandatory, total epoch
98 changes: 98 additions & 0 deletions applications/PPHuman/datasets/prepare_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import json
import numpy as np
import pickle
"""
This python script is used to convert keypoint results of UR FALL dataset
for training by PaddleVideo
"""


def self_norm(kpt, bbox):
# kpt: (2, T, 17, 1), bbox: (T, 4)
tl = bbox[:, 0:2]
wh = bbox[:, 2:]
tl = np.expand_dims(np.transpose(tl, (1, 0)), (2, 3))
wh = np.expand_dims(np.transpose(wh, (1, 0)), (2, 3))

res = (kpt - tl) / wh
res *= np.expand_dims(np.array([[384.], [512.]]), (2, 3))
return res


def convert_to_ppvideo(all_kpts, all_scores, all_bbox):
# shape of all_kpts is (T, 17, 2)
keypoint = np.expand_dims(np.transpose(all_kpts, [2, 0, 1]),
-1) #(2, T, 17, 1)
keypoint = self_norm(keypoint, all_bbox)

scores = all_scores
if keypoint.shape[1] > 100:
frame_start = (keypoint.shape[1] - 100) // 2
keypoint = keypoint[:, frame_start:frame_start + 100:2, :, :]
scores = all_scores[frame_start:frame_start + 100:2, :, :]
elif keypoint.shape[1] < 100:
keypoint = np.concatenate([
keypoint,
np.zeros((2, 100 - keypoint.shape[1], 17, 1), dtype=keypoint.dtype)
], 1)[:, ::2, :, :]
scores = np.concatenate([
all_scores,
np.zeros((100 - all_scores.shape[0], 17, 1), dtype=keypoint.dtype)
], 0)[::2, :, :]
else:
keypoint = keypoint[:, ::2, :, :]
scores = scores[::2, :, :]
return keypoint, scores


def decode_json_path(json_path):
content = json.load(open(json_path))
content = sorted(content, key=lambda x: x[0])
all_kpts = []
all_score = []
all_bbox = []
for annos in content:
bboxes = annos[1]
kpts = annos[2][0]
frame_id = annos[0]

if len(bboxes) != 1:
continue
kpt_res = []
kpt_score = []
for kpt in kpts[0]:
x, y, score = kpt
kpt_res.append([x, y])
kpt_score.append([score])
all_kpts.append(np.array(kpt_res))
all_score.append(np.array(kpt_score))
all_bbox.append([
bboxes[0][0], bboxes[0][1], bboxes[0][2] - bboxes[0][0],
bboxes[0][3] - bboxes[0][1]
])
all_kpts_np = np.array(all_kpts)
all_score_np = np.array(all_score)
all_bbox_np = np.array(all_bbox)
video_anno, scores = convert_to_ppvideo(all_kpts_np, all_score_np,
all_bbox_np)

return video_anno, scores


if __name__ == '__main__':
all_keypoints = []
all_labels = [[], []]
all_scores = []
for i, path in enumerate(os.listdir("annotations")):
video_anno, score = decode_json_path(os.path.join("annotations", path))

all_keypoints.append(video_anno)
all_labels[0].append(str(i))
all_labels[1].append(0) #label 0 means falling
all_scores.append(score)
all_data = np.stack(all_keypoints, 0)
all_score_data = np.stack(all_scores, 0)
np.save(f"train_data.npy", all_data)
pickle.dump(all_labels, open(f"train_label.pkl", "wb"))
np.save("kptscore_data.npy", all_score_data)
7 changes: 5 additions & 2 deletions paddlevideo/metrics/skeleton_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,29 @@ class SkeletonMetric(BaseMetric):
Args:
out_file: str, file to save test results.
"""

def __init__(self,
data_size,
batch_size,
out_file='submission.csv',
log_interval=1):
log_interval=1,
top_k=5):
"""prepare for metrics
"""
super().__init__(data_size, batch_size, log_interval)
self.top1 = []
self.top5 = []
self.values = []
self.out_file = out_file
self.k = top_k

def update(self, batch_id, data, outputs):
"""update metrics during each iter
"""
if len(data) == 2: # data with label
labels = data[1]
top1 = paddle.metric.accuracy(input=outputs, label=labels, k=1)
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=5)
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=self.k)
if self.world_size > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.SUM) / self.world_size
Expand Down
Loading