Skip to content

Commit 8f826ed

Browse files
committed
update for infer config
1 parent d1583d4 commit 8f826ed

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

tools/infer/utility.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -215,22 +215,29 @@ def create_predictor(args, mode, logger):
215215
else:
216216
file_names = ["model", "inference"]
217217
for file_name in file_names:
218-
model_file_path = "{}/{}.pdmodel".format(model_dir, file_name)
219218
params_file_path = "{}/{}.pdiparams".format(model_dir, file_name)
220-
if os.path.exists(model_file_path) and os.path.exists(params_file_path):
219+
if os.path.exists(params_file_path):
221220
break
222-
if not os.path.exists(model_file_path):
221+
222+
if not os.path.exists(params_file_path):
223223
raise ValueError(
224-
"not find model.pdmodel or inference.pdmodel in {}".format(model_dir)
224+
f"not find {file_name}.pdiparams or {file_name}.pdiparams in {model_dir}"
225225
)
226-
if not os.path.exists(params_file_path):
226+
227+
if not os.path.exists(
228+
"{}/{}.pdmodel".format(model_dir, file_name)
229+
) or not os.path.exists("{}/{}.json".format(model_dir, file_name)):
227230
raise ValueError(
228-
"not find model.pdiparams or inference.pdiparams in {}".format(
229-
model_dir
230-
)
231+
f"not find {file_name}.json or {file_name}.pdmodel in {model_dir}"
231232
)
232233

233-
config = inference.Config(model_file_path, params_file_path)
234+
if paddle.__version__ == "0.0.0" or paddle.__version__ >= "3.0.0":
235+
model_path = model_dir
236+
model_prefix = file_name
237+
config = inference.Config(model_path, model_prefix)
238+
else:
239+
model_file_path = "{}/{}.pdmodel".format(model_dir, file_name)
240+
config = inference.Config(model_file_path, params_file_path)
234241

235242
if hasattr(args, "precision"):
236243
if args.precision == "fp16" and args.use_tensorrt:

0 commit comments

Comments
 (0)