Skip to content

Commit c6c48d4

Browse files
committed
added ability to provide onnxruntime SessionOptions
1 parent d28a039 commit c6c48d4

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tools/infer/utility.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,14 @@ def create_predictor(args, mode, logger):
195195
if not os.path.exists(model_file_path):
196196
raise ValueError("not find model file path {}".format(model_file_path))
197197

198+
sess_options = args.onnx_sess_options or None
199+
198200
if args.onnx_providers and len(args.onnx_providers) > 0:
199-
sess = ort.InferenceSession(model_file_path, providers=args.onnx_providers)
201+
sess = ort.InferenceSession(
202+
model_file_path,
203+
providers=args.onnx_providers,
204+
sess_options=sess_options,
205+
)
200206
elif args.use_gpu:
201207
sess = ort.InferenceSession(
202208
model_file_path,
@@ -206,10 +212,13 @@ def create_predictor(args, mode, logger):
206212
{"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"},
207213
)
208214
],
215+
sess_options=sess_options,
209216
)
210217
else:
211218
sess = ort.InferenceSession(
212-
model_file_path, providers=["CPUExecutionProvider"]
219+
model_file_path,
220+
providers=["CPUExecutionProvider"],
221+
sess_options=sess_options,
213222
)
214223
return sess, sess.get_inputs()[0], None, None
215224

0 commit comments

Comments
 (0)