Skip to content

Commit cce726a

Browse files
authored
【CINN】Add boundary simplify (PaddlePaddle#72000)
* add boundary simplify * fix ci bug
1 parent 83a02da commit cce726a

File tree

7 files changed

+103
-18
lines changed

7 files changed

+103
-18
lines changed

paddle/cinn/common/ir_util.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
203203
optim::SimplifyCast(&indice_cast);
204204
res = RampRelatedAdd(RampRelatedMul(res, shape[i]), indice_cast);
205205
if (res.is_index()) {
206-
res = res.as_index().Normalize(ir::IndexExpr::OptLevel::Level2);
206+
res = res.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2);
207207
} else {
208208
VLOG(8) << "**** expr is not index ****: " << res;
209209
}

paddle/cinn/ir/ir_base.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,15 @@ IndexExpr Simplify(const IndexExpr &expr, IndexExpr::OptLevel level) {
526526
auto rhs = Simplify(expr.operand(1), level);
527527
auto res =
528528
optim::ConstructIndexExprByNodeType(expr.node_type(), lhs, rhs);
529-
if (level == IndexExpr::OptLevel::Level2 &&
530-
expr.node_type() == ir::IrNodeTy::Add)
529+
if (level >= IndexExpr::OptLevel::kLevel2 &&
530+
expr.node_type() == ir::IrNodeTy::Add) {
531531
res = common::MergeMulMod(res);
532+
}
533+
if (level == IndexExpr::OptLevel::kLevel3 &&
534+
(expr.node_type() == ir::IrNodeTy::Div ||
535+
expr.node_type() == ir::IrNodeTy::Mod)) {
536+
res = optim::BoundSimplify(res);
537+
}
532538
return res;
533539
}
534540
default:

paddle/cinn/ir/ir_base.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,18 @@ struct IndexExpr : public IrNodeRef {
554554
* Level2: Each factor in the expression is attempted to be simplified with
555555
* the other factors
556556
* e.g. x / 2 * 2 + y / 2 + 5 + x % 2 ==> y / 2 + x + 5
557+
* Level3: Simplify with boundary.
558+
* e.g. x % S0 ==> x if x < S0
559+
* x / S0 ==> 0 if x < S0
557560
*
558561
* Note: Because IndexExpr is generated in order, Short operand is at the
559562
* end of the expression, so Level1 is usually used.
560563
*/
561564
enum class OptLevel {
562-
Level0 = 0, // TODO(liujinnan): Only constant folding is performed
563-
Level1 = 1, // Constant folding and sequential simplification are performed
564-
Level2 = 2 // Top level, simplify
565+
kLevel0 = 0, // TODO(liujinnan): Only constant folding is performed
566+
kLevel1 = 1,
567+
kLevel2 = 2,
568+
kLevel3 = 3 // Top level, simplify
565569
};
566570

567571
enum class IndexType {
@@ -571,7 +575,7 @@ struct IndexExpr : public IrNodeRef {
571575
kCast = 3 // exist cast
572576
};
573577

574-
IndexExpr Normalize(OptLevel level = OptLevel::Level1) const;
578+
IndexExpr Normalize(OptLevel level = OptLevel::kLevel1) const;
575579

576580
bool IsDynamic() const;
577581

paddle/cinn/optim/if_fold_pass.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ LogicalResult IfFoldPass::Run(StmtRef stmt) {
121121
VLOG(6) << "-------------cond_vec end----------------";
122122

123123
// Normalize expr to simplify the expr after Mul and Sum.
124-
expr = expr.Normalize(ir::IndexExpr::OptLevel::Level2);
124+
expr = expr.Normalize(ir::IndexExpr::OptLevel::kLevel2);
125125

126126
if (expr != ir::IndexExpr(0) && expr.length() < min_len &&
127127
inner_op.defined()) {

paddle/cinn/optim/simplify_util.cc

+24
Original file line numberDiff line numberDiff line change
@@ -662,5 +662,29 @@ std::optional<std::unordered_map<std::string, ir::IndexExpr>> MatchPattern(
662662

663663
return std::nullopt;
664664
}
665+
666+
ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr) {
667+
// return expr if expr is not a division or modulo
668+
if (expr.node_type() != ir::IrNodeTy::Div &&
669+
expr.node_type() != ir::IrNodeTy::Mod)
670+
return expr;
671+
672+
common::cas_intervals_t var_intervals =
673+
common::CollectVarIntervalsOfExprs({expr});
674+
common::SymbolicExprAnalyzer ana(var_intervals);
675+
// Because the SymbolicExprAnalyzer bound result is [lower, upper), `ProveLE`
676+
// is used here instead of `ProveLT`.
677+
auto canBeSimplified =
678+
ana.ProveLE(ana.UpperBound(expr.operand(0)), expr.operand(1));
679+
680+
if (canBeSimplified.value_or(false)) {
681+
if (expr.node_type() == ir::IrNodeTy::Div) {
682+
return ir::IndexExpr(0);
683+
} else if (expr.node_type() == ir::IrNodeTy::Mod) {
684+
return expr.operand(0);
685+
}
686+
}
687+
return expr;
688+
}
665689
} // namespace optim
666690
} // namespace cinn

paddle/cinn/optim/simplify_util.h

+11
Original file line numberDiff line numberDiff line change
@@ -313,5 +313,16 @@ std::optional<std::unordered_map<std::string, ir::IndexExpr>> MatchPattern(
313313
const std::function<bool(
314314
const std::unordered_map<std::string, ir::IndexExpr> &)> &condition =
315315
nullptr);
316+
317+
/*!
318+
* \brief Simplify IndexExpr with bound information.
319+
* For example:
320+
* x % S0 ==> x if x < S0
321+
* x / S0 ==> 0 if x < S0
322+
*
323+
* \param expr The `IndexExpr` to be simplified.
324+
* \return `IndexExpr` after simplification.
325+
*/
326+
ir::IndexExpr BoundSimplify(const ir::IndexExpr &expr);
316327
} // namespace optim
317328
} // namespace cinn

