Skip to content

Commit c3eeb61

Browse files
authored
support multy device set device id in inference (#3928)
1 parent e3e6203 commit c3eeb61

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddlex/inference/models/common/static_infer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def _create(
373373
logging.debug("`device_id` has been set to None")
374374

375375
if (
376-
self._option.device_type in ("gpu", "dcu")
376+
self._option.device_type in ("gpu", "dcu", "npu", "mlu", "gcu", "xpu")
377377
and self._option.device_id is None
378378
):
379379
self._option.device_id = 0
@@ -417,12 +417,14 @@ def _create(
417417
if hasattr(config, "enable_new_executor"):
418418
config.enable_new_executor()
419419
elif self._option.device_type == "xpu":
420+
config.enable_xpu()
421+
config.set_xpu_device_id(self._option.device_id)
420422
if hasattr(config, "enable_new_ir"):
421423
config.enable_new_ir(self._option.enable_new_ir)
422424
if hasattr(config, "enable_new_executor"):
423425
config.enable_new_executor()
424426
elif self._option.device_type == "mlu":
425-
config.enable_custom_device("mlu")
427+
config.enable_custom_device("mlu", self._option.device_id)
426428
if hasattr(config, "enable_new_ir"):
427429
config.enable_new_ir(self._option.enable_new_ir)
428430
if hasattr(config, "enable_new_executor"):
@@ -431,7 +433,7 @@ def _create(
431433
from paddle_custom_device.gcu import passes as gcu_passes
432434

433435
gcu_passes.setUp()
434-
config.enable_custom_device("gcu")
436+
config.enable_custom_device("gcu", self._option.device_id)
435437
if hasattr(config, "enable_new_ir"):
436438
config.enable_new_ir()
437439
if hasattr(config, "enable_new_executor"):

0 commit comments

Comments
 (0)