Skip to content

Commit 92bc9e0

Browse files
committed
add infer (fp32/fp16/int8) benchmark
1 parent 946fe27 commit 92bc9e0

File tree

5 files changed

+163
-50
lines changed

5 files changed

+163
-50
lines changed

deploy/python/infer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ def __init__(self, args):
7171
if not args.print_detail:
7272
pred_cfg.disable_glog_info()
7373
pred_cfg.enable_memory_optim()
74+
pred_cfg.switch_ir_optim(True)
7475

7576
if args.device == 'gpu':
7677
# set GPU configs accordingly
7778
# such as intialize the gpu memory, enable tensorrt
7879
logger.info("Use GPU")
7980
pred_cfg.enable_use_gpu(100, 0)
80-
pred_cfg.switch_ir_optim(True)
8181
precision_map = {
8282
"fp16": PrecisionType.Half,
8383
"fp32": PrecisionType.Float32,
@@ -96,7 +96,7 @@ def __init__(self, args):
9696
use_calib_mode=False)
9797
min_input_shape = {"x": [1, 3, 100, 100]}
9898
max_input_shape = {"x": [1, 3, 2000, 3000]}
99-
opt_input_shape = {"x": [1, 3, 192, 192]}
99+
opt_input_shape = {"x": [1, 3, 512, 1024]}
100100
pred_cfg.set_trt_dynamic_shape_info(
101101
min_input_shape, max_input_shape, opt_input_shape)
102102
else:
@@ -105,6 +105,7 @@ def __init__(self, args):
105105
logger.info("Use CPU")
106106
pred_cfg.disable_gpu()
107107
if args.enable_mkldnn:
108+
logger.info("Use MKLDNN")
108109
# cache 10 different shapes for mkldnn to avoid memory leak
109110
pred_cfg.set_mkldnn_cache_capacity(10)
110111
pred_cfg.enable_mkldnn()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# 推理Benchmark
2+
3+
测试环境:
4+
* GPU: V100 32G
5+
* CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
6+
* CUDA: 10.1
7+
* cuDNN: 7.6
8+
* TensorRT: 6.0.1.5
9+
* Paddle: 2.1.1
10+
11+
12+
GPU上分割模型的测试方法:
13+
1. 使用cityspcaes的全量验证数据集(1024x2048)进行测试
14+
2. 单GPU,Batchsize为1
15+
3. 运行耗时为纯模型预测时间
16+
4. 使用Paddle Inference的[Python API](./python_inference.md)测试,通过use_trt参数设置是否使用TRT,使用precision参数设置预测类型
17+
18+
19+
| 模型 | 使用TRT | 预测类型 | mIoU | 耗时(s/img) |
20+
| - | :-: | :-: | :-: | :-: |
21+
| ANN_ResNet50_OS8 | N | FP32 | 0.7909 | 0.274 |
22+
| ANN_ResNet50_OS8 | Y | FP32 | 0.7909 | 0.281 |
23+
| ANN_ResNet50_OS8 | Y | FP16 | 0.7909 | 0.168 |
24+
| ANN_ResNet50_OS8 | Y | INT8 | 0.7906 | 0.195 |
25+
| DANet_ResNet50_OS8 | N | FP32 | 0.8027 | 0.371 |
26+
| DANet_ResNet50_OS8 | Y | FP32 | 0.8027 | 0.330 |
27+
| DANet_ResNet50_OS8 | Y | FP16 | 0.8027 | 0.183 |
28+
| DANet_ResNet50_OS8 | Y | INT8 | 0.8039 | 0.266 |
29+
| DeepLabV3P_ResNet50_OS8 | N | FP32 | 0.8036 | 0.165 |
30+
| DeepLabV3P_ResNet50_OS8 | Y | FP32 | 0.8036 | 0.206 |
31+
| DeepLabV3P_ResNet50_OS8 | Y | FP16 | 0.8036 | 0.196 |
32+
| DeepLabV3P_ResNet50_OS8 | Y | INT8 | 0.8044 | 0.083 |
33+
| DNLNet_ResNet50_OS8 | N | FP32 | 0.7995 | 0.381 |
34+
| DNLNet_ResNet50_OS8 | Y | FP32 | 0.7995 | 0.360 |
35+
| DNLNet_ResNet50_OS8 | Y | FP16 | 0.7995 | 0.230 |
36+
| DNLNet_ResNet50_OS8 | Y | INT8 | 0.7989 | 0.236 |
37+
| EMANet_ResNet50_OS8 | N | FP32 | 0.7905 | 0.208 |
38+
| EMANet_ResNet50_OS8 | Y | FP32 | 0.7905 | 0.186 |
39+
| EMANet_ResNet50_OS8 | Y | FP16 | 0.7904 | 0.062 |
40+
| EMANet_ResNet50_OS8 | Y | INT8 | 0.7939 | 0.106 |
41+
| GCNet_ResNet50_OS8 | N | FP32 | 0.7950 | 0.247 |
42+
| GCNet_ResNet50_OS8 | Y | FP32 | 0.7950 | 0.228 |
43+
| GCNet_ResNet50_OS8 | Y | FP16 | 0.7950 | 0.100 |
44+
| GCNet_ResNet50_OS8 | Y | INT8 | 0.7959 | 0.144 |
45+
| PSPNet_ResNet50_OS8 | N | FP32 | 0.7883 | 0.327 |
46+
| PSPNet_ResNet50_OS8 | Y | FP32 | 0.7883 | 0.324 |
47+
| PSPNet_ResNet50_OS8 | Y | FP16 | 0.7883 | 0.218 |
48+
| PSPNet_ResNet50_OS8 | Y | INT8 | 0.7915 | 0.223 |
49+
| UNet | N | FP32 | 0.6500 | 0.071 |
50+
| UNet | Y | FP32 | 0.6500 | 0.099 |
51+
| UNet | Y | FP16 | 0.6500 | 0.099 |
52+
| UNet | Y | INT8 | 0.6503 | 0.099 |

