Skip to content

Commit e106f05

Browse files
authored
Deploy (#528)
* add serving tensort * modify msvsr infer since the second output is the result * update wav2lip model path
1 parent e19a889 commit e106f05

File tree

8 files changed

+357
-7
lines changed

8 files changed

+357
-7
lines changed

deploy/TENSOR_RT.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# TensorRT预测部署教程
2+
TensorRT是NVIDIA提出的用于统一模型部署的加速库,可以应用于V100、JETSON Xavier等硬件,它可以极大提高预测速度。Paddle TensorRT教程请参考文档[使用Paddle-TensorRT库预测](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html#)
3+
4+
## 1. 安装PaddleInference预测库
5+
- Python安装包,请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-release) 下载带有tensorrt的安装包进行安装
6+
7+
- CPP预测库,请从[这里](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/05_inference_deployment/inference/build_and_install_lib_cn.html) 下载带有TensorRT编译的预测库
8+
9+
- 如果Python和CPP官网没有提供已编译好的安装包或预测库,请参考[源码安装](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/linux-compile.html) 自行编译
10+
11+
**注意:**
12+
- 您的机器上TensorRT的版本需要跟您使用的预测库中TensorRT版本保持一致。
13+
- PaddleGAN中部署预测要求TensorRT版本 > 7.0。
14+
15+
## 2. 导出模型
16+
模型导出具体请参考文档[PaddleGAN模型导出教程](../EXPORT_MODEL.md)
17+
18+
## 3. 开启TensorRT加速
19+
### 3.1 配置TensorRT
20+
在使用Paddle预测库构建预测器配置config时,打开TensorRT引擎就可以了:
21+
22+
```
23+
config->EnableUseGpu(100, 0); // 初始化100M显存,使用GPU ID为0
24+
config->GpuDeviceId(); // 返回正在使用的GPU ID
25+
// 开启TensorRT预测,可提升GPU预测性能,需要使用带TensorRT的预测库
26+
config->EnableTensorRtEngine(1 << 20 /*workspace_size*/,
27+
batch_size /*max_batch_size*/,
28+
3 /*min_subgraph_size*/,
29+
AnalysisConfig::Precision::kFloat32 /*precision*/,
30+
false /*use_static*/,
31+
false /*use_calib_mode*/);
32+
33+
```
34+
35+
### 3.2 TensorRT固定尺寸预测
36+
37+
`msvsr`为例,使用固定尺寸输入预测:
38+
```
39+
python tools/inference.py --model_path=/root/to/model --config-file /root/to/config --run_mode trt_fp32 --min_subgraph_size 20 --mode_type msvsr
40+
```
41+
42+
## 4、常见问题QA
43+
**Q:** 提示没有`tensorrt_op`</br>
44+
**A:** 请检查是否使用带有TensorRT的Paddle Python包或预测库。
45+
46+
**Q:** 提示`op out of memory`</br>
47+
**A:** 检查GPU是否是别人也在使用,请尝试使用空闲GPU
48+
49+
**Q:** 提示`some trt inputs dynamic shape info not set`</br>
50+
**A:** 这是由于`TensorRT`会把网络结果划分成多个子图,我们只设置了输入数据的动态尺寸,划分的其他子图的输入并未设置动态尺寸。有两个解决方法:
51+
52+
- 方法一:通过增大`min_subgraph_size`,跳过对这些子图的优化。根据提示,设置min_subgraph_size大于并未设置动态尺寸输入的子图中OP个数即可。
53+
`min_subgraph_size`的意思是,在加载TensorRT引擎的时候,大于`min_subgraph_size`的OP才会被优化,并且这些OP是连续的且是TensorRT可以优化的。
54+
55+
- 方法二:找到子图的这些输入,按照上面方式也设置子图的输入动态尺寸。
56+
57+
**Q:** 如何打开日志</br>
58+
**A:** 预测库默认是打开日志的,只要注释掉`config.disable_glog_info()`就可以打开日志
59+
60+
**Q:** 开启TensorRT,预测时提示Slice on batch axis is not supported in TensorRT</br>
61+
**A:** 请尝试使用动态尺寸输入

deploy/serving/README.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# 服务端预测部署
2+
3+
`PaddleGAN`训练出来的模型可以使用[Serving](https://github.com/PaddlePaddle/Serving) 部署在服务端。
4+
本教程以在REDS数据集上用`configs/msvsr_reds.yaml`算法训练的模型进行部署。
5+
预训练模型权重文件为[PP-MSVSR_reds_x4.pdparams](https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams)
6+
7+
## 1. 安装 paddle serving
8+
请参考[PaddleServing](https://github.com/PaddlePaddle/Serving/tree/v0.6.0) 中安装教程安装(版本>=0.6.0)。
9+
10+
## 2. 导出模型
11+
PaddleGAN在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleGAN/blob/develop/deploy/EXPORT_MODEL.md)
12+
13+
```
14+
python tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --load /path/to/model --export_serving_model True
15+
----output_dir /path/to/output
16+
```
17+
18+
以上命令会在`/path/to/output`文件夹下生成一个`msvsr`文件夹:
19+
```
20+
output
21+
│ ├── multistagevsrmodel_generator
22+
│ │ ├── multistagevsrmodel_generator.pdiparams
23+
│ │ ├── multistagevsrmodel_generator.pdiparams.info
24+
│ │ ├── multistagevsrmodel_generator.pdmodel
25+
│ │ ├── serving_client
26+
│ │ │ ├── serving_client_conf.prototxt
27+
│ │ │ ├── serving_client_conf.stream.prototxt
28+
│ │ ├── serving_server
29+
│ │ │ ├── __model__
30+
│ │ │ ├── __params__
31+
│ │ │ ├── serving_server_conf.prototxt
32+
│ │ │ ├── serving_server_conf.stream.prototxt
33+
│ │ │ ├── ...
34+
```
35+
36+
`serving_client`文件夹下`serving_client_conf.prototxt`详细说明了模型输入输出信息
37+
`serving_client_conf.prototxt`文件内容为:
38+
```
39+
feed_var {
40+
name: "lqs"
41+
alias_name: "lqs"
42+
is_lod_tensor: false
43+
feed_type: 1
44+
shape: 1
45+
shape: 2
46+
shape: 3
47+
shape: 180
48+
shape: 320
49+
}
50+
fetch_var {
51+
name: "stack_18.tmp_0"
52+
alias_name: "stack_18.tmp_0"
53+
is_lod_tensor: false
54+
fetch_type: 1
55+
shape: 1
56+
shape: 2
57+
shape: 3
58+
shape: 720
59+
shape: 1280
60+
}
61+
fetch_var {
62+
name: "stack_19.tmp_0"
63+
alias_name: "stack_19.tmp_0"
64+
is_lod_tensor: false
65+
fetch_type: 1
66+
shape: 1
67+
shape: 3
68+
shape: 720
69+
shape: 1280
70+
}
71+
```
72+
73+
## 4. 启动PaddleServing服务
74+
75+
```
76+
cd output_dir/multistagevsrmodel_generator/
77+
78+
# GPU
79+
python -m paddle_serving_server.serve --model serving_server --port 9393 --gpu_ids 0
80+
81+
# CPU
82+
python -m paddle_serving_server.serve --model serving_server --port 9393
83+
```
84+
85+
## 5. 测试部署的服务
86+
```
87+
# 进入到导出模型文件夹
88+
cd output/msvsr/
89+
```
90+
91+
设置`prototxt`文件路径为`serving_client/serving_client_conf.prototxt`
92+
设置`fetch``fetch=["stack_19.tmp_0"])`
93+
94+
测试
95+
```
96+
# 进入目录
97+
cd output/msvsr/
98+
99+
# 测试代码 test_client.py 会自动创建output文件夹,并在output下生成`res.mp4`文件
100+
python ../../deploy/serving/test_client.py input_video frame_num
101+
```

deploy/serving/test_client.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import numpy as np
17+
from paddle_serving_client import Client
18+
from paddle_serving_app.reader import *
19+
import cv2
20+
import os
21+
import imageio
22+
23+
def get_img(pred):
24+
pred = pred.squeeze()
25+
pred = np.clip(pred, a_min=0., a_max=1.0)
26+
pred = pred * 255
27+
pred = pred.round()
28+
pred = pred.astype('uint8')
29+
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc
30+
return pred
31+
32+
preprocess = Sequential([
33+
BGR2RGB(), Resize(
34+
(320, 180)), Div(255.0), Transpose(
35+
(2, 0, 1))
36+
])
37+
38+
client = Client()
39+
40+
client.load_client_config("serving_client/serving_client_conf.prototxt")
41+
client.connect(['127.0.0.1:9393'])
42+
43+
frame_num = int(sys.argv[2])
44+
45+
cap = cv2.VideoCapture(sys.argv[1])
46+
fps = cap.get(cv2.CAP_PROP_FPS)
47+
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
48+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
49+
success, frame = cap.read()
50+
read_end = False
51+
res_frames = []
52+
output_dir = "./output"
53+
if not os.path.exists(output_dir):
54+
os.makedirs(output_dir)
55+
56+
while success:
57+
frames = []
58+
for i in range(frame_num):
59+
if success:
60+
frames.append(preprocess(frame))
61+
success, frame = cap.read()
62+
else:
63+
read_end = True
64+
if read_end: break
65+
66+
frames = np.stack(frames, axis=0)
67+
fetch_map = client.predict(
68+
feed={
69+
"lqs": frames,
70+
},
71+
fetch=["stack_19.tmp_0"],
72+
batch=False)
73+
res_frames.extend([fetch_map["stack_19.tmp_0"][0][i] for i in range(frame_num)])
74+
75+
imageio.mimsave("output/output.mp4",
76+
[get_img(frame) for frame in res_frames],
77+
fps=fps)
78+

docs/zh_CN/tutorials/wav2lip.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ python -m paddle.distributed.launch \
7474
### 2.3 模型
7575
Model|Dataset|BatchSize|Inference speed|Download
7676
---|:--:|:--:|:--:|:--:
77-
wa2lip_hq|LRS2| 1 | 0.2853s/image (GPU:P40) | [model](https://paddlegan.bj.bcebos.com/models/psgan_weight.pdparam://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams)
77+
wa2lip_hq|LRS2| 1 | 0.2853s/image (GPU:P40) | [model](https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams)
7878

7979
## 3. 结果展示
8080

ppgan/models/base_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def set_requires_grad(self, nets, requires_grad=False):
183183
for param in net.parameters():
184184
param.trainable = requires_grad
185185

186-
def export_model(self, export_model, output_dir=None, inputs_size=[]):
186+
def export_model(self, export_model, output_dir=None, inputs_size=[], export_serving_model=False):
187187
inputs_num = 0
188188
for net in export_model:
189189
input_spec = [
@@ -201,3 +201,16 @@ def export_model(self, export_model, output_dir=None, inputs_size=[]):
201201
os.path.join(
202202
output_dir, '{}_{}'.format(self.__class__.__name__.lower(),
203203
net["name"])))
204+
if export_serving_model:
205+
from paddle_serving_client.io import inference_model_to_serving
206+
model_name = '{}_{}'.format(self.__class__.__name__.lower(),
207+
net["name"])
208+
209+
inference_model_to_serving(
210+
dirname=output_dir,
211+
serving_server="{}/{}/serving_server".format(output_dir,
212+
model_name),
213+
serving_client="{}/{}/serving_client".format(output_dir,
214+
model_name),
215+
model_filename="{}.pdmodel".format(model_name),
216+
params_filename="{}.pdiparams".format(model_name))
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Metric psnr: 24.3250
2-
Metric ssim: 0.6497
1+
c psnr: 27.2885
2+
Metric ssim: 0.7969

tools/export_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def parse_args():
5151
type=str,
5252
help="The path prefix of inference model to be used.",
5353
)
54+
parser.add_argument(
55+
"--export_serving_model",
56+
default=False,
57+
type=bool,
58+
help="export serving model.",
59+
)
5460
args = parser.parse_args()
5561
return args
5662

@@ -64,7 +70,7 @@ def main(args, cfg):
6470
for net_name, net in model.nets.items():
6571
if net_name in state_dicts:
6672
net.set_state_dict(state_dicts[net_name])
67-
model.export_model(cfg.export_model, args.output_dir, inputs_size)
73+
model.export_model(cfg.export_model, args.output_dir, inputs_size, args.export_serving_model)
6874

6975

7076
if __name__ == "__main__":

0 commit comments

Comments
 (0)