@@ -98,7 +98,7 @@ def save_predict_result(task: str, preds_list: list, output_save_dir: str):
98
98
save_rec_res (rec_res , img_list , save_path = os .path .join (output_save_dir , "rec_results.txt" ))
99
99
100
100
101
- def predict_single_step (cfg ):
101
+ def predict_single_step (cfg , save_res = True ):
102
102
"""Run predict for det task or rec task"""
103
103
# 1. Set the environment information.
104
104
set_context (mode = cfg .system .mode )
@@ -162,6 +162,12 @@ def predict_single_step(cfg):
162
162
"The program has switched to amp_level O2 automatically."
163
163
)
164
164
amp_level = "O2"
165
+ cfg .model .backbone .pretrained = False
166
+ if cfg .predict .ckpt_load_path is None :
167
+ logger .warning (
168
+ f"No ckpt is available for { cfg .model .task } , "
169
+ "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
170
+ )
165
171
network = build_model (cfg .model , ckpt_load_path = cfg .predict .ckpt_load_path , amp_level = amp_level )
166
172
network .set_train (False )
167
173
@@ -219,7 +225,8 @@ def predict_single_step(cfg):
219
225
preds_list .append (preds )
220
226
221
227
# 7. Save the prediction results to the specified directory
222
- save_predict_result (cfg .model .type , preds_list , output_save_dir )
228
+ if save_res is True :
229
+ save_predict_result (cfg .model .type , preds_list , output_save_dir )
223
230
return preds_list
224
231
225
232
@@ -231,7 +238,7 @@ def predict_system(args, det_cfg, rec_cfg):
231
238
output_save_dir = det_cfg .predict .output_save_dir or "./output"
232
239
233
240
# get det result from predict
234
- preds_list = predict_single_step (det_cfg )
241
+ preds_list = predict_single_step (det_cfg , save_res = False )
235
242
236
243
# set amp level
237
244
amp_level = det_cfg .system .get ("amp_level_infer" , "O0" )
@@ -247,6 +254,12 @@ def predict_system(args, det_cfg, rec_cfg):
247
254
postprocessor = build_postprocess (rec_cfg .postprocess )
248
255
249
256
# build rec model from yaml
257
+ rec_cfg .model .backbone .pretrained = False
258
+ if rec_cfg .predict .ckpt_load_path is None :
259
+ logger .warning (
260
+ f"No ckpt is available for { rec_cfg .model .type } , "
261
+ "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
262
+ )
250
263
rec_network = build_model (rec_cfg .model , ckpt_load_path = rec_cfg .predict .ckpt_load_path , amp_level = amp_level )
251
264
252
265
# start rec task
0 commit comments