Skip to content

Commit 3789d83

Browse files
authored
sharding stage1 V1 support Broadcast overlap Forward (PaddlePaddle#63945)
* sharding v1 overlap * delete pybind * add txt * add b.txt * delete test file * add pybind * add test case for stage1 v1 overlap * add test case for stage1 v1 overlap * update test case * delete print optimizer * update * update models to layers * update mlp1
1 parent c949b51 commit 3789d83

File tree

3 files changed

+145
-17
lines changed

3 files changed

+145
-17
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

+91-17
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ def __init__(self, optimizer, hcg):
116116
self._rank2params = self._partition_parameters()
117117
self._param2rank = self._map_param_to_rank()
118118

119+
self._broadcast_overlap = False
120+
self._forward_pre_hook_remove_helper = []
121+
122+
try:
123+
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
124+
# Have to sort the params to make sure all params are in the forward using order.
125+
self._broadcast_order_params = sorted(
126+
self._parameter_list,
127+
key=lambda x: int(x.name.split('.')[0].split('_')[-1]),
128+
)
129+
130+
except ValueError:
131+
self._broadcast_order_params = None
132+
119133
if not self.tensor_fusion and not self.comm_overlap:
120134
local_params = self._rank2params[self._sharding_rank]
121135
self._set_inner_opt_attr('_parameter_list', local_params)
@@ -318,6 +332,13 @@ def reduce_gradients(self, parameter_list, hcg):
318332
sync_op=True,
319333
)
320334

