Skip to content

Commit 6312429

Browse files
Merge pull request #225 from MRXLT/general-server-bert-v1
updated readme && benchmark
2 parents dfb78ed + 2af01cf commit 6312429

File tree

10 files changed

+113
-78
lines changed

10 files changed

+113
-78
lines changed

doc/bert-benchmark-batch-size-1.png

24.2 KB
Loading

doc/imdb-benchmark-server-16.png

24.3 KB
Loading

python/examples/bert/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
## 语义理解预测服务
2+
3+
示例中采用BERT模型进行语义理解预测,将文本表示为向量的形式,可以用来做进一步的分析和预测。
4+
5+
### 获取模型
6+
7+
示例中采用[Paddlehub](https://github.com/PaddlePaddle/PaddleHub)中的[BERT中文模型](https://www.paddlepaddle.org.cn/hubdetail?name=bert_chinese_L-12_H-768_A-12&en_category=SemanticModel)
8+
执行
9+
```
10+
python prepare_model.py
11+
```
12+
生成server端配置文件与模型文件,存放在serving_server_model文件夹
13+
生成client端配置文件,存放在serving_client_conf文件夹
14+
15+
### 启动预测服务
16+
执行
17+
```
18+
python bert_server.py serving_server_model 9292 #启动cpu预测服务
19+
```
20+
或者
21+
```
22+
python bert_gpu_server.py serving_server_model 9292 0 #在gpu 0上启动gpu预测服务
23+
```
24+
25+
### 执行预测
26+
27+
执行
28+
```
29+
sh get_data.sh
30+
```
31+
获取中文样例数据
32+
33+
执行
34+
```
35+
head data-c.txt | python bert_client.py
36+
```
37+
将预测样例数据中的前十条样例,并将向量表示打印到标准输出。
38+
39+
### Benchmark
40+
41+
模型:bert_chinese_L-12_H-768_A-12
42+
43+
设备:GPU V100 * 1
44+
45+
环境:CUDA 9.2,cudnn 7.1.4
46+
47+
测试中将样例数据中的1W个样本复制为10W个样本,每个client线程发送线程数分之一个样本,batch size为1,max_seq_len为20,时间单位为秒.
48+
49+
在client线程数为4时,预测速度可以达到432样本每秒。
50+
由于单张GPU内部只能串行计算,client线程增多只能减少GPU的空闲时间,因此在线程数达到4之后,线程数增多对预测速度没有提升。
51+
52+
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
53+
| ------------------ | ------ | ------------ | ----- | ------ | ---- | ------- | ------ |
54+
| 1 | 3.05 | 290.54 | 0.37 | 239.15 | 6.43 | 0.71 | 365.63 |
55+
| 4 | 0.85 | 213.66 | 0.091 | 200.39 | 1.62 | 0.2 | 231.45 |
56+
| 8 | 0.42 | 223.12 | 0.043 | 110.99 | 0.8 | 0.098 | 232.05 |
57+
| 12 | 0.32 | 225.26 | 0.029 | 73.87 | 0.53 | 0.078 | 231.45 |
58+
| 16 | 0.23 | 227.26 | 0.022 | 55.61 | 0.4 | 0.056 | 231.9 |
59+
60+
总耗时变化规律如下:
61+
![bert benchmark](../../../doc/bert-benchmark-batch-size-1.png)

python/examples/bert/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from paddle_serving_client.metric import auc
1818
from paddle_serving_client.utils import MultiThreadRunner
1919
import time
20-
from test_bert_client import BertService
20+
from bert_client import BertService
2121

2222

2323
def predict(thr_id, resource):
@@ -55,7 +55,7 @@ def predict(thr_id, resource):
5555
thread_num = sys.argv[3]
5656
resource = {}
5757
resource["conf_file"] = conf_file
58-
resource["server_endpoint"] = ["127.0.0.1:9293"]
58+
resource["server_endpoint"] = ["127.0.0.1:9292"]
5959
resource["filelist"] = [data_file]
6060
resource["thread_num"] = int(thread_num)
6161

python/examples/bert/benchmark_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from paddle_serving_client.metric import auc
1818
from paddle_serving_client.utils import MultiThreadRunner
1919
import time
20-
from test_bert_client import BertService
20+
from bert_client import BertService
2121

2222

2323
def predict(thr_id, resource, batch_size):

python/examples/bert/test_bert_client.py renamed to python/examples/bert/bert_client.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# coding:utf-8
2+
import os
23
import sys
34
import numpy as np
45
import paddlehub as hub
56
import ujson
67
import random
8+
import time
79
from paddlehub.common.logger import logger
810
import socket
911
from paddle_serving_client import Client
@@ -20,29 +22,22 @@
2022

2123
class BertService():
2224
def __init__(self,
23-
profile=False,
2425
max_seq_len=128,
2526
model_name="bert_uncased_L-12_H-768_A-12",
2627
show_ids=False,
2728
do_lower_case=True,
2829
process_id=0,
29-
retry=3,
30-
load_balance='round_robin'):
30+
retry=3):
3131
self.process_id = process_id
3232
self.reader_flag = False
3333
self.batch_size = 0
3434
self.max_seq_len = max_seq_len
35-
self.profile = profile
3635
self.model_name = model_name
3736
self.show_ids = show_ids
3837
self.do_lower_case = do_lower_case
39-
self.con_list = []
40-
self.con_index = 0
41-
self.load_balance = load_balance
42-
self.server_list = []
43-
self.serving_list = []
44-
self.feed_var_names = ''
4538
self.retry = retry
39+
self.profile = True if ("FLAGS_profile_client" in os.environ and
40+
os.environ["FLAGS_profile_client"]) else False
4641

