Skip to content

Commit ce7158d

Browse files
authored
Refine fuse_allreduce_split_pass (PaddlePaddle#64807)
* add * add * add * add * add
1 parent 40680e7 commit ce7158d

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

paddle/fluid/pir/dialect/distributed/transforms/fuse_allreduce_split_to_reducescatter_pass.cc

+10-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase {
3535
const auto &c_allreduce_sum_ =
3636
pat.Op(paddle::dialect::CAllreduceSum_Op::name(),
3737
{{"ring_id", pat.Attr("ring_id")},
38-
{"use_calc_stream", pat.Attr("use_calc_stream")}});
38+
{"use_calc_stream", pat.Attr("use_calc_stream")},
39+
{"execution_stream", pat.Attr("execution_stream")},
40+
{"force_record_event", pat.Attr("force_record_event")},
41+
{"event_to_record", pat.Attr("event_to_record")},
42+
{"events_to_wait", pat.Attr("events_to_wait")}});
3943
const auto &assign = pat.Op(paddle::dialect::AssignOp::name());
4044
const auto &full = pat.Op(paddle::dialect::FullOp::name());
4145
const auto &split_with_num = pat.Op(paddle::dialect::SplitWithNumOp::name(),
@@ -74,7 +78,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase {
7478
res.Op(paddle::dialect::CReducescatterOp::name(),
7579
{{"ring_id", pat.Attr("ring_id")},
7680
{"nranks", pat.Attr("num")},
77-
{"use_calc_stream", pat.Attr("use_calc_stream")}});
81+
{"use_calc_stream", pat.Attr("use_calc_stream")}},
82+
{{"execution_stream", pat.Attr("execution_stream")},
83+
{"force_record_event", pat.Attr("force_record_event")},
84+
{"event_to_record", pat.Attr("event_to_record")},
85+
{"events_to_wait", pat.Attr("events_to_wait")}});
7886

7987
c_reducescatter({&res.Tensor("input_grad_partial")}, {&res.Tensor("out")});
8088
}

test/distributed_passes/test_fuse_allreduce_split_to_reducescatter_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
(%38) = "pd_op.data" () {dtype:(pd_op.DataType)bfloat16,name:"linear_0.tmp_0",persistable:[false],place:(pd_op.Place)Place(gpu:0),shape:(pd_op.IntArray)[4096,1,28672],stop_gradient:[false]} : () -> builtin.tensor<4096x1x28672xbf16>
2323
(%48) = "pd_op.data" () {dtype:(pd_op.DataType)bfloat16,name:"input",persistable:[false],place:(pd_op.Place)Place(gpu:0),shape:(pd_op.IntArray)[4096,1,28672],stop_gradient:[false]} : () -> builtin.tensor<4096x1x28672xbf16>
2424
(%50) = "pd_op.matmul" (%48, %2) {persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:true} : (builtin.tensor<4096x1x28672xbf16>, builtin.tensor<8192x28672xbf16>) -> builtin.tensor<4096x1x8192xbf16>
25-
(%57) = "pd_op.c_allreduce_sum_" (%50) {persistable:[false],ring_id:(Int32)36,stop_gradient:[false],use_calc_stream:true,use_model_parallel:true} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16>
25+
(%57) = "pd_op.c_allreduce_sum_" (%50) {event_to_record:"event_7989",events_to_wait:[],execution_stream:"auto_parallel_mp",force_record_event:false,persistable:[false],ring_id:(Int32)36,stop_gradient:[false],use_calc_stream:true,use_model_parallel:true} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16>
2626
(%63) = "pd_op.assign" (%57) {persistable:[false],stop_gradient:[false]} : (builtin.tensor<4096x1x8192xbf16>) -> builtin.tensor<4096x1x8192xbf16>
2727
(%64) = "pd_op.full" () {dtype:(pd_op.DataType)int32,place:(pd_op.Place)Place(cpu),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Float)0} : () -> builtin.tensor<1xi32>
2828
(%65) = "pd_op.split_with_num" (%63, %64) {num:(Int32)2,persistable:[false],stop_gradient:[false]} : (builtin.tensor<4096x1x8192xbf16>, builtin.tensor<1xi32>) -> vec[builtin.tensor<2048x1x8192xbf16>,builtin.tensor<2048x1x8192xbf16>]

0 commit comments

Comments
 (0)