Skip to content

Commit e758692

Browse files
committed
Fix
1 parent dd457cc commit e758692

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

python/paddle/distributed/passes/auto_parallel_sharding.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,13 @@ def _shard_gradient_clip(self, main_block):
405405
for i, sharding_info in enumerate(self.sharding_infos):
406406
new_op = main_block._insert_op(
407407
idx + i + 1,
408-
type='c_allreduce_sum',
409-
inputs={'X': [sum_op_output]},
410-
outputs={'Out': [sum_op_output]},
408+
type='all_reduce',
409+
inputs={'x': [sum_op_output]},
410+
outputs={'out': [sum_op_output]},
411411
attrs={
412412
'ring_id': sharding_info.group.id,
413413
'op_namescope': "/gradient_clip_model_parallelism",
414-
'use_calc_stream': True,
414+
'reduce_type': paddle.distributed.ReduceOp.SUM,
415415
OP_ROLE_KEY: OpRole.Optimize,
416416
},
417417
)
@@ -535,9 +535,16 @@ def _shard_gradient_synchronization(self, main_block):
535535
dp_ring_ids = [group.id for group in self.dp_groups]
536536
for idx, op in reversed(list(enumerate(main_block.ops))):
537537
if _is_param_grad_allreduce_op(op, main_block):
538-
if op.type == "c_allreduce_sum" or (
539-
op.type == "reduce"
540-
and op.attr("reduce_type") == dist.ReduceOp.SUM
538+
if (
539+
op.type == "c_allreduce_sum"
540+
or (
541+
op.type == "all_reduce"
542+
and op.attr("reduce_type") == dist.ReduceOp.SUM
543+
)
544+
or (
545+
op.type == "reduce"
546+
and op.attr("reduce_type") == dist.ReduceOp.SUM
547+
)
541548
):
542549
reduce_op_type = "reduce"
543550
reduce_type = dist.ReduceOp.SUM
@@ -1036,7 +1043,13 @@ def op_depend_on_group(op, group):
10361043
cur_group.is_in_local_shard = True
10371044
assert ops[i + 1].type in [
10381045
"c_allreduce_sum",
1039-
], "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel"
1046+
] or (
1047+
ops[i + 1].type == 'all_reduce'
1048+
and ops[i + 1].attr('reduce_type')
1049+
in [
1050+
paddle.distributed.ReduceOp.SUM,
1051+
]
1052+
), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel"
10401053
assert (
10411054
ops[i + 1].output_arg_names[0] == grad_name
10421055
), "Hybrid Sharding with Data-Parallel should sync same gradient var"
@@ -1236,7 +1249,13 @@ def _overlap_grad_comm(
12361249
grad_comm_op_to_stream_idx = {}
12371250
for idx, op in enumerate(ops):
12381251
if is_data_parallel_reduce_op(op):
1239-
if op.type in ["c_allreduce_sum"]:
1252+
if op.type in ["c_allreduce_sum"] or (
1253+
op.type == 'all_reduce'
1254+
and op.attr('reduce_type')
1255+
in [
1256+
paddle.distributed.ReduceOp.SUM,
1257+
]
1258+
):
12401259
continue
12411260
stream_idx = reduce_op_count % self.grad_comm_stream_num
12421261
grad_comm_op_to_stream_idx[op] = stream_idx
@@ -1291,7 +1310,13 @@ def _overlap_grad_comm(
12911310
next_op = ops[idx + 1]
12921311
assert next_op.type in [
12931312
"c_allreduce_sum",
1294-
]
1313+
] or (
1314+
next_op.type == 'all_reduce'
1315+
and next_op.attr('reduce_type')
1316+
in [
1317+
paddle.distributed.ReduceOp.SUM,
1318+
]
1319+
)
12951320
assert next_op.output("Out")[0] == reduce_varname
12961321
# FIXME hybrid sharding-dp support multi comm & stream in feature
12971322
# next_op._set_attr("ring_id", comm_group.id)

0 commit comments

Comments
 (0)