@@ -61,6 +61,7 @@ ir::IndexExpr IterMapToExprNormalizer::ConvertIterSplit(ir::IterSplit* expr) {
61
61
Visit (&(mark->source ), &(mark->source ));
62
62
source = mark->source ;
63
63
}
64
+
64
65
// quick branch
65
66
if (IsZero (expr->scale ) || IsOne (expr->extent ))
66
67
return ir::Zero (expr->extent .type ());
@@ -88,7 +89,7 @@ void IterMapRewriter::Visit(const ir::_Var_* op, Expr* expr) {
88
89
void IterMapRewriter::Visit (const ir::Add* op, Expr* expr) {
89
90
auto a = op->a ();
90
91
auto b = op->b ();
91
- VLOG ( 10 ) << " in visit add: " << a << " " << b;
92
+
92
93
Visit (&a);
93
94
Visit (&b);
94
95
@@ -102,6 +103,7 @@ void IterMapRewriter::Visit(const ir::Add* op, Expr* expr) {
102
103
103
104
Expr ret = ir::ir_utils::IRCopy (ToIterSum (a));
104
105
ir::IterSum* ret_sum = ret.As <ir::IterSum>();
106
+
105
107
if (auto b_sum = b.As <ir::IterSum>()) {
106
108
AddToLhs (ret_sum, *b_sum, 1 );
107
109
} else if (auto b_split = b.As <ir::IterSplit>()) {
@@ -110,13 +112,12 @@ void IterMapRewriter::Visit(const ir::Add* op, Expr* expr) {
110
112
ret_sum->base = ret_sum->base + b.as_index ();
111
113
}
112
114
*expr = ret;
113
- VLOG (10 ) << " out visit add" ;
114
115
}
115
116
116
117
void IterMapRewriter::Visit (const ir::Sub* op, Expr* expr) {
117
118
auto a = op->a ();
118
119
auto b = op->b ();
119
- VLOG ( 10 ) << " in visit sub: " << a << " " << b;
120
+
120
121
Visit (&a);
121
122
Visit (&b);
122
123
@@ -138,13 +139,12 @@ void IterMapRewriter::Visit(const ir::Sub* op, Expr* expr) {
138
139
}
139
140
140
141
*expr = ret;
141
- VLOG (10 ) << " out visit sub" ;
142
142
}
143
143
144
144
void IterMapRewriter::Visit (const ir::Mul* op, Expr* expr) {
145
145
auto a = op->a ();
146
146
auto b = op->b ();
147
- VLOG ( 10 ) << " in visit mul: " << a << " " << b;
147
+
148
148
Visit (&a);
149
149
Visit (&b);
150
150
@@ -176,14 +176,12 @@ void IterMapRewriter::Visit(const ir::Mul* op, Expr* expr) {
176
176
}
177
177
178
178
*expr = ret;
179
- VLOG (10 ) << " out visit mul" ;
180
179
}
181
180
182
181
void IterMapRewriter::Visit (const ir::Div* op, Expr* expr) {
183
182
auto a = op->a ();
184
183
auto b = op->b ();
185
184
186
- VLOG (10 ) << " in visit div: " << a << " " << b;
187
185
Visit (&a);
188
186
Visit (&b);
189
187
@@ -199,19 +197,21 @@ void IterMapRewriter::Visit(const ir::Div* op, Expr* expr) {
199
197
" Division of iter and iter is not supported" ));
200
198
return ;
201
199
}
200
+
202
201
auto ret = ir::ir_utils::IRCopy (a);
202
+
203
203
auto preprocessed = PreprocessDividend (ret);
204
204
auto preprocessed_sum = preprocessed.As <ir::IterSum>();
205
205
206
206
ret = SplitDivConst (preprocessed_sum->args [0 ], preprocessed_sum->base , b);
207
+
207
208
*expr = ret;
208
- VLOG (10 ) << " out visit div" ;
209
209
}
210
210
211
211
void IterMapRewriter::Visit (const ir::Mod* op, Expr* expr) {
212
212
auto a = op->a ();
213
213
auto b = op->b ();
214
- VLOG ( 10 ) << " in visit mod: " << a << " " << b;
214
+
215
215
Visit (&a);
216
216
Visit (&b);
217
217
@@ -236,7 +236,6 @@ void IterMapRewriter::Visit(const ir::Mod* op, Expr* expr) {
236
236
ret = SplitModConst (preprocessed_sum->args [0 ], preprocessed_sum->base , b);
237
237
238
238
*expr = ret;
239
- VLOG (10 ) << " out visit mod" ;
240
239
}
241
240
242
241
Expr IterMapRewriter::PreprocessDividend (const Expr& dividend) {
@@ -472,6 +471,7 @@ std::optional<Expr> IterMapRewriter::TryFuse(const Expr& expr) {
472
471
return opt.value ();
473
472
}
474
473
}
474
+
475
475
// Select iter with smallest scale as base iter.
476
476
std::vector<bool > visited (iter_sum->args .size (), false );
477
477
int base_index = FindBaseSplit (*iter_sum, visited, Expr (), -1 );
@@ -484,6 +484,7 @@ std::optional<Expr> IterMapRewriter::TryFuse(const Expr& expr) {
484
484
ir::IndexExpr expected_scale = base_scale;
485
485
int first_possible_unit_extent_pos =
486
486
FindFirstPossibleUnitExtentIndex (*iter_sum);
487
+
487
488
// Find iter with same scale as expected_scale and update expected_scale.
488
489
// e.g. i * 32 + j * 8 + k * 1, Extent(i, j, k) = 2, 4, 8.
489
490
// first base_index = 2, expected_scale = 1. means select k as base iter.
@@ -492,7 +493,7 @@ std::optional<Expr> IterMapRewriter::TryFuse(const Expr& expr) {
492
493
// finally matched_pos = 0, expected_scale = 32 * 2 = 64. means match i.
493
494
// if match failed, indicates that expr is illegal and cannot be merged.
494
495
for (size_t i = 0 ; i < iter_sum->args .size (); ++i) {
495
- ir::IndexExpr matched_scale;
496
+ ir::IndexExpr matched_scale{ nullptr } ;
496
497
int matched_pos =
497
498
i == 0 ? base_index
498
499
: FindSplitWithExactScale (*iter_sum,
0 commit comments