4742
module = hub.Module(name=self.model_name)
4843
inputs, outputs, program = module.context(
@@ -51,7 +46,6 @@ def __init__(self,
5146
position_ids = inputs["position_ids"]
5247
segment_ids = inputs["segment_ids"]
5348
input_mask = inputs["input_mask"]
54-
self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name
5549
self.reader = hub.reader.ClassifyReader(
5650
vocab_path=module.get_vocab_path(),
5751
dataset=None,
@@ -69,6 +63,7 @@ def run_general(self, text, fetch):
6963
data_generator = self.reader.data_generator(
7064
batch_size=self.batch_size, phase='predict', data=text)
7165
result = []
66+
prepro_start = time.time()
7267
for run_step, batch in enumerate(data_generator(), start=1):
7368
token_list = batch[0][0].reshape(-1).tolist()
7469
pos_list = batch[0][1].reshape(-1).tolist()
@@ -81,6 +76,11 @@ def run_general(self, text, fetch):
8176
"segment_ids": sent_list,
8277
"input_mask": mask_list
8378
}
79+
prepro_end = time.time()
80+
if self.profile:
81+
print("PROFILE\tbert_pre_0:{} bert_pre_1:{}".format(
82+
int(round(prepro_start * 1000000)),
83+
int(round(prepro_end * 1000000))))
8484
fetch_map = self.client.predict(feed=feed, fetch=fetch)
8585

8686
return fetch_map
@@ -90,6 +90,7 @@ def run_batch_general(self, text, fetch):
9090
data_generator = self.reader.data_generator(
9191
batch_size=self.batch_size, phase='predict', data=text)
9292
result = []
93+
prepro_start = time.time()
9394
for run_step, batch in enumerate(data_generator(), start=1):
9495
token_list = batch[0][0].reshape(-1).tolist()
9596
pos_list = batch[0][1].reshape(-1).tolist()
@@ -108,6 +109,11 @@ def run_batch_general(self, text, fetch):
108109
mask_list[si * self.max_seq_len:(si + 1) * self.max_seq_len]
109110
}
110111
feed_batch.append(feed)
112+
prepro_end = time.time()
113+
if self.profile:
114+
print("PROFILE\tbert_pre_0:{} bert_pre_1:{}".format(
115+
int(round(prepro_start * 1000000)),
116+
int(round(prepro_end * 1000000))))
111117
fetch_map_batch = self.client.batch_predict(
112118
feed_batch=feed_batch, fetch=fetch)
113119
return fetch_map_batch
@@ -116,11 +122,11 @@ def run_batch_general(self, text, fetch):
116122
def test():
117123

118124
bc = BertService(
119-
model_name='bert_uncased_L-12_H-768_A-12',
125+
model_name='bert_chinese_L-12_H-768_A-12',
120126
max_seq_len=20,
121127
show_ids=False,
122128
do_lower_case=True)
123-
server_addr = ["127.0.0.1:9293"]
129+
server_addr = ["127.0.0.1:9292"]
124130
config_file = './serving_client_conf/serving_client_conf.prototxt'
125131
fetch = ["pooled_output"]
126132
bc.load_client(config_file, server_addr)
@@ -133,8 +139,7 @@ def test():
133139
result = bc.run_batch_general(batch, fetch)
134140
batch = []
135141
for r in result:
136-
for e in r["pooled_output"]:
137-
print(e)
142+
print(r)
138143

139144

140145
if __name__ == '__main__':

python/examples/bert/test_gpu_server.py renamed to python/examples/bert/bert_gpu_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,7 @@
3636

3737
server.load_model_config(sys.argv[1])
3838
port = int(sys.argv[2])
39+
gpuid = sys.argv[3]
40+
server.set_gpuid(gpuid)
3941
server.prepare_server(workdir="work_dir1", port=port, device="gpu")
4042
server.run_server()

python/examples/imdb/README.md

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,46 +21,10 @@ cat test.data | python test_client_batch.py inference.conf 4 > result
2121

2222
模型 :IMDB-CNN
2323

24-
测试中,client共发送2500条测试样本,图中数据为单个线程的耗时,时间单位为秒
25-
26-
server thread num :4
27-
28-
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
29-
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
30-
| 1 | 0.99 | 27.39 | 0.085 | 19.92 | 0.046 | 0.032 | 29.84 |
31-
| 4 | 0.22 | 7.66 | 0.021 | 4.93 | 0.011 | 0.0082 | 8.28 |
32-
| 8 | 0.1 | 6.66 | 0.01 | 2.42 | 0.0038 | 0.0046 | 6.95 |
33-
| 12 | 0.074 | 6.87 | 0.0069 | 1.61 | 0.0059 | 0.0032 | 7.07 |
34-
| 16 | 0.056 | 7.01 | 0.0053 | 1.23 | 0.0029 | 0.0026 | 7.17 |
35-
| 20 | 0.045 | 7.02 | 0.0042 | 0.97 | 0.0023 | 0.002 | 7.15 |
36-
| 24 | 0.039 | 7.012 | 0.0034 | 0.8 | 0.0019 | 0.0016 | 7.12 |
37-
38-
server thread num : 8
39-
40-
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
41-
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
42-
| 1 | 1.02 | 28.9 | 0.096 | 20.64 | 0.047 | 0.036 | 31.51 |
43-
| 4 | 0.22 | 7.83 | 0.021 | 5.08 | 0.012 | 0.01 | 8.45 |
44-
| 8 | 0.11 | 4.44 | 0.01 | 2.5 | 0.0059 | 0.0051 | 4.73 |
45-
| 12 | 0.074 | 4.11 | 0.0069 | 1.65 | 0.0039 | 0.0029 | 4.31 |
46-
| 16 | 0.057 | 4.2 | 0.0052 | 1.24 | 0.0029 | 0.0024 | 4.35 |
47-
| 20 | 0.046 | 4.05 | 0.0043 | 1.01 | 0.0024 | 0.0021 | 4.18 |
48-
| 24 | 0.038 | 4.02 | 0.0034 | 0.81 | 0.0019 | 0.0015 | 4.13 |
49-
50-
server thread num : 12
51-
52-
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
53-
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
54-
| 1 | 1.02 | 29.47 | 0.098 | 20.95 | 0.048 | 0.038 | 31.96 |
55-
| 4 | 0.21 | 7.36 | 0.022 | 5.01 | 0.011 | 0.0081 | 7.95 |
56-
| 8 | 0.11 | 4.52 | 0.011 | 2.58 | 0.0061 | 0.0051 | 4.83 |
57-
| 12 | 0.072 | 3.25 | 0.0076 | 1.72 | 0.0042 | 0.0038 | 3.45 |
58-
| 16 | 0.059 | 3.93 | 0.0055 | 1.26 | 0.0029 | 0.0023 | 4.1 |
59-
| 20 | 0.047 | 3.79 | 0.0044 | 1.01 | 0.0024 | 0.0021 | 3.92 |
60-
| 24 | 0.041 | 3.76 | 0.0036 | 0.83 | 0.0019 | 0.0017 | 3.87 |
61-
6224
server thread num : 16
6325

26+
测试中,client共发送25000条测试样本,图中数据为单个线程的耗时,时间单位为秒。可以看出,client端多线程的预测速度相比单线程有明显提升,在16线程时预测速度是单线程的8.7倍。
27+
6428
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
6529
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
6630
| 1 | 1.09 | 28.79 | 0.094 | 20.59 | 0.047 | 0.034 | 31.41 |
@@ -71,26 +35,6 @@ server thread num : 16
7135
| 20 | 0.049 | 3.77 | 0.0047 | 1.03 | 0.0025 | 0.0022 | 3.91 |
7236
| 24 | 0.041 | 3.86 | 0.0039 | 0.85 | 0.002 | 0.0017 | 3.98 |
7337

74-
server thread num : 20
75-
76-
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
77-
| ------------------ | ------ | ------------ | ------ | ----- | ------ | ------- | ----- |
78-
| 1 | 1.03 | 28.42 | 0.085 | 20.47 | 0.046 | 0.037 | 30.98 |
79-
| 4 | 0.22 | 7.94 | 0.022 | 5.33 | 0.012 | 0.011 | 8.53 |
80-
| 8 | 0.11 | 4.54 | 0.01 | 2.58 | 0.006 | 0.0046 | 4.84 |
81-
| 12 | 0.079 | 4.54 | 0.0076 | 1.78 | 0.0042 | 0.0039 | 4.76 |
82-
| 16 | 0.059 | 3.41 | 0.0057 | 1.33 | 0.0032 | 0.0027 | 3.58 |
83-
| 20 | 0.051 | 4.33 | 0.0047 | 1.06 | 0.0025 | 0.0023 | 4.48 |
84-
| 24 | 0.043 | 4.51 | 0.004 | 0.88 | 0.0021 | 0.0018 | 4.63 |
85-
86-
server thread num :24
38+
预测总耗时变化规律如下:
8739

88-
| client thread num | prepro | client infer | op0 | op1 | op2 | postpro | total |
89-
| ------------------ | ------ | ------------ | ------ | ---- | ------ | ------- | ----- |
90-
| 1 | 0.93 | 29.28 | 0.099 | 20.5 | 0.048 | 0.028 | 31.61 |
91-
| 4 | 0.22 | 7.72 | 0.023 | 4.98 | 0.011 | 0.0095 | 8.33 |
92-
| 8 | 0.11 | 4.77 | 0.012 | 2.65 | 0.0062 | 0.0049 | 5.09 |
93-
| 12 | 0.081 | 4.22 | 0.0078 | 1.77 | 0.0042 | 0.0033 | 4.44 |
94-
| 16 | 0.062 | 4.21 | 0.0061 | 1.34 | 0.0032 | 0.0026 | 4.39 |
95-
| 20 | 0.5 | 3.58 | 0.005 | 1.07 | 0.0026 | 0.0023 | 3.72 |
96-
| 24 | 0.043 | 4.27 | 0.0042 | 0.89 | 0.0022 | 0.0018 | 4.4 |
40+
![total cost](../../../doc/imdb-benchmark-server-16.png)

python/examples/util/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
## Timeline工具使用
2+
3+
serving框架中内置了预测服务中各阶段时间打点的功能,通过环境变量来控制是否开启。
4+
```
5+
export FLAGS_profile_client=1 #开启client端各阶段时间打点
6+
export FLAGS_profile_server=1 #开启server端各阶段时间打点
7+
```
8+
开启该功能后,client端在预测的过程中会将对应的日志信息打印到标准输出。
9+
10+
为了更直观地展现各阶段的耗时,提供脚本对日志文件做进一步的分析处理。
11+
12+
使用时先将client的输出保存到文件,以profile为例。
13+
```
14+
python show_profile.py profile ${thread_num}
15+
```
16+
脚本将计算各阶段的耗时,并除以线程数做平均,打印到标准输出。
17+
18+
```
19+
python timeline_trace.py profile trace
20+
```
21+
脚本将日志中的时间打点信息转换成json格式保存到trace文件,trace文件可以通过chrome浏览器的tracing功能进行可视化。
22+
23+
具体操作:打开chrome浏览器,在地址栏输入chrome://tracing/,跳转至tracing页面,点击load按钮,打开保存的trace文件,即可将预测服务的各阶段时间信息可视化。

0 commit comments

Comments
 (0)