Skip to content

Commit e3e62b2

Browse files
committed
Fix
1 parent 6d4cb7e commit e3e62b2

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

python/paddle/distributed/auto_parallel/static/cost/base_cost.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"broadcast",
3131
"all_gather",
3232
"c_allreduce_sum",
33+
"all_reduce",
3334
"c_identity",
3435
]
3536
NON_COMP_TYPE = ["while", *COMM_OP_TYPE]
@@ -311,7 +312,10 @@ def build_comm_desc_from_dist_op(
311312
input_list.append((var.dtype, shape))
312313

313314
# NOTE: The input_name of comm ops used usually is X.
314-
desc["inputs"] = {"X": input_list}
315+
if op_type == "all_reduce":
316+
desc["inputs"] = {"x": input_list}
317+
else:
318+
desc["inputs"] = {"X": input_list}
315319

316320
# Get comm group by parallel_axis or the given group_ranks.
317321
if parallel_axis is not None:
@@ -349,7 +353,10 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
349353
desc = {}
350354
desc["op"] = op_type
351355
desc["group_ranks"] = group_ranks
352-
desc["inputs"] = {"X": [(dtype, shape)]}
356+
if op_type == "all_reduce":
357+
desc["inputs"] = {"x": [(dtype, shape)]}
358+
else:
359+
desc["inputs"] = {"X": [(dtype, shape)]}
353360
desc["attrs"] = attrs
354361
return desc
355362

@@ -416,7 +423,7 @@ def build_dp_costs(
416423
if not has_found:
417424
return
418425

419-
c_allreduce_sum_descs = build_comm_desc_from_dist_op(
426+
all_reduce_sum_descs = build_comm_desc_from_dist_op(
420427
"c_allreduce_sum",
421428
dist_op,
422429
ctx,
@@ -428,7 +435,7 @@ def build_dp_costs(
428435
_g_op_cost_factory["c_allreduce_sum"],
429436
ctx,
430437
processes,
431-
c_allreduce_sum_descs,
438+
all_reduce_sum_descs,
432439
cluster,
433440
is_dp=True,
434441
)
@@ -787,17 +794,27 @@ def comm_count(self):
787794
vars = self.op.block.vars
788795
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overridden
789796
try:
790-
var_name = self.op.input("X")[0]
797+
if self.op.type != "all_reduce":
798+
var_name = self.op.input("X")[0]
799+
else:
800+
var_name = self.op.input("x")[0]
791801
except:
792-
var_name = self.op.output("Out")[0]
802+
if self.op.type != "all_reduce":
803+
var_name = self.op.output("Out")[0]
804+
else:
805+
var_name = self.op.output("out")[0]
793806
var = get_var_with_recursion(
794807
var_name, self.op.block, self.op.block.program
795808
)
796809
dtype = var.dtype
797810
shape = var.shape
798811
elif self.op_desc is not None:
799-
dtype = self.op_desc["inputs"]["X"][0][0]
800-
shape = self.op_desc["inputs"]["X"][0][1]
812+
if "op" in self.op_desc and self.op_desc["op"] == "all_reduce":
813+
dtype = self.op_desc["inputs"]["x"][0][0]
814+
shape = self.op_desc["inputs"]["x"][0][1]
815+
else:
816+
dtype = self.op_desc["inputs"]["X"][0][0]
817+
shape = self.op_desc["inputs"]["X"][0][1]
801818

802819
factor = None
803820
if dtype == paddle.float32 or dtype == paddle.int32:

python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def _comms_overlap_calc(self):
280280
# comm wait calc to finish
281281
for idx, op in reversed(list(enumerate(block.ops))):
282282
if is_data_parallel_reduce_op(op):
283-
assert op.has_attr('use_calc_stream')
284283
assert op.has_attr('ring_id')
285284

286285
op._set_attr('use_calc_stream', False)
@@ -492,6 +491,7 @@ def _update_program(self, grad_groups):
492491

493492
allreduce_op = block.ops[group.allreduce_op_idx]
494493
assert allreduce_op.type in [
494+
'all_reduce',
495495
'c_allreduce_avg',
496496
'c_allreduce_sum',
497497
], f"should found c_allreduce_avg or c_allreduce_sum op but found {allreduce_op}"

0 commit comments

Comments
 (0)