13
13
// limitations under the License.
14
14
15
15
#include " paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
16
+ #include " paddle/cinn/common/ir_util.h"
16
17
#include " paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
17
18
18
19
PD_DECLARE_bool (cinn_enable_tile_transpose);
@@ -123,8 +124,8 @@ class TileTransposeTactic final : public ScheduleTactic {
123
124
ScheduleContext* context_;
124
125
bool can_apply_;
125
126
126
- // The common permutation of all transposes in the graph .
127
- std::vector<int > common_perm_ ;
127
+ // The sub iter space apart from the main iter space .
128
+ std::vector<int > sub_iter_space_ ;
128
129
129
130
// Groups of axis as illustrated in the above graph.
130
131
std::vector<int > high_axis_;
@@ -156,21 +157,6 @@ class TileTransposeTactic final : public ScheduleTactic {
156
157
std::unordered_set<ir::Expr, LoadHash> unconditional_loads_;
157
158
};
158
159
159
- std::vector<int > GetTransposePerm (const std::vector<ir::Expr>& indices,
160
- int data_rank) {
161
- if (indices.size () != data_rank) return {};
162
- std::vector<int > perm (data_rank);
163
- for (int i = 0 ; i < data_rank; ++i) {
164
- if (!indices[i].is_var ()) return {};
165
- auto * loop_var = indices[i].as_var ();
166
- // Strip the prefix "loop_var_" to get the loop_index.
167
- int loop_index =
168
- std::stoi (loop_var->name .substr (strlen (ir::analyzer::kLoopVar )));
169
- perm[loop_index] = i;
170
- }
171
- return perm;
172
- }
173
-
174
160
std::vector<int > OffsetVec (const std::vector<int >& vec, int offset) {
175
161
std::vector<int > new_vec = vec;
176
162
for (auto & e : new_vec) e += offset;
@@ -194,6 +180,162 @@ int64_t GetLoopRangeProduct(const std::vector<ir::Expr>& loops,
194
180
return prod;
195
181
}
196
182
183
+ /* *
184
+ * Get the relative iter space of the load according to the loops.
185
+ *
186
+ * This class currently supports the following cases:
187
+ * 1) var[i, k, m, j] (index mapping)
188
+ * iter space: [i, k, m, j]
189
+ * 2) var[i, k % 32, k % 32, j] (simple splitting)
190
+ * iter space: [i, k, j]
191
+ * 3) var[i, k * 32 + m, j] (simple fusion)
192
+ * iter space: [i, k, m, j]
193
+ * 4) var[i, k + 128, j] (simple offsetting)
194
+ * iter space: [i, k, j]
195
+ *
196
+ * The result is translated to the corresponding loop_index instead of returning
197
+ * loop_vars directly.
198
+ */
199
+ struct IterSpaceGetter {
200
+ IterSpaceGetter (const ir::Load* load, const std::vector<ir::Expr>& loops)
201
+ : load_(load), loops_(loops), indices_vars_(load->indices.size()) {
202
+ for (int i = 0 ; i < load_->indices .size (); ++i) {
203
+ ir::ir_utils::CollectIRNodes (load_->indices [i], [&](const ir::Expr* x) {
204
+ if (x->is_var () && !x->as_var ()->is_symbolic_constant ) {
205
+ indices_vars_[i].insert (x->as_var_ref ());
206
+ }
207
+ return false ;
208
+ });
209
+ }
210
+ }
211
+
212
+ std::vector<int > operator ()() {
213
+ // Try to arrange the iter vars in the order of the iter space
214
+ std::vector<ir::Var> iter_space_vars;
215
+ for (int i = 0 ; i < load_->indices .size (); ++i) {
216
+ // Case 1. constant
217
+ if (indices_vars_[i].size () == 0 ) {
218
+ continue ;
219
+ }
220
+
221
+ // Case 2. single variable
222
+ if (indices_vars_[i].size () == 1 ) {
223
+ int cover_range = CheckSingleVar (i);
224
+ if (cover_range < 0 ) return {};
225
+ iter_space_vars.push_back (*indices_vars_[i].begin ());
226
+ i += cover_range - 1 ;
227
+ continue ;
228
+ }
229
+
230
+ // Case 3. no more than 3 variables
231
+ if (indices_vars_[i].size () <= 3 ) {
232
+ std::vector<ir::Var> arranged_vars = CheckMultipleVars (i);
233
+ if (arranged_vars.empty ()) return {};
234
+ iter_space_vars.insert (
235
+ iter_space_vars.end (), arranged_vars.begin (), arranged_vars.end ());
236
+ continue ;
237
+ }
238
+
239
+ return {};
240
+ }
241
+
242
+ // Construct the iter space
243
+ std::vector<int > iter_space;
244
+ for (auto & var : iter_space_vars) {
245
+ int loop_index =
246
+ std::stoi (var->name .substr (std::strlen (analyzer::kLoopVar )));
247
+ iter_space.push_back (loop_index);
248
+ }
249
+ return iter_space;
250
+ }
251
+
252
+ private:
253
+ int CheckSingleVar (int begin) {
254
+ ir::Var var = *indices_vars_[begin].begin ();
255
+
256
+ // Check that var exclusively covers a continuous range, such as:
257
+ // [ ..., i / 32, i % 32, ... ]
258
+ // The following cases are not supported:
259
+ // [ ..., i / 32, (i % 32) * 4 + j, ... ] # not exclusive
260
+ // [ ..., i / 32, ..., i % 32, ... ] # not continuous
261
+ int end;
262
+ for (end = begin + 1 ; end < indices_vars_.size (); ++end) {
263
+ if (indices_vars_[end].count (var) == 0 ) break ;
264
+ if (indices_vars_[end].size () > 1 ) return -1 ;
265
+ }
266
+ for (int i = end + 1 ; i < indices_vars_.size (); ++i) {
267
+ if (indices_vars_[i].count (var) > 0 ) return -1 ;
268
+ }
269
+
270
+ // Try to fuse the indices that contain `var` into one expression
271
+ ir::Expr fused_index;
272
+ if (end - begin == 1 ) {
273
+ fused_index = optim::ArithSimplify (load_->indices [begin]);
274
+ } else {
275
+ auto shape_it = load_->tensor .as_tensor ()->shape .begin ();
276
+ auto indices_it = load_->indices .begin ();
277
+ std::vector<ir::Expr> sub_shape (shape_it + begin, shape_it + end);
278
+ std::vector<ir::Expr> sub_indices (indices_it + begin, indices_it + end);
279
+ fused_index = common::IndiceToAbsOffset (sub_shape, sub_indices);
280
+ }
281
+
282
+ // Check that fused_index is either a single `var` or `var + offset`
283
+ if (fused_index != ir::Expr (var)) {
284
+ auto * add_node = fused_index.As <ir::Add>();
285
+ if (!add_node || add_node->a () != ir::Expr (var)) return -1 ;
286
+ }
287
+
288
+ return end - begin;
289
+ }
290
+
291
+ std::vector<ir::Var> CheckMultipleVars (int pos) {
292
+ // Check that vars at this pos only appear at this pos, such as:
293
+ // [ ..., i * 32 + j, ... ]
294
+ // The following case is not supported:
295
+ // [ ..., (i * 32 + j) / 8, j % 8, ... ]
296
+ // because j appears at multiple positions.
297
+ for (int i = 0 ; i < indices_vars_.size (); ++i) {
298
+ if (i == pos) continue ;
299
+ for (auto & var : indices_vars_[i]) {
300
+ if (indices_vars_[pos].count (var) > 0 ) return {};
301
+ }
302
+ }
303
+
304
+ // Collect vars in this index in ast order
305
+ std::vector<ir::Var> vars_in_ast_order;
306
+ ir::ir_utils::CollectIRNodes (load_->indices [pos], [&](const ir::Expr* x) {
307
+ if (x->is_var () && !x->as_var ()->is_symbolic_constant ) {
308
+ vars_in_ast_order.push_back (x->as_var_ref ());
309
+ }
310
+ return false ;
311
+ });
312
+
313
+ // Re-construct the index using the vars in ast order
314
+ std::vector<ir::Expr> sub_shape;
315
+ std::vector<ir::Expr> sub_indices;
316
+ for (auto & var : vars_in_ast_order) {
317
+ int loop_index =
318
+ std::stoi (var->name .substr (std::strlen (analyzer::kLoopVar )));
319
+ sub_shape.push_back (loops_[loop_index].As <ir::For>()->extent );
320
+ sub_indices.push_back (var);
321
+ }
322
+ ir::Expr sub_index = common::IndiceToAbsOffset (sub_shape, sub_indices);
323
+
324
+ // Compare the re-constructed index with the actual index
325
+ if (sub_index == load_->indices [pos]) {
326
+ return vars_in_ast_order;
327
+ }
328
+ return {};
329
+ }
330
+
331
+ private:
332
+ const ir::Load* load_;
333
+ const std::vector<ir::Expr>& loops_;
334
+
335
+ // iter vars in each of the load's indices
336
+ std::vector<std::set<ir::Var>> indices_vars_;
337
+ };
338
+
197
339
void TileTransposeTactic::Init (ScheduleContext* context, ir::IRSchedule* sch) {
198
340
context_ = context;
199
341
can_apply_ = false ;
@@ -213,8 +355,8 @@ void TileTransposeTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
213
355
InitUnconditionalLoads (sch);
214
356
InitCandidates (sch);
215
357
216
- VLOG (4 ) << " Common permutation : " << utils::Join (common_perm_ , " , " );
217
- if (common_perm_ .empty ()) return ;
358
+ VLOG (4 ) << " sub_iter_space : " << utils::Join (sub_iter_space_ , " , " );
359
+ if (sub_iter_space_ .empty ()) return ;
218
360
219
361
can_apply_ = true ;
220
362
root_node->attrs [kTileMethod ] = TacticName ();
@@ -251,7 +393,7 @@ void TileTransposeTactic::InitUnconditionalLoads(ir::IRSchedule* sch) {
251
393
}
252
394
253
395
void TileTransposeTactic::InitCandidates (ir::IRSchedule* sch) {
254
- common_perm_ .clear ();
396
+ sub_iter_space_ .clear ();
255
397
load2candidates_.clear ();
256
398
block2candidates_.clear ();
257
399
processed_loads_.clear ();
@@ -289,25 +431,24 @@ void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
289
431
auto * tensor = load.As <ir::Load>()->tensor .as_tensor ();
290
432
if (sch->HasBlock (tensor->name )) continue ;
291
433
292
- std::vector< int > perm =
293
- GetTransposePerm (load. As <ir::Load>()-> indices , loops. size () );
434
+ IterSpaceGetter iter_space_getter (load. As <ir::Load>(), loops);
435
+ std::vector< int > iter_space = iter_space_getter ( );
294
436
295
437
// 4. This is a critical transpose, including:
296
438
// 1) its dim size equals to the loop size (not a broadcast).
297
439
// 2) its last dim is changed in permutation (incurs discrete access).
298
440
// 3) both the src/dst_low_axis are non-unit (not a squeeze/unsqueeze).
299
- if (perm.size () != loops.size ()) continue ;
300
- if (perm.back () == perm.size () - 1 ) continue ;
301
- if (GetLoopRangeProduct (loops, GetSrcLowAxis (perm)) == 1 ) continue ;
302
- if (GetLoopRangeProduct (loops, GetDstLowAxis (perm)) == 1 ) continue ;
303
-
304
- // 5. All transposes in this graph should have the same permutation.
305
- // Otherwise, it would be too complex to ensure the correctness and
306
- // performance. The violating cases should be rare.
307
- if (common_perm_.empty ()) {
308
- common_perm_ = perm;
309
- } else if (common_perm_ != perm) {
310
- common_perm_.clear ();
441
+ if (iter_space.size () != loops.size ()) continue ;
442
+ if (iter_space.back () == iter_space.size () - 1 ) continue ;
443
+ if (GetLoopRangeProduct (loops, GetSrcLowAxis (iter_space)) == 1 ) continue ;
444
+ if (GetLoopRangeProduct (loops, GetDstLowAxis (iter_space)) == 1 ) continue ;
445
+
446
+ // 5. All transposes in this graph should be in the same sub iter space,
447
+ // because we only support the alignment of two iter spaces.
448
+ if (sub_iter_space_.empty ()) {
449
+ sub_iter_space_ = iter_space;
450
+ } else if (sub_iter_space_ != iter_space) {
451
+ sub_iter_space_.clear ();
311
452
return ;
312
453
}
313
454
@@ -319,37 +460,38 @@ void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
319
460
}
320
461
321
462
void TileTransposeTactic::InitAxisInfo () {
322
- src_low_axis_ = GetSrcLowAxis (common_perm_ );
323
- dst_low_axis_ = GetDstLowAxis (common_perm_ );
463
+ src_low_axis_ = GetSrcLowAxis (sub_iter_space_ );
464
+ dst_low_axis_ = GetDstLowAxis (sub_iter_space_ );
324
465
325
466
std::set<int > high_axis;
326
- for (int i = 0 ; i < common_perm_ .size (); ++i) high_axis.insert (i);
467
+ for (int i = 0 ; i < sub_iter_space_ .size (); ++i) high_axis.insert (i);
327
468
for (auto i : src_low_axis_) high_axis.erase (i);
328
469
for (auto i : dst_low_axis_) high_axis.erase (i);
329
470
high_axis_.assign (high_axis.begin (), high_axis.end ());
330
471
}
331
472
332
473
std::vector<int > TileTransposeTactic::GetSrcLowAxis (
333
- const std::vector<int >& perm) {
334
- std::set<int > src_low_axis;
335
- for (int i = 0 ; i < perm.size (); ++i) {
336
- if (perm[i] == perm.size () - 1 ) {
337
- src_low_axis.insert (i);
338
- for (int j = i - 1 ; j >= 0 ; j--) {
339
- if (perm[j] + 1 != perm[j + 1 ]) break ;
340
- src_low_axis.insert (j);
341
- }
342
- }
474
+ const std::vector<int >& iter_space) {
475
+ std::set<int > src_low_axis{iter_space.back ()};
476
+ for (int i = iter_space.size () - 2 ; i >= 0 ; --i) {
477
+ if (iter_space[i] + 1 != iter_space[i + 1 ]) break ;
478
+ src_low_axis.insert (iter_space[i]);
343
479
}
344
480
return {src_low_axis.begin (), src_low_axis.end ()};
345
481
}
346
482
347
483
std::vector<int > TileTransposeTactic::GetDstLowAxis (
348
- const std::vector<int >& perm) {
349
- std::set<int > dst_low_axis{perm.size () - 1 };
350
- for (int i = perm.size () - 2 ; i >= 0 ; --i) {
351
- if (perm[i] + 1 != perm[i + 1 ]) break ;
352
- dst_low_axis.insert (i);
484
+ const std::vector<int >& iter_space) {
485
+ std::set<int > dst_low_axis;
486
+ auto it =
487
+ std::find (iter_space.begin (), iter_space.end (), iter_space.size () - 1 );
488
+ if (it != iter_space.end ()) {
489
+ dst_low_axis.insert (*it);
490
+ while (it != iter_space.begin ()) {
491
+ if (*(it - 1 ) != *it - 1 ) break ;
492
+ --it;
493
+ dst_low_axis.insert (*it);
494
+ }
353
495
}
354
496
return {dst_low_axis.begin (), dst_low_axis.end ()};
355
497
}
@@ -392,9 +534,6 @@ std::string TileTransposeTactic::CreateCacheBlock(
392
534
std::string cache_block_id = ir::analyzer::GetBlockName (cache_block);
393
535
context_->output_names .insert (cache_block_id);
394
536
395
- // Note: the CacheRead primitive de-transposes the input, so we need to apply
396
- // the transpose permutation again on the cache block.
397
- sch->Reorder (cache_block_id, common_perm_);
398
537
return cache_block_id;
399
538
}
400
539
0 commit comments