Skip to content

Commit c8e9c5f

Browse files
authored
Merge pull request #415 from HydrogenSulfate/add_ppTSM_quant
add PTQ for PP-TSM
2 parents 8ebb0ee + 3b3ca75 commit c8e9c5f

File tree

4 files changed

+418
-0
lines changed

4 files changed

+418
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
DATASET:
2+
batch_size: 1
3+
num_workers: 0
4+
quant:
5+
format: "FrameDataset"
6+
data_prefix: "../../data/k400/rawframes"
7+
file_path: "../../data/k400/val_small_frames.list"
8+
suffix: 'img_{:05}.jpg'
9+
10+
PIPELINE:
11+
quant:
12+
decode:
13+
name: "FrameDecoder"
14+
sample:
15+
name: "Sampler"
16+
num_seg: 8
17+
seg_len: 1
18+
valid_mode: True
19+
transform:
20+
- Scale:
21+
short_size: 256
22+
- CenterCrop:
23+
target_size: 224
24+
- Image2Array:
25+
- Normalization:
26+
mean: [0.485, 0.456, 0.406]
27+
std: [0.229, 0.224, 0.225]
28+
29+
inference_model_dir: "../../inference/ppTSM"
30+
quant_output_dir: "../../inference/ppTSM/quant_model"
31+
32+
model_name: "ppTSM"
33+
log_level: "INFO" #Optional, the logger level. default: "INFO"

deploy/slim/quant_post_static.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
import os.path as osp
18+
import sys
19+
20+
import numpy as np
21+
import paddle
22+
from paddleslim.quant import quant_post_static
23+
24+
__dir__ = os.path.dirname(os.path.abspath(__file__))
25+
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
26+
27+
from paddlevideo.loader.builder import build_dataloader, build_dataset
28+
from paddlevideo.utils import get_config, get_logger
29+
30+
31+
def parse_args():
32+
def str2bool(v):
33+
return v.lower() in ("true", "t", "1")
34+
35+
parser = argparse.ArgumentParser("PaddleVideo Inference model script")
36+
parser.add_argument(
37+
'-c',
38+
'--config',
39+
type=str,
40+
default=
41+
'../../configs/recognition/pptsm/pptsm_k400_frames_uniform_quantization.yaml',
42+
help='quantization config file path')
43+
parser.add_argument('-o',
44+
'--override',
45+
action='append',
46+
default=[],
47+
help='config options to be overridden')
48+
parser.add_argument("--use_gpu",
49+
type=str2bool,
50+
default=True,
51+
help="whether use gpui during quantization")
52+
53+
return parser.parse_args()
54+
55+
56+
def post_training_quantization(cfg, use_gpu: bool = True):
57+
"""Quantization entry
58+
59+
Args:
60+
cfg (dict): quntization configuration.
61+
use_gpu (bool, optional): whether to use gpu during quantization. Defaults to True.
62+
"""
63+
logger = get_logger("paddlevideo")
64+
65+
place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace()
66+
67+
# get defined params
68+
batch_size = cfg.DATASET.get('batch_size', 1)
69+
num_workers = cfg.DATASET.get('num_workers', 0)
70+
inference_file_name = cfg.get('model_name', 'inference')
71+
inference_model_dir = cfg.get('inference_model_dir',
72+
f'./inference/{inference_file_name}')
73+
quant_output_dir = cfg.get('quant_output_dir',
74+
osp.join(inference_model_dir, 'quant_model'))
75+
batch_nums = cfg.get('batch_nums', 10)
76+
77+
# build dataloader for quantization, lite data is enough
78+
slim_dataset = build_dataset((cfg.DATASET.quant, cfg.PIPELINE.quant))
79+
slim_dataloader_setting = dict(batch_size=batch_size,
80+
num_workers=num_workers,
81+
places=place,
82+
drop_last=False,
83+
shuffle=False)
84+
slim_loader = build_dataloader(slim_dataset, **slim_dataloader_setting)
85+
86+
logger.info("Build slim_loader finished")
87+
88+
def sample_generator(loader):
89+
def __reader__():
90+
for indx, data in enumerate(loader):
91+
# must return np.ndarray, not paddle.Tensor
92+
videos = np.array(data[0])
93+
yield videos
94+
95+
return __reader__
96+
97+
# execute quantization in static graph mode
98+
paddle.enable_static()
99+
100+
exe = paddle.static.Executor(place)
101+
102+
logger.info("Staring Post-Training Quantization...")
103+
104+
quant_post_static(executor=exe,
105+
model_dir=inference_model_dir,
106+
quantize_model_path=quant_output_dir,
107+
sample_generator=sample_generator(slim_loader),
108+
model_filename=f'{inference_file_name}.pdmodel',
109+
params_filename=f'{inference_file_name}.pdiparams',
110+
batch_size=batch_size,
111+
batch_nums=batch_nums,
112+
algo='KL')
113+
114+
logger.info("Post-Training Quantization finished...")
115+
116+
117+
if __name__ == '__main__':
118+
args = parse_args()
119+
cfg = get_config(args.config, overrides=args.override)
120+
post_training_quantization(cfg, args.use_gpu)

