Skip to content

【CINN】Add boundary simplify #72000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
optim::SimplifyCast(&indice_cast);
res = RampRelatedAdd(RampRelatedMul(res, shape[i]), indice_cast);
if (res.is_index()) {
res = res.as_index().Normalize(ir::IndexExpr::OptLevel::Level2);
res = res.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2);
} else {
VLOG(8) << "**** expr is not index ****: " << res;
}
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/ir/ir_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,15 @@ IndexExpr Simplify(const IndexExpr &expr, IndexExpr::OptLevel level) {
auto rhs = Simplify(expr.operand(1), level);
auto res =
optim::ConstructIndexExprByNodeType(expr.node_type(), lhs, rhs);
if (level == IndexExpr::OptLevel::Level2 &&
expr.node_type() == ir::IrNodeTy::Add)
if (level >= IndexExpr::OptLevel::kLevel2 &&
expr.node_type() == ir::IrNodeTy::Add) {
res = common::MergeMulMod(res);
}
if (level == IndexExpr::OptLevel::kLevel3 &&
(expr.node_type() == ir::IrNodeTy::Div ||
expr.node_type() == ir::IrNodeTy::Mod)) {
res = optim::BoundSimplify(res);
}
return res;
}
default:
Expand Down
12 changes: 8 additions & 4 deletions paddle/cinn/ir/ir_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,14 +554,18 @@ struct IndexExpr : public IrNodeRef {
* Level2: Each factor in the expression is attempted to be simplified with
* the other factors
* e.g. x / 2 * 2 + y / 2 + 5 + x % 2 ==> y / 2 + x + 5
* Level3: Simplify with boundary.
* e.g. x % S0 ==> x if x < S0
* x / S0 ==> 0 if x < S0
*
* Note: Because IndexExpr is generated in order, Short operand is at the
* end of the expression, so Level1 is usually used.
*/
enum class OptLevel {
Level0 = 0, // TODO(liujinnan): Only constant folding is performed
Level1 = 1, // Constant folding and sequential simplification are performed
Level2 = 2 // Top level, simplify
kLevel0 = 0, // TODO(liujinnan): Only constant folding is performed
kLevel1 = 1,
kLevel2 = 2,
kLevel3 = 3 // Top level, simplify
Comment on lines +565 to +568
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加些注释说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image 注释在枚举类上方

};

enum class IndexType {
Expand All @@ -571,7 +575,7 @@ struct IndexExpr : public IrNodeRef {
kCast = 3 // exist cast
};

IndexExpr Normalize(OptLevel level = OptLevel::Level1) const;
IndexExpr Normalize(OptLevel level = OptLevel::kLevel1) const;

bool IsDynamic() const;

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/optim/if_fold_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ LogicalResult IfFoldPass::Run(StmtRef stmt) {
VLOG(6) << "-------------cond_vec end----------------";

// Normalize expr to simplify the expr after Mul and Sum.
expr = expr.Normalize(ir::IndexExpr::OptLevel::Level2);
expr = expr.Normalize(ir::IndexExpr::OptLevel::kLevel2);

if (expr != ir::IndexExpr(0) && expr.length() < min_len &&
inner_op.defined()) {
Expand Down
24 changes: 24 additions & 0 deletions paddle/cinn/optim/simplify_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,5 +662,29 @@ std::optional<std::unordered_map<std::string, ir::IndexExpr>> MatchPattern(

return std::nullopt;
}

ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr) {
// return expr if expr is not a division or modulo
if (expr.node_type() != ir::IrNodeTy::Div &&
expr.node_type() != ir::IrNodeTy::Mod)
return expr;

common::cas_intervals_t var_intervals =
common::CollectVarIntervalsOfExprs({expr});
common::SymbolicExprAnalyzer ana(var_intervals);
// Because the SymbolicExprAnalyzer bound result is [lower, upper), `ProveLE`
// is used here instead of `ProveLT`.
auto canBeSimplified =
ana.ProveLE(ana.UpperBound(expr.operand(0)), expr.operand(1));

if (canBeSimplified.value_or(false)) {
if (expr.node_type() == ir::IrNodeTy::Div) {
return ir::IndexExpr(0);
} else if (expr.node_type() == ir::IrNodeTy::Mod) {
return expr.operand(0);
}
}
return expr;
}
} // namespace optim
} // namespace cinn
11 changes: 11 additions & 0 deletions paddle/cinn/optim/simplify_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,5 +313,16 @@ std::optional<std::unordered_map<std::string, ir::IndexExpr>> MatchPattern(
const std::function<bool(
const std::unordered_map<std::string, ir::IndexExpr> &)> &condition =
nullptr);

/*!
* \brief Simplify IndexExpr with bound information.
* For example:
* x % S0 ==> x if x < S0
* x / S0 ==> 0 if x < S0
*
* \param expr The `IndexExpr` to be simplified.
* \return `IndexExpr` after simplification.
*/
ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr);
} // namespace optim
} // namespace cinn
60 changes: 50 additions & 10 deletions test/cpp/pir/cinn/adt/index_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ TEST_F(TestIndexExpr, IndexExpr_3) {
EXPECT_EQ(q14.as_index().Normalize(), ir::IndexExpr(S4 + S5));
EXPECT_EQ(q15.as_index().Normalize(),
ir::IndexExpr((S4 * 256 + S5 + S6 * 1024)) % 25088);
EXPECT_EQ(q16.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
EXPECT_EQ(q16.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
ir::IndexExpr(S4 * 256 + S5));
EXPECT_EQ(q17.as_index().Normalize(), ir::IndexExpr(S4 / S5));
EXPECT_EQ(q18.as_index().Normalize(),
Expand Down Expand Up @@ -301,18 +301,19 @@ TEST_F(TestIndexExpr, Test_dynamic) {
(f % (S5 * S6)));

EXPECT_EQ(
q.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
q.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
((((((((S7 * 1024) + S8) + (S9 * 4096)) / ((S5 * S6) * 640)) * S5) * S6) *
S4) +
(((((S7 * 1024) + S8) + (S9 * 4096)) % ((S5 * S6) * 640)) %
((S5 * S6) * S4))));
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
((f % ((S5 * S6) * 640)) % ((S5 * S6) * S4)));
EXPECT_EQ(q2.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
EXPECT_EQ(q2.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
f % (S5 * S6));
EXPECT_EQ(q3.as_index().Normalize(ir::IndexExpr::OptLevel::Level2), Expr(S4));
EXPECT_EQ(q4.as_index().Normalize(ir::IndexExpr::OptLevel::Level2), Expr(0));
EXPECT_EQ(q5.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
EXPECT_EQ(q3.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
Expr(S4));
EXPECT_EQ(q4.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2), Expr(0));
EXPECT_EQ(q5.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
(((((f / ((S5 * S6) * 640)) * S4) * S5) * S6) +
((f % ((S5 * S6) * 640)) % ((S5 * S6) * S4))));
}
Expand Down Expand Up @@ -492,12 +493,12 @@ TEST_F(TestIndexExpr, CommonFactor) {
((S2 * S3) * S13)) +
((S2 * S3) * S1))));

