Skip to content

Commit 2f2ef64

Browse files
authored
[CINN] Merge non-related subgraphs in BuildCinnPass (PaddlePaddle#71824)
1 parent 6affc71 commit 2f2ef64

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

paddle/fluid/pir/transforms/sub_graph_detector.cc

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,18 @@ class SubgraphDetector {
391391
return op2subgraph_.at(op);
392392
}
393393

394+
std::vector<SubGraphPtr> GetSubgraphList() {
395+
std::unordered_set<SubGraphPtr> subgraph_set;
396+
std::vector<SubGraphPtr> subgraph_list;
397+
for (const auto& op : sort_ops_) {
398+
SubGraphPtr subgraph = GetOpSubgraph(op);
399+
if (subgraph_set.count(subgraph)) continue;
400+
subgraph_set.insert(subgraph);
401+
subgraph_list.push_back(subgraph);
402+
}
403+
return subgraph_list;
404+
}
405+
394406
std::unordered_map<pir::Operation*, int> op2index_;
395407
std::vector<pir::Operation*> sort_ops_;
396408
std::unordered_map<pir::Operation*, SubGraphPtr> op2subgraph_;
@@ -543,18 +555,29 @@ void SubgraphDetector::SubgraphFusion() {
543555
}
544556
}
545557
}
558+
559+
VLOG(4) << "Merge non-related subgraphs";
560+
auto subgraph_list = GetSubgraphList();
561+
for (size_t i = 0; i < subgraph_list.size(); ++i) {
562+
auto lhs = subgraph_list[i];
563+
if (!lhs->substitute) continue;
564+
for (size_t j = i + 1; j < subgraph_list.size();) {
565+
auto rhs = subgraph_list[j];
566+
if (lhs == rhs || !rhs->substitute || HasRoute(lhs, rhs) ||
567+
HasRoute(rhs, lhs)) {
568+
++j;
569+
continue;
570+
}
571+
MergeSource2Target(rhs, lhs);
572+
subgraph_list.erase(subgraph_list.begin() + j);
573+
VLOG(6) << "Merged subgraph: " << lhs->DebugStr();
574+
}
575+
}
546576
}
547577

548578
std::vector<GroupOpsVec> SubgraphDetector::BuildGroups() {
549579
// 1. Get subgraph list in topo order
550-
std::unordered_set<SubGraphPtr> subgraph_set;
551-
std::vector<SubGraphPtr> subgraph_list;
552-
for (const auto& op : sort_ops_) {
553-
SubGraphPtr subgraph = GetOpSubgraph(op);
554-
if (subgraph_set.count(subgraph)) continue;
555-
subgraph_set.insert(subgraph);
556-
subgraph_list.push_back(subgraph);
557-
}
580+
auto subgraph_list = GetSubgraphList();
558581
std::reverse(subgraph_list.begin(), subgraph_list.end());
559582
VLOG(6) << "Subgraphs after building groups: ";
560583
for (const auto& subgraph : subgraph_list) {

0 commit comments

Comments
 (0)