deploy/slim/readme.md

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
## Slim功能介绍
3+
复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余。此部分提供精简模型的功能,包括两部分:模型量化(量化训练、离线量化)、模型剪枝。
4+
5+
其中模型量化将全精度缩减到定点数减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
6+
模型量化可以在基本不损失模型的精度的情况下,将FP32精度的模型参数转换为Int8精度,减小模型参数大小并加速计算,使用量化后的模型在移动端等部署时更具备速度优势。
7+
8+
模型剪枝将CNN中不重要的卷积核裁剪掉,减少模型参数量,从而降低模型计算复杂度。
9+
10+
本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleVideo模型的压缩。
11+
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。
12+
13+
在开始本教程之前,建议先了解[PaddleVideo模型的训练方法](../../docs/zh-CN/usage.md)以及[PaddleSlim](https://paddleslim.readthedocs.io/zh_CN/latest/index.html)
14+
15+
16+
## 快速开始
17+
当训练出一个模型后,如果希望进一步的压缩模型大小并加速预测,可使用量化或者剪枝的方法压缩模型。
18+
19+
模型压缩主要包括五个步骤:
20+
1. 安装 PaddleSlim
21+
2. 准备训练好的模型
22+
3. 模型压缩
23+
4. 导出量化推理模型
24+
5. 量化模型预测部署
25+
26+
### 1. 安装PaddleSlim
27+
28+
* 可以通过pip install的方式进行安装。
29+
30+
```bash
31+
python3.7 -m pip install paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple
32+
```
33+
34+
* 如果获取PaddleSlim的最新特性,可以从源码安装。
35+
36+
```bash
37+
git clone https://github.com/PaddlePaddle/PaddleSlim.git
38+
cd Paddleslim
39+
python3.7 setup.py install
40+
```
41+
42+
### 2. 准备训练好的模型
43+
44+
PaddleVideo提供了一系列训练好的[模型](../../docs/zh-CN/model_zoo/README.md),如果待量化的模型不在列表中,需要按照[常规训练](../../docs/zh-CN/usage.md)方法得到训练好的模型。
45+
46+
### 3. 模型压缩
47+
48+
进入PaddleVideo根目录
49+
50+
```bash
51+
cd PaddleVideo
52+
```
53+
54+
离线量化代码位于`deploy/slim/quant_post_static.py`
55+
56+
#### 3.1 模型量化
57+
58+
量化训练包括离线量化训练和在线量化训练(TODO),在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。
59+
60+
##### 3.1.1 在线量化训练
61+
TODO
62+
63+
##### 3.1.2 离线量化
64+
65+
**注意**:目前离线量化,必须使用已经训练好的模型导出的`inference model`进行量化。一般模型导出`inference model`可参考[教程](../../docs/zh-CN/usage.md#5-模型推理).
66+
67+
一般来说,离线量化损失模型精度较多。
68+
69+
以PP-TSM模型为例,生成`inference model`后,离线量化运行方式如下
70+
71+
```bash
72+
# 下载并解压出少量数据用于离线量化的校准
73+
pushd ./data/k400
74+
wget -nc https://videotag.bj.bcebos.com/Data/k400_rawframes_small.tar
75+
tar -xf k400_rawframes_small.tar
76+
popd
77+
78+
# 然后进入deploy/slim目录下
79+
cd deploy/slim
80+
81+
# 执行离线量化命令
82+
python3.7 quant_post_static.py \
83+
-c ../../configs/recognition/pptsm/pptsm_k400_frames_uniform_quantization.yaml \
84+
--use_gpu=True
85+
```
86+
87+
`use_gpu`外,所有的量化环境参数都在`pptsm_k400_frames_uniform_quantization.yaml`文件中进行配置
88+
其中`inference_model_dir`表示上一步导出的`inference model`目录路径,`quant_output_dir`表示量化模型的输出目录路径
89+
90+
执行成功后,在`quant_output_dir`的目录下生成了`__model__`文件和`__params__`文件,这二者用于存储生成的离线量化模型
91+
类似`inference model`的使用方法,接下来可以直接用这两个文件进行预测部署,无需再重新导出模型。
92+
93+
```bash
94+
# 使用PP-TSM离线量化模型进行预测
95+
# 回到PaddleVideo目录下
96+
cd ../../
97+
98+
# 使用量化模型进行预测
99+
python3.7 tools/predict.py \
100+
--input_file data/example.avi \
101+
--config configs/recognition/pptsm/pptsm_k400_frames_uniform.yaml \
102+
--model_file ./inference/ppTSM/quant_model/__model__ \
103+
--params_file ./inference/ppTSM/quant_model/__params__ \
104+
--use_gpu=True \
105+
--use_tensorrt=False
106+
```
107+
108+
输出如下:
109+
```bash
110+
Current video file: data/example.avi
111+
top-1 class: 5
112+
top-1 score: 0.9997928738594055
113+
```
114+
#### 3.2 模型剪枝
115+
TODO
116+
117+
118+
### 4. 导出模型
119+
TODO
120+
121+
122+
### 5. 模型部署
123+
124+
上述步骤导出的模型可以通过PaddleLite的opt模型转换工具完成模型转换。
125+
模型部署的可参考
126+
[Serving Python部署](../python_serving/readme.md)
127+
[Serving C++部署](../cpp_serving/readme.md)
128+
129+
130+
## 训练超参数建议
131+
132+
* 量化训练时,建议加载常规训练得到的预训练模型,加速量化训练收敛。
133+
* 量化训练时,建议初始学习率修改为常规训练的`1/20~1/10`,同时将训练epoch数修改为常规训练的`1/5~1/2`,学习率策略方面,加上Warmup,其他配置信息不建议修改。

0 commit comments

Comments
 (0)