Skip to content

[CINN] Adapt to sub_iter_space in TileTransposeTactic #71976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 194 additions & 55 deletions paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"

PD_DECLARE_bool(cinn_enable_tile_transpose);
Expand Down Expand Up @@ -123,8 +124,8 @@ class TileTransposeTactic final : public ScheduleTactic {
ScheduleContext* context_;
bool can_apply_;

// The common permutation of all transposes in the graph.
std::vector<int> common_perm_;
// The sub iter space apart from the main iter space.
std::vector<int> sub_iter_space_;

// Groups of axis as illustrated in the above graph.
std::vector<int> high_axis_;
Expand Down Expand Up @@ -156,21 +157,6 @@ class TileTransposeTactic final : public ScheduleTactic {
std::unordered_set<ir::Expr, LoadHash> unconditional_loads_;
};

std::vector<int> GetTransposePerm(const std::vector<ir::Expr>& indices,
int data_rank) {
if (indices.size() != data_rank) return {};
std::vector<int> perm(data_rank);
for (int i = 0; i < data_rank; ++i) {
if (!indices[i].is_var()) return {};
auto* loop_var = indices[i].as_var();
// Strip the prefix "loop_var_" to get the loop_index.
int loop_index =
std::stoi(loop_var->name.substr(strlen(ir::analyzer::kLoopVar)));
perm[loop_index] = i;
}
return perm;
}

std::vector<int> OffsetVec(const std::vector<int>& vec, int offset) {
std::vector<int> new_vec = vec;
for (auto& e : new_vec) e += offset;
Expand All @@ -194,6 +180,162 @@ int64_t GetLoopRangeProduct(const std::vector<ir::Expr>& loops,
return prod;
}

/**
* Get the relative iter space of the load according to the loops.
*
* This class currently supports the following cases:
* 1) var[i, k, m, j] (index mapping)
* iter space: [i, k, m, j]
* 2) var[i, k % 32, k % 32, j] (simple splitting)
* iter space: [i, k, j]
* 3) var[i, k * 32 + m, j] (simple fusion)
* iter space: [i, k, m, j]
* 4) var[i, k + 128, j] (simple offsetting)
* iter space: [i, k, j]
*
* The result is translated to the corresponding loop_index instead of returning
* loop_vars directly.
*/
struct IterSpaceGetter {
IterSpaceGetter(const ir::Load* load, const std::vector<ir::Expr>& loops)
: load_(load), loops_(loops), indices_vars_(load->indices.size()) {
for (int i = 0; i < load_->indices.size(); ++i) {
ir::ir_utils::CollectIRNodes(load_->indices[i], [&](const ir::Expr* x) {
if (x->is_var() && !x->as_var()->is_symbolic_constant) {
indices_vars_[i].insert(x->as_var_ref());
}
return false;
});
}
}

std::vector<int> operator()() {
// Try to arrange the iter vars in the order of the iter space
std::vector<ir::Var> iter_space_vars;
for (int i = 0; i < load_->indices.size(); ++i) {
// Case 1. constant
if (indices_vars_[i].size() == 0) {
continue;
}

// Case 2. single variable
if (indices_vars_[i].size() == 1) {
int cover_range = CheckSingleVar(i);
if (cover_range < 0) return {};
iter_space_vars.push_back(*indices_vars_[i].begin());
i += cover_range - 1;
continue;
}

// Case 3. no more than 3 variables
if (indices_vars_[i].size() <= 3) {
std::vector<ir::Var> arranged_vars = CheckMultipleVars(i);
if (arranged_vars.empty()) return {};
iter_space_vars.insert(
iter_space_vars.end(), arranged_vars.begin(), arranged_vars.end());
continue;
}

return {};
}

// Construct the iter space
std::vector<int> iter_space;
for (auto& var : iter_space_vars) {
int loop_index =
std::stoi(var->name.substr(std::strlen(analyzer::kLoopVar)));
iter_space.push_back(loop_index);
}
return iter_space;
}

private:
int CheckSingleVar(int begin) {
ir::Var var = *indices_vars_[begin].begin();

// Check that var exclusively covers a continuous range, such as:
// [ ..., i / 32, i % 32, ... ]
// The following cases are not supported:
// [ ..., i / 32, (i % 32) * 4 + j, ... ] # not exclusive
// [ ..., i / 32, ..., i % 32, ... ] # not continuous
int end;
for (end = begin + 1; end < indices_vars_.size(); ++end) {
if (indices_vars_[end].count(var) == 0) break;
if (indices_vars_[end].size() > 1) return -1;
}
for (int i = end + 1; i < indices_vars_.size(); ++i) {
if (indices_vars_[i].count(var) > 0) return -1;
}

// Try to fuse the indices that contain `var` into one expression
ir::Expr fused_index;
if (end - begin == 1) {
fused_index = optim::ArithSimplify(load_->indices[begin]);
} else {
auto shape_it = load_->tensor.as_tensor()->shape.begin();
auto indices_it = load_->indices.begin();
std::vector<ir::Expr> sub_shape(shape_it + begin, shape_it + end);
std::vector<ir::Expr> sub_indices(indices_it + begin, indices_it + end);
fused_index = common::IndiceToAbsOffset(sub_shape, sub_indices);
}

// Check that fused_index is either a single `var` or `var + offset`
if (fused_index != ir::Expr(var)) {
auto* add_node = fused_index.As<ir::Add>();
if (!add_node || add_node->a() != ir::Expr(var)) return -1;
}

return end - begin;
}

std::vector<ir::Var> CheckMultipleVars(int pos) {
// Check that vars at this pos only appear at this pos, such as:
// [ ..., i * 32 + j, ... ]
// The following case is not supported:
// [ ..., (i * 32 + j) / 8, j % 8, ... ]
// because j appears at multiple positions.
for (int i = 0; i < indices_vars_.size(); ++i) {
if (i == pos) continue;
for (auto& var : indices_vars_[i]) {
if (indices_vars_[pos].count(var) > 0) return {};
}
}

// Collect vars in this index in ast order
std::vector<ir::Var> vars_in_ast_order;
ir::ir_utils::CollectIRNodes(load_->indices[pos], [&](const ir::Expr* x) {
if (x->is_var() && !x->as_var()->is_symbolic_constant) {
vars_in_ast_order.push_back(x->as_var_ref());
}
return false;
});

// Re-construct the index using the vars in ast order
std::vector<ir::Expr> sub_shape;
std::vector<ir::Expr> sub_indices;
for (auto& var : vars_in_ast_order) {
int loop_index =
std::stoi(var->name.substr(std::strlen(analyzer::kLoopVar)));
sub_shape.push_back(loops_[loop_index].As<ir::For>()->extent);
sub_indices.push_back(var);
}
ir::Expr sub_index = common::IndiceToAbsOffset(sub_shape, sub_indices);

// Compare the re-constructed index with the actual index
if (sub_index == load_->indices[pos]) {
return vars_in_ast_order;
}
return {};
}

private:
const ir::Load* load_;
const std::vector<ir::Expr>& loops_;

// iter vars in each of the load's indices
std::vector<std::set<ir::Var>> indices_vars_;
};

void TileTransposeTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
context_ = context;
can_apply_ = false;
Expand All @@ -213,8 +355,8 @@ void TileTransposeTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
InitUnconditionalLoads(sch);
InitCandidates(sch);

VLOG(4) << "Common permutation: " << utils::Join(common_perm_, ", ");
if (common_perm_.empty()) return;
VLOG(4) << "sub_iter_space: " << utils::Join(sub_iter_space_, ", ");
if (sub_iter_space_.empty()) return;

can_apply_ = true;
root_node->attrs[kTileMethod] = TacticName();
Expand Down Expand Up @@ -251,7 +393,7 @@ void TileTransposeTactic::InitUnconditionalLoads(ir::IRSchedule* sch) {
}

void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
common_perm_.clear();
sub_iter_space_.clear();
load2candidates_.clear();
block2candidates_.clear();
processed_loads_.clear();
Expand Down Expand Up @@ -289,25 +431,24 @@ void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
auto* tensor = load.As<ir::Load>()->tensor.as_tensor();
if (sch->HasBlock(tensor->name)) continue;

std::vector<int> perm =
GetTransposePerm(load.As<ir::Load>()->indices, loops.size());
IterSpaceGetter iter_space_getter(load.As<ir::Load>(), loops);
std::vector<int> iter_space = iter_space_getter();
Comment on lines +434 to +435
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于超过1个副迭代空间的情况,前面融合的时候已经拆开了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

融合没法保证,所以我下面做了判断,如果有多于1个sub_iter_space,就不match这个tactic


// 4. This is a critical transpose, including:
// 1) its dim size equals to the loop size (not a broadcast).
// 2) its last dim is changed in permutation (incurs discrete access).
// 3) both the src/dst_low_axis are non-unit (not a squeeze/unsqueeze).
if (perm.size() != loops.size()) continue;
if (perm.back() == perm.size() - 1) continue;
if (GetLoopRangeProduct(loops, GetSrcLowAxis(perm)) == 1) continue;
if (GetLoopRangeProduct(loops, GetDstLowAxis(perm)) == 1) continue;

// 5. All transposes in this graph should have the same permutation.
// Otherwise, it would be too complex to ensure the correctness and
// performance. The violating cases should be rare.
if (common_perm_.empty()) {
common_perm_ = perm;
} else if (common_perm_ != perm) {
common_perm_.clear();
if (iter_space.size() != loops.size()) continue;
if (iter_space.back() == iter_space.size() - 1) continue;
if (GetLoopRangeProduct(loops, GetSrcLowAxis(iter_space)) == 1) continue;
if (GetLoopRangeProduct(loops, GetDstLowAxis(iter_space)) == 1) continue;

// 5. All transposes in this graph should be in the same sub iter space,
// because we only support the alignment of two iter spaces.
if (sub_iter_space_.empty()) {
sub_iter_space_ = iter_space;
} else if (sub_iter_space_ != iter_space) {
sub_iter_space_.clear();
return;
}

Expand All @@ -319,37 +460,38 @@ void TileTransposeTactic::InitCandidates(ir::IRSchedule* sch) {
}

void TileTransposeTactic::InitAxisInfo() {
src_low_axis_ = GetSrcLowAxis(common_perm_);
dst_low_axis_ = GetDstLowAxis(common_perm_);
src_low_axis_ = GetSrcLowAxis(sub_iter_space_);
dst_low_axis_ = GetDstLowAxis(sub_iter_space_);

std::set<int> high_axis;
for (int i = 0; i < common_perm_.size(); ++i) high_axis.insert(i);
for (int i = 0; i < sub_iter_space_.size(); ++i) high_axis.insert(i);
for (auto i : src_low_axis_) high_axis.erase(i);
for (auto i : dst_low_axis_) high_axis.erase(i);
high_axis_.assign(high_axis.begin(), high_axis.end());
}

std::vector<int> TileTransposeTactic::GetSrcLowAxis(
const std::vector<int>& perm) {
std::set<int> src_low_axis;
for (int i = 0; i < perm.size(); ++i) {
if (perm[i] == perm.size() - 1) {
src_low_axis.insert(i);
for (int j = i - 1; j >= 0; j--) {
if (perm[j] + 1 != perm[j + 1]) break;
src_low_axis.insert(j);
}
}
const std::vector<int>& iter_space) {
std::set<int> src_low_axis{iter_space.back()};
for (int i = iter_space.size() - 2; i >= 0; --i) {
if (iter_space[i] + 1 != iter_space[i + 1]) break;
src_low_axis.insert(iter_space[i]);
}
return {src_low_axis.begin(), src_low_axis.end()};
}

std::vector<int> TileTransposeTactic::GetDstLowAxis(
const std::vector<int>& perm) {
std::set<int> dst_low_axis{perm.size() - 1};
for (int i = perm.size() - 2; i >= 0; --i) {
if (perm[i] + 1 != perm[i + 1]) break;
dst_low_axis.insert(i);
const std::vector<int>& iter_space) {
std::set<int> dst_low_axis;
auto it =
std::find(iter_space.begin(), iter_space.end(), iter_space.size() - 1);
if (it != iter_space.end()) {
dst_low_axis.insert(*it);
while (it != iter_space.begin()) {
if (*(it - 1) != *it - 1) break;
--it;
dst_low_axis.insert(*it);
}
}
return {dst_low_axis.begin(), dst_low_axis.end()};
}
Expand Down Expand Up @@ -392,9 +534,6 @@ std::string TileTransposeTactic::CreateCacheBlock(
std::string cache_block_id = ir::analyzer::GetBlockName(cache_block);
context_->output_names.insert(cache_block_id);

// Note: the CacheRead primitive de-transposes the input, so we need to apply
// the transpose permutation again on the cache block.
sch->Reorder(cache_block_id, common_perm_);
return cache_block_id;
}

Expand Down
Loading