@@ -215,22 +215,29 @@ def create_predictor(args, mode, logger):
215
215
else :
216
216
file_names = ["model" , "inference" ]
217
217
for file_name in file_names :
218
- model_file_path = "{}/{}.pdmodel" .format (model_dir , file_name )
219
218
params_file_path = "{}/{}.pdiparams" .format (model_dir , file_name )
220
- if os .path .exists (model_file_path ) and os . path . exists ( params_file_path ):
219
+ if os .path .exists (params_file_path ):
221
220
break
222
- if not os .path .exists (model_file_path ):
221
+
222
+ if not os .path .exists (params_file_path ):
223
223
raise ValueError (
224
- "not find model.pdmodel or inference.pdmodel in {}" . format ( model_dir )
224
+ f "not find { file_name } .pdiparams or { file_name } .pdiparams in { model_dir } "
225
225
)
226
- if not os .path .exists (params_file_path ):
226
+
227
+ if not os .path .exists (
228
+ "{}/{}.pdmodel" .format (model_dir , file_name )
229
+ ) or not os .path .exists ("{}/{}.json" .format (model_dir , file_name )):
227
230
raise ValueError (
228
- "not find model.pdiparams or inference.pdiparams in {}" .format (
229
- model_dir
230
- )
231
+ f"not find { file_name } .json or { file_name } .pdmodel in { model_dir } "
231
232
)
232
233
233
- config = inference .Config (model_file_path , params_file_path )
234
+ if paddle .__version__ == "0.0.0" or paddle .__version__ >= "3.0.0" :
235
+ model_path = model_dir
236
+ model_prefix = file_name
237
+ config = inference .Config (model_path , model_prefix )
238
+ else :
239
+ model_file_path = "{}/{}.pdmodel" .format (model_dir , file_name )
240
+ config = inference .Config (model_file_path , params_file_path )
234
241
235
242
if hasattr (args , "precision" ):
236
243
if args .precision == "fp16" and args .use_tensorrt :
0 commit comments