Skip to content

Commit bf0ae2d

Browse files
committed
[CINN] TileTransposeTactic turns to use the sub_iter_space concept
1 parent a27df21 commit bf0ae2d

File tree

2 files changed

+231
-72
lines changed

2 files changed

+231
-72
lines changed

paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.cc

+194-55
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
16+
#include "paddle/cinn/common/ir_util.h"
1617
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
1718

1819
PD_DECLARE_bool(cinn_enable_tile_transpose);
@@ -123,8 +124,8 @@ class TileTransposeTactic final : public ScheduleTactic {
123124
ScheduleContext* context_;
124125
bool can_apply_;
125126

126-
// The common permutation of all transposes in the graph.
127-
std::vector<int> common_perm_;
127+
// The sub iter space apart from the main iter space.
128+
std::vector<int> sub_iter_space_;
128129

129130
// Groups of axis as illustrated in the above graph.
130131
std::vector<int> high_axis_;
@@ -156,21 +157,6 @@ class TileTransposeTactic final : public ScheduleTactic {
156157
std::unordered_set<ir::Expr, LoadHash> unconditional_loads_;
157158
};
158159

159-
std::vector<int> GetTransposePerm(const std::vector<ir::Expr>& indices,
160-
int data_rank) {
161-
if (indices.size() != data_rank) return {};
162-
std::vector<int> perm(data_rank);
163-
for (int i = 0; i < data_rank; ++i) {
164-
if (!indices[i].is_var()) return {};
165-
auto* loop_var = indices[i].as_var();
166-
// Strip the prefix "loop_var_" to get the loop_index.
167-
int loop_index =
168-
std::stoi(loop_var->name.substr(strlen(ir::analyzer::kLoopVar)));
169-
perm[loop_index] = i;
170-
}
171-
return perm;
172-
}
173-
174160
std::vector<int> OffsetVec(const std::vector<int>& vec, int offset) {
175161
std::vector<int> new_vec = vec;
176162
for (auto& e : new_vec) e += offset;
@@ -194,6 +180,162 @@ int64_t GetLoopRangeProduct(const std::vector<ir::Expr>& loops,
194180
return prod;
195181
}
196182

183+
/**
184+
* Get the relative iter space of the load according to the loops.
185+
*
186+
* This class currently supports the following cases:
187+
* 1) var[i, k, m, j] (index mapping)
188+
* iter space: [i, k, m, j]
189+
* 2) var[i, k % 32, k % 32, j] (simple splitting)
190+
* iter space: [i, k, j]
191+
* 3) var[i, k * 32 + m, j] (simple fusion)
192+
* iter space: [i, k, m, j]
193+
* 4) var[i, k + 128, j] (simple offsetting)
194+
* iter space: [i, k, j]
195+
*
196+
* The result is translated to the corresponding loop_index instead of returning
197+
* loop_vars directly.
198+
*/
199+
struct IterSpaceGetter {
200+
IterSpaceGetter(const ir::Load* load, const std::vector<ir::Expr>& loops)
201+
: load_(load), loops_(loops), indices_vars_(load->indices.size()) {
202+
for (int i = 0; i < load_->indices.size(); ++i) {
203+
ir::ir_utils::CollectIRNodes(load_->indices[i], [&](const ir::Expr* x) {
204+
if (x->is_var() && !x->as_var()->is_symbolic_constant) {
205+
indices_vars_[i].insert(x->as_var_ref());
206+
}
207+
return false;
208+
});
209+
}
210+
}
211+
212+
std::vector<int> operator()() {
213+
// Try to arrange the iter vars in the order of the iter space
214+
std::vector<ir::Var> iter_space_vars;
215+
for (int i = 0; i < load_->indices.size(); ++i) {
216+
// Case 1. constant
217+
if (indices_vars_[i].size() == 0) {
218+
continue;
219+
}
220+
221+
// Case 2. single variable
222+
if (indices_vars_[i].size() == 1) {
223+
int cover_range = CheckSingleVar(i);
224+
if (cover_range < 0) return {};
225+
iter_space_vars.push_back(*indices_vars_[i].begin());
226+
i += cover_range - 1;
227+
continue;
228+
}
229+
230+
// Case 3. no more than 3 variables
231+
if (indices_vars_[i].size() <= 3) {
232+
std::vector<ir::Var> arranged_vars = CheckMultipleVars(i);
233+
if (arranged_vars.empty()) return {};
234+
iter_space_vars.insert(
235+
iter_space_vars.end(), arranged_vars.begin(), arranged_vars.end());
236+
continue;
237+
}
238+
239+
return {};
240+
}
241+
242+
// Construct the iter space
243+
std::vector<int> iter_space;
244+
for (auto& var : iter_space_vars) {
245+
int loop_index =
246+
std::stoi(var->name.substr(std::strlen(analyzer::kLoopVar)));
247+
iter_space.push_back(loop_index);
248+
}
249+
return iter_space;
250+
}
251+
252+
private:
253+
int CheckSingleVar(int begin) {
254+
ir::Var var = *indices_vars_[begin].begin();
255+
256+
// Check that var exclusively covers a continuous range, such as:
257+
// [ ..., i / 32, i % 32, ... ]
258+
// The following cases are not supported:
259+
// [ ..., i / 32, (i % 32) * 4 + j, ... ] # not exclusive
260+
// [ ..., i / 32, ..., i % 32, ... ] # not continuous
261+
int end;
262+
for (end = begin + 1; end < indices_vars_.size(); ++end) {
263+
if (indices_vars_[end].count(var) == 0) break;
264+
if (indices_vars_[end].size() > 1) return -1;
265+
}
266+
for (int i = end + 1; i < indices_vars_.size(); ++i) {
267+
if (indices_vars_[i].count(var) > 0) return -1;
268+
}
269+
270+
// Try to fuse the indices that contain `var` into one expression
271+
ir::Expr fused_index;
272+
if (end - begin == 1) {
273+
fused_index = optim::ArithSimplify(load_->indices[begin]);
274+
} else {
275+
auto shape_it = load_->tensor.as_tensor()->shape.begin();
276+
auto indices_it = load_->indices.begin();
277+
std::vector<ir::Expr> sub_shape(shape_it + begin, shape_it + end);
278+
std::vector<ir::Expr> sub_indices(indices_it + begin, indices_it + end);
279+
fused_index = common::IndiceToAbsOffset(sub_shape, sub_indices);
280+
}
281+
282+
// Check that fused_index is either a single `var` or `var + offset`
283+
if (fused_index != ir::Expr(var)) {
284+
auto* add_node = fused_index.As<ir::Add>();
285+
if (!add_node || add_node->a() != ir::Expr(var)) return -1;
286+
}
287+
288+
return end - begin;
289+
}
290+
291+
std::vector<ir::Var> CheckMultipleVars(int pos) {
292+
// Check that vars at this pos only appear at this pos, such as:
293+
// [ ..., i * 32 + j, ... ]
294+
// The following case is not supported:
295+
// [ ..., (i * 32 + j) / 8, j % 8, ... ]
296+
// because j appears at multiple positions.
297+
for (int i = 0; i < indices_vars_.size(); ++i) {
298+
if (i == pos) continue;
299+
for (auto& var : indices_vars_[i]) {
300+
if (indices_vars_[pos].count(var) > 0) return {};
301+
}
302+
}
303+
304+
// Collect vars in this index in ast order
305+
std::vector<ir::Var> vars_in_ast_order;
306+
ir::ir_utils::CollectIRNodes(load_->indices[pos], [&](const ir::Expr* x) {
307+
if (x->is_var() && !x->as_var()->is_symbolic_constant) {
308+
vars_in_ast_order.push_back(x->as_var_ref());
309+
}
310+
return false;
311+
});
312+
313+
// Re-construct the index using the vars in ast order
314+
std::vector<ir::Expr> sub_shape;
315+
std::vector<ir::Expr> sub_indices;
316+
for (auto& var : vars_in_ast_order) {
317+
int loop_index =
318+
std::stoi(var->name.substr(std::strlen(analyzer::kLoopVar)));
319+
sub_shape.push_back(loops_[loop_index].As<ir::For>()->extent);
320+
sub_indices.push_back(var);
321+
}
322+
ir::Expr sub_index = common::IndiceToAbsOffset(sub_shape, sub_indices);
323+
324+
// Compare the re-constructed index with the actual index
325+
if (sub_index == load_->indices[pos]) {
326+
return vars_in_ast_order;
327+
}
328+
return {};
329+
}
330+
331+
private:
332+
const ir::Load* load_;
333+
const std::vector<ir::Expr>& loops_;
334+
335+
// iter vars in each of the load's indices
336+
std::vector<std::set<ir::Var>> indices_vars_;
337+
};
338+
197339
void TileTransposeTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
198340
context_ = context;
199341
can_apply_ = false;
@@ -213,8 +355,8 @@ void TileTransposeTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
213355
InitUnconditionalLoads(sch);
214356
InitCandidates(sch);
215357

216-
VLOG(4) << "Common permutation: " << utils::Join(common_perm_, ", ");
217-
if (common_perm_.empty()) return;
358+
VLOG(4) << "sub_iter_space: " << utils::Join(sub_iter_space_, ", ");
359+
if (sub_iter_space_.empty()) return;
218360

219361
can_apply_ = true;
220362
root_node->attrs[kTileMethod] = TacticName();
@@ -251,7 +393,7 @@ void TileTransposeTactic::InitUnconditionalLoads(ir::IRSchedule* sch) {
251393
}
252394

253395
void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
254-
common_perm_.clear();
396+
sub_iter_space_.clear();
255397
load2candidates_.clear();
256398
block2candidates_.clear();
257399
processed_loads_.clear();
@@ -289,25 +431,24 @@ void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
289431
auto* tensor = load.As<ir::Load>()->tensor.as_tensor();
290432
if (sch->HasBlock(tensor->name)) continue;
291433

292-
std::vector<int> perm =
293-
GetTransposePerm(load.As<ir::Load>()->indices, loops.size());
434+
IterSpaceGetter iter_space_getter(load.As<ir::Load>(), loops);
435+
std::vector<int> iter_space = iter_space_getter();
294436

295437
// 4. This is a critical transpose, including:
296438
// 1) its dim size equals to the loop size (not a broadcast).
297439
// 2) its last dim is changed in permutation (incurs discrete access).
298440
// 3) both the src/dst_low_axis are non-unit (not a squeeze/unsqueeze).
299-
if (perm.size() != loops.size()) continue;
300-
if (perm.back() == perm.size() - 1) continue;
301-
if (GetLoopRangeProduct(loops, GetSrcLowAxis(perm)) == 1) continue;
302-
if (GetLoopRangeProduct(loops, GetDstLowAxis(perm)) == 1) continue;
303-
304-
// 5. All transposes in this graph should have the same permutation.
305-
// Otherwise, it would be too complex to ensure the correctness and
306-
// performance. The violating cases should be rare.
307-
if (common_perm_.empty()) {
308-
common_perm_ = perm;
309-
} else if (common_perm_ != perm) {
310-
common_perm_.clear();
441+
if (iter_space.size() != loops.size()) continue;
442+
if (iter_space.back() == iter_space.size() - 1) continue;
443+
if (GetLoopRangeProduct(loops, GetSrcLowAxis(iter_space)) == 1) continue;
444+
if (GetLoopRangeProduct(loops, GetDstLowAxis(iter_space)) == 1) continue;
445+
446+
// 5. All transposes in this graph should be in the same sub iter space,
447+
// because we only support the alignment of two iter spaces.
448+
if (sub_iter_space_.empty()) {
449+
sub_iter_space_ = iter_space;
450+
} else if (sub_iter_space_ != iter_space) {
451+
sub_iter_space_.clear();
311452
return;
312453
}
313454

@@ -319,37 +460,38 @@ void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
319460
}
320461

321462
void TileTransposeTactic::InitAxisInfo() {
322-
src_low_axis_ = GetSrcLowAxis(common_perm_);
323-
dst_low_axis_ = GetDstLowAxis(common_perm_);
463+
src_low_axis_ = GetSrcLowAxis(sub_iter_space_);
464+
dst_low_axis_ = GetDstLowAxis(sub_iter_space_);
324465

325466
std::set<int> high_axis;
326-
for (int i = 0; i < common_perm_.size(); ++i) high_axis.insert(i);
467+
for (int i = 0; i < sub_iter_space_.size(); ++i) high_axis.insert(i);
327468
for (auto i : src_low_axis_) high_axis.erase(i);
328469
for (auto i : dst_low_axis_) high_axis.erase(i);
329470
high_axis_.assign(high_axis.begin(), high_axis.end());
330471
}
331472

332473
std::vector<int> TileTransposeTactic::GetSrcLowAxis(
333-
const std::vector<int>& perm) {
334-
std::set<int> src_low_axis;
335-
for (int i = 0; i < perm.size(); ++i) {
336-
if (perm[i] == perm.size() - 1) {
337-
src_low_axis.insert(i);
338-
for (int j = i - 1; j >= 0; j--) {
339-
if (perm[j] + 1 != perm[j + 1]) break;
340-
src_low_axis.insert(j);
341-
}
342-
}
474+
const std::vector<int>& iter_space) {
475+
std::set<int> src_low_axis{iter_space.back()};
476+
for (int i = iter_space.size() - 2; i >= 0; --i) {
477+
if (iter_space[i] + 1 != iter_space[i + 1]) break;
478+
src_low_axis.insert(iter_space[i]);
343479
}
344480
return {src_low_axis.begin(), src_low_axis.end()};
345481
}
346482

347483
std::vector<int> TileTransposeTactic::GetDstLowAxis(
348-
const std::vector<int>& perm) {
349-
std::set<int> dst_low_axis{perm.size() - 1};
350-
for (int i = perm.size() - 2; i >= 0; --i) {
351-
if (perm[i] + 1 != perm[i + 1]) break;
352-
dst_low_axis.insert(i);
484+
const std::vector<int>& iter_space) {
485+
std::set<int> dst_low_axis;
486+
auto it =
487+
std::find(iter_space.begin(), iter_space.end(), iter_space.size() - 1);
488+
if (it != iter_space.end()) {
489+
dst_low_axis.insert(*it);
490+
while (it != iter_space.begin()) {
491+
if (*(it - 1) != *it - 1) break;
492+
--it;
493+
dst_low_axis.insert(*it);
494+
}
353495
}
354496
return {dst_low_axis.begin(), dst_low_axis.end()};
355497
}
@@ -392,9 +534,6 @@ std::string TileTransposeTactic::CreateCacheBlock(
392534
std::string cache_block_id = ir::analyzer::GetBlockName(cache_block);
393535
context_->output_names.insert(cache_block_id);
394536

395-
// Note: the CacheRead primitive de-transposes the input, so we need to apply
396-
// the transpose permutation again on the cache block.
397-
sch->Reorder(cache_block_id, common_perm_);
398537
return cache_block_id;
399538
}
400539

0 commit comments

Comments
 (0)