@@ -234,6 +234,7 @@ def __init__(self):
234
234
super (MPISymetricRoleMaker , self ).__init__ ()
235
235
self ._node_type = None
236
236
self ._proc_per_node = 2
237
+ self ._pserver_rand_port = 0
237
238
238
239
def _check_role_generation (self ):
239
240
if not self ._role_is_generated :
@@ -248,6 +249,20 @@ def is_first_worker(self):
248
249
return self .is_worker () and 0 == self .worker_index ()
249
250
return False
250
251
252
+ def get_pserver_endpoints (self ):
253
+ if self ._pserver_rand_port <= 0 :
254
+ import random
255
+ random .seed (self ._server_num ())
256
+ # port will be randomly generated from 60001 to 63999
257
+ # random seed is server num so that all nodes will get
258
+ # the same port
259
+ self ._pserver_rand_port = random .randint (60001 , 64000 )
260
+ endpoints = [
261
+ x + ":" + str (self ._pserver_rand_port )
262
+ for x in self ._server_endpoints
263
+ ]
264
+ return endpoints
265
+
251
266
def worker_num (self ):
252
267
return self ._worker_num ()
253
268
@@ -273,33 +288,38 @@ def _worker_num(self):
273
288
"""
274
289
if self ._check_role_generation ():
275
290
if self .is_worker ():
276
- return self ._get_size () / 2
291
+ return self ._get_size () / self . _proc_per_node
277
292
return 0
278
293
279
294
def _server_num (self ):
280
295
"""
281
296
return the current number of server
282
297
"""
283
298
if self ._check_role_generation ():
284
- if self .is_server ():
285
- return self ._get_size () / 2
286
- return 0
299
+ return self ._get_size () / self ._proc_per_node
300
+ else :
301
+ self .generate_role ()
302
+ return self ._get_size () / self ._proc_per_node
287
303
288
304
def worker_index (self ):
289
305
"""
290
306
return the index of worker
291
307
"""
292
308
if self ._check_role_generation ():
293
309
return self ._rank / self ._proc_per_node
294
- return 0
310
+ else :
311
+ self .generate_role ()
312
+ return self ._get_size () / 2
295
313
296
314
def server_index (self ):
297
315
"""
298
316
return the index of server
299
317
"""
300
318
if self ._check_role_generation ():
301
319
return self ._rank / self ._proc_per_node
302
- return 0
320
+ else :
321
+ self .generate_role ()
322
+ return self ._get_size () / self ._proc_per_node
303
323
304
324
def _barrier_worker (self ):
305
325
"""
@@ -308,6 +328,8 @@ def _barrier_worker(self):
308
328
if self ._check_role_generation ():
309
329
if self .is_worker ():
310
330
self ._node_type_comm .barrier ()
331
+ else :
332
+ raise Exception ("You should check role generation first" )
311
333
312
334
def _barrier_server (self ):
313
335
"""
@@ -316,6 +338,8 @@ def _barrier_server(self):
316
338
if self ._check_role_generation ():
317
339
if self .is_server ():
318
340
self ._node_type_comm .barrier ()
341
+ else :
342
+ raise Exception ("You should check role generation first" )
319
343
320
344
def generate_role (self ):
321
345
"""
@@ -332,6 +356,8 @@ def generate_role(self):
332
356
self ._node_type = 1
333
357
self ._node_type_comm = self ._comm .Split (self ._node_type )
334
358
self ._role_is_generated = True
359
+ else :
360
+ raise Exception ("You should check role generation first" )
335
361
336
362
337
363
class PaddleCloudRoleMaker (RoleMakerBase ):
0 commit comments