Skip to content

Commit 0516cca

Browse files
authored
Added option for predict_single_step whether to save intermediate results. (#654)
1 parent c03f28b commit 0516cca

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tools/infer/text/predict_from_yaml.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def save_predict_result(task: str, preds_list: list, output_save_dir: str):
9898
save_rec_res(rec_res, img_list, save_path=os.path.join(output_save_dir, "rec_results.txt"))
9999

100100

101-
def predict_single_step(cfg):
101+
def predict_single_step(cfg, save_res=True):
102102
"""Run predict for det task or rec task"""
103103
# 1. Set the environment information.
104104
set_context(mode=cfg.system.mode)
@@ -162,6 +162,12 @@ def predict_single_step(cfg):
162162
"The program has switched to amp_level O2 automatically."
163163
)
164164
amp_level = "O2"
165+
cfg.model.backbone.pretrained = False
166+
if cfg.predict.ckpt_load_path is None:
167+
logger.warning(
168+
f"No ckpt is available for {cfg.model.task}, "
169+
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
170+
)
165171
network = build_model(cfg.model, ckpt_load_path=cfg.predict.ckpt_load_path, amp_level=amp_level)
166172
network.set_train(False)
167173

@@ -219,7 +225,8 @@ def predict_single_step(cfg):
219225
preds_list.append(preds)
220226

221227
# 7. Save the prediction results to the specified directory
222-
save_predict_result(cfg.model.type, preds_list, output_save_dir)
228+
if save_res is True:
229+
save_predict_result(cfg.model.type, preds_list, output_save_dir)
223230
return preds_list
224231

225232

@@ -231,7 +238,7 @@ def predict_system(args, det_cfg, rec_cfg):
231238
output_save_dir = det_cfg.predict.output_save_dir or "./output"
232239

233240
# get det result from predict
234-
preds_list = predict_single_step(det_cfg)
241+
preds_list = predict_single_step(det_cfg, save_res=False)
235242

236243
# set amp level
237244
amp_level = det_cfg.system.get("amp_level_infer", "O0")
@@ -247,6 +254,12 @@ def predict_system(args, det_cfg, rec_cfg):
247254
postprocessor = build_postprocess(rec_cfg.postprocess)
248255

249256
# build rec model from yaml
257+
rec_cfg.model.backbone.pretrained = False
258+
if rec_cfg.predict.ckpt_load_path is None:
259+
logger.warning(
260+
f"No ckpt is available for {rec_cfg.model.type}, "
261+
"please check your configuration of 'predict.ckpt_load_path' in the yaml file."
262+
)
250263
rec_network = build_model(rec_cfg.model, ckpt_load_path=rec_cfg.predict.ckpt_load_path, amp_level=amp_level)
251264

252265
# start rec task

0 commit comments

Comments
 (0)