Skip to content

Commit 7d15312

Browse files
authored
Merge pull request #306 from Yancey1989/fix_trainer_count
Fix invalied specify trainer_count
2 parents 705d08a + 8531600 commit 7d15312

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

paddlecloud/paddlejob/paddle_job.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def get_env(self):
8787
envs.append({"name":"ENTRY", "value":self._entry})
8888
envs.append({"name":"TRAINER_PACKAGE", "value":self._job_package})
8989
envs.append({"name":"PADDLE_INIT_PORT", "value":str(DEFAULT_PADDLE_PORT)})
90-
envs.append({"name":"PADDLE_INIT_TRAINER_COUNT", "value":str(self._cpu)})
90+
if self._gpu > 0:
91+
envs.append({"name":"PADDLE_INIT_TRAINER_COUNT", "value":str(self._gpu)})
92+
else:
93+
envs.append({"name":"PADDLE_INIT_TRAINER_COUNT", "value":str(self._cpu)})
9194
envs.append({"name":"PADDLE_INIT_PORTS_NUM", "value":str(self._ports_num)})
9295
envs.append({"name":"PADDLE_INIT_PORTS_NUM_FOR_SPARSE", "value":str(self._ports_num_for_sparse)})
9396
envs.append({"name":"PADDLE_INIT_NUM_GRADIENT_SERVERS", "value":str(self._num_gradient_servers)})

0 commit comments

Comments
 (0)