@@ -41,6 +41,39 @@ def _transforms(dataset):
41
41
return transforms
42
42
43
43
44
+ def find_images_with_bounding_size (eval_dataset : paddle .io .Dataset ):
45
+ max_length_index = - 1
46
+ max_width_index = - 1
47
+ min_length_index = - 1
48
+ min_width_index = - 1
49
+
50
+ max_length = float ('-inf' )
51
+ max_width = float ('-inf' )
52
+ min_length = float ('inf' )
53
+ min_width = float ('inf' )
54
+ for idx , data in enumerate (eval_dataset ):
55
+ image = np .array (data ['img' ])
56
+ h , w = image .shape [- 2 :]
57
+ if h > max_length :
58
+ max_length = h
59
+ max_length_index = idx
60
+ if w > max_width :
61
+ max_width = w
62
+ max_width_index = idx
63
+ if h < min_length :
64
+ min_length = h
65
+ min_length_index = idx
66
+ if w < min_width :
67
+ min_width = w
68
+ min_width_index = idx
69
+ print (f"Found max image length: { max_length } , index: { max_length_index } " )
70
+ print (f"Found max image width: { max_width } , index: { max_width_index } " )
71
+ print (f"Found min image length: { min_length } , index: { min_length_index } " )
72
+ print (f"Found min image width: { min_width } , index: { min_width_index } " )
73
+ return paddle .io .Subset (eval_dataset , [max_width_index , max_length_index ,
74
+ min_width_index , min_length_index ])
75
+
76
+
44
77
def load_predictor (args ):
45
78
"""
46
79
load predictor func
@@ -109,7 +142,7 @@ def predict_image(args):
109
142
data = transform ({'img' : args .image_file })
110
143
data = data ['img' ][np .newaxis , :]
111
144
112
- # Step2: Prepare prdictor
145
+ # Step2: Prepare predictor
113
146
predictor , rerun_flag = load_predictor (args )
114
147
115
148
# Step3: Inference
@@ -167,6 +200,15 @@ def eval(args):
167
200
168
201
eval_dataset = builder .val_dataset
169
202
203
+ predictor , rerun_flag = load_predictor (args )
204
+
205
+ if rerun_flag and args .use_multi_img_for_dynamic_shape_collect :
206
+ print (
207
+ "***** Try to find the images with the largest and smallest length and width respectively in the ADE20K "
208
+ "dataset for collecting dynamic shape. *****"
209
+ )
210
+ eval_dataset = find_images_with_bounding_size (eval_dataset )
211
+
170
212
batch_sampler = paddle .io .BatchSampler (
171
213
eval_dataset , batch_size = 1 , shuffle = False , drop_last = False )
172
214
loader = paddle .io .DataLoader (
@@ -175,8 +217,6 @@ def eval(args):
175
217
num_workers = 0 ,
176
218
return_list = True )
177
219
178
- predictor , rerun_flag = load_predictor (args )
179
-
180
220
intersect_area_all = 0
181
221
pred_area_all = 0
182
222
label_area_all = 0
@@ -207,14 +247,22 @@ def eval(args):
207
247
time_max = max (time_max , timed )
208
248
predict_time += timed
209
249
if rerun_flag :
210
- print (
211
- "***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
212
- )
213
- return
214
-
250
+ if args .use_multi_img_for_dynamic_shape_collect :
251
+ if batch_id == sample_nums - 1 :
252
+ print (
253
+ "***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
254
+ )
255
+ return
256
+ else :
257
+ continue
258
+ else :
259
+ print (
260
+ "***** Collect dynamic shape done, Please rerun the program to get correct results. *****"
261
+ )
262
+ return
215
263
logit = reverse_transform (
216
- paddle .to_tensor (results ), data ['trans_info' ], mode = "bilinear" )
217
- pred = paddle .to_tensor (logit )
264
+ paddle .to_tensor (results ). unsqueeze ( 0 ) , data ['trans_info' ], mode = "bilinear" )
265
+ pred = paddle .to_tensor (logit ). squeeze ( 0 )
218
266
if len (
219
267
pred .shape
220
268
) == 4 : # for humanseg model whose prediction is distribution but not class id
@@ -314,6 +362,12 @@ def eval(args):
314
362
help = "Whether use mkldnn or not." )
315
363
parser .add_argument (
316
364
"--cpu_threads" , type = int , default = 1 , help = "Num of cpu threads." )
365
+ parser .add_argument (
366
+ "--use_multi_img_for_dynamic_shape_collect" ,
367
+ type = bool ,
368
+ default = False ,
369
+ help = "Whether it is necessary to use multiple images to collect shape infomation,\
370
+ When the image sizes in the data set are different, it needs to be set to True." )
317
371
args = parser .parse_args ()
318
372
if args .image_file :
319
373
predict_image (args )
0 commit comments