18
18
import os
19
19
import base64
20
20
import logging
21
+ import shutil
21
22
22
23
cv_module_method = {
23
24
"vgg19_imagenet" : "predict_classification" ,
47
48
"faster_rcnn_coco2017" : "predict_object_detection" ,
48
49
"cyclegan_cityscapes" : "predict_gan" ,
49
50
"deeplabv3p_xception65_humanseg" : "predict_semantic_segmentation" ,
50
- "ace2p" : "predict_semantic_segmentation"
51
+ "ace2p" : "predict_semantic_segmentation" ,
52
+ "pyramidbox_lite_server_mask" : "predict_mask" ,
53
+ "pyramidbox_lite_mobile_mask" : "predict_mask"
51
54
}
52
55
53
56
@@ -132,6 +135,59 @@ def predict_gan(module, input_img, id, batch_size, extra={}):
132
135
return results_pack
133
136
134
137
138
+ def predict_mask (module , input_img , id , batch_size , extra = None , r_img = False ):
139
+ output_folder = "detection_result"
140
+ global use_gpu
141
+ method_name = module .desc .attr .map .data ['default_signature' ].s
142
+ predict_method = getattr (module , method_name )
143
+ try :
144
+ data = {}
145
+ if input_img is not None :
146
+ input_img = {"image" : input_img }
147
+ data .update (input_img )
148
+ if extra is not None :
149
+ data .update (extra )
150
+ r_img = True if "r_img" in extra .keys () else False
151
+ results = predict_method (
152
+ data = data , use_gpu = use_gpu , batch_size = batch_size )
153
+ results = utils .handle_mask_results (results )
154
+ except Exception as err :
155
+ curr = time .strftime ("%Y-%m-%d %H:%M:%S" , time .localtime (time .time ()))
156
+ print (curr , " - " , err )
157
+ return {"result" : "Please check data format!" }
158
+ finally :
159
+ base64_list = []
160
+ results_pack = []
161
+ if input_img is not None :
162
+ if r_img is False :
163
+ shutil .rmtree (output_folder )
164
+ for index in range (len (results )):
165
+ results [index ]["path" ] = ""
166
+ results_pack = results
167
+ else :
168
+ input_img = input_img .get ("image" , [])
169
+ for index in range (len (input_img )):
170
+ item = input_img [index ]
171
+ with open (os .path .join (output_folder , item ), "rb" ) as fp :
172
+ b_head = "data:image/" + item .split ("." )[- 1 ] + ";base64"
173
+ b_body = base64 .b64encode (fp .read ())
174
+ b_body = str (b_body ).replace ("b'" , "" ).replace ("'" , "" )
175
+ b_img = b_head + "," + b_body
176
+ base64_list .append (b_img )
177
+ results [index ]["path" ] = results [index ]["path" ].replace (
178
+ id + "_" , "" ) if results [index ]["path" ] != "" \
179
+ else ""
180
+
181
+ results [index ].update ({"base64" : b_img })
182
+ results_pack .append (results [index ])
183
+ os .remove (item )
184
+ os .remove (os .path .join (output_folder , item ))
185
+ else :
186
+ results_pack = results
187
+
188
+ return results_pack
189
+
190
+
135
191
def predict_object_detection (module , input_img , id , batch_size , extra = {}):
136
192
output_folder = "detection_result"
137
193
global use_gpu
@@ -253,14 +309,22 @@ def predict_image(module_name):
253
309
extra_info = {}
254
310
for item in list (request .form .keys ()):
255
311
extra_info .update ({item : request .form .getlist (item )})
312
+
313
+ for key in extra_info .keys ():
314
+ if isinstance (extra_info [key ], list ):
315
+ extra_info [key ] = utils .base64s_to_cvmats (
316
+ eval (extra_info [key ][0 ])["b64s" ]) if isinstance (
317
+ extra_info [key ][0 ], str
318
+ ) and "b64s" in extra_info [key ][0 ] else extra_info [key ]
319
+
256
320
file_name_list = []
257
321
if img_base64 != []:
258
322
for item in img_base64 :
259
323
ext = item .split (";" )[0 ].split ("/" )[- 1 ]
260
324
if ext not in ["jpeg" , "jpg" , "png" ]:
261
325
return {"result" : "Unrecognized file type" }
262
326
filename = req_id + "_" \
263
- + utils .md5 (str (time .time ())+ item [0 :20 ]) \
327
+ + utils .md5 (str (time .time ()) + item [0 :20 ]) \
264
328
+ "." \
265
329
+ ext
266
330
img_data = base64 .b64decode (item .split (',' )[- 1 ])
@@ -281,6 +345,10 @@ def predict_image(module_name):
281
345
module_type = module .type .split ("/" )[- 1 ].replace ("-" , "_" ).lower ()
282
346
predict_func = eval ("predict_" + module_type )
283
347
batch_size = batch_size_dict .get (module_name , 1 )
348
+ if file_name_list == []:
349
+ file_name_list = None
350
+ if extra_info == {}:
351
+ extra_info = None
284
352
results = predict_func (module , file_name_list , req_id , batch_size ,
285
353
extra_info )
286
354
r = {"results" : str (results )}
0 commit comments