Skip to content

Commit cf5c243

Browse files
committed
fix bert demo
1 parent 6312429 commit cf5c243

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

core/general-client/src/general_model.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,12 @@ int PredictorClient::create_predictor() {
132132
_api.thrd_initialize();
133133
}
134134

135-
int PredictorClient::predict(
136-
const std::vector<std::vector<float>>& float_feed,
137-
const std::vector<std::string>& float_feed_name,
138-
const std::vector<std::vector<int64_t>>& int_feed,
139-
const std::vector<std::string>& int_feed_name,
140-
const std::vector<std::string>& fetch_name,
141-
PredictorRes & predict_res) { // NOLINT
135+
int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
136+
const std::vector<std::string> &float_feed_name,
137+
const std::vector<std::vector<int64_t>> &int_feed,
138+
const std::vector<std::string> &int_feed_name,
139+
const std::vector<std::string> &fetch_name,
140+
PredictorRes &predict_res) { // NOLINT
142141
predict_res._int64_map.clear();
143142
predict_res._float_map.clear();
144143
Timer timeline;
@@ -218,6 +217,7 @@ int PredictorClient::predict(
218217
VLOG(2) << "fetch name: " << name;
219218
if (_fetch_name_to_type[name] == 0) {
220219
int len = res.insts(0).tensor_array(idx).int64_data_size();
220+
VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len;
221221
predict_res._int64_map[name].resize(1);
222222
predict_res._int64_map[name][0].resize(len);
223223
for (int i = 0; i < len; ++i) {
@@ -226,6 +226,7 @@ int PredictorClient::predict(
226226
}
227227
} else if (_fetch_name_to_type[name] == 1) {
228228
int len = res.insts(0).tensor_array(idx).float_data_size();
229+
VLOG(2) << "fetch tensor : " << name << " type: float32 len : " << len;
229230
predict_res._float_map[name].resize(1);
230231
predict_res._float_map[name][0].resize(len);
231232
for (int i = 0; i < len; ++i) {
@@ -244,18 +245,18 @@ int PredictorClient::predict(
244245
<< "prepro_1:" << preprocess_end << " "
245246
<< "client_infer_0:" << client_infer_start << " "
246247
<< "client_infer_1:" << client_infer_end << " ";
247-
248+
248249
if (FLAGS_profile_server) {
249250
int op_num = res.profile_time_size() / 2;
250251
for (int i = 0; i < op_num; ++i) {
251252
oss << "op" << i << "_0:" << res.profile_time(i * 2) << " ";
252253
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
253254
}
254255
}
255-
256+
256257
oss << "postpro_0:" << postprocess_start << " ";
257258
oss << "postpro_1:" << postprocess_end;
258-
259+
259260
fprintf(stderr, "%s\n", oss.str().c_str());
260261
}
261262
return 0;
@@ -342,7 +343,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
342343
}
343344

344345
VLOG(2) << "batch [" << bi << "] "
345-
<< "itn feed value prepared";
346+
<< "int feed value prepared";
346347
}
347348

348349
int64_t preprocess_end = timeline.TimeStampUS();

python/examples/bert/bert_client.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def run_batch_general(self, text, fetch):
120120

121121

122122
def test():
123-
124123
bc = BertService(
125124
model_name='bert_chinese_L-12_H-768_A-12',
126125
max_seq_len=20,
@@ -130,16 +129,25 @@ def test():
130129
config_file = './serving_client_conf/serving_client_conf.prototxt'
131130
fetch = ["pooled_output"]
132131
bc.load_client(config_file, server_addr)
133-
batch_size = 4
132+
batch_size = 1
134133
batch = []
135134
for line in sys.stdin:
136-
if len(batch) < batch_size:
137-
batch.append([line.strip()])
135+
if batch_size == 1:
136+
result = bc.run_general([[line.strip()]], fetch)
137+
print(result)
138138
else:
139-
result = bc.run_batch_general(batch, fetch)
140-
batch = []
141-
for r in result:
142-
print(r)
139+
if len(batch) < batch_size:
140+
batch.append([line.strip()])
141+
else:
142+
result = bc.run_batch_general(batch, fetch)
143+
batch = []
144+
for r in result:
145+
print(r)
146+
if len(batch) > 0:
147+
result = bc.run_batch_general(batch, fetch)
148+
batch = []
149+
for r in result:
150+
print(r)
143151

144152

145153
if __name__ == '__main__':

python/examples/bert/bert_server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
server = Server()
3232
server.set_op_sequence(op_seq_maker.get_op_sequence())
3333
server.set_num_threads(4)
34-
server.set_local_bin(
35-
"~/github/Serving/build_server/core/general-server/serving")
3634

3735
server.load_model_config(sys.argv[1])
3836
port = int(sys.argv[2])

python/examples/bert/get_data.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate

0 commit comments

Comments
 (0)