diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index 4fd7893f02d0e5..2adf7ac8df6023 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -670,9 +670,11 @@ def is_data_parallel_scale_op(op): def is_data_parallel_reduce_op(op): - is_allreduce_op = op.type in [ - "c_allreduce_sum", - "c_allreduce_avg", + is_allreduce_op = op.type == "all_reduce" and op.desc.attr( + "reduce_type" + ) in [ + dist.ReduceOp.SUM, + dist.ReduceOp.AVG, ] is_reduce_op = op.type == "reduce" and op.desc.attr("reduce_type") in [ dist.ReduceOp.SUM,