Skip to content

Commit 25bace5

Browse files
authored
sharding stage 1 check diff lr and use param decay fn (#59537)
1 parent e3caa7c commit 25bace5

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
######
1616
import os
17+
import warnings
1718
from distutils.util import strtobool
1819
from functools import reduce
1920

@@ -115,6 +116,23 @@ def __init__(self, optimizer, hcg):
115116
self._set_inner_opt_attr('_parameter_list', local_params)
116117
self._set_inner_opt_attr('_param_groups', local_params)
117118
else:
119+
if self.fuse_optimizer:
120+
lr = None
121+
for param in self._origin_parameter_list:
122+
if hasattr(param, "optimize_attr"):
123+
param_lr = param.optimize_attr['learning_rate']
124+
if lr is None:
125+
lr = param_lr
126+
elif lr != param_lr:
127+
warnings.warn(
128+
"Parameters have different learning rate, "
129+
"won't do fusion on the optimizer."
130+
)
131+
self.fuse_optimizer = False
132+
break
133+
self.origin_decay_param_fun = getattr(
134+
self._inner_opt, '_apply_decay_param_fun', None
135+
)
118136
self._tensor_fusion()
119137

120138
decay_params = [
@@ -138,10 +156,7 @@ def __init__(self, optimizer, hcg):
138156
# Without comm overlap, all grads will be communicated after check_finite,
139157
# which means each sharding rank should do check_finite to all grads.
140158
self._local_parameter_list = local_fused_params
141-
origin_decay_param_fun = getattr(
142-
self._inner_opt, '_apply_decay_param_fun', None
143-
)
144-
if origin_decay_param_fun is not None:
159+
if self.origin_decay_param_fun is not None:
145160
self._set_inner_opt_attr(
146161
'_apply_decay_param_fun', apply_decay_param_fun
147162
)
@@ -191,6 +206,7 @@ def _tensor_fusion(self):
191206
dst=dst,
192207
acc_step=self.accumulate_steps,
193208
scale_after_comm=False,
209+
apply_decay_param_fun=self.origin_decay_param_fun,
194210
)
195211
if self.comm_overlap:
196212
self._comm_buffers += all_buffer

python/paddle/distributed/fleet/utils/tensor_fusion_helper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def _fused_parameters_impl(
550550
dst=-1,
551551
acc_step=1,
552552
scale_after_comm=False,
553+
apply_decay_param_fun=None,
553554
):
554555
param_groups = []
555556
attrs = []
@@ -579,7 +580,9 @@ def _fused_parameters_impl(
579580
other_params = []
580581

581582
for param in params:
582-
if not any(nd in param.name for nd in ["bias", "norm", "b_0"]):
583+
if apply_decay_param_fun is not None and apply_decay_param_fun(
584+
param.name
585+
):
583586
decay_params.append(param)
584587
else:
585588
other_params.append(param)
@@ -632,6 +635,7 @@ def fused_parameters(
632635
acc_step=1,
633636
scale_after_comm=False,
634637
group_params=False,
638+
apply_decay_param_fun=None,
635639
):
636640
"""
637641
Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
@@ -645,6 +649,7 @@ def fused_parameters(
645649
:param fuse_param: fuse param or not
646650
:param scale_after_comm: if enable comm overlap, specify the location of grad scale
647651
:param group_params: the format of the input parameters is param group
652+
:param apply_decay_param_fun: the funtion to filter decay param
648653
:return: param storage if fused, comm buffers if comm overlap, param groups if use group params
649654
"""
650655
if act is None:
@@ -690,6 +695,7 @@ def fused_parameters(
690695
dst=dst,
691696
acc_step=acc_step,
692697
scale_after_comm=scale_after_comm,
698+
apply_decay_param_fun=apply_decay_param_fun,
693699
)
694700
if comm_overlap:
695701
comm_buffers.extend(group_all_buffers)
@@ -709,6 +715,7 @@ def fused_parameters(
709715
dst=dst,
710716
acc_step=acc_step,
711717
scale_after_comm=scale_after_comm,
718+
apply_decay_param_fun=apply_decay_param_fun,
712719
)
713720

714721
return decay_fused, all_fused, all_buffers

test/collective/fleet/hybrid_parallel_sharding_model_with_fusion_amp.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,18 @@ def build_optimizer(self, model):
8585
)
8686
return optimizer
8787

88-
def build_model_optimizer(self):
88+
def build_model_optimizer(self, diff_lr=False):
8989
model = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size)
9090
optimizer = self.build_optimizer(model)
9191
model, optimizer = paddle.amp.decorate(
9292
model, optimizers=optimizer, level="O2", dtype="float16"
9393
)
94+
if diff_lr:
95+
for param in model.parameters():
96+
if 'w' in param.name:
97+
param.optimize_attr = {"learning_rate": 1.0}
98+
else:
99+
param.optimize_attr = {"learning_rate": 2.0}
94100
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
95101
scaler = fleet.distributed_scaler(scaler)
96102
model = fleet.distributed_model(model)
@@ -109,8 +115,13 @@ def sharding_model(self):
109115
scaler.update()
110116
optimizer.clear_grad()
111117

118+
def sharding_different_lr(self):
119+
model, optimizer, scaler = self.build_model_optimizer(diff_lr=True)
120+
assert optimizer._inner_opt.fuse_optimizer is False
121+
112122
def test_sharding_adam(self):
113123
self.sharding_model()
124+
self.sharding_different_lr()
114125

115126

116127
if __name__ == "__main__":

0 commit comments

Comments
 (0)