Skip to content

Commit 0a31620

Browse files
committed
fix_randomness_in_fusion*
1 parent b3cb80e commit 0a31620

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ struct SearchAlgorithm<NodePattern, GraphMatcher, GraphOperation> {
4242
if (GraphMatcher()(*graph_, iter_node) &&
4343
!visited_nodes.count(iter_node)) {
4444
visited_nodes.insert(iter_node);
45-
VLOG(4) << "Find Matched Node: " << iter_node;
45+
VLOG(4) << "Find Matched Node: " << iter_node->id() << "(" << iter_node
46+
<< ")";
4647
return iter_node;
4748
}
4849
}
@@ -76,7 +77,8 @@ struct SearchAlgorithm<NodePairPattern, GraphMatcher, GraphOperation> {
7677
const auto& pair = std::make_pair(i, j);
7778
if (GraphMatcher()(*graph_, i, j) && !visited_node_pair.count(pair)) {
7879
visited_node_pair.insert(pair);
79-
VLOG(4) << "Find Matched Node Pair: (" << i << ", " << j << ")";
80+
VLOG(4) << "Find Matched Node Pair: (" << i->id() << ", " << j->id()
81+
<< ")";
8082
return pair;
8183
}
8284
}

paddle/cinn/operator_fusion/pattern.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@
2525

