@@ -50,7 +50,12 @@ def __init__(self, clip, hcg):
50
50
@imperative_base .no_grad
51
51
def _dygraph_clip (self , params_grads ):
52
52
params_and_grads = []
53
- sum_square_list = []
53
+
54
+ sum_square_dist_fp16 = []
55
+ sum_square_dist_fp32 = []
56
+ sum_square_not_dist_fp16 = []
57
+ sum_square_not_dist_fp32 = []
58
+
54
59
for p , g in params_grads :
55
60
if g is None :
56
61
continue
@@ -62,32 +67,98 @@ def _dygraph_clip(self, params_grads):
62
67
merge_grad = layers .get_tensor_from_selected_rows (merge_grad )
63
68
square = layers .square (merge_grad )
64
69
sum_square = layers .reduce_sum (square )
65
- sum_square_list .append (sum_square )
66
70
67
- # all parameters have been filterd out
68
- if len (sum_square_list ) == 0 :
69
- return params_grads
70
-
71
- global_norm_var = layers .concat (sum_square_list )
72
- global_norm_var = layers .reduce_sum (global_norm_var )
73
- # add all reduce to get global norm in world size
74
- paddle .distributed .all_reduce (global_norm_var ,
75
- self ._hcg .get_check_parallel_group ())
76
- global_norm_var = layers .sqrt (global_norm_var )
71
+ not_shared_enable = (not hasattr (p , 'is_firstly_shared' )) or (
72
+ hasattr (p , 'is_firstly_shared' ) and
73
+ getattr (p , 'is_firstly_shared' , True ))
74
+
75
+ if not_shared_enable :
76
+ if p .is_distributed :
77
+ if p .dtype == paddle .float16 :
78
+ sum_square_dist_fp16 .append (sum_square )
79
+ elif p .dtype == paddle .float32 :
80
+ sum_square_dist_fp32 .append (sum_square )
81
+ else :
82
+ if p .dtype == paddle .float16 :
83
+ sum_square_not_dist_fp16 .append (sum_square )
84
+ elif p .dtype == paddle .float32 :
85
+ sum_square_not_dist_fp32 .append (sum_square )
86
+
87
+ # global norm of distributed FP16 params_and_grads
88
+ if len (sum_square_dist_fp16 ) == 0 :
89
+ global_norm_dist_fp16 = paddle .to_tensor ([0. ], dtype = paddle .float32 )
90
+ else :
91
+ global_norm_dist_fp16 = layers .concat (sum_square_dist_fp16 )
92
+ global_norm_dist_fp16 = layers .reduce_sum (global_norm_dist_fp16 )
93
+ global_norm_dist_fp16 = paddle .cast (
94
+ global_norm_dist_fp16 , dtype = paddle .float32 )
95
+
96
+ # global norm of non-distributed FP16 params_and_grads
97
+ if len (sum_square_not_dist_fp16 ) == 0 :
98
+ global_norm_not_dist_fp16 = paddle .to_tensor (
99
+ [0. ], dtype = paddle .float32 )
100
+ else :
101
+ global_norm_not_dist_fp16 = layers .concat (sum_square_not_dist_fp16 )
102
+ global_norm_not_dist_fp16 = layers .reduce_sum (
103
+ global_norm_not_dist_fp16 )
104
+ global_norm_not_dist_fp16 = paddle .cast (
105
+ global_norm_not_dist_fp16 , dtype = paddle .float32 )
106
+
107
+ # global norm of distributed FP32 params_and_grads
108
+ global_norm_dist_fp32 = layers .concat (sum_square_dist_fp32 ) if len (
109
+ sum_square_dist_fp32 ) != 0 else paddle .to_tensor (
110
+ [0. ], dtype = paddle .float32 )
111
+ global_norm_dist_fp32 = layers .reduce_sum (global_norm_dist_fp32 )
112
+
113
+ # global norm of non-distributed FP32 params_and_grads
114
+ global_norm_not_dist_fp32 = layers .concat (
115
+ sum_square_not_dist_fp32 ) if len (
116
+ sum_square_not_dist_fp32 ) != 0 else paddle .to_tensor (
117
+ [0. ], dtype = paddle .float32 )
118
+ global_norm_not_dist_fp32 = layers .reduce_sum (global_norm_not_dist_fp32 )
119
+
120
+ global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_fp32
121
+ global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_fp32
122
+
123
+ # add all reduce to get global norm of distributed params_and_grads
124
+ if self ._hcg .get_model_parallel_world_size () > 1 :
125
+ paddle .distributed .all_reduce (
126
+ global_norm_var_dist ,
127
+ group = self ._hcg .get_check_parallel_group ())
128
+
129
+ # add all reduce to get global norm of non-distributed params_and_grads in groups of pp
130
+ if self ._hcg .get_pipe_parallel_world_size () > 1 :
131
+ paddle .distributed .all_reduce (
132
+ global_norm_var_not_dist ,
133
+ group = self ._hcg .get_pipe_parallel_group ())
134
+
135
+ # In Sharding mode, param and grad is mapping different rank in optimizer.
136
+ # ClipGradByGlobalNorm need allreduce to get globol norm
137
+ if self ._hcg .get_sharding_parallel_world_size () > 1 :
138
+ paddle .distributed .all_reduce (
139
+ global_norm_var_not_dist ,
140
+ group = self ._hcg .get_sharding_parallel_group ())
141
+
142
+ global_norm_var_fp32 = layers .sqrt (global_norm_var_dist +
143
+ global_norm_var_not_dist )
77
144
78
145
max_global_norm = layers .fill_constant (
79
- shape = [1 ], dtype = global_norm_var .dtype , value = self .clip_norm )
146
+ shape = [1 ], dtype = global_norm_var_fp32 .dtype , value = self .clip_norm )
80
147
clip_var = layers .elementwise_div (
81
148
x = max_global_norm ,
82
149
y = layers .elementwise_max (
83
- x = global_norm_var , y = max_global_norm ))
150
+ x = global_norm_var_fp32 , y = max_global_norm ))
151
+ clip_var_fp16 = paddle .cast (clip_var , paddle .float16 )
84
152
for p , g in params_grads :
85
153
if g is None :
86
154
continue
87
155
if getattr (p , 'need_clip' , True ) is False :
88
156
params_and_grads .append ((p , g ))
89
157
continue
90
- new_grad = layers .elementwise_mul (x = g , y = clip_var )
158
+ if p .dtype == paddle .float16 :
159
+ new_grad = layers .elementwise_mul (x = g , y = clip_var_fp16 )
160
+ else :
161
+ new_grad = layers .elementwise_mul (x = g , y = clip_var )
91
162
params_and_grads .append ((p , new_grad ))
92
163
93
164
return params_and_grads
@@ -96,7 +167,7 @@ def __getattr__(self, item):
96
167
return getattr (self ._clip , item )
97
168
98
169
def __call__ (self , params_grads ):
99
- return self ._clip (params_grads )
170
+ return self ._dygraph_clip (params_grads )
100
171
101
172
102
173
class HybridParallelOptimizer :
@@ -112,19 +183,24 @@ def __init__(self, optimizer, hcg, strategy):
112
183
self ._need_dp = (self ._hcg .get_data_parallel_world_size () > 1 )
113
184
114
185
# NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization
115
- # is achieved through reducer, so there is no need to call fuse_allreduce in oprimizer .
186
+ # is achieved through reducer, so there is no need to call fuse_allreduce in optimizer .
116
187
self ._dp_enable = not self ._use_dp_mode and self ._need_dp
117
188
118
189
self ._sharding_enable = (
119
190
self ._hcg .get_sharding_parallel_world_size () > 1 )
120
191
121
192
if isinstance (self ._inner_opt ._grad_clip ,
122
193
ClipGradByGlobalNorm ) and not self ._use_dp_mode :
123
- logger .warning ("using ClipGradByGlobalNorm in TensorParallel, the origin " \
124
- "optmizer'grad clip will be changed." )
125
-
126
- self ._inner_opt ._grad_clip = HybridParallelClipGrad (
127
- self ._inner_opt ._grad_clip , hcg )
194
+ logger .warning ("While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " \
195
+ "or Sharding, the grad clip of original optimizer will be changed." )
196
+
197
+ if self ._sharding_enable :
198
+ # change sharding inner_optimizer's _grad_clip
199
+ self ._inner_opt ._inner_optimizer ._grad_clip = HybridParallelClipGrad (
200
+ self ._inner_opt ._grad_clip , hcg )
201
+ else :
202
+ self ._inner_opt ._grad_clip = HybridParallelClipGrad (
203
+ self ._inner_opt ._grad_clip , hcg )
128
204
129
205
@imperative_base .no_grad
130
206
@framework .dygraph_only
0 commit comments