Skip to content

[CINN] Fix inplace order change in build_cinn_pass #72426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 112 additions & 56 deletions paddle/fluid/pir/transforms/sub_graph_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,23 @@ struct SubGraph : public std::enable_shared_from_this<SubGraph> {
return std::string("Subgraph_") + std::to_string(id);
}

struct compare {
struct CompareById {
bool operator()(const SubGraphPtr& lhs, const SubGraphPtr& rhs) const {
// sort by reverse order of topo id
return lhs->id > rhs->id;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id和topo_index有什么区别?

Copy link
Member Author

@huangjiyi huangjiyi Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id 只是第一次排序给每个子图加的一个唯一标识(不保证顺序),topo_index 则是后面每一次子图合并后都能保证顺序正确,之前按 Id 进行排序只是为了避免随机性,现在拓扑序重排则需要用 topo_index 排序了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那id还有用吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前按 id 排序主要还是方便调试,因为 topo_index 是实时变化的,有些地方换成 topo_index 会比较乱

}
};

struct CompareByTopo {
bool operator()(const SubGraphPtr& lhs, const SubGraphPtr& rhs) const {
// sort by topo index
return lhs->topo_index > rhs->topo_index;
}
};

std::vector<pir::Operation*> ops;
std::set<SubGraphPtr, compare> upstreams;
std::set<SubGraphPtr, compare> downstreams;
std::set<SubGraphPtr, CompareById> upstreams;
std::set<SubGraphPtr, CompareById> downstreams;

bool substitute; // whether this subgraph can be merged
int topo_index;
Expand Down Expand Up @@ -357,8 +364,8 @@ bool CanFuseUpstream2Downstream(const SubGraphPtr& upstream,

std::optional<std::string> DetectCirclesInSubgraphs(
const std::vector<SubGraphPtr>& subgraph_list) {
std::set<SubGraphPtr, SubGraph::compare> subgraph_set(subgraph_list.begin(),
subgraph_list.end());
std::set<SubGraphPtr, SubGraph::CompareById> subgraph_set(
subgraph_list.begin(), subgraph_list.end());
std::unordered_map<SubGraphPtr, size_t> in_degree;
std::unordered_map<SubGraphPtr, size_t> out_degree;
for (const auto& subgraph : subgraph_set) {
Expand Down Expand Up @@ -411,7 +418,7 @@ class SubgraphDetector {
private:
void ReorderIndexOfSubgraphs();

void MergeSource2Target(const SubGraphPtr& source, const SubGraphPtr& target);
bool MergeSource2Target(const SubGraphPtr& source, const SubGraphPtr& target);

void FallbackSubGraphFusion(const SubGraphPtr& source,
const SubGraphPtr& target,
Expand Down Expand Up @@ -470,7 +477,10 @@ void SubgraphDetector::ReorderIndexOfSubgraphs() {
// After merging subgraphs with direct relation, brother subgraphs with
// indirect relation may not be detected by index order. So we need to
// reorder the index of subgraphs.
std::queue<SubGraphPtr> queue;
using SubGraphQueue = std::priority_queue<SubGraphPtr,
std::vector<SubGraphPtr>,
SubGraph::CompareByTopo>;
SubGraphQueue queue; // min heap
std::unordered_map<SubGraphPtr, int> in_degree;
for (auto it = sort_ops_.rbegin(); it != sort_ops_.rend(); ++it) {
auto subgraph = GetOpSubgraph(*it);
Expand All @@ -481,7 +491,7 @@ void SubgraphDetector::ReorderIndexOfSubgraphs() {
subgraph_index_set_.clear();
int index = 0;
while (!queue.empty()) {
auto subgraph = queue.front();
auto subgraph = queue.top();
queue.pop();
subgraph->topo_index = index++;
subgraph_index_set_.insert(subgraph->topo_index);
Expand All @@ -492,49 +502,62 @@ void SubgraphDetector::ReorderIndexOfSubgraphs() {
}
}

void SubgraphDetector::MergeSource2Target(const SubGraphPtr& source,
bool SubgraphDetector::MergeSource2Target(const SubGraphPtr& source,
const SubGraphPtr& target) {
VLOG(6) << "Merge source: " << source->DebugStr();
VLOG(6) << "Merge target: " << target->DebugStr();
SubGraph source_back = *source;
SubGraph target_back = *target;
target->Merge(source);
for (const auto& op : source->ops) {
op2subgraph_[op] = target;
}
int max_index = std::max(source->topo_index, target->topo_index);
int min_index = std::min(source->topo_index, target->topo_index);
auto merged = target;
// Check if merged subgraph and its related subgraphs
// satisfy the topological order condition.
int upstream_max_index = -1, downstream_min_index = INT_MAX;
for (const auto& upstream : merged->upstreams) {
upstream_max_index = std::max(upstream->topo_index, upstream_max_index);
}
for (const auto& downstream : merged->downstreams) {
downstream_min_index =
std::min(downstream->topo_index, downstream_min_index);
}
// 1. If satisfy the topological order after merging, just set max_index
VLOG(6) << "Check if satisfy the topological order after merging";
if (min_index > upstream_max_index && max_index < downstream_min_index) {
merged->topo_index = max_index;
subgraph_index_set_.erase(min_index);
return;
}
// 2. If not satisfy the order, find a index between upstream_max_index
// and downstream_min_index while not in subgraph_index_set_.
VLOG(6) << "Try to find a valid index not in subgraph_index_set_";
for (int i = upstream_max_index + 1; i < downstream_min_index; ++i) {
if (!subgraph_index_set_.count(i)) {
merged->topo_index = i;
const auto& update_topo_index = [&]() -> void {
int max_index = std::max(source->topo_index, target->topo_index);
int min_index = std::min(source->topo_index, target->topo_index);
auto merged = target;
// Check if merged subgraph and its related subgraphs
// satisfy the topological order condition.
int upstream_max_index = -1, downstream_min_index = INT_MAX;
for (const auto& upstream : merged->upstreams) {
upstream_max_index = std::max(upstream->topo_index, upstream_max_index);
}
for (const auto& downstream : merged->downstreams) {
downstream_min_index =
std::min(downstream->topo_index, downstream_min_index);
}
// 1. If satisfy the topological order after merging, just set max_index
VLOG(6) << "Check if satisfy the topological order after merging";
if (min_index > upstream_max_index && max_index < downstream_min_index) {
merged->topo_index = max_index;
subgraph_index_set_.erase(min_index);
subgraph_index_set_.erase(max_index);
subgraph_index_set_.insert(i);
return;
}
// 2. If not satisfy the order, find a index between upstream_max_index
// and downstream_min_index while not in subgraph_index_set_.
VLOG(6) << "Try to find a valid index not in subgraph_index_set_";
for (int i = upstream_max_index + 1; i < downstream_min_index; ++i) {
if (!subgraph_index_set_.count(i)) {
merged->topo_index = i;
subgraph_index_set_.erase(min_index);
subgraph_index_set_.erase(max_index);
subgraph_index_set_.insert(i);
return;
}
}
// 3. If can not find a valid index, reorder topo index of all subgraphs.
VLOG(6) << "Reorder topo index of all subgraphs";
merged->topo_index = max_index;
ReorderIndexOfSubgraphs();
};
update_topo_index();
if (CheckSideEffectOpsOrder()) {
VLOG(6) << "Merged subgraph: " << target->DebugStr();
return true;
} else {
FallbackSubGraphFusion(source, target, source_back, target_back);
return false;
}
// 3. If can not find a valid index, reorder topo index of all subgraphs.
VLOG(6) << "Reorder topo index of all subgraphs";
ReorderIndexOfSubgraphs();
}

void SubgraphDetector::FallbackSubGraphFusion(const SubGraphPtr& source,
Expand Down Expand Up @@ -612,25 +635,63 @@ void SubgraphDetector::InitInplaceOpsOrder(
}
inplace_values_sets.push_back(inplace_input_values);
}
std::set<pir::Value> shared_inplace_values;
std::set<pir::Value> shared_inplace_values_set;
for (size_t i = 0; i < inplace_values_sets.size(); ++i) {
for (size_t j = i + 1; j < inplace_values_sets.size(); ++j) {
std::set_intersection(
inplace_values_sets[i].begin(),
inplace_values_sets[i].end(),
inplace_values_sets[j].begin(),
inplace_values_sets[j].end(),
std::inserter(shared_inplace_values, shared_inplace_values.begin()));
std::set_intersection(inplace_values_sets[i].begin(),
inplace_values_sets[i].end(),
inplace_values_sets[j].begin(),
inplace_values_sets[j].end(),
std::inserter(shared_inplace_values_set,
shared_inplace_values_set.begin()));
}
}
for (const auto& shared_value : shared_inplace_values) {
inplace_ops_order_.emplace_back();
std::vector<std::vector<pir::Operation*>> inplace_ops_order;
std::vector<pir::Value> shared_inplace_values;
for (const auto& shared_value : shared_inplace_values_set) {
inplace_ops_order.emplace_back();
shared_inplace_values.push_back(shared_value);
for (size_t i = 0; i < inplace_values_sets.size(); ++i) {
if (inplace_values_sets[i].count(shared_value)) {
inplace_ops_order_.back().push_back(inplace_ops[i]);
inplace_ops_order.back().push_back(inplace_ops[i]);
}
}
}
// If a value is inplaced by multiple ops, the order of ops which use this
// value after different inplace op also needs to be considered together.
for (size_t i = 0; i < inplace_ops_order.size(); ++i) {
auto only_inplace_ops = inplace_ops_order[i];
auto inplace_root_value = shared_inplace_values[i];
std::unordered_set<pir::Operation*> ordered_ops_set(
only_inplace_ops.begin(), only_inplace_ops.end());
for (const auto& inplace_op : only_inplace_ops) {
pir::Value output_inplace_value;
auto output_input_values = GetInplaceValues(inplace_op);
for (const auto& [output_value, _unused] : output_input_values) {
if (get_inplace_root_value(output_value) == inplace_root_value) {
output_inplace_value = output_value;
break;
}
}
if (output_inplace_value.use_empty()) continue;
for (auto use_iter = output_inplace_value.use_begin();
use_iter != output_inplace_value.use_end();
++use_iter) {
auto user_op = use_iter.owner();
if (ordered_ops_set.count(user_op)) continue;
ordered_ops_set.insert(user_op);
}
}
// Sort by origin order in blocks
std::vector<pir::Operation*> ordered_ops(ordered_ops_set.begin(),
ordered_ops_set.end());
std::sort(ordered_ops.begin(),
ordered_ops.end(),
[this](const auto& lhs, const auto& rhs) {
return this->op2index_.at(lhs) < this->op2index_.at(rhs);
});
this->inplace_ops_order_.push_back(ordered_ops);
}
}

SubgraphDetector::SubgraphDetector(pir::Block* block,
Expand Down Expand Up @@ -698,7 +759,6 @@ void SubgraphDetector::SubgraphFusion() {
if (upstream == downstream || !upstream->substitute) continue;
if (CanFuseUpstream2Downstream(upstream, downstream)) {
MergeSource2Target(upstream, downstream);
VLOG(6) << "Merged subgraph: " << downstream->DebugStr();
}
}
}
Expand All @@ -714,7 +774,6 @@ void SubgraphDetector::SubgraphFusion() {
if (brother == subgraph || !brother->substitute) continue;
if (!HasRoute(subgraph, brother) && !HasRoute(brother, subgraph)) {
MergeSource2Target(brother, subgraph);
VLOG(6) << "Merged subgraph: " << subgraph->DebugStr();
}
}
}
Expand All @@ -734,12 +793,9 @@ void SubgraphDetector::SubgraphFusion() {
}
SubGraph lhs_back = *lhs;
SubGraph rhs_back = *rhs;
MergeSource2Target(rhs, lhs);
if (CheckSideEffectOpsOrder()) {
if (MergeSource2Target(rhs, lhs)) {
subgraph_list.erase(subgraph_list.begin() + j);
VLOG(6) << "Merged subgraph: " << lhs->DebugStr();
} else {
FallbackSubGraphFusion(rhs, lhs, rhs_back, lhs_back);
++j;
}
}
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ def test_loop_change_value_to_int(self):
def loop_update_iter_inner_normal(x):
y = x + 1
out = 0
for item in y:
y[0] = paddle.full([], 1, dtype="int64")
for i in range(len(y)):
y[0] = paddle.full([], 1, dtype="int64") + i
out += y
return out

Expand Down
Loading