Skip to content

Commit 5b357e0

Browse files
haohongxiangForFisheszhaoyinglia
authored
[cherry-pick]Support FP16 in HybridParallel and Fix bugs in HybridOptimizer (#36707)
* fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer (#36237) * fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer * update * update * fix bugs in mp_layers、pp_layers and HybridParallelClipGrad (#36144) * fix calling bug of HybridParallelClipGrad * fix bugs of HybridParallelClipGrad * add unittest of pp with HybridParallelClipGrad * fix bugs in mp_layers.py * update * fix bugs in pp_layers.py * update * [HybridParallel]Rebuild code for pipeline (#36396) * add no_sync for parameters sync * add pipeline for moe * [HybridParallel]Support fp16 in dygraph hybrid parallel (#36420) * [HybridParallel]Support fp16 in dygraph hybrid parallel * update * update * update for recompute * add unittest of pp+fp16 * add unittest of recompute+fp16 * update * modify ut * modify ut of cond (#36475) * fix bugs of ClipGradByGlobalNorm in HybridParallel (#36555) * fix bugs of ClipGradByGlobalNorm * add unittests * add unittests * [HybridParallel]fix bug of check_inf in fleet_base.py (#36651) * fix bug of check_inf * fix allreduce * support ClipGradByGlobalNorm in sharding (#36012) * support ClipGradByGlobalNorm in sharding * support ClipGradByGlobalNorm in sharding * test=allcase * Update test_linalg_cond.py * Update hybrid_parallel_util.py * Update hybrid_parallel_util.py Co-authored-by: ShenLiang <1422485404@qq.com> Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
1 parent 77034fc commit 5b357e0

19 files changed

+533
-94
lines changed

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from ..meta_parallel import PipelineParallel, ShardingParallel
3636
from ..meta_optimizers import HybridParallelOptimizer
3737
from paddle import _C_ops
38+
from paddle.fluid import core
39+
from paddle.fluid.dygraph import to_variable
3840

3941
__all__ = []
4042

@@ -1547,26 +1549,52 @@ def unscale_method(self, optimizer):
15471549
if getattr(optimizer, '_param_groups', None) and isinstance(
15481550
optimizer._param_groups[0], dict):
15491551
param_grads = []
1552+
param_grads_fp16 = []
1553+
param_grads_fp32 = []
15501554
for group in optimizer._param_groups:
15511555
for param in group['params']:
15521556
if param._grad_ivar() is not None:
15531557
param_grads.append(param._grad_ivar())
1558+
if param._grad_ivar(
1559+
).dtype == core.VarDesc.VarType.FP16:
1560+
param_grads_fp16.append(param._grad_ivar())
1561+
else:
1562+
param_grads_fp32.append(param._grad_ivar())
15541563
else:
15551564
param_grads = [
15561565
param._grad_ivar() for param in optimizer._parameter_list
15571566
if param._grad_ivar() is not None
15581567
]
1559-
_C_ops.check_finite_and_unscale(param_grads, self._scale,
1560-
param_grads, self._found_inf)
1561-
1562-
self._found_inf = paddle.cast(self._found_inf, dtype="int32")
1568+
param_grads_fp16 = [
1569+
param._grad_ivar() for param in optimizer._parameter_list
1570+
if (param._grad_ivar() is not None) and (param._grad_ivar(
1571+
).dtype == core.VarDesc.VarType.FP16)
1572+
]
1573+
param_grads_fp32 = [
1574+
param._grad_ivar() for param in optimizer._parameter_list
1575+
if (param._grad_ivar() is not None) and (param._grad_ivar(
1576+
).dtype == core.VarDesc.VarType.FP32)
1577+
]
1578+
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
1579+
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
1580+
if len(param_grads_fp16):
1581+
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
1582+
param_grads_fp16,
1583+
temp_found_inf_fp16)
1584+
if len(param_grads_fp32):
1585+
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
1586+
param_grads_fp32,
1587+
temp_found_inf_fp32)
1588+
1589+
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
1590+
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
15631591

