Skip to content

[CINN] Optimize compute inline conditions #72281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 77 additions & 89 deletions paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"

#include <string>
#include <vector>

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"

namespace cinn {
namespace ir {
namespace {

/**
* The types of the AutoInline
Expand Down Expand Up @@ -51,6 +46,12 @@ class ComputeInlineTactic final : public ScheduleTactic {
bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr,
ir::IRSchedule* ir_sch) const;

// Check whether all consumers of block have their load indices aligned with
// the block (i.e. no cross-thread access).
bool CheckAllConsumersAligned(const Expr& block,
const std::vector<ir::Expr>& consumers,
ir::IRSchedule* ir_sch) const;

std::unordered_set<std::string> output_names_;
cinn::common::Target target_;
};
Expand All @@ -60,108 +61,94 @@ void ComputeInlineTactic::Init(ScheduleContext* context) {
target_ = context->target;
}

bool ComputeInlineTactic::CanInlineIntoConsumer(
const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
const ir::ScheduleBlockRealize* sche_block_realize =
sche_block_realize_expr.As<ir::ScheduleBlockRealize>();
const ir::ScheduleBlock* sche_block =
sche_block_realize->schedule_block.As<ir::ScheduleBlock>();
ir::Expr compute_body = sche_block->body;
ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr);

// Check the schedule block to be inlined is not a reduce tensor.
for (const ir::Var& iter_var : sche_block->iter_vars) {
if (iter_var->is_reduce_axis) {
return false;
}
}
std::vector<ir::Expr> find_store = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); });
if (find_store.size() != 1UL) {
return false;
}

ir::Expr tensor_expr = (*find_store.begin()).As<ir::Store>()->tensor;
ir::Tensor tensor = tensor_expr.as_tensor_ref();
if (tensor->is_reduce_tensor()) {
return false;
}

// LoweredFunc output can be tensor name or tensor buffer name
if (output_names_.find(tensor->name) != output_names_.end() ||
output_names_.find(tensor->buffer->name) != output_names_.end()) {
return false;
int64_t GetSerialLoopExtent(const std::vector<ir::Expr>& loops) {
int64_t extent = 1;
for (auto& loop : loops) {
auto* node = loop.As<ir::For>();
if (node->is_binded()) continue;
if (!node->extent.is_constant()) return -1;
extent *= node->extent.as_int64();
}
return extent;
}

// the xxx_reduce_init block cannot be inlined.
if (ir::IsReduceInitTensorName(tensor->name)) {
return false;
bool ComputeInlineTactic::CheckAllConsumersAligned(
const Expr& block,
const std::vector<ir::Expr>& consumers,
ir::IRSchedule* ir_sch) const {
ir::Expr store = ir::analyzer::GetStoreOfSBlock(block);
auto* tensor = store.As<ir::Store>()->tensor.as_tensor();
std::vector<ir::Expr> loops = ir_sch->GetLoops(block);

std::vector<ir::Expr> store_indices;
for (ir::Expr index : store.As<ir::Store>()->indices) {
index = ir::analyzer::ExpandIterVar(index, block);
index = ir::analyzer::CanonicalizeLoopVar(index, loops);
store_indices.push_back(index);
}

// Skip external calls
std::vector<ir::Expr> consumers =
ir::GetConsumers(sche_block_realize_expr, root);
for (const ir::Expr& consumer : consumers) {
std::vector<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
consumer.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body,
[&](const ir::Expr* x) {
return x->As<ir::Load>() &&
x->As<ir::Load>()->tensor.as_tensor_ref()->name ==
tensor->name;
});
if (find_load.empty()) {
const auto CheckLoadsAligned = [&](const ir::Expr& expr) {
bool aligned = true;
ir::ir_utils::CollectIRNodesInOrder(expr, [&](const ir::Expr* x) {
auto* node = x->As<ir::Load>();
if (node && node->tensor.as_tensor()->name == tensor->name) {
if (node->indices != store_indices) {
aligned = false;
}
}
return false;
}
});
return aligned;
};

for (auto& consumer_block : consumers) {
ir::Expr consumer_store = ir::analyzer::GetStoreOfSBlock(consumer_block);
std::vector<ir::Expr> consumer_loops = ir_sch->GetLoops(consumer_block);
ir::Expr value = consumer_store.As<ir::Store>()->value;
value = ir::analyzer::ExpandIterVar(value, consumer_block);
value = ir::analyzer::CanonicalizeLoopVar(value, consumer_loops);
if (!CheckLoadsAligned(value)) return false;
}

// write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer
if (sche_block->read_buffers.empty()) {
return true;
}
return true;
}

// Check this schedule block is the only writer of the tensor.
find_store =
ir::ir_utils::CollectIRNodesWithoutTensor(root, [&](const Expr* x) {
return x->As<ir::Store>() &&
(x->As<ir::Store>()->tensor).as_tensor_ref()->name ==
tensor->name;
});
if (find_store.size() != 1UL) {
return false;
}
// Check there is no overlap between the buffers the schedule block reads and
// writes.
std::vector<ir::Expr> find_load = ir::ir_utils::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) {
return x->As<ir::Load>() && x->As<ir::Load>()->tensor == tensor_expr;
});
if (!find_load.empty()) {
bool ComputeInlineTactic::CanInlineIntoConsumer(const Expr& block,
ir::IRSchedule* ir_sch) const {
ir::Expr root = ir_sch->GetRootBlock(block);
ir::Expr store = ir::analyzer::GetStoreOfSBlock(block);
auto* tensor = store.As<ir::Store>()->tensor.as_tensor();

// 1. It is not a reduce nor reduce_init.
if (ir::analyzer::IsReductionSBlock(block) || tensor->is_reduce_tensor() ||
ir::IsReduceInitTensorName(tensor->name)) {
return false;
}

ir::Expr store = *(find_store.begin());

ir::ComputeInliner inliner(store.As<ir::Store>()->tensor.as_tensor_ref(),
store);
if (!inliner.BodyPatternAllowInline()) {
// 2. It is not an output node.
if (output_names_.count(tensor->name) > 0) {
return false;
}

ir::LeafBlockRemovalPlan remove_plan(
sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root);
if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) {
// 3. For block with multiple consumers, we prefer to buffer the intermediate
// result instead of inlining it in order to avoid redundant computation,
// if the following conditions are also satisfied:
// 1) The loop extent is <= 8, otherwise the intermediate result is too large
// to buffer.
// 2) Its consumers are all aligned with it, otherwise it will incur cross-
// thread access, which is not possible using local buffer.
std::vector<ir::Expr> consumers = ir::GetConsumers(block, root);
int64_t loop_extent = GetSerialLoopExtent(ir_sch->GetLoops(block));
bool is_small_loop = loop_extent <= 8 && loop_extent != -1;
if (consumers.size() > 1 && is_small_loop &&
CheckAllConsumersAligned(block, consumers, ir_sch)) {
return false;
}

VLOG(6) << "Found store Expr " << store << ", which CanInlineIntoConsumer";
return true;
}

namespace {
bool ContainsNodeType(ir::Expr expr,
const std::unordered_set<ir::IrNodeTy>& node_types) {
std::vector<ir::Expr> collection =
Expand Down Expand Up @@ -211,7 +198,6 @@ void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
return false;
});
}
} // namespace

AutoInlineType ComputeInlineTactic::AnalyzeInlineType(
const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const {
Expand Down Expand Up @@ -293,6 +279,8 @@ void ComputeInlineTactic::Apply(ir::IRSchedule* sch,
<< sch->GetModule().GetExprs().front();
}

} // namespace

std::unique_ptr<ScheduleTactic> CreateComputeInlineTactic() {
return std::make_unique<ComputeInlineTactic>();
}
Expand Down
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_mnist_pure_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from time import time

import numpy as np
from test_mnist import MNIST, SEED, TestMNIST
from test_mnist import MNIST, TestMNIST

import paddle

SEED = 2025

if paddle.base.is_compiled_with_cuda():
paddle.base.set_flags({'FLAGS_cudnn_deterministic': True})

Expand Down
Loading