Skip to content

Commit 1f1cc22

Browse files
add random port (#18504)
* add random port
1 parent fe32879 commit 1f1cc22

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

python/paddle/fluid/incubate/fleet/base/role_maker.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def __init__(self):
234234
super(MPISymetricRoleMaker, self).__init__()
235235
self._node_type = None
236236
self._proc_per_node = 2
237+
self._pserver_rand_port = 0
237238

238239
def _check_role_generation(self):
239240
if not self._role_is_generated:
@@ -248,6 +249,20 @@ def is_first_worker(self):
248249
return self.is_worker() and 0 == self.worker_index()
249250
return False
250251

252+
def get_pserver_endpoints(self):
253+
if self._pserver_rand_port <= 0:
254+
import random
255+
random.seed(self._server_num())
256+
# port will be randomly generated from 60001 to 63999
257+
# random seed is server num so that all nodes will get
258+
# the same port
259+
self._pserver_rand_port = random.randint(60001, 64000)
260+
endpoints = [
261+
x + ":" + str(self._pserver_rand_port)
262+
for x in self._server_endpoints
263+
]
264+
return endpoints
265+
251266
def worker_num(self):
252267
return self._worker_num()
253268

@@ -273,33 +288,38 @@ def _worker_num(self):
273288
"""
274289
if self._check_role_generation():
275290
if self.is_worker():
276-
return self._get_size() / 2
291+
return self._get_size() / self._proc_per_node
277292
return 0
278293

279294
def _server_num(self):
280295
"""
281296
return the current number of server
282297
"""
283298
if self._check_role_generation():
284-
if self.is_server():
285-
return self._get_size() / 2
286-
return 0
299+
return self._get_size() / self._proc_per_node
300+
else:
301+
self.generate_role()
302+
return self._get_size() / self._proc_per_node
287303

288304
def worker_index(self):
289305
"""
290306
return the index of worker
291307
"""
292308
if self._check_role_generation():
293309
return self._rank / self._proc_per_node
294-
return 0
310+
else:
311+
self.generate_role()
312+
return self._get_size() / 2
295313

296314
def server_index(self):
297315
"""
298316
return the index of server
299317
"""
300318
if self._check_role_generation():
301319
return self._rank / self._proc_per_node
302-
return 0
320+
else:
321+
self.generate_role()
322+
return self._get_size() / self._proc_per_node
303323

304324
def _barrier_worker(self):
305325
"""
@@ -308,6 +328,8 @@ def _barrier_worker(self):
308328
if self._check_role_generation():
309329
if self.is_worker():
310330
self._node_type_comm.barrier()
331+
else:
332+
raise Exception("You should check role generation first")
311333

312334
def _barrier_server(self):
313335
"""
@@ -316,6 +338,8 @@ def _barrier_server(self):
316338
if self._check_role_generation():
317339
if self.is_server():
318340
self._node_type_comm.barrier()
341+
else:
342+
raise Exception("You should check role generation first")
319343

320344
def generate_role(self):
321345
"""
@@ -332,6 +356,8 @@ def generate_role(self):
332356
self._node_type = 1
333357
self._node_type_comm = self._comm.Split(self._node_type)
334358
self._role_is_generated = True
359+
else:
360+
raise Exception("You should check role generation first")
335361

336362

337363
class PaddleCloudRoleMaker(RoleMakerBase):

0 commit comments

Comments
 (0)