@@ -373,7 +373,7 @@ def _create(
373
373
logging .debug ("`device_id` has been set to None" )
374
374
375
375
if (
376
- self ._option .device_type in ("gpu" , "dcu" )
376
+ self ._option .device_type in ("gpu" , "dcu" , "npu" , "mlu" , "gcu" , "xpu" )
377
377
and self ._option .device_id is None
378
378
):
379
379
self ._option .device_id = 0
@@ -417,12 +417,14 @@ def _create(
417
417
if hasattr (config , "enable_new_executor" ):
418
418
config .enable_new_executor ()
419
419
elif self ._option .device_type == "xpu" :
420
+ config .enable_xpu ()
421
+ config .set_xpu_device_id (self ._option .device_id )
420
422
if hasattr (config , "enable_new_ir" ):
421
423
config .enable_new_ir (self ._option .enable_new_ir )
422
424
if hasattr (config , "enable_new_executor" ):
423
425
config .enable_new_executor ()
424
426
elif self ._option .device_type == "mlu" :
425
- config .enable_custom_device ("mlu" )
427
+ config .enable_custom_device ("mlu" , self . _option . device_id )
426
428
if hasattr (config , "enable_new_ir" ):
427
429
config .enable_new_ir (self ._option .enable_new_ir )
428
430
if hasattr (config , "enable_new_executor" ):
@@ -431,7 +433,7 @@ def _create(
431
433
from paddle_custom_device .gcu import passes as gcu_passes
432
434
433
435
gcu_passes .setUp ()
434
- config .enable_custom_device ("gcu" )
436
+ config .enable_custom_device ("gcu" , self . _option . device_id )
435
437
if hasattr (config , "enable_new_ir" ):
436
438
config .enable_new_ir ()
437
439
if hasattr (config , "enable_new_executor" ):
0 commit comments