EXPECT_EQ(q.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
EXPECT_EQ(q.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
(((((((S1 + S13) + S17) + S21) + S5) + S9) * S2) * S3));
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
(((((((S5 + S9) + S21) + S17) + S13) + S1) * S2) * S3));
EXPECT_EQ(
q2.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
q2.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
((((f * 1024) + tx) + (bx * 4096)) %
((((((((((((((((S5 + S9) + S21) + S17) + S13) + S1) * S2) * S3) * S0) /
4096) *
Expand Down Expand Up @@ -642,5 +643,44 @@ TEST_F(TestIndexExpr, MatchPattern) {
EXPECT_EQ(result9->at("x"), x);
EXPECT_EQ(result9->at("y"), y);
}
TEST_F(TestIndexExpr, BoundSimplify) {
ir::Var S0 = ir::Var("S0");
ir::Var i = ir::Var(ir::Expr(0), ir::Expr(5), "i");
ir::Var j = ir::Var(ir::Expr(0), S0, "j");

ir::Expr q0 = i / Expr(5);
ir::Expr q1 = i / Expr(4);
ir::Expr q2 = i / Expr(6);
ir::Expr q3 = j / S0;
ir::Expr q4 = j / (S0 - 1);
ir::Expr q5 = j / (S0 + 1);

ir::Expr q6 = i % Expr(5);
ir::Expr q7 = i % Expr(4);
ir::Expr q8 = i % Expr(6);
ir::Expr q9 = j % S0;
ir::Expr q10 = j % (S0 - 1);
ir::Expr q11 = j % (S0 + 1);
EXPECT_EQ(q0.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
ir::Expr(0));
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
i / Expr(4));
EXPECT_EQ(q2.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
ir::Expr(0));
EXPECT_EQ(q3.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
ir::Expr(0));
EXPECT_EQ(q4.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
j / (S0 + ir::Expr(-1)));
EXPECT_EQ(q5.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
ir::Expr(0));
EXPECT_EQ(q6.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), i);
EXPECT_EQ(q7.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
i % Expr(4));
EXPECT_EQ(q8.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), i);
EXPECT_EQ(q9.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), j);
EXPECT_EQ(q10.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
j % (S0 + ir::Expr(-1)));
EXPECT_EQ(q11.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), j);
}
} // namespace common
} // namespace cinn
Loading