14
14
from __future__ import annotations
15
15
16
16
import copy
17
+ import os
17
18
import warnings
18
19
from typing import (
19
20
TYPE_CHECKING ,
@@ -655,6 +656,24 @@ def amp_guard(
655
656
and not amp_global_state ().already_register_final_backward_hook
656
657
):
657
658
659
+ def _dtensor_from_local (local_tensor , mesh , placements ):
660
+ global_dims = list (local_tensor .shape )
661
+ for idx , placement in enumerate (placements ):
662
+ if placement .is_shard ():
663
+ global_dims [placement .get_dim ()] = (
664
+ global_dims [placement .get_dim ()] * mesh .shape [idx ]
665
+ )
666
+ place = paddle .framework ._current_expected_place ()
667
+ place = paddle .framework ._get_paddle_place (place )
668
+
669
+ return paddle .Tensor (
670
+ local_tensor ,
671
+ dims = global_dims ,
672
+ process_mesh = mesh ,
673
+ placements = placements ,
674
+ place = place ,
675
+ )
676
+
658
677
def master_grad_hook ():
659
678
# NOTE(lizhiyu): To support semi-auto of dygraph mode, we must
660
679
# classify the params of model into different classes according to their process_mesh.
@@ -674,17 +693,52 @@ def master_grad_hook():
674
693
param .process_mesh
675
694
].append (param )
676
695
amp_global_state ().already_classify_params_meshes = True
677
-
678
- if len (amp_global_state ().mesh2params ):
679
- for _ , params in amp_global_state ().mesh2params .items ():
680
- core .eager .set_master_grads (params )
681
- else :
682
- core .eager .set_master_grads (
683
- amp_global_state ().model_parameters
684
- )
696
+ if os .getenv ("FLAGS_enable_tensor_fusion" ) not in [
697
+ "True" ,
698
+ "true" ,
699
+ "1" ,
700
+ ]:
701
+ if len (amp_global_state ().mesh2params ):
702
+ for _ , params in amp_global_state ().mesh2params .items ():
703
+ core .eager .set_master_grads (params )
704
+ else :
705
+ core .eager .set_master_grads (
706
+ amp_global_state ().model_parameters
707
+ )
685
708
686
709
amp_global_state ().already_register_final_backward_hook = False
687
710
711
+ def _update_main_grad_hook (param ):
712
+ @paddle .autograd .no_grad ()
713
+ def param_hook (tmp_grad ):
714
+ if tmp_grad is not None and tmp_grad ._is_initialized ():
715
+ if param .main_grad is None :
716
+ tmp = core .eager .Tensor (
717
+ value = tmp_grad ._local_value ()
718
+ .cast (paddle .float32 )
719
+ .value (),
720
+ place = tmp_grad .place ,
721
+ name = "main_grad@" + param .name ,
722
+ )
723
+ param .main_grad = _dtensor_from_local (
724
+ tmp ,
725
+ tmp_grad .process_mesh ,
726
+ tmp_grad .placements ,
727
+ )
728
+ else :
729
+ param .main_grad ._local_value ().add_ (
730
+ tmp_grad ._local_value ()
731
+ )
732
+ tmp_grad ._clear_data ()
733
+
734
+ return param_hook
735
+
736
+ if os .getenv ("FLAGS_enable_tensor_fusion" ) in ["True" , "true" , "1" ]:
737
+ for param in amp_global_state ().model_parameters :
738
+ if not hasattr (param , "main_grad" ):
739
+ param .main_grad = None
740
+ param ._register_grad_hook (_update_main_grad_hook (param ))
741
+
688
742
core .eager ._add_backward_final_hook (master_grad_hook )
689
743
amp_global_state ().already_register_final_backward_hook = True
690
744
0 commit comments