test/cpp/pir/cinn/adt/index_expr_test.cc

+50-10
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ TEST_F(TestIndexExpr, IndexExpr_3) {
183183
EXPECT_EQ(q14.as_index().Normalize(), ir::IndexExpr(S4 + S5));
184184
EXPECT_EQ(q15.as_index().Normalize(),
185185
ir::IndexExpr((S4 * 256 + S5 + S6 * 1024)) % 25088);
186-
EXPECT_EQ(q16.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
186+
EXPECT_EQ(q16.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
187187
ir::IndexExpr(S4 * 256 + S5));
188188
EXPECT_EQ(q17.as_index().Normalize(), ir::IndexExpr(S4 / S5));
189189
EXPECT_EQ(q18.as_index().Normalize(),
@@ -301,18 +301,19 @@ TEST_F(TestIndexExpr, Test_dynamic) {
301301
(f % (S5 * S6)));
302302

303303
EXPECT_EQ(
304-
q.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
304+
q.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
305305
((((((((S7 * 1024) + S8) + (S9 * 4096)) / ((S5 * S6) * 640)) * S5) * S6) *
306306
S4) +
307307
(((((S7 * 1024) + S8) + (S9 * 4096)) % ((S5 * S6) * 640)) %
308308
((S5 * S6) * S4))));
309-
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
309+
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
310310
((f % ((S5 * S6) * 640)) % ((S5 * S6) * S4)));
311-
EXPECT_EQ(q2.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
311+
EXPECT_EQ(q2.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
312312
f % (S5 * S6));
313-
EXPECT_EQ(q3.as_index().Normalize(ir::IndexExpr::OptLevel::Level2), Expr(S4));
314-
EXPECT_EQ(q4.as_index().Normalize(ir::IndexExpr::OptLevel::Level2), Expr(0));
315-
EXPECT_EQ(q5.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
313+
EXPECT_EQ(q3.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
314+
Expr(S4));
315+
EXPECT_EQ(q4.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2), Expr(0));
316+
EXPECT_EQ(q5.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
316317
(((((f / ((S5 * S6) * 640)) * S4) * S5) * S6) +
317318
((f % ((S5 * S6) * 640)) % ((S5 * S6) * S4))));
318319
}
@@ -492,12 +493,12 @@ TEST_F(TestIndexExpr, CommonFactor) {
492493
((S2 * S3) * S13)) +
493494
((S2 * S3) * S1))));
494495

495-
EXPECT_EQ(q.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
496+
EXPECT_EQ(q.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
496497
(((((((S1 + S13) + S17) + S21) + S5) + S9) * S2) * S3));
497-
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
498+
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
498499
(((((((S5 + S9) + S21) + S17) + S13) + S1) * S2) * S3));
499500
EXPECT_EQ(
500-
q2.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
501+
q2.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel2),
501502
((((f * 1024) + tx) + (bx * 4096)) %
502503
((((((((((((((((S5 + S9) + S21) + S17) + S13) + S1) * S2) * S3) * S0) /
503504
4096) *
@@ -642,5 +643,44 @@ TEST_F(TestIndexExpr, MatchPattern) {
642643
EXPECT_EQ(result9->at("x"), x);
643644
EXPECT_EQ(result9->at("y"), y);
644645
}
646+
TEST_F(TestIndexExpr, BoundSimplify) {
647+
ir::Var S0 = ir::Var("S0");
648+
ir::Var i = ir::Var(ir::Expr(0), ir::Expr(5), "i");
649+
ir::Var j = ir::Var(ir::Expr(0), S0, "j");
650+
651+
ir::Expr q0 = i / Expr(5);
652+
ir::Expr q1 = i / Expr(4);
653+
ir::Expr q2 = i / Expr(6);
654+
ir::Expr q3 = j / S0;
655+
ir::Expr q4 = j / (S0 - 1);
656+
ir::Expr q5 = j / (S0 + 1);
657+
658+
ir::Expr q6 = i % Expr(5);
659+
ir::Expr q7 = i % Expr(4);
660+
ir::Expr q8 = i % Expr(6);
661+
ir::Expr q9 = j % S0;
662+
ir::Expr q10 = j % (S0 - 1);
663+
ir::Expr q11 = j % (S0 + 1);
664+
EXPECT_EQ(q0.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
665+
ir::Expr(0));
666+
EXPECT_EQ(q1.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
667+
i / Expr(4));
668+
EXPECT_EQ(q2.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
669+
ir::Expr(0));
670+
EXPECT_EQ(q3.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
671+
ir::Expr(0));
672+
EXPECT_EQ(q4.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
673+
j / (S0 + ir::Expr(-1)));
674+
EXPECT_EQ(q5.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
675+
ir::Expr(0));
676+
EXPECT_EQ(q6.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), i);
677+
EXPECT_EQ(q7.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
678+
i % Expr(4));
679+
EXPECT_EQ(q8.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), i);
680+
EXPECT_EQ(q9.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), j);
681+
EXPECT_EQ(q10.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3),
682+
j % (S0 + ir::Expr(-1)));
683+
EXPECT_EQ(q11.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3), j);
684+
}
645685
} // namespace common
646686
} // namespace cinn

0 commit comments

Comments
 (0)