Skip to content

Commit d214d49

Browse files
guru4elephantseiriosPlus
authored andcommitted
add paddle cloud role maker for customized usage, note this is only for industrial users that have cloud environment pre-configuration (#18121)
add paddle cloud role maker for specific cloud usage. This pr will simplifies user's configuration in distributed training.
1 parent ad72ceb commit d214d49

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

python/paddle/fluid/incubate/fleet/base/fleet_base.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,7 @@ def init(self, role_maker=None):
188188
if role_maker and not isinstance(role_maker, RoleMakerBase):
189189
raise ValueError("role_maker must be an instance of RoleMakerBase")
190190

191-
if isinstance(role_maker, MPISymetricRoleMaker):
192-
self._role_maker = role_maker
193-
self._role_maker.generate_role()
194-
195-
elif isinstance(role_maker, UserDefinedRoleMaker):
196-
self._role_maker = role_maker
197-
198-
else:
199-
raise ValueError(
200-
"role_maker must be an instance of UserDefinedRoleMaker/MPISymetricRoleMaker"
201-
)
191+
self._role_maker.generate_role()
202192

203193
self._is_initialized = True
204194

python/paddle/fluid/incubate/fleet/base/role_maker.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
__all__ = [
1818
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
19-
'UserDefinedCollectiveRoleMaker'
19+
'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
2020
]
2121

2222

@@ -292,6 +292,50 @@ def generate_role(self):
292292
self._role_is_generated = True
293293

294294

295+
class PaddleCloudRoleMaker(RoleMakerBase):
296+
def __init__(self):
297+
super(PaddleCloudRoleMaker, self).__init__()
298+
299+
def generate_role(self):
300+
if not self._role_is_generated:
301+
self.port = os.getenv("PADDLE_PORT", "6174")
302+
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
303+
eplist = []
304+
for ip in pserver_ips.split(","):
305+
eplist.append(':'.join([ip, port]))
306+
self.endpoints = ",".join(eplist)
307+
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
308+
self.current_endpoint = os.getenv("POD_IP",
309+
"localhost") + ":" + port
310+
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
311+
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
312+
self.eplist = eplist
313+
self.endpoints = self.endpoints.split(",")
314+
if self.role.upper() == "PSERVER":
315+
self.current_id = self.endpoints.index(self.current_endpoint)
316+
else:
317+
self.current_id = self.trainer_id
318+
self._role_is_generated = True
319+
320+
def is_wokrer(self):
321+
return self._role == Role.WORKER
322+
323+
def is_server(self):
324+
return self._role == Role.SERVER
325+
326+
def is_first_worker(self):
327+
return self._role == Role.WORKER and self._current_id == 0
328+
329+
def worker_index(self):
330+
return self._current_id
331+
332+
def server_index(self):
333+
return self._current_id
334+
335+
def worker_num(self):
336+
return self._worker_num
337+
338+
295339
class UserDefinedRoleMaker(RoleMakerBase):
296340
def __init__(self,
297341
current_id=0,
@@ -329,6 +373,9 @@ def __init__(self,
329373
else:
330374
self._server_endpoints = server_endpoints
331375

376+
def generate_role(self):
377+
self._role_is_generated = True
378+
332379
def is_worker(self):
333380
return self._role == Role.WORKER
334381

@@ -369,6 +416,9 @@ def __init__(self, current_id=0, worker_endpoints=None):
369416
self._worker_endpoints = worker_endpoints
370417
self._worker_num = len(self._worker_endpoints)
371418

419+
def generate_role(self):
420+
self._role_is_generated = True
421+
372422
def is_worker(self):
373423
return True
374424

0 commit comments

Comments
 (0)