|
15 | 15 | #include "paddle/fluid/framework/ir/graph_traits.h"
|
16 | 16 |
|
17 | 17 | #include <set>
|
| 18 | +#include <utility> |
18 | 19 | #include <vector>
|
19 | 20 |
|
20 | 21 | namespace paddle {
|
@@ -79,29 +80,23 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
|
79 | 80 | PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
|
80 | 81 | }
|
81 | 82 |
|
82 |
| - std::unordered_set<Node *> visited; |
83 | 83 | std::set<Node *> to_visit{source.begin(), source.end()};
|
84 |
| - |
85 |
| - std::vector<Node *> inlink_visited; |
| 84 | + std::vector<Node *> inlink_sorted; |
86 | 85 | while (!to_visit.empty()) {
|
87 | 86 | std::vector<Node *> queue(to_visit.begin(), to_visit.end());
|
88 | 87 | for (auto *p : queue) {
|
89 |
| - inlink_visited.clear(); |
90 |
| - |
91 |
| - std::copy_if(p->inputs.begin(), p->inputs.end(), |
92 |
| - std::back_inserter(inlink_visited), |
93 |
| - [&](Node *x) -> bool { return visited.count(x) != 0; }); |
94 |
| - |
95 |
| - if (inlink_visited.size() == p->inputs.size()) { |
96 |
| - sorted_.push_back(p); |
97 |
| - for (auto *_ : p->outputs) { |
98 |
| - if (!visited.count(_)) { |
99 |
| - to_visit.insert(_); |
100 |
| - } |
| 88 | + to_visit.erase(p); |
| 89 | + sorted_.push_back(p); |
| 90 | + for (auto *out : p->outputs) { |
| 91 | + inlink_sorted.clear(); |
| 92 | + std::copy_if(out->inputs.begin(), out->inputs.end(), |
| 93 | + std::back_inserter(inlink_sorted), [&](Node *x) -> bool { |
| 94 | + return std::find(sorted_.begin(), sorted_.end(), x) != |
| 95 | + sorted_.end(); |
| 96 | + }); |
| 97 | + if (inlink_sorted.size() == out->inputs.size()) { |
| 98 | + to_visit.insert(out); |
101 | 99 | }
|
102 |
| - |
103 |
| - to_visit.erase(p); |
104 |
| - visited.insert(p); |
105 | 100 | }
|
106 | 101 | }
|
107 | 102 | }
|
|
0 commit comments