Skip to content

Commit 61fe956

Browse files
ShenYuhannepeplwu
authored andcommitted
1.add config batch_size; 2.delete req_id_ for every file (#213)
* 1.add config batch_size; 2.delete req_id_ for every file
1 parent 66fb66c commit 61fe956

File tree

1 file changed

+40
-19
lines changed

1 file changed

+40
-19
lines changed

paddlehub/serving/app_single.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,47 +65,53 @@
6565
}
6666

6767

68-
def predict_sentiment_analysis(module, input_text, extra=None):
68+
def predict_sentiment_analysis(module, input_text, batch_size, extra=None):
6969
global use_gpu
7070
method_name = module.desc.attr.map.data['default_signature'].s
7171
predict_method = getattr(module, method_name)
7272
try:
7373
data = input_text[0]
7474
data.update(input_text[1])
75-
results = predict_method(data=data, use_gpu=use_gpu)
75+
results = predict_method(
76+
data=data, use_gpu=use_gpu, batch_size=batch_size)
7677
except Exception as err:
7778
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
7879
print(curr, " - ", err)
7980
return {"result": "Please check data format!"}
8081
return results
8182

8283

83-
def predict_pretrained_model(module, input_text, extra=None):
84+
def predict_pretrained_model(module, input_text, batch_size, extra=None):
8485
global use_gpu
8586
method_name = module.desc.attr.map.data['default_signature'].s
8687
predict_method = getattr(module, method_name)
8788
try:
8889
data = {"text": input_text}
89-
results = predict_method(data=data, use_gpu=use_gpu)
90+
results = predict_method(
91+
data=data, use_gpu=use_gpu, batch_size=batch_size)
9092
except Exception as err:
9193
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
9294
print(curr, " - ", err)
9395
return {"result": "Please check data format!"}
9496
return results
9597

9698

97-
def predict_lexical_analysis(module, input_text, extra=[]):
99+
def predict_lexical_analysis(module, input_text, batch_size, extra=[]):
98100
global use_gpu
99101
method_name = module.desc.attr.map.data['default_signature'].s
100102
predict_method = getattr(module, method_name)
101103
data = {"text": input_text}
102104
try:
103105
if extra == []:
104-
results = predict_method(data=data, use_gpu=use_gpu)
106+
results = predict_method(
107+
data=data, use_gpu=use_gpu, batch_size=batch_size)
105108
else:
106109
user_dict = extra[0]
107110
results = predict_method(
108-
data=data, user_dict=user_dict, use_gpu=use_gpu)
111+
data=data,
112+
user_dict=user_dict,
113+
use_gpu=use_gpu,
114+
batch_size=batch_size)
109115
for path in extra:
110116
os.remove(path)
111117
except Exception as err:
@@ -115,29 +121,31 @@ def predict_lexical_analysis(module, input_text, extra=[]):
115121
return results
116122

117123