15641592
# TODO(shenliang03) Since dp allreduce in the optimizer is
15651593
# after the gradscaler, check_finite needs to synchronize global
15661594
# information. In the future, we should use check_group to speed.
15671595
paddle.distributed.all_reduce(
1568-
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
1569-
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
1596+
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
1597+
self._found_inf = is_found_inf.numpy()[0]
15701598

15711599
# Only tensor_parallel and pipeline_parallel need to modify scaler
15721600
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
from .hybrid_parallel_optimizer import HybridParallelOptimizer
1414
from .hybrid_parallel_gradscaler import HybridParallelGradScaler
15+
from .dygraph_sharding_optimizer import DygraphShardingOptimizer
1516

1617
__all__ = []

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ def __init__(self, clip, hcg):
5050
@imperative_base.no_grad
5151
def _dygraph_clip(self, params_grads):
5252
params_and_grads = []
53-
sum_square_list = []
53+
54+
sum_square_dist_fp16 = []
55+
sum_square_dist_fp32 = []
56+
sum_square_not_dist_fp16 = []
57+
sum_square_not_dist_fp32 = []
58+
5459
for p, g in params_grads:
5560
if g is None:
5661
continue
@@ -62,32 +67,98 @@ def _dygraph_clip(self, params_grads):
6267
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
6368
square = layers.square(merge_grad)
6469
sum_square = layers.reduce_sum(square)
65-
sum_square_list.append(sum_square)
6670

67-
# all parameters have been filterd out
68-
if len(sum_square_list) == 0:
69-
return params_grads
70-
71-
global_norm_var = layers.concat(sum_square_list)
72-
global_norm_var = layers.reduce_sum(global_norm_var)
73-
# add all reduce to get global norm in world size
74-
paddle.distributed.all_reduce(global_norm_var,
75-
self._hcg.get_check_parallel_group())
76-
global_norm_var = layers.sqrt(global_norm_var)
71+
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
72+
hasattr(p, 'is_firstly_shared') and
73+
getattr(p, 'is_firstly_shared', True))
74+
75+
if not_shared_enable:
76+
if p.is_distributed:
77+
if p.dtype == paddle.float16:
78+
sum_square_dist_fp16.append(sum_square)
79+
elif p.dtype == paddle.float32:
80+
sum_square_dist_fp32.append(sum_square)
81+
else:
82+
if p.dtype == paddle.float16:
83+
sum_square_not_dist_fp16.append(sum_square)
84+
elif p.dtype == paddle.float32:
85+
sum_square_not_dist_fp32.append(sum_square)
86+
87+
# global norm of distributed FP16 params_and_grads
88+
if len(sum_square_dist_fp16) == 0:
89+
global_norm_dist_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
90+
else:
91+
global_norm_dist_fp16 = layers.concat(sum_square_dist_fp16)
92+
global_norm_dist_fp16 = layers.reduce_sum(global_norm_dist_fp16)
93+
global_norm_dist_fp16 = paddle.cast(
94+
global_norm_dist_fp16, dtype=paddle.float32)
95+
96+
# global norm of non-distributed FP16 params_and_grads
97+
if len(sum_square_not_dist_fp16) == 0:
98+
global_norm_not_dist_fp16 = paddle.to_tensor(
99+
[0.], dtype=paddle.float32)
100+
else:
101+
global_norm_not_dist_fp16 = layers.concat(sum_square_not_dist_fp16)
102+
global_norm_not_dist_fp16 = layers.reduce_sum(
103+
global_norm_not_dist_fp16)
104+
global_norm_not_dist_fp16 = paddle.cast(
105+
global_norm_not_dist_fp16, dtype=paddle.float32)
106+
107+
# global norm of distributed FP32 params_and_grads
108+
global_norm_dist_fp32 = layers.concat(sum_square_dist_fp32) if len(
109+
sum_square_dist_fp32) != 0 else paddle.to_tensor(
110+
[0.], dtype=paddle.float32)
111+
global_norm_dist_fp32 = layers.reduce_sum(global_norm_dist_fp32)
112+
113+
# global norm of non-distributed FP32 params_and_grads
114+
global_norm_not_dist_fp32 = layers.concat(
115+
sum_square_not_dist_fp32) if len(
116+
sum_square_not_dist_fp32) != 0 else paddle.to_tensor(
117+
[0.], dtype=paddle.float32)
118+
global_norm_not_dist_fp32 = layers.reduce_sum(global_norm_not_dist_fp32)
119+
120+
global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_fp32
121+
global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_fp32
122+
123+
# add all reduce to get global norm of distributed params_and_grads
124+
if self._hcg.get_model_parallel_world_size() > 1:
125+
paddle.distributed.all_reduce(
126+
global_norm_var_dist,
127+
group=self._hcg.get_check_parallel_group())
128+
129+
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
130+
if self._hcg.get_pipe_parallel_world_size() > 1:
131+
paddle.distributed.all_reduce(
132+
global_norm_var_not_dist,
133+
group=self._hcg.get_pipe_parallel_group())
134+
135+
# In Sharding mode, param and grad is mapping different rank in optimizer.
136+
# ClipGradByGlobalNorm need allreduce to get globol norm
137+
if self._hcg.get_sharding_parallel_world_size() > 1:
138+
paddle.distributed.all_reduce(
139+
global_norm_var_not_dist,
140+
group=self._hcg.get_sharding_parallel_group())
141+
142+
global_norm_var_fp32 = layers.sqrt(global_norm_var_dist +
143+
global_norm_var_not_dist)
77144

