@@ -195,8 +195,14 @@ def create_predictor(args, mode, logger):
195
195
if not os .path .exists (model_file_path ):
196
196
raise ValueError ("not find model file path {}" .format (model_file_path ))
197
197
198
+ sess_options = args .onnx_sess_options or None
199
+
198
200
if args .onnx_providers and len (args .onnx_providers ) > 0 :
199
- sess = ort .InferenceSession (model_file_path , providers = args .onnx_providers )
201
+ sess = ort .InferenceSession (
202
+ model_file_path ,
203
+ providers = args .onnx_providers ,
204
+ sess_options = sess_options ,
205
+ )
200
206
elif args .use_gpu :
201
207
sess = ort .InferenceSession (
202
208
model_file_path ,
@@ -206,10 +212,13 @@ def create_predictor(args, mode, logger):
206
212
{"device_id" : args .gpu_id , "cudnn_conv_algo_search" : "DEFAULT" },
207
213
)
208
214
],
215
+ sess_options = sess_options ,
209
216
)
210
217
else :
211
218
sess = ort .InferenceSession (
212
- model_file_path , providers = ["CPUExecutionProvider" ]
219
+ model_file_path ,
220
+ providers = ["CPUExecutionProvider" ],
221
+ sess_options = sess_options ,
213
222
)
214
223
return sess , sess .get_inputs ()[0 ], None , None
215
224
0 commit comments