@@ -116,6 +116,20 @@ def __init__(self, optimizer, hcg):
116
116
self ._rank2params = self ._partition_parameters ()
117
117
self ._param2rank = self ._map_param_to_rank ()
118
118
119
+ self ._broadcast_overlap = False
120
+ self ._forward_pre_hook_remove_helper = []
121
+
122
+ try :
123
+ # The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
124
+ # Have to sort the params to make sure all params are in the forward using order.
125
+ self ._broadcast_order_params = sorted (
126
+ self ._parameter_list ,
127
+ key = lambda x : int (x .name .split ('.' )[0 ].split ('_' )[- 1 ]),
128
+ )
129
+
130
+ except ValueError :
131
+ self ._broadcast_order_params = None
132
+
119
133
if not self .tensor_fusion and not self .comm_overlap :
120
134
local_params = self ._rank2params [self ._sharding_rank ]
121
135
self ._set_inner_opt_attr ('_parameter_list' , local_params )
@@ -318,6 +332,13 @@ def reduce_gradients(self, parameter_list, hcg):
318
332
sync_op = True ,
319
333
)
320
334
335
+ def _forward_pre_hook_function (self , tasks ):
336
+ def __impl__ (x , y ):
337
+ for task in tasks :
338
+ task .wait ()
339
+
340
+ return __impl__
341
+
321
342
def _sharding_sync_parameters (self ):
322
343
"""
323
344
Synchronize parameter across sharding group efficiently.
@@ -334,27 +355,57 @@ def _sharding_sync_parameters(self):
334
355
sharding_group_ranks = self ._hcg .get_sharding_parallel_group ().ranks
335
356
336
357
broadcast_tasks = []
337
- for rank , params in valid_rank_to_params .items ():
338
- # Compute the global source rank only once per each rank's set of parameters
339
- src_rank = sharding_group_ranks [rank ]
340
-
341
- for param in params :
342
- # NOTE: We should check if the parameter is trainable, because some parameters
343
- # (e.g., freeze the parameters for training) are not trainable and should
344
- # not be broadcasted.
345
- g_var = self ._get_param_grad (param )
346
- if g_var is not None :
358
+ if self ._broadcast_overlap :
359
+ param2task = {}
360
+
361
+ group = self ._hcg .get_sharding_parallel_group ()
362
+ for param in self ._broadcast_order_params :
363
+ if param .trainable :
347
364
task = paddle .distributed .broadcast (
348
- param ,
349
- src = src_rank ,
350
- group = self . _hcg . get_sharding_parallel_group () ,
365
+ tensor = param ,
366
+ src = group . ranks [ self . _param2rank [ param . name ]] ,
367
+ group = group ,
351
368
sync_op = False ,
352
369
)
353
- broadcast_tasks .append (task )
370
+ assert param .name not in param2task
371
+ param2task [param .name ] = task
372
+
373
+ for layer in self ._layers .sublayers ():
374
+ if len (layer .sublayers ()) == 0 :
375
+ # Register forward pre hood for leaf layers. This will get the best performance.
376
+ tasks = []
377
+ for param in layer .parameters ():
378
+ if param .trainable :
379
+ if param .name in param2task :
380
+ tasks .append (param2task [param .name ])
381
+ self ._forward_pre_hook_remove_helper .append (
382
+ layer .register_forward_pre_hook (
383
+ self ._forward_pre_hook_function (tasks )
384
+ )
385
+ )
354
386
355
- # Wait for all async broadcast tasks to complete
356
- for task in broadcast_tasks :
357
- task .wait ()
387
+ else :
388
+ for rank , params in valid_rank_to_params .items ():
389
+ # Compute the global source rank only once per each rank's set of parameters
390
+ src_rank = sharding_group_ranks [rank ]
391
+
392
+ for param in params :
393
+ # NOTE: We should check if the parameter is trainable, because some parameters
394
+ # (e.g., freeze the parameters for training) are not trainable and should
395
+ # not be broadcasted.
396
+ g_var = self ._get_param_grad (param )
397
+ if g_var is not None :
398
+ task = paddle .distributed .broadcast (
399
+ param ,
400
+ src = src_rank ,
401
+ group = self ._hcg .get_sharding_parallel_group (),
402
+ sync_op = False ,
403
+ )
404
+ broadcast_tasks .append (task )
405
+
406
+ # Wait for all async broadcast tasks to complete
407
+ for task in broadcast_tasks :
408
+ task .wait ()
358
409
359
410
def _update_trainable (self ):
360
411
"""
@@ -384,10 +435,33 @@ def minimize(
384
435
385
436
return result
386
437
438
+ def _set_broadcast_overlap (self , broadcast_overlap , layers = None ):
439
+ self ._broadcast_overlap = broadcast_overlap
440
+ if self ._broadcast_overlap :
441
+ assert (
442
+ layers is not None
443
+ ), "To Enable Stage1 Optimizer Broadcast Overlap Forward, layers cannot be None"
444
+ self ._layers = layers
445
+ warnings .warn (
446
+ r"Setting overlap broadcast implies that `paddle.device.cuda.synchronize()` must be manually invoked before calling `paddle.save()` and prior to inference"
447
+ )
448
+
449
+ if self ._broadcast_order_params is None :
450
+ warnings .warn (
451
+ r"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
452
+ "overlap broadcast may harm the performance."
453
+ )
454
+ self ._broadcast_order_params = self ._parameter_list
455
+
387
456
@imperative_base .no_grad
388
457
@framework .dygraph_only
389
458
def step (self ):
390
459
# TODO Check whether the model trainable param changed and update state accordingly
460
+ if self ._broadcast_overlap :
461
+ # Clear the pre forward hook in the optimizer step.
462
+ for hook_remove in self ._forward_pre_hook_remove_helper :
463
+ hook_remove .remove ()
464
+ self ._forward_pre_hook_remove_helper = []
391
465
392
466
target_param_list = (
393
467
self ._origin_parameter_list
0 commit comments