@@ -35,7 +35,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase {
35
35
const auto &c_allreduce_sum_ =
36
36
pat.Op (paddle::dialect::CAllreduceSum_Op::name (),
37
37
{{" ring_id" , pat.Attr (" ring_id" )},
38
- {" use_calc_stream" , pat.Attr (" use_calc_stream" )}});
38
+ {" use_calc_stream" , pat.Attr (" use_calc_stream" )},
39
+ {" execution_stream" , pat.Attr (" execution_stream" )},
40
+ {" force_record_event" , pat.Attr (" force_record_event" )},
41
+ {" event_to_record" , pat.Attr (" event_to_record" )},
42
+ {" events_to_wait" , pat.Attr (" events_to_wait" )}});
39
43
const auto &assign = pat.Op (paddle::dialect::AssignOp::name ());
40
44
const auto &full = pat.Op (paddle::dialect::FullOp::name ());
41
45
const auto &split_with_num = pat.Op (paddle::dialect::SplitWithNumOp::name (),
@@ -74,7 +78,11 @@ class FusedAllReduceSplitPattern : public paddle::drr::DrrPatternBase {
74
78
res.Op (paddle::dialect::CReducescatterOp::name (),
75
79
{{" ring_id" , pat.Attr (" ring_id" )},
76
80
{" nranks" , pat.Attr (" num" )},
77
- {" use_calc_stream" , pat.Attr (" use_calc_stream" )}});
81
+ {" use_calc_stream" , pat.Attr (" use_calc_stream" )}},
82
+ {{" execution_stream" , pat.Attr (" execution_stream" )},
83
+ {" force_record_event" , pat.Attr (" force_record_event" )},
84
+ {" event_to_record" , pat.Attr (" event_to_record" )},
85
+ {" events_to_wait" , pat.Attr (" events_to_wait" )}});
78
86
79
87
c_reducescatter ({&res.Tensor (" input_grad_partial" )}, {&res.Tensor (" out" )});
80
88
}
0 commit comments