Skip to content

Commit 92eca04

Browse files
authored
Revert "【CINN】Set index for IrNode (#70208)" (#70427)
* Revert "【CINN】Set index for IrNode (#70208)" This reverts commit 46b17db. * fix conflict * fix conflict
1 parent 421c69c commit 92eca04

24 files changed

+112
-411
lines changed

paddle/cinn/common/axis.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ std::string axis_name(int level) {
6868
std::vector<ir::Var> GenDefaultAxis(int naxis) {
6969
std::vector<ir::Var> axis;
7070
for (int i = 0; i < naxis; i++) {
71-
Var ax(cinn::common::axis_name(i));
72-
axis.emplace_back(ax.set_index(true));
71+
axis.emplace_back(cinn::common::axis_name(i));
7372
PADDLE_ENFORCE_EQ(axis.back()->type().valid(),
7473
true,
7574
::common::errors::InvalidArgument(
@@ -84,7 +83,7 @@ std::vector<ir::Expr> GenDefaultAxisAsExpr(int naxis) {
8483
auto vars = GenDefaultAxis(naxis);
8584
std::vector<Expr> res;
8685
for (auto& v : vars) {
87-
res.push_back(Expr(v.set_index(true)));
86+
res.push_back(Expr(v));
8887
}
8988
return res;
9089
}

paddle/cinn/common/cas.cc

+2-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ using namespace ir; // NOLINT
3636
Expr AutoSimplify(
3737
const Expr& u,
3838
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
39-
bool is_index = u.is_index();
4039
VLOG(7) << "Begin AutoSimplify: " << u;
4140
Expr copied = ir::ir_utils::IRCopy(u);
4241
if (copied.type().is_float()) {
@@ -57,7 +56,7 @@ Expr AutoSimplify(
5756
copied = CasSimplify(copied, s_var_intervals);
5857
copied = detail::ConvertCasToCinn(copied);
5958
VLOG(7) << "End AutoSimplify " << copied;
60-
return is_index ? copied.set_index(true) : copied;
59+
return copied;
6160
}
6261

6362
int gcd(int a, int b) {
@@ -626,6 +625,7 @@ Expr CasSimplifyMutator::SimplifySum(Expr u) {
626625
if (!temp.As<Sum>()) return temp;
627626

628627
operands = temp.As<Sum>()->operands();
628+
629629
auto args = SimplifySumRec(operands);
630630
if (args.empty()) return make_const(u.type(), 0);
631631
if (args.size() == 1) return args[0];
@@ -1654,9 +1654,6 @@ Expr ConvertCinnToCAS(Expr expr) {
16541654
void operator()(Expr* expr) { Visit(expr); }
16551655
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
16561656

1657-
// Because indice of `Load` is simplify by IndexExpr, we just skip it.
1658-
void Visit(const Load* op, Expr* expr) override { return; }
1659-
16601657
private:
16611658
void Visit(const Add* op, Expr* expr) override {
16621659
auto a = op->a();

paddle/cinn/common/const_fold.h

-24
Original file line numberDiff line numberDiff line change
@@ -107,29 +107,5 @@ inline std::optional<ir::Expr> TryConstFold<ir::Mod>(ir::Expr a, ir::Expr b) {
107107
return std::nullopt;
108108
}
109109

110-
template <>
111-
inline std::optional<ir::Expr> TryConstFold<ir::Min>(ir::Expr a, ir::Expr b) {
112-
const ir::IntImm* pa = a.As<ir::IntImm>();
113-
const ir::IntImm* pb = b.As<ir::IntImm>();
114-
const auto& rtype = a.type();
115-
if (pa && pb) {
116-
int64_t res = std::min(pa->value, pb->value);
117-
return cinn::common::make_shared<ir::IntImm>(rtype, res);
118-
}
119-
return std::nullopt;
120-
}
121-
122-
template <>
123-
inline std::optional<ir::Expr> TryConstFold<ir::Max>(ir::Expr a, ir::Expr b) {
124-
const ir::IntImm* pa = a.As<ir::IntImm>();
125-
const ir::IntImm* pb = b.As<ir::IntImm>();
126-
const auto& rtype = a.type();
127-
if (pa && pb) {
128-
int64_t res = std::max(pa->value, pb->value);
129-
return cinn::common::make_shared<ir::IntImm>(rtype, res);
130-
}
131-
return std::nullopt;
132-
}
133-
134110
} // namespace common
135111
} // namespace cinn

paddle/cinn/common/dim_expr_converter.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct DimExprToIrExprVisitor {
3737
dim_expr,
3838
/* is_reduce = */ false,
3939
/* is_symbolic_constant = */ true);
40-
return x.set_index(true);
40+
return x;
4141
}
4242

4343
ir::Expr operator()(const Negative<DimExpr>& dim_expr) {

paddle/cinn/common/ir_util.cc

+1-14
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,6 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
197197
i,
198198
shape[i].type()));
199199

200-
// if(VerifyIndex(shape[i]))shape[i].set_index(true);
201-
// if(VerifyIndex(indices[i]))indices[i].set_index(true);
202-
203200
Expr indice_cast = indices[i];
204201
optim::SimplifyCast(&indice_cast);
205202
res = RampRelatedAdd(RampRelatedMul(res, shape[i]), indice_cast);
@@ -478,11 +475,10 @@ bool ComparePriority(const ir::IndexExpr &lhs, const ir::IndexExpr &rhs) {
478475
if (auto rhsVar = rhs.As<ir::_Var_>())
479476
return std::make_tuple(lhsVar->name.length(), lhsVar->name) <=
480477
std::make_tuple(rhsVar->name.length(), rhsVar->name);
481-
482478
auto lhsLen = lhs.length();
483479
auto rhsLen = rhs.length();
484480
if (lhsLen < rhsLen) return false;
485-
// Add < Mul < Div < Mod < Min < Max < Load.
481+
// Add < Mul < Div < Mod.
486482
else if (lhsLen == rhsLen)
487483
return lhs.node_type() <= rhs.node_type();
488484
else
@@ -513,10 +509,6 @@ bool IsSumPartialBySymbol(const ir::IndexExpr &expr,
513509
return IsSumPartialBySymbol(expr.operand(0), symbol);
514510
}
515511
case ir::IrNodeTy::Mod:
516-
case ir::IrNodeTy::Min:
517-
case ir::IrNodeTy::Max:
518-
case ir::IrNodeTy::Load:
519-
case ir::IrNodeTy::Cast:
520512
return false;
521513
default:
522514
PADDLE_THROW(::common::errors::InvalidArgument(
@@ -594,11 +586,6 @@ bool IsDivisiblieBySymbol(const ir::IndexExpr &expr,
594586
if (ty != expr.node_type()) return false;
595587
return IsDivisiblieBySymbol(expr.operand(0), symbol, expr.node_type());
596588
}
597-
case ir::IrNodeTy::Min:
598-
case ir::IrNodeTy::Max:
599-
case ir::IrNodeTy::Load:
600-
case ir::IrNodeTy::Cast:
601-
return false;
602589
default:
603590
PADDLE_THROW(::common::errors::InvalidArgument(
604591
"Unsupported type of expr in IsDivisiblieBySymbol which is: %s",

paddle/cinn/common/ir_util.h

-8
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,5 @@ bool ProveDivisible(const ir::IndexExpr &lhs, const ir::IndexExpr &rhs);
339339
bool IsNegatedIndexExpr(const ir::IndexExpr &candidate,
340340
ir::IndexExpr &expr); // NOLINT
341341

342-
/*!
343-
* \brief Judge type of `expr` is valid type of `IndexExpr` or not.
344-
* \param expr The expression to be checked.
345-
* \return A boolean value indicating whether the type of `expr` is valid
346-
* IndexExpr type.
347-
*/
348-
bool VerifyIndex(const ir::Expr &expr);
349-
350342
} // namespace common
351343
} // namespace cinn

paddle/cinn/common/iter_simplify.cc

+12-11
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ ir::IndexExpr IterMapToExprNormalizer::ConvertIterSplit(ir::IterSplit* expr) {
6161
Visit(&(mark->source), &(mark->source));
6262
source = mark->source;
6363
}
64+
6465
// quick branch
6566
if (IsZero(expr->scale) || IsOne(expr->extent))
6667
return ir::Zero(expr->extent.type());
@@ -88,7 +89,7 @@ void IterMapRewriter::Visit(const ir::_Var_* op, Expr* expr) {
8889
void IterMapRewriter::Visit(const ir::Add* op, Expr* expr) {
8990
auto a = op->a();
9091
auto b = op->b();
91-
VLOG(10) << "in visit add: " << a << " " << b;
92+
9293
Visit(&a);
9394
Visit(&b);
9495

@@ -102,6 +103,7 @@ void IterMapRewriter::Visit(const ir::Add* op, Expr* expr) {
102103

103104
Expr ret = ir::ir_utils::IRCopy(ToIterSum(a));
104105
ir::IterSum* ret_sum = ret.As<ir::IterSum>();
106+
105107
if (auto b_sum = b.As<ir::IterSum>()) {
106108
AddToLhs(ret_sum, *b_sum, 1);
107109
} else if (auto b_split = b.As<ir::IterSplit>()) {
@@ -110,13 +112,12 @@ void IterMapRewriter::Visit(const ir::Add* op, Expr* expr) {
110112
ret_sum->base = ret_sum->base + b.as_index();
111113
}
112114
*expr = ret;
113-
VLOG(10) << "out visit add";
114115
}
115116

116117
void IterMapRewriter::Visit(const ir::Sub* op, Expr* expr) {
117118
auto a = op->a();
118119
auto b = op->b();
119-
VLOG(10) << "in visit sub: " << a << " " << b;
120+
120121
Visit(&a);
121122
Visit(&b);
122123

@@ -138,13 +139,12 @@ void IterMapRewriter::Visit(const ir::Sub* op, Expr* expr) {
138139
}
139140

140141
*expr = ret;
141-
VLOG(10) << "out visit sub";
142142
}
143143

144144
void IterMapRewriter::Visit(const ir::Mul* op, Expr* expr) {
145145
auto a = op->a();
146146
auto b = op->b();
147-
VLOG(10) << "in visit mul: " << a << " " << b;
147+
148148
Visit(&a);
149149
Visit(&b);
150150

@@ -176,14 +176,12 @@ void IterMapRewriter::Visit(const ir::Mul* op, Expr* expr) {
176176
}
177177

178178
*expr = ret;
179-
VLOG(10) << "out visit mul";
180179
}
181180

182181
void IterMapRewriter::Visit(const ir::Div* op, Expr* expr) {
183182
auto a = op->a();
184183
auto b = op->b();
185184

186-
VLOG(10) << "in visit div: " << a << " " << b;
187185
Visit(&a);
188186
Visit(&b);
189187

@@ -199,19 +197,21 @@ void IterMapRewriter::Visit(const ir::Div* op, Expr* expr) {
199197
"Division of iter and iter is not supported"));
200198
return;
201199
}
200+
202201
auto ret = ir::ir_utils::IRCopy(a);
202+
203203
auto preprocessed = PreprocessDividend(ret);
204204
auto preprocessed_sum = preprocessed.As<ir::IterSum>();
205205

206206
ret = SplitDivConst(preprocessed_sum->args[0], preprocessed_sum->base, b);
207+
207208
*expr = ret;
208-
VLOG(10) << "out visit div";
209209
}
210210

211211
void IterMapRewriter::Visit(const ir::Mod* op, Expr* expr) {
212212
auto a = op->a();
213213
auto b = op->b();
214-
VLOG(10) << "in visit mod: " << a << " " << b;
214+
215215
Visit(&a);
216216
Visit(&b);
217217

@@ -236,7 +236,6 @@ void IterMapRewriter::Visit(const ir::Mod* op, Expr* expr) {
236236
ret = SplitModConst(preprocessed_sum->args[0], preprocessed_sum->base, b);
237237

238238
*expr = ret;
239-
VLOG(10) << "out visit mod";
240239
}
241240

242241
Expr IterMapRewriter::PreprocessDividend(const Expr& dividend) {
@@ -472,6 +471,7 @@ std::optional<Expr> IterMapRewriter::TryFuse(const Expr& expr) {
472471
return opt.value();
473472
}
474473
}
474+
475475
// Select iter with smallest scale as base iter.
476476
std::vector<bool> visited(iter_sum->args.size(), false);
477477
int base_index = FindBaseSplit(*iter_sum, visited, Expr(), -1);
@@ -484,6 +484,7 @@ std::optional<Expr> IterMapRewriter::TryFuse(const Expr& expr) {
484484
ir::IndexExpr expected_scale = base_scale;
485485
int first_possible_unit_extent_pos =
486486
FindFirstPossibleUnitExtentIndex(*iter_sum);
487+
487488
// Find iter with same scale as expected_scale and update expected_scale.
488489
// e.g. i * 32 + j * 8 + k * 1, Extent(i, j, k) = 2, 4, 8.
489490
// first base_index = 2, expected_scale = 1. means select k as base iter.
@@ -492,7 +493,7 @@ std::optional<Expr> IterMapRewriter::TryFuse(const Expr& expr) {
492493
// finally matched_pos = 0, expected_scale = 32 * 2 = 64. means match i.
493494
// if match failed, indicates that expr is illegal and cannot be merged.
494495
for (size_t i = 0; i < iter_sum->args.size(); ++i) {
495-
ir::IndexExpr matched_scale;
496+
ir::IndexExpr matched_scale{nullptr};
496497
int matched_pos =
497498
i == 0 ? base_index
498499
: FindSplitWithExactScale(*iter_sum,

paddle/cinn/hlir/framework/pir/trivial_op_util.cc

-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ void SubstitudeTargetExprWithDestExpr(const ir::Expr& source,
182182
VLOG(4) << "SubstitideExpr Start";
183183
VLOG(5) << "Substitide Body : " << *body;
184184
ir::Expr new_dest = dest;
185-
optim::Simplify(&new_dest);
186185
if (source.type() != dest.type()) {
187186
VLOG(4) << "Cast the dest" << dest << " to type" << source.type();
188187
new_dest = ir::Cast::Make(source.type(), dest);

paddle/cinn/ir/buffer.cc

-24
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,6 @@ Buffer _Buffer_::Make(Var data,
9494
node->offset_factor = offset_factor;
9595
node->target = target;
9696
node->dtype = dtype;
97-
98-
std::for_each(node->shape.begin(), node->shape.end(), [](Expr &indice) {
99-
indice = indice.set_index(true).as_index().Normalize();
100-
});
101-
std::for_each(node->strides.begin(), node->strides.end(), [](Expr &indice) {
102-
indice = indice.set_index(true).as_index().Normalize();
103-
});
104-
elem_offset.set_index(true);
105-
10697
return Buffer(node);
10798
}
10899

@@ -111,9 +102,6 @@ Buffer _Buffer_::Make(const std::string &name, const std::vector<Expr> &shape) {
111102
node->name = name;
112103
node->shape = shape;
113104
node->dtype = Void();
114-
std::for_each(node->shape.begin(), node->shape.end(), [](Expr &indice) {
115-
indice = indice.set_index(true).as_index().Normalize();
116-
});
117105
return Buffer(node);
118106
}
119107

@@ -137,9 +125,6 @@ void _Buffer_::BindTo(const _Tensor_ *tensor) {
137125
"tensor's shape is properly initialized and not empty."));
138126

139127
shape = tensor->shape;
140-
std::for_each(shape.begin(), shape.end(), [](Expr &indice) {
141-
indice = indice.set_index(true).as_index().Normalize();
142-
});
143128
binded_tensors_names_.insert(tensor->name);
144129
}
145130
void _Buffer_::Unbind(const _Tensor_ *tensor) {
@@ -207,9 +192,6 @@ Expr _BufferRange_::Make(const Expr &buffer, const std::vector<Var> &ranges) {
207192
auto node = make_shared<_BufferRange_>();
208193
node->buffer = buffer;
209194
node->ranges = ranges;
210-
std::for_each(node->ranges.begin(), node->ranges.end(), [](Var &v) {
211-
v.set_index(true);
212-
});
213195
return Expr(node);
214196
}
215197
void _BufferRange_::Verify() const {
@@ -223,9 +205,6 @@ Expr _BufferRange_::Copy() const {
223205
auto node = make_shared<_BufferRange_>();
224206
node->buffer = buffer;
225207
node->ranges = ranges;
226-
std::for_each(node->ranges.begin(), node->ranges.end(), [](Var &v) {
227-
v.set_index(true);
228-
});
229208
node->set_type(type());
230209
return Expr(node);
231210
}
@@ -265,9 +244,6 @@ BufferRange &BufferRange::operator=(const _BufferRange_ *x) {
265244
auto node = make_shared<_BufferRange_>();
266245
node->buffer = x->buffer;
267246
node->ranges = x->ranges;
268-
std::for_each(node->ranges.begin(), node->ranges.end(), [](Var &v) {
269-
v.set_index(true);
270-
});
271247
node->set_type(x->type());
272248
*this = BufferRange(node);
273249
return *this;

paddle/cinn/ir/dim.cc

-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ Dim _Dim_::Make(const std::string& name, const symbol::DimExpr& sym_dim) {
3333
n->name = name;
3434
n->sym_dim = sym_dim;
3535
n->dim_expr = common::DimExprConverter().ConvertToIrExpr(sym_dim);
36-
37-
n->dim_expr.set_index(true);
3836
n->set_type(n->dim_expr.type());
3937
return Dim(n);
4038
}

0 commit comments

Comments
 (0)