@@ -655,6 +655,33 @@ def amp_guard(
655
655
and not amp_global_state ().already_register_final_backward_hook
656
656
):
657
657
658
+ def _dtensor_from_local (
659
+ local_tensor , mesh , placements , local_tensor_shape = None
660
+ ):
661
+ # assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape
662
+ global_dims = list (local_tensor .shape )
663
+ if local_tensor_shape is not None :
664
+ global_dims = local_tensor_shape
665
+ for idx , placement in enumerate (placements ):
666
+ if placement .is_shard ():
667
+ shard_dim = placement .get_dim ()
668
+ local_dim_size = global_dims [shard_dim ]
669
+ global_dims [shard_dim ] = (
670
+ local_dim_size * mesh .shape [idx ]
671
+ )
672
+
673
+ if paddle .in_dynamic_mode ():
674
+ place = paddle .framework ._current_expected_place ()
675
+ place = paddle .framework ._get_paddle_place (place )
676
+
677
+ return paddle .Tensor (
678
+ local_tensor ,
679
+ dims = global_dims ,
680
+ process_mesh = mesh ,
681
+ placements = placements ,
682
+ place = place ,
683
+ )
684
+
658
685
def master_grad_hook ():
659
686
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
660
687
# classify the params of model into different classes according to their process_mesh.
@@ -677,7 +704,27 @@ def master_grad_hook():
677
704
678
705
if len (amp_global_state ().mesh2params ):
679
706
for _ , params in amp_global_state ().mesh2params .items ():
680
- core .eager .set_master_grads (params )
707
+ for param in params :
708
+ tmp_grad = param ._grad_ivar ()
709
+ if param .main_grad is None :
710
+ tmp = core .eager .Tensor (
711
+ value = tmp_grad ._local_value ()
712
+ .cast (paddle .float32 )
713
+ .value (),
714
+ place = tmp_grad .place ,
715
+ name = "main_grad@" + param .name ,
716
+ )
717
+ param .main_grad = _dtensor_from_local (
718
+ tmp ,
719
+ tmp_grad .process_mesh ,
720
+ tmp_grad .placements ,
721
+ )
722
+ else :
723
+ param .main_grad ._local_value ().add_ (
724
+ tmp_grad ._local_value ()
725
+ )
726
+
727
+ # core.eager.set_master_grads(params)
681
728
else :
682
729
core .eager .set_master_grads (
683
730
amp_global_state ().model_parameters
0 commit comments