Skip to content

Commit 31116de

Browse files
committed
fix hang
1 parent af28a45 commit 31116de

File tree

1 file changed

+13
-1
lines changed
  • python/paddle/distributed/auto_parallel/static

1 file changed

+13
-1
lines changed

python/paddle/distributed/auto_parallel/static/pir_pass.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,18 @@ def reshard_op_pass(dist_program, global_params_grads=None, block=None):
367367
op.result(0).type(),
368368
)
369369

370+
# The OP chunk id of the Optimize role needs to be set to -1.
371+
reshard_value = op.operand_source(0)
372+
out_op = out_value.get_defining_op()
373+
while out_op.op_role == int(OpRole.Optimize):
374+
out_op.set_int_attr("chunk_id", -1)
375+
if out_op.num_operands() == 0:
376+
break
377+
in_value = out_op.operand_source(0)
378+
if in_value.is_same(reshard_value):
379+
break
380+
out_op = in_value.get_defining_op()
381+
370382
if out_value is not None:
371383
op.result(0).replace_all_uses_with(out_value)
372384

@@ -712,7 +724,7 @@ def remove_unuseful_comm_op_pass(program):
712724
if op.name() in comm_ops or (
713725
op.name() == "pd_op.all_reduce"
714726
and op.int_attr("reduce_type")
715-
in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]
727+
in [dist.ReduceOp.SUM, dist.ReduceOp.MAX, dist.ReduceOp.AVG]
716728
):
717729
ring_id = op.int_attr("ring_id")
718730
process_group = get_process_group(ring_id)

0 commit comments

Comments
 (0)