|
16 | 16 |
|
17 | 17 | __all__ = [
|
18 | 18 | 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
|
19 |
| - 'UserDefinedCollectiveRoleMaker' |
| 19 | + 'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker' |
20 | 20 | ]
|
21 | 21 |
|
22 | 22 |
|
@@ -292,6 +292,50 @@ def generate_role(self):
|
292 | 292 | self._role_is_generated = True
|
293 | 293 |
|
294 | 294 |
|
| 295 | +class PaddleCloudRoleMaker(RoleMakerBase): |
| 296 | + def __init__(self): |
| 297 | + super(PaddleCloudRoleMaker, self).__init__() |
| 298 | + |
| 299 | + def generate_role(self): |
| 300 | + if not self._role_is_generated: |
| 301 | + self.port = os.getenv("PADDLE_PORT", "6174") |
| 302 | + self.pserver_ips = os.getenv("PADDLE_PSERVERS", "") |
| 303 | + eplist = [] |
| 304 | + for ip in pserver_ips.split(","): |
| 305 | + eplist.append(':'.join([ip, port])) |
| 306 | + self.endpoints = ",".join(eplist) |
| 307 | + self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) |
| 308 | + self.current_endpoint = os.getenv("POD_IP", |
| 309 | + "localhost") + ":" + port |
| 310 | + self.role = os.getenv("TRAINING_ROLE", "TRAINER") |
| 311 | + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) |
| 312 | + self.eplist = eplist |
| 313 | + self.endpoints = self.endpoints.split(",") |
| 314 | + if self.role.upper() == "PSERVER": |
| 315 | + self.current_id = self.endpoints.index(self.current_endpoint) |
| 316 | + else: |
| 317 | + self.current_id = self.trainer_id |
| 318 | + self._role_is_generated = True |
| 319 | + |
| 320 | + def is_wokrer(self): |
| 321 | + return self._role == Role.WORKER |
| 322 | + |
| 323 | + def is_server(self): |
| 324 | + return self._role == Role.SERVER |
| 325 | + |
| 326 | + def is_first_worker(self): |
| 327 | + return self._role == Role.WORKER and self._current_id == 0 |
| 328 | + |
| 329 | + def worker_index(self): |
| 330 | + return self._current_id |
| 331 | + |
| 332 | + def server_index(self): |
| 333 | + return self._current_id |
| 334 | + |
| 335 | + def worker_num(self): |
| 336 | + return self._worker_num |
| 337 | + |
| 338 | + |
295 | 339 | class UserDefinedRoleMaker(RoleMakerBase):
|
296 | 340 | def __init__(self,
|
297 | 341 | current_id=0,
|
@@ -329,6 +373,9 @@ def __init__(self,
|
329 | 373 | else:
|
330 | 374 | self._server_endpoints = server_endpoints
|
331 | 375 |
|
| 376 | + def generate_role(self): |
| 377 | + self._role_is_generated = True |
| 378 | + |
332 | 379 | def is_worker(self):
|
333 | 380 | return self._role == Role.WORKER
|
334 | 381 |
|
@@ -369,6 +416,9 @@ def __init__(self, current_id=0, worker_endpoints=None):
|
369 | 416 | self._worker_endpoints = worker_endpoints
|
370 | 417 | self._worker_num = len(self._worker_endpoints)
|
371 | 418 |
|
| 419 | + def generate_role(self): |
| 420 | + self._role_is_generated = True |
| 421 | + |
372 | 422 | def is_worker(self):
|
373 | 423 | return True
|
374 | 424 |
|
|
0 commit comments