2626
namespace cinn::fusion {
2727

28+
enum class PatternType {
29+
Trivial = 0,
30+
Reduce,
31+
ReduceTree,
32+
ReduceTreePlusTrivial,
33+
ItersPermutation,
34+
Horizontal,
35+
Unsupport,
36+
};
37+
2838
struct PatternContent {
2939
explicit PatternContent(pir::Operation* op) : op(op) {}
3040
pir::Operation* op;
@@ -43,6 +53,7 @@ struct TrivialPattern {
4353
std::vector<pir::Operation*> ops() const { return ops_; }
4454
pir::Operation* sink_op() const { return sink_op_; }
4555

56+
static PatternType type() { return PatternType::Trivial; }
4657
static std::string name() { return "Trivial"; }
4758

4859
static std::string UniqueId() {
@@ -67,6 +78,7 @@ struct ReducePattern {
6778
std::vector<pir::Operation*> ops() const { return ops_; }
6879
pir::Operation* GetReduceOp() const { return ops_.back(); }
6980

81+
static PatternType type() { return PatternType::Reduce; }
7082
static std::string name() { return "Reduce"; }
7183

7284
static std::string UniqueId() {
@@ -108,6 +120,7 @@ struct ReduceTreePattern {
108120
return result;
109121
}
110122

123+
static PatternType type() { return PatternType::ReduceTree; }
111124
static std::string name() { return "ReduceTree"; }
112125

113126
static std::string UniqueId() {
@@ -174,6 +187,7 @@ struct ReduceTreePlusTrivialPattern {
174187
}
175188
std::vector<size_t> fake_reduce_iter_idx;
176189

190+
static PatternType type() { return PatternType::ReduceTreePlusTrivial; }
177191
static std::string name() { return "ReduceTreePlusTrivial"; }
178192

179193
static std::string UniqueId() {
@@ -219,6 +233,7 @@ struct ItersPermutationPattern {
219233
std::vector<pir::Operation*> ops_;
220234
std::vector<pir::Operation*> ops() const { return ops_; }
221235

236+
static PatternType type() { return PatternType::ItersPermutation; }
222237
static std::string name() { return "ItersPermutation"; }
223238
static std::string UniqueId() {
224239
static std::atomic<int64_t> counter = 0;
@@ -246,6 +261,7 @@ struct HorizontalFusionPattern {
246261
std::vector<PaddingStmtPattern> padding_patterns_;
247262
inline std::vector<pir::Operation*> ops() const;
248263

264+
static PatternType type() { return PatternType::Horizontal; }
249265
static std::string name() { return "Horizontal"; }
250266

251267
static std::string UniqueId() {
@@ -268,6 +284,7 @@ struct UnsupportPattern {
268284
std::vector<pir::Operation*> ops_;
269285
std::vector<pir::Operation*> ops() const { return ops_; }
270286

287+
static PatternType type() { return PatternType::Unsupport; }
271288
static std::string name() { return "Unsupport"; }
272289

273290
static std::string UniqueId() {
@@ -332,6 +349,10 @@ static std::string StmtPatternDebugStr(const StmtPattern& stmt) {
332349
return ss.str();
333350
}
334351

352+
static PatternType GetPatternType(const StmtPattern& s) {
353+
return std::visit([](const auto& impl) { return impl.type(); }, s);
354+
}
355+
335356
static std::string GetPatternName(const StmtPattern& s) {
336357
return std::visit([](const auto& impl) { return impl.name(); }, s);
337358
}

paddle/cinn/operator_fusion/pattern_graph.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,17 @@ PatternGraph::PatternGraph(const std::vector<PatternContent>& contents,
283283
}
284284

285285
void PatternGraph::RemoveNode(const PatternNodePtr& node) {
286-
VLOG(4) << "Start Remove: " << node;
287-
if (all_pattern_nodes_.find(node) != all_pattern_nodes_.end()) {
288-
VLOG(4) << "Removed! ";
289-
all_pattern_nodes_.erase(node);
286+
VLOG(4) << "Start Remove: " << node->id() << "(" << node << ")";
287+
for (const auto& n : all_pattern_nodes_) {
288+
if (n->id() == node->id()) {
289+
VLOG(4) << "Removed " << n->id();
290+
all_pattern_nodes_.erase(n);
291+
break;
292+
}
290293
}
291-
292294
for (const PatternNodePtr& upstream : node->upstream()) {
293295
upstream->RemoveNodeFromDownstream(node);
294296
}
295-
296297
for (const PatternNodePtr& downstream : node->downstream()) {
297298
downstream->RemoveNodeFromUpstream(node);
298299
}
@@ -302,10 +303,12 @@ void PatternGraph::AppendNode(const PatternNodePtr& node) {
302303
all_pattern_nodes_.emplace(node);
303304
}
304305

306+
// void PatternGraph
307+
305308
std::string PatternGraph::GraphInfo() const {
306309
std::stringstream ss;
307310
ss << "\n========= GraphInfo ===========";
308-
for (const auto& v : SortByTopoOrder()) {
311+
for (const auto& v : all_pattern_nodes_) {
309312
ss << "\n##############################";
310313
ss << "\n" << v->DebugStr();
311314
ss << " IsOutput: " << IsOutputNodeMatcher()(*this, v);

paddle/cinn/operator_fusion/pattern_graph.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
namespace cinn::fusion {
2222

23-
using PatternNodePtrSet = std::unordered_set<PatternNodePtr>;
24-
2523
using MergePatternFn =
2624
std::function<StmtPattern(const StmtPattern&, const StmtPattern&)>;
2725
class PatternGraph {

paddle/cinn/operator_fusion/pattern_node.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct PatternNode {
7070
void set_stmt_pattern(const StmtPattern& pattern) { stmt_pattern_ = pattern; }
7171
const std::vector<PatternNodePtr>& upstream() const { return upstream_; }
7272
const std::vector<PatternNodePtr>& downstream() const { return downstream_; }
73+
PatternType type() const { return GetPatternType(stmt_pattern_); }
7374
std::string name() const { return GetPatternName(stmt_pattern_); }
7475
std::string id() const { return GetPatternId(stmt_pattern_); }
7576
void set_return() const { SetReturnInstr(stmt_pattern_); }
@@ -105,4 +106,16 @@ struct PatternNode {
105106
};
106107

107108
using PatternNodePtr = std::shared_ptr<PatternNode>;
109+
110+
struct PatternNodeCompare {
111+
bool operator()(const PatternNodePtr& lhs, const PatternNodePtr& rhs) const {
112+
int lhs_id = std::stoi(
113+
lhs->id().substr(lhs->id().find_last_of('_') + 1, std::string::npos));
114+
int rhs_id = std::stoi(
115+
rhs->id().substr(rhs->id().find_last_of('_') + 1, std::string::npos));
116+
return lhs->type() == rhs->type() ? lhs_id < rhs_id
117+
: lhs->type() < rhs->type();
118+
}
119+
};
120+
using PatternNodePtrSet = std::set<PatternNodePtr, PatternNodeCompare>;
108121
} // namespace cinn::fusion

0 commit comments

Comments
 (0)