@@ -82,7 +82,24 @@ def _current_expected_place():
82
82
83
83
84
84
def _cpu_num ():
85
- return int (os .environ .get ('CPU_NUM' , multiprocessing .cpu_count ()))
85
+ cpu_num = os .environ .get ('CPU_NUM' , None )
86
+ if cpu_num is None :
87
+ sys .stderr .write (
88
+ 'The CPU_NUM is not specified, you should set CPU_NUM in '
89
+ 'the environment variable list, i.e export CPU_NUM=1. CPU_NUM '
90
+ 'indicates that how many CPUPlace are used in the current task.\n '
91
+ '!!! The default number of CPUPlaces is 1.' )
92
+ os .environ ['CPU_NUM' ] = str (1 )
93
+ return int (cpu_num )
94
+
95
+
96
+ def _cuda_ids ():
97
+ gpus_env = os .getenv ("FLAGS_selected_gpus" )
98
+ if gpus_env :
99
+ device_ids = [int (s ) for s in gpus_env .split ("," )]
100
+ else :
101
+ device_ids = six .moves .range (core .get_cuda_device_count ())
102
+ return device_ids
86
103
87
104
88
105
def cuda_places (device_ids = None ):
@@ -116,11 +133,7 @@ def cuda_places(device_ids=None):
116
133
assert core .is_compiled_with_cuda (), \
117
134
"Not compiled with CUDA"
118
135
if device_ids is None :
119
- gpus_env = os .getenv ("FLAGS_selected_gpus" )
120
- if gpus_env :
121
- device_ids = [int (s ) for s in gpus_env .split ("," )]
122
- else :
123
- device_ids = six .moves .range (core .get_cuda_device_count ())
136
+ device_ids = _cuda_ids ()
124
137
elif not isinstance (device_ids , (list , tuple )):
125
138
device_ids = [device_ids ]
126
139
return [core .CUDAPlace (dev_id ) for dev_id in device_ids ]
0 commit comments