Skip to content

Commit 1d68cd4

Browse files
authored
[CINN] Add second AnchorFusion after SplitRecompute (#71812)
1 parent 9434777 commit 1d68cd4

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

paddle/cinn/operator_fusion/pattern_graph.cc

+7-12
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ std::vector<PatternNodePtr> PatternGraph::ClusterOps() {
4949
VLOG(4) << "[Group Cluster] After ReduceTree_Trivial_Fusion: ";
5050
PrintGraphInfo();
5151

52-
// All -> AnchorPattern
53-
VLOG(4) << "[Group Cluster] Start LiftToAnchorPattern";
54-
LiftToAnchorPattern();
55-
VLOG(4) << "[Group Cluster] After LiftToAnchorPattern: ";
56-
PrintGraphInfo();
57-
5852
// AnchorPattern x AnchorPattern Fusion
5953
VLOG(4) << "[Group Cluster] Start AnchorFusion";
6054
AnchorFusion();
@@ -67,6 +61,12 @@ std::vector<PatternNodePtr> PatternGraph::ClusterOps() {
6761
VLOG(4) << "[Group Cluster] After SplitRecomputePattern: ";
6862
PrintGraphInfo();
6963

64+
// Second AnchorFusion after split recompute
65+
VLOG(4) << "[Group Cluster] Start Second AnchorFusion";
66+
AnchorFusion();
67+
VLOG(4) << "[Group Cluster] After AnchorFusion: ";
68+
PrintGraphInfo();
69+
7070
// Horizontal fusion.
7171
VLOG(4) << "[Group Cluster] Start HorizontalFusion";
7272
HorizontalFusion();
@@ -192,16 +192,14 @@ void PatternGraph::ReduceTree_Trivial_Fusion() {
192192
MergeReduceTreeAndTrivialOperation>(this);
193193
}
194194

195-
void PatternGraph::LiftToAnchorPattern() {
195+
void PatternGraph::AnchorFusion() {
196196
GraphTransformer<NodePattern,
197197
Or<StmtPatternGraphMatcher<TrivialPattern>,
198198
StmtPatternGraphMatcher<ReduceTreePlusTrivialPattern>,
199199
StmtPatternGraphMatcher<ReducePattern>,
200200
StmtPatternGraphMatcher<ReduceTreePattern>>,
201201
LiftToAnchorPatternOperation>(this);
202-
}
203202

204-
void PatternGraph::AnchorFusion() {
205203
GraphTransformer<ReverseTopoNodePairPattern,
206204
And<CanAnchorFusionMatcher, InputOutputMaximumConstrain>,
207205
AnchorFusionOperation>(this);
@@ -239,9 +237,6 @@ void PatternGraph::ItersPermutationFusion() {
239237
void PatternGraph::SplitRecomputePattern() {
240238
GraphTransformer<NodePattern, RecomputeNodeMatcher, SplitRecomputeOperation>(
241239
this);
242-
GraphTransformer<NodePattern,
243-
StmtPatternGraphMatcher<TrivialPattern>,
244-
LiftToAnchorPatternOperation>(this);
245240
}
246241

247242
PatternGraph::PatternGraph(const std::vector<PatternContent>& contents,

0 commit comments

Comments
 (0)