Skip to content

Commit 1c9f875

Browse files
committed
Fix
1 parent 85e1c9a commit 1c9f875

File tree

6 files changed

+7
-18
lines changed

6 files changed

+7
-18
lines changed

python/paddle/distributed/communication/reduce.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class ReduceOp:
6868
AVG: ClassVar[Literal[4]] = 4
6969

7070

71-
def _get_reduce_op(reduce_op, func_name):
71+
def _get_reduce_op(reduce_op, func_name=""):
7272
if framework.in_dynamic_mode():
7373
if reduce_op == ReduceOp.SUM:
7474
return framework.core.ReduceOp.SUM
@@ -80,17 +80,6 @@ def _get_reduce_op(reduce_op, func_name):
8080
return framework.core.ReduceOp.PRODUCT
8181
elif reduce_op == ReduceOp.AVG:
8282
return framework.core.ReduceOp.AVG
83-
else:
84-
if reduce_op == ReduceOp.SUM:
85-
return f'c_{func_name}_sum'
86-
elif reduce_op == ReduceOp.MAX:
87-
return f'c_{func_name}_max'
88-
elif reduce_op == ReduceOp.MIN:
89-
return f'c_{func_name}_min'
90-
elif reduce_op == ReduceOp.PROD:
91-
return f'c_{func_name}_prod'
92-
else:
93-
return f'c_{func_name}'
9483

9584
raise ValueError(f"Unknown reduce_op type for {func_name}.")
9685

python/paddle/distributed/communication/stream/all_reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _all_reduce_in_dygraph(
4343
sync_op: bool,
4444
use_calc_stream: bool,
4545
) -> task:
46-
op_type = _get_reduce_op(op, "allreduce")
46+
op_type = _get_reduce_op(op)
4747

4848
if use_calc_stream:
4949
return group.process_group.all_reduce_on_calc_stream(tensor, op_type)

python/paddle/distributed/communication/stream/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
def _reduce_in_dygraph(
3636
tensor, dst_rank_in_group, op, group, sync_op, use_calc_stream
3737
):
38-
op_type = _get_reduce_op(op, "reduce")
38+
op_type = _get_reduce_op(op)
3939
if use_calc_stream:
4040
return group.process_group.reduce_on_calc_stream(
4141
tensor, dst_rank_in_group, op_type

python/paddle/distributed/communication/stream/reduce_scatter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _reduce_scatter_tensor_in_dygraph(
6363
def _reduce_scatter_in_dygraph(
6464
tensor, tensor_list, op, group, sync_op, use_calc_stream
6565
):
66-
op_type = _get_reduce_op(op, "reduce_scatter")
66+
op_type = _get_reduce_op(op)
6767

6868
if use_calc_stream:
6969
return group.process_group.reduce_scatter_on_calc_stream(

python/paddle/distributed/fleet/layers/mpu/mp_layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def backward(ctx, dy):
227227
dx = paddle.matmul(
228228
dy, paddle.cast(weight, dtype=dy.dtype), transpose_y=True
229229
)
230-
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
230+
op_type = _get_reduce_op(ReduceOp.SUM)
231231
task = ctx.model_parallel_group.process_group.all_reduce(
232232
dx, op_type, sync_op=False
233233
)

python/paddle/distributed/fleet/layers/mpu/mp_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def forward(ctx, tensor, group, skip_c_identity_dynamic):
4949

5050
@staticmethod
5151
def backward(ctx, dy):
52-
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
52+
op_type = _get_reduce_op(ReduceOp.SUM)
5353
ctx.group.process_group.all_reduce_on_calc_stream(dy, op_type)
5454
return dy
5555

@@ -238,7 +238,7 @@ def forward(
238238
ctx.skip_c_identity_dynamic = skip_c_identity_dynamic
239239

240240
if use_calc_stream:
241-
op_type = _get_reduce_op(op, "_mp_allreduce")
241+
op_type = _get_reduce_op(op)
242242
group.process_group.all_reduce_on_calc_stream(tensor, op_type)
243243
return tensor
244244
else:

0 commit comments

Comments
 (0)