@@ -391,6 +391,18 @@ class SubgraphDetector {
391
391
return op2subgraph_.at (op);
392
392
}
393
393
394
+ std::vector<SubGraphPtr> GetSubgraphList () {
395
+ std::unordered_set<SubGraphPtr> subgraph_set;
396
+ std::vector<SubGraphPtr> subgraph_list;
397
+ for (const auto & op : sort_ops_) {
398
+ SubGraphPtr subgraph = GetOpSubgraph (op);
399
+ if (subgraph_set.count (subgraph)) continue ;
400
+ subgraph_set.insert (subgraph);
401
+ subgraph_list.push_back (subgraph);
402
+ }
403
+ return subgraph_list;
404
+ }
405
+
394
406
std::unordered_map<pir::Operation*, int > op2index_;
395
407
std::vector<pir::Operation*> sort_ops_;
396
408
std::unordered_map<pir::Operation*, SubGraphPtr> op2subgraph_;
@@ -543,18 +555,29 @@ void SubgraphDetector::SubgraphFusion() {
543
555
}
544
556
}
545
557
}
558
+
559
+ VLOG (4 ) << " Merge non-related subgraphs" ;
560
+ auto subgraph_list = GetSubgraphList ();
561
+ for (size_t i = 0 ; i < subgraph_list.size (); ++i) {
562
+ auto lhs = subgraph_list[i];
563
+ if (!lhs->substitute ) continue ;
564
+ for (size_t j = i + 1 ; j < subgraph_list.size ();) {
565
+ auto rhs = subgraph_list[j];
566
+ if (lhs == rhs || !rhs->substitute || HasRoute (lhs, rhs) ||
567
+ HasRoute (rhs, lhs)) {
568
+ ++j;
569
+ continue ;
570
+ }
571
+ MergeSource2Target (rhs, lhs);
572
+ subgraph_list.erase (subgraph_list.begin () + j);
573
+ VLOG (6 ) << " Merged subgraph: " << lhs->DebugStr ();
574
+ }
575
+ }
546
576
}
547
577
548
578
std::vector<GroupOpsVec> SubgraphDetector::BuildGroups () {
549
579
// 1. Get subgraph list in topo order
550
- std::unordered_set<SubGraphPtr> subgraph_set;
551
- std::vector<SubGraphPtr> subgraph_list;
552
- for (const auto & op : sort_ops_) {
553
- SubGraphPtr subgraph = GetOpSubgraph (op);
554
- if (subgraph_set.count (subgraph)) continue ;
555
- subgraph_set.insert (subgraph);
556
- subgraph_list.push_back (subgraph);
557
- }
580
+ auto subgraph_list = GetSubgraphList ();
558
581
std::reverse (subgraph_list.begin (), subgraph_list.end ());
559
582
VLOG (6 ) << " Subgraphs after building groups: " ;
560
583
for (const auto & subgraph : subgraph_list) {
0 commit comments