docs/deployment/inference/python_inference.md

+12-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,18 @@ python deploy/python/infer.py \
5454

5555
**注意**
5656

57-
1. 使用TensorRT需要使用支持TRT功能的Paddle库,请参考[附录](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release)下载带有trt的PaddlePaddle安装包,或者参考[源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/fromsource.html)自行编译。
57+
1. 如果使用TensorRT预测,需要安装支持TRT功能的Paddle库。Paddle支持`cuda10.1+cudnn7+trt6.0.1.5``cuda10.2+cudnn8.1+trt7.1.3.4`两种版本,大家可以根据实际情况选择,通过如下链接进行下载。
58+
```
59+
https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.1-cudnn7.6-trt6.0.tar
60+
https://paddle-inference-dist.bj.bcebos.com/tensorrt_test/cuda10.2-cudnn8.0-trt7.1.tgz
61+
```
62+
63+
* 配置安装cuda和cudnn。
64+
* 下载TRT,设置`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<tensorrt_path>`
65+
* 参考[附录](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release)下载带有trt的PaddlePaddle安装包或者参考[源码编译](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/fromsource.html)自行编译。
66+
* 安装PaddlePaddle。
67+
* 部署模型。
5868

59-
2. 当使用量化模型在GPU上预测时,需要设置device=gpu、use_trt=True、precision=int8
69+
2. 当使用量化模型在GPU上预测时,需要设置device=gpu、use_trt=True、precision=int8
6070

