Skip to content

Commit a7d0d88

Browse files
authored
CHERRY_PICK 20720: fix ts_sort's bug, test=develop (#20726)
test=release/1.6
1 parent 92f4a52 commit a7d0d88

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

paddle/fluid/framework/ir/graph_traits.cc

+13-18
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/ir/graph_traits.h"
1616

1717
#include <set>
18+
#include <utility>
1819
#include <vector>
1920

2021
namespace paddle {
@@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
7980
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
8081
}
8182

82-
std::unordered_set<Node *> visited;
8383
std::set<Node *> to_visit{source.begin(), source.end()};
84-
85-
std::vector<Node *> inlink_visited;
84+
std::vector<Node *> inlink_sorted;
8685
while (!to_visit.empty()) {
8786
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
8887
for (auto *p : queue) {
89-
inlink_visited.clear();
90-
91-
std::copy_if(p->inputs.begin(), p->inputs.end(),
92-
std::back_inserter(inlink_visited),
93-
[&](Node *x) -> bool { return visited.count(x) != 0; });
94-
95-
if (inlink_visited.size() == p->inputs.size()) {
96-
sorted_.push_back(p);
97-
for (auto *_ : p->outputs) {
98-
if (!visited.count(_)) {
99-
to_visit.insert(_);
100-
}
88+
to_visit.erase(p);
89+
sorted_.push_back(p);
90+
for (auto *out : p->outputs) {
91+
inlink_sorted.clear();
92+
std::copy_if(out->inputs.begin(), out->inputs.end(),
93+
std::back_inserter(inlink_sorted), [&](Node *x) -> bool {
94+
return std::find(sorted_.begin(), sorted_.end(), x) !=
95+
sorted_.end();
96+
});
97+
if (inlink_sorted.size() == out->inputs.size()) {
98+
to_visit.insert(out);
10199
}
102-
103-
to_visit.erase(p);
104-
visited.insert(p);
105100
}
106101
}
107102
}

0 commit comments

Comments
 (0)