335+
def _forward_pre_hook_function(self, tasks):
336+
def __impl__(x, y):
337+
for task in tasks:
338+
task.wait()
339+
340+
return __impl__
341+
321342
def _sharding_sync_parameters(self):
322343
"""
323344
Synchronize parameter across sharding group efficiently.
@@ -334,27 +355,57 @@ def _sharding_sync_parameters(self):
334355
sharding_group_ranks = self._hcg.get_sharding_parallel_group().ranks
335356

336357
broadcast_tasks = []
337-
for rank, params in valid_rank_to_params.items():
338-
# Compute the global source rank only once per each rank's set of parameters
339-
src_rank = sharding_group_ranks[rank]
340-
341-
for param in params:
342-
# NOTE: We should check if the parameter is trainable, because some parameters
343-
# (e.g., freeze the parameters for training) are not trainable and should
344-
# not be broadcasted.
345-
g_var = self._get_param_grad(param)
346-
if g_var is not None:
358+
if self._broadcast_overlap:
359+
param2task = {}
360+
361+
group = self._hcg.get_sharding_parallel_group()
362+
for param in self._broadcast_order_params:
363+
if param.trainable:
347364
task = paddle.distributed.broadcast(
348-
param,
349-
src=src_rank,
350-
group=self._hcg.get_sharding_parallel_group(),
365+
tensor=param,
366+
src=group.ranks[self._param2rank[param.name]],
367+
group=group,
351368
sync_op=False,
352369
)
353-
broadcast_tasks.append(task)
370+
assert param.name not in param2task
371+
param2task[param.name] = task
372+
373+
for layer in self._layers.sublayers():
374+
if len(layer.sublayers()) == 0:
375+
# Register forward pre hood for leaf layers. This will get the best performance.
376+
tasks = []
377+
for param in layer.parameters():
378+
if param.trainable:
379+
if param.name in param2task:
380+
tasks.append(param2task[param.name])
381+
self._forward_pre_hook_remove_helper.append(
382+
layer.register_forward_pre_hook(
383+
self._forward_pre_hook_function(tasks)
384+
)
385+
)
354386

355-
# Wait for all async broadcast tasks to complete
356-
for task in broadcast_tasks:
357-
task.wait()
387+
else:
388+
for rank, params in valid_rank_to_params.items():
389+
# Compute the global source rank only once per each rank's set of parameters
390+
src_rank = sharding_group_ranks[rank]
391+
392+
for param in params:
393+
# NOTE: We should check if the parameter is trainable, because some parameters
394+
# (e.g., freeze the parameters for training) are not trainable and should
395+
# not be broadcasted.
396+
g_var = self._get_param_grad(param)
397+
if g_var is not None:
398+
task = paddle.distributed.broadcast(
399+
param,
400+
src=src_rank,
401+
group=self._hcg.get_sharding_parallel_group(),
402+
sync_op=False,
403+
)
404+
broadcast_tasks.append(task)
405+
406+
# Wait for all async broadcast tasks to complete
407+
for task in broadcast_tasks:
408+
task.wait()
358409

359410
def _update_trainable(self):
360411
"""
@@ -384,10 +435,33 @@ def minimize(
384435

385436
return result
386437

438+
def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
439+
self._broadcast_overlap = broadcast_overlap
440+
if self._broadcast_overlap:
441+
assert (
442+
layers is not None
443+
), "To Enable Stage1 Optimizer Broadcast Overlap Forward, layers cannot be None"
444+
self._layers = layers
445+
warnings.warn(
446+
r"Setting overlap broadcast implies that `paddle.device.cuda.synchronize()` must be manually invoked before calling `paddle.save()` and prior to inference"
447+
)
448+
449+
if self._broadcast_order_params is None:
450+
warnings.warn(
451+
r"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
452+
"overlap broadcast may harm the performance."
453+
)
454+
self._broadcast_order_params = self._parameter_list
455+
387456
@imperative_base.no_grad
388457
@framework.dygraph_only
389458
def step(self):
390459
# TODO Check whether the model trainable param changed and update state accordingly
460+
if self._broadcast_overlap:
461+
# Clear the pre forward hook in the optimizer step.
462+
for hook_remove in self._forward_pre_hook_remove_helper:
463+
hook_remove.remove()
464+
self._forward_pre_hook_remove_helper = []
391465

392466
target_param_list = (
393467
self._origin_parameter_list

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

+8
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,14 @@ def __init__(self, optimizer, hcg, strategy):
331331
inner_opt._grad_clip, hcg
332332
)
333333

334+
def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
335+
self._broadcast_overlap = broadcast_overlap
336+
if self._broadcast_overlap:
337+
self._layers = layers
338+
self._inner_opt._set_broadcast_overlap(
339+
self._broadcast_overlap, self._layers
340+
)
341+
334342
def _insert_sync(self, sync_var, src, mp_group, sync_mode):
335343
if sync_mode == "broadcast":
336344
paddle.distributed.broadcast(

test/collective/fleet/dygraph_group_sharded_stage1_bf16.py

+46
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def train_mlp(
3737
acc_steps=1,
3838
use_main_grad=False,
3939
test_scaler=False,
40+
broadcast_overlap=False,
4041
):
4142
logging.info(
4243
f"-- Train Info: use_pure_bf16={use_pure_bf16}, use_main_grad={use_main_grad}, acc_steps={acc_steps}"
@@ -86,6 +87,10 @@ def train_mlp(
8687

8788
if sharding_stage == 1:
8889
optimizer = fleet.distributed_optimizer(optimizer)
90+
if broadcast_overlap:
91+
optimizer._set_broadcast_overlap(
92+
broadcast_overlap=broadcast_overlap, layers=model
93+
)
8994

9095
if sharding_stage == 1:
9196
model.to(device="gpu")
@@ -191,6 +196,7 @@ def _compare_bf16_o1_vs_o2(acc_steps=1):
191196
train_loader=train_loader,
192197
use_pure_bf16=False,
193198
acc_steps=acc_steps,
199+
broadcast_overlap=False,
194200
)
195201
o2_losses, model_param_dict_o2, optimizer_state_dict_o2 = train_mlp(
196202
mlp2,
@@ -199,17 +205,57 @@ def _compare_bf16_o1_vs_o2(acc_steps=1):
199205
use_pure_bf16=True,
200206
use_main_grad=True,
201207
acc_steps=acc_steps,
208+
broadcast_overlap=False,
202209
)
203210
np.testing.assert_array_equal(o2_losses, o1_losses)
204211
compare_state_dict(
205212
model_param_dict_o1, model_param_dict_o2, optimizer_state_dict_o2
206213
)
207214

215+
def _compare_bf16_broadcast_overlap(acc_steps=1):
216+
mlp1 = MLP()
217+
mlp2 = MLP()
218+
mlp1.set_state_dict(state_dict)
219+
mlp2.set_state_dict(state_dict)
220+
(
221+
o1_losses_overlap,
222+
model_param_dict_o1_overlap,
223+
optimizer_state_dict_o1_overlap,
224+
) = train_mlp(
225+
mlp1,
226+
sharding_stage=1,
227+
train_loader=train_loader,
228+
use_pure_bf16=False,
229+
acc_steps=acc_steps,
230+
broadcast_overlap=True,
231+
)
232+
mlp1.set_state_dict(state_dict)
233+
(
234+
o1_losses_no_overlap,
235+
model_param_dict_o1_no_overlap,
236+
optimizer_state_dict_o1_no_overlap,
237+
) = train_mlp(
238+
mlp1,
239+
sharding_stage=1,
240+
train_loader=train_loader,
241+
use_pure_bf16=False,
242+
acc_steps=acc_steps,
243+
broadcast_overlap=False,
244+
)
245+
246+
np.testing.assert_array_equal(o1_losses_overlap, o1_losses_no_overlap)
247+
np.testing.assert_array_equal(
248+
model_param_dict_o1_overlap, model_param_dict_o1_no_overlap
249+
)
250+
208251
# no gradient accumulation
209252
_compare_bf16_o1_vs_o2(acc_steps=1)
210253
# gradient accumulation
211254
_compare_bf16_o1_vs_o2(acc_steps=2)
212255

256+
_compare_bf16_broadcast_overlap(acc_steps=1)
257+
_compare_bf16_broadcast_overlap(acc_steps=2)
258+
213259
# stage1 scaler test with main_grad
214260
mlp3 = MLP()
215261
mlp3.set_state_dict(state_dict)

0 commit comments

Comments
 (0)