6171
3. 要开启`--benchmark`的话需要安装auto_log,请参考[安装方式](https://github.com/LDOUBLEV/AutoLog)

docs/slim/quant/quant.md

+91-45
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,54 @@
11
# 模型量化教程
22

3-
模型量化是使用整数替代浮点数进行存储和计算的方法。举例而言,模型量化可以将32bit浮点数转换成8bit整数,则模型存储空间可以减少4倍,同时整数运算替换浮点数运算,可以加快模型推理速度、降低计算内存。
3+
## 1 概述
4+
5+
模型量化是一种常见的模型压缩方法,是使用整数替代浮点数进行存储和计算。
6+
7+
比如,模型量化将32bit浮点数转换成8bit整数,则模型存储空间可以减少4倍,同时整数运算替换浮点数运算,可以加快模型推理速度、降低计算内存。
48

59
PaddleSeg基于PaddleSlim,集成了量化训练(QAT)方法,特点如下:
6-
* 概述:使用大量训练数据,在训练过程中更新权重,减小量化损失。
7-
* 注意事项:训练数据需要有Ground Truth。
10+
* 概述:使用训练数据,在训练过程中更新权重,减小量化损失。
811
* 优点:量化模型的精度高;使用该量化模型预测,可以减少计算量、降低计算内存、减小模型大小。
912
* 缺点:易用性稍差,需要一定时间产出量化模型
1013

11-
下面,本文以一个示例来介绍如何产出和部署量化模型。
14+
## 2 量化模型精度和性能
15+
16+
测试环境:
17+
* GPU: V100 32G
18+
* CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
19+
* CUDA: 10.1
20+
* cuDNN: 7.6
21+
* TensorRT: 6.0.1.5
22+
* Paddle: 2.1.1
1223

13-
## 1 环境准备
24+
测试方法:
25+
1. 在GPU上使用TensorRT测试原始模型和量化模型
26+
2. 使用cityspcaes的全量验证数据集(1024x2048)进行测试
27+
3. 单GPU,Batchsize为1
28+
4. 运行耗时为纯模型预测时间
29+
5. 使用Paddle Inference的[Python API](../../depolyment/inference/python_inference.md)测试,通过use_trt参数设置是否使用TRT,使用precision参数设置预测类型。
30+
31+
| 模型 | 类型 | mIoU | 耗时(s/img) | 量化加速比 |
32+
| - | :-: | :-: | :-: | :-: |
33+
| ANN_ResNet50_OS8 | FP32 | 0.7909 | 0.281 | - |
34+
| ANN_ResNet50_OS8 | INT8 | 0.7906 | 0.195 | 30.6% |
35+
| DANet_ResNet50_OS8 | FP32 | 0.8027 | 0.330 | - |
36+
| DANet_ResNet50_OS8 | INT8 | 0.8039 | 0.266 | 19.4% |
37+
| DeepLabV3P_ResNet50_OS8 | FP32 | 0.8036 | 0.206 | - |
38+
| DeepLabV3P_ResNet50_OS8 | INT8 | 0.8044 | 0.083 | 59.7% |
39+
| DNLNet_ResNet50_OS8 | FP32 | 0.7995 | 0.360 | - |
40+
| DNLNet_ResNet50_OS8 | INT8 | 0.7989 | 0.236 | 52.5% |
41+
| EMANet_ResNet50_OS8 | FP32 | 0.7905 | 0.186 | - |
42+
| EMANet_ResNet50_OS8 | INT8 | 0.7939 | 0.106 | 43.0% |
43+
| GCNet_ResNet50_OS8 | FP32 | 0.7950 | 0.228 | - |
44+
| GCNet_ResNet50_OS8 | INT8 | 0.7959 | 0.144 | 36.8% |
45+
| PSPNet_ResNet50_OS8 | FP32 | 0.7883 | 0.324 | - |
46+
| PSPNet_ResNet50_OS8 | INT8 | 0.7915 | 0.223 | 32.1% |
47+
48+
## 3 示例
49+
50+
我们以一个示例来介绍如何产出和部署量化模型。
51+
### 3.1 环境准备
1452

1553
请参考[安装文档](../../install.md)准备好PaddleSeg的基础环境。由于量化功能要求最新的PaddlePaddle版本,所以请参考[文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)安装develop(Nightly build)版本。
1654

@@ -26,15 +64,15 @@ git reset --hard 15ef0c7dcee5a622787b7445f21ad9d1dea0a933
2664
python setup.py install
2765
```
2866

29-
## 2 产出量化模型
67+
### 3.2 产出量化模型
3068

31-
### 2.1 训练FP32模型
69+
#### 3.2.1 训练FP32模型
3270

3371
在产出量化模型之前,我们需要提前准备训练或者fintune好的FP32模型。
3472

35-
此处,我们选用视盘分割数据集和BiseNetV2模型,从头开始训练模型
73+
此处,我们选用视盘分割数据集和BiseNetV2模型,使用train.py从头开始训练模型。train.py输入参数的介绍,请参考[文档](../../train/train.md)
3674

37-
在PaddleSeg目录下,执行如下脚本,会自动下载数据集进行训练。训练结束后,精度最高的权重会保存到`output_fp32/best_model`目录下。
75+
在PaddleSeg目录下,执行如下脚本,会自动下载数据集进行训练。
3876

3977
```shell
4078
# 设置1张可用的GPU卡
@@ -50,11 +88,33 @@ python train.py \
5088
--save_dir output_fp32
5189
```
5290

53-
### 2.2 使用量化训练方法产出量化模型
91+
训练结束后,精度最高的权重会保存到`output_fp32/best_model`目录下。
5492

55-
**训练量化模型**
93+
#### 3.2.2 使用量化训练方法产出量化模型
5694

57-
基于2.1步骤中训练好的FP32模型权重,执行如下命令,使用`slim/quant/qat_train.py`脚本进行量化训练。
95+
**1)产出量化模型**
96+
97+
基于训练好的FP32模型权重,使用`slim/quant/qat_train.py`进行量化训练。
98+
99+
qat_train.py和train.py的输入参数基本相似(如下)。注意,量化训练的学习率需要调小,使用`model_path`参数指定FP32模型的权重。
100+
101+
| 参数名 | 用途 | 是否必选项 | 默认值 |
102+
| ------------------- | ------------------------------------------------------------ | ---------- | ---------------- |
103+
| config | FP32模型的配置文件 || - | - |
104+
| model_path | FP32模型的预训练权重 || - |
105+
| iters | 训练迭代次数 || 配置文件中指定值 |
106+
| batch_size | 单卡batch size || 配置文件中指定值 |
107+
| learning_rate | 初始学习率 || 配置文件中指定值 |
108+
| save_dir | 模型和visualdl日志文件的保存根路径 || output |
109+
| num_workers | 用于异步读取数据的进程数量, 大于等于1时开启子进程读取数据 || 0 |
110+
| use_vdl | 是否开启visualdl记录训练数据 |||
111+
| save_interval_iters | 模型保存的间隔步数 || 1000 |
112+
| do_eval | 是否在保存模型时启动评估, 启动时将会根据mIoU保存最佳模型至best_model |||
113+
| log_iters | 打印日志的间隔步数 || 10 |
114+
| resume_model | 恢复训练模型路径,如:`output/iter_1000` | 否 | None
115+
116+
117+
执行如下命令,进行量化训练。量化训练结束后,精度最高的量化模型权重保存在`output_quant/best_model`目录下。
58118

59119
```shell
60120
python slim/quant/qat_train.py \
@@ -67,23 +127,29 @@ python slim/quant/qat_train.py \
67127
--save_dir output_quant
68128
```
69129

70-
上述脚本的输入参数和常规训练相似,复用2.1步骤的config文件,使用`model_path`参数指定FP32模型的权重,初始学习率相应调小。
71-
72-
训练结束后,精度最高的量化模型权重会保存到`output_quant/best_model`目录下。
73-
74-
**测试量化模型**
130+
**2)测试量化模型**
75131

76-
执行如下命令,使用`slim/quant/qat_val.py`脚本加载量化模型的权重,测试模型量化的精度。
132+
如果需要,可以执行如下命令,使用`slim/quant/qat_val.py`脚本加载量化模型的权重,测试模型量化的精度。
77133

78134
```
79135
python slim/quant/qat_val.py \
80136
--config configs/quick_start/bisenet_optic_disc_512x512_1k.yml \
81137
--model_path output_quant/best_model/model.pdparams
82138
```
83139

84-
**导出量化预测模型**
140+
**3)导出量化预测模型**
85141

86-
基于此前训练好的量化模型权重,执行如下命令,使用`slim/quant/qat_export.py`导出预测量化模型,保存在`output_quant_infer`目录下。
142+
基于训练好的量化模型权重,使用`slim/quant/qat_export.py`导出预测量化模型,脚本输入参数如下。
143+
144+
|参数名|用途|是否必选项|默认值|
145+
|-|-|-|-|
146+
|config|模型配置文件||-|
147+
|save_dir|预测量化模型保存的文件夹||output|
148+
|model_path|量化模型的权重||配置文件中指定值|
149+
|with_softmax|在网络末端添加softmax算子。由于PaddleSeg组网默认返回logits,如果想要部署模型获取概率值,可以置为True||False|
150+
|without_argmax|是否不在网络末端添加argmax算子。由于PaddleSeg组网默认返回logits,为部署模型可以直接获取预测结果,我们默认在网络末端添加argmax算子||False|
151+
152+
执行如下命令,导出预测量化模型保存在`output_quant_infer`目录。
87153

88154
```
89155
python slim/quant/qat_export.py \
@@ -92,35 +158,15 @@ python slim/quant/qat_export.py \
92158
--save_dir output_quant_infer
93159
```
94160

95-
## 3 部署
161+
### 3.3 部署量化模型
162+
163+
得到量化预测模型后,我们可以进行部署应用,请参考如下教程。
96164

97-
得到量化预测模型后,我们可以进行部署应用。
98165
* [Paddle Inference Python部署](../../deployment/inference/python_inference.md)
99166
* [Paddle Inference C++部署](../../deployment/inference/cpp_inference.md)
100-
* [PaddleLite部署](../../deployment/lite/lite.md)
101-
102-
## 4 量化加速比
103-
104-
测试环境:
105-
* GPU: V100
106-
* CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
107-
* CUDA: 10.2
108-
* cuDNN: 7.6
109-
* TensorRT: 6.0.1.5
110-
111-
测试方法:
112-
1. 运行耗时为纯模型预测时间,测试图片cityspcaes(1024x2048)
113-
2. 预测10次作为热启动,连续预测50次取平均得到预测时间
114-
3. 使用GPU + TensorRT测试
115-
116-
|模型|未量化运行耗时(ms)|量化运行耗时(ms)|加速比|
117-
|-|-|-|-|
118-
|deeplabv3_resnet50_os8|204.2|150.1|26.49%|
119-
|deeplabv3p_resnet50_os8|147.2|89.5|39.20%|
120-
|gcnet_resnet50_os8|201.8|126.1|37.51%|
121-
|pspnet_resnet50_os8|266.8|206.8|22.49%|
167+
* [PaddleLite部署](../../deployment/lite/lite.md)
122168

123-
## 5 参考资料
169+
### 3.4 参考资料
124170

125171
* [PaddleSlim Github](https://github.com/PaddlePaddle/PaddleSlim)
126172
* [PaddleSlim 文档](https://paddleslim.readthedocs.io/zh_CN/latest/)

slim/quant/qat_val.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
from paddleseg.core import evaluate
2626
from paddleseg.utils import get_sys_env, logger, config_check, utils
2727
from qat_config import quant_config
28+
from qat_train import skip_quant
2829

2930
from paddleslim import QAT
3031

32+
3133
def get_test_config(cfg, args):
3234

3335
test_config = cfg.test_config
@@ -163,6 +165,7 @@ def main(args):
163165

164166
model = cfg.model
165167

168+
skip_quant(model)
166169
quantizer = QAT(config=quant_config)
167170
quant_model = quantizer.quantize(model)
168171
logger.info('Quantize the model successfully')
@@ -174,7 +177,8 @@ def main(args):
174177
test_config = get_test_config(cfg, args)
175178
config_check(cfg, val_dataset=val_dataset)
176179

177-
evaluate(quant_model, val_dataset, num_workers=args.num_workers, **test_config)
180+
evaluate(
181+
quant_model, val_dataset, num_workers=args.num_workers, **test_config)
178182

179183

180184
if __name__ == '__main__':

0 commit comments

Comments
 (0)