118-
def predict_classification(module, input_img):
124+
def predict_classification(module, input_img, batch_size):
119125
global use_gpu
120126
method_name = module.desc.attr.map.data['default_signature'].s
121127
predict_method = getattr(module, method_name)
122128
try:
123129
input_img = {"image": input_img}
124-
results = predict_method(data=input_img, use_gpu=use_gpu)
130+
results = predict_method(
131+
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
125132
except Exception as err:
126133
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
127134
print(curr, " - ", err)
128135
return {"result": "Please check data format!"}
129136
return results
130137

131138

132-
def predict_gan(module, input_img, extra={}):
139+
def predict_gan(module, input_img, id, batch_size, extra={}):
133140
# special
134141
output_folder = module.name.split("_")[0] + "_" + "output"
135142
global use_gpu
136143
method_name = module.desc.attr.map.data['default_signature'].s
137144
predict_method = getattr(module, method_name)
138145
try:
139146
input_img = {"image": input_img}
140-
results = predict_method(data=input_img, use_gpu=use_gpu)
147+
results = predict_method(
148+
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
141149
except Exception as err:
142150
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
143151
print(curr, " - ", err)
@@ -155,6 +163,7 @@ def predict_gan(module, input_img, extra={}):
155163
b_body = str(b_body).replace("b'", "").replace("'", "")
156164
b_img = b_head + "," + b_body
157165
base64_list.append(b_img)
166+
results[index] = results[index].replace(id + "_", "")
158167
results[index] = {"path": results[index]}
159168
results[index].update({"base64": b_img})
160169
results_pack.append(results[index])
@@ -163,14 +172,15 @@ def predict_gan(module, input_img, extra={}):
163172
return results_pack
164173

165174

166-
def predict_object_detection(module, input_img):
175+
def predict_object_detection(module, input_img, id, batch_size):
167176
output_folder = "output"
168177
global use_gpu
169178
method_name = module.desc.attr.map.data['default_signature'].s
170179
predict_method = getattr(module, method_name)
171180
try:
172181
input_img = {"image": input_img}
173-
results = predict_method(data=input_img, use_gpu=use_gpu)
182+
results = predict_method(
183+
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
174184
except Exception as err:
175185
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
176186
print(curr, " - ", err)
@@ -186,22 +196,25 @@ def predict_object_detection(module, input_img):
186196
b_body = str(b_body).replace("b'", "").replace("'", "")
187197
b_img = b_head + "," + b_body
188198
base64_list.append(b_img)
199+
results[index]["path"] = results[index]["path"].replace(
200+
id + "_", "")
189201
results[index].update({"base64": b_img})
190202
results_pack.append(results[index])
191203
os.remove(item)
192204
os.remove(os.path.join(output_folder, item))
193205
return results_pack
194206

195207

196-
def predict_semantic_segmentation(module, input_img):
208+
def predict_semantic_segmentation(module, input_img, id, batch_size):
197209
# special
198210
output_folder = module.name.split("_")[-1] + "_" + "output"
199211
global use_gpu
200212
method_name = module.desc.attr.map.data['default_signature'].s
201213
predict_method = getattr(module, method_name)
202214
try:
203215
input_img = {"image": input_img}
204-
results = predict_method(data=input_img, use_gpu=use_gpu)
216+
results = predict_method(
217+
data=input_img, use_gpu=use_gpu, batch_size=batch_size)
205218
except Exception as err:
206219
curr = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
207220
print(curr, " - ", err)
@@ -219,6 +232,10 @@ def predict_semantic_segmentation(module, input_img):
219232
b_body = str(b_body).replace("b'", "").replace("'", "")
220233
b_img = b_head + "," + b_body
221234
base64_list.append(b_img)
235+
results[index]["origin"] = results[index]["origin"].replace(
236+
id + "_", "")
237+
results[index]["processed"] = results[index]["processed"].replace(
238+
id + "_", "")
222239
results[index].update({"base64": b_img})
223240
results_pack.append(results[index])
224241
os.remove(item)
@@ -260,7 +277,7 @@ def get_modules_info():
260277
@app_instance.route("/predict/image/<module_name>", methods=["POST"])
261278
def predict_image(module_name):
262279
req_id = request.data.get("id")
263-
global use_gpu
280+
global use_gpu, batch_size_dict
264281
img_base64 = request.form.getlist("image")
265282
file_name_list = []
266283
if img_base64 != []:
@@ -289,7 +306,8 @@ def predict_image(module_name):
289306
else:
290307
module_type = module.type.split("/")[-1].replace("-", "_").lower()
291308
predict_func = eval("predict_" + module_type)
292-
results = predict_func(module, file_name_list)
309+
batch_size = batch_size_dict.get(module_name, 1)
310+
results = predict_func(module, file_name_list, req_id, batch_size)
293311
r = {"results": str(results)}
294312
return r
295313

@@ -316,22 +334,25 @@ def predict_text(module_name):
316334
file_path = req_id + "_" + item.filename
317335
file_list.append(file_path)
318336
item.save(file_path)
319-
results = predict_func(module, data, file_list)
337+
batch_size = batch_size_dict.get(module_name, 1)
338+
results = predict_func(module, data, batch_size, file_list)
320339
return {"results": results}
321340

322341
return app_instance
323342

324343

325344
def config_with_file(configs):
326-
global nlp_module, cv_module
345+
global nlp_module, cv_module, batch_size_dict
327346
nlp_module = []
328347
cv_module = []
348+
batch_size_dict = {}
329349
for item in configs:
330350
print(item)
331351
if item["category"] == "CV":
332352
cv_module.append(item["module"])
333353
elif item["category"] == "NLP":
334354
nlp_module.append(item["module"])
355+
batch_size_dict.update({item["module"]: item["batch_size"]})
335356

336357

337358
def run(is_use_gpu=False, configs=None, port=8866, timeout=60):

0 commit comments

Comments
 (0)