78145
max_global_norm = layers.fill_constant(
79-
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
146+
shape=[1], dtype=global_norm_var_fp32.dtype, value=self.clip_norm)
80147
clip_var = layers.elementwise_div(
81148
x=max_global_norm,
82149
y=layers.elementwise_max(
83-
x=global_norm_var, y=max_global_norm))
150+
x=global_norm_var_fp32, y=max_global_norm))
151+
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
84152
for p, g in params_grads:
85153
if g is None:
86154
continue
87155
if getattr(p, 'need_clip', True) is False:
88156
params_and_grads.append((p, g))
89157
continue
90-
new_grad = layers.elementwise_mul(x=g, y=clip_var)
158+
if p.dtype == paddle.float16:
159+
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
160+
else:
161+
new_grad = layers.elementwise_mul(x=g, y=clip_var)
91162
params_and_grads.append((p, new_grad))
92163

93164
return params_and_grads
@@ -96,7 +167,7 @@ def __getattr__(self, item):
96167
return getattr(self._clip, item)
97168

98169
def __call__(self, params_grads):
99-
return self._clip(params_grads)
170+
return self._dygraph_clip(params_grads)
100171

101172

102173
class HybridParallelOptimizer:
@@ -112,19 +183,24 @@ def __init__(self, optimizer, hcg, strategy):
112183
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
113184

114185
# NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization
115-
# is achieved through reducer, so there is no need to call fuse_allreduce in oprimizer.
186+
# is achieved through reducer, so there is no need to call fuse_allreduce in optimizer.
116187
self._dp_enable = not self._use_dp_mode and self._need_dp
117188

118189
self._sharding_enable = (
119190
self._hcg.get_sharding_parallel_world_size() > 1)
120191

121192
if isinstance(self._inner_opt._grad_clip,
122193
ClipGradByGlobalNorm) and not self._use_dp_mode:
123-
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
124-
"optmizer'grad clip will be changed.")
125-
126-
self._inner_opt._grad_clip = HybridParallelClipGrad(
127-
self._inner_opt._grad_clip, hcg)
194+
logger.warning("While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " \
195+
"or Sharding, the grad clip of original optimizer will be changed.")
196+
197+
if self._sharding_enable:
198+
# change sharding inner_optimizer's _grad_clip
199+
self._inner_opt._inner_optimizer._grad_clip = HybridParallelClipGrad(
200+
self._inner_opt._grad_clip, hcg)
201+
else:
202+
self._inner_opt._grad_clip = HybridParallelClipGrad(
203+
self._inner_opt._grad_clip, hcg)
128204

129205
@imperative_base.no_grad
130206
@framework.dygraph_only

python/paddle/distributed/fleet/meta_parallel/parallel_layers/mp_layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self,
7070
dtype=self._dtype,
7171
is_bias=False)
7272

73-
self.weight.is_distributed = True
73+
self.weight.is_distributed = True if self.is_mp else False
7474

7575
def forward(self, x):
7676
if self.is_mp:
@@ -135,7 +135,7 @@ def __init__(self,
135135
dtype=self._dtype,
136136
is_bias=False)
137137

138-
self.weight.is_distributed = True
138+
self.weight.is_distributed = True if self.is_mp else False
139139

140140
if has_bias:
141141
# initialize bias to zero like Megatron
@@ -144,7 +144,7 @@ def __init__(self,
144144
attr=paddle.nn.initializer.Constant(value=0.0),
145145
dtype=self._dtype,
146146
is_bias=True)
147-
self.bias.is_distributed = True
147+
self.bias.is_distributed = True if self.is_mp else False
148148
else:
149149
self.bias = None
150150

@@ -212,7 +212,7 @@ def __init__(self,
212212
dtype=self._dtype,
213213
is_bias=False)
214214

215-
self.weight.is_distributed = True
215+
self.weight.is_distributed = True if self.is_mp else False
216216

217217
if has_bias:
218218
self.bias = self.create_parameter(

python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,10 @@ def _synchronize_shared_weights(self):
261261
src=min(comm['ranks']),
262262
group=comm['group'])
263263

264+
for param in comm['layer'].parameters():
265+
if self.global_rank != min(comm['ranks']):
266+
setattr(param, 'is_firstly_shared', False)
267+
264268
def allreduce_shared_weight_gradients(self):
265269
for key, comm in self.shared_comm.items():
266270
param = getattr(self.shared_layers[key], comm['weight_attr'])
@@ -316,6 +320,9 @@ def _build_layer(self):
316320
self.shared_layers[layer.layer_name] = layer.build_layer()
317321
self.shared_weight_attrs[
318322
layer.layer_name] = layer.shared_weight_attr
323+
for param in self.shared_layers[
324+
layer.layer_name].parameters():
325+
setattr(param, "is_firstly_shared", True)
319326

320327
if layer.forward_func is None:
321328
self.run_function.append(self.shared_layers[

0 commit comments

Comments
 (0)