Skip to content

[CINN] TileBroadcastTactic NHWC layout broadcast #71464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 192 additions & 49 deletions paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,19 @@ class TileBroadcastTactic final : public ScheduleTactic {
void InitBroadcastAxisInfo(ir::IRSchedule* sch);
void InitBroadcastSizeInfo();
void FuseAxisGroups(ir::IRSchedule* sch, const std::string& block_id);
std::vector<std::string> TileNCHW(ir::IRSchedule* sch,
const std::string& block_id,
int block_size);
std::vector<std::string> TileNHWC(ir::IRSchedule* sch,
const std::string& block_id,
int block_size);

private:
ScheduleContext* context_;
bool can_apply_;

enum class BroadcastLayout { Invalid = 0x0, NCHWLayout, NHWCLayout };

BroadcastLayout applied_layout_; // applied tactic

// list of broadcast axis in ascending order
std::vector<int> broadcast_axis_;
Expand All @@ -107,6 +116,7 @@ class TileBroadcastTactic final : public ScheduleTactic {
// ^ ^ ^ ^
// | | low_broadcast_axis
// preserved_axis

std::vector<int> high_broadcast_axis_;
std::vector<int> preserved_axis_;
std::vector<int> low_broadcast_axis_;
Expand All @@ -115,6 +125,8 @@ class TileBroadcastTactic final : public ScheduleTactic {
int64_t broadcast_size_;
// product of the low broadcast axis's dim sizes
int64_t low_broadcast_size_;
// product of the preserved axis's dim sizes
int64_t preserved_size_;
};

std::unordered_set<ir::Var> CollectIterVars(
Expand Down Expand Up @@ -271,9 +283,43 @@ bool ScheduleBlockEnableVectorize(const ScheduleConfig& config,
return true;
}

static int CalcNumWarps(int64_t num_warps) {
// NHWC layout: calculate number of warps per block
constexpr int MAX_WARP_BLOCK = 32;
// the largest preserved size is 1024, for size bigger than 1024
// TODO(heqianyue): the code should be revised to be a DP version
if (num_warps > 1024) {
return -1;
}
// several rules to decide the thread block size
// (1) the preserved size (channel) is too small
// we need more threads in a block
if (num_warps <= 3) {
return std::max(4l, num_warps * 2);
}
// (2) if the num_warps is power of 2, use thread block size 256
if ((num_warps & (num_warps - 1)) == 0) {
return 8;
}
// (3) if num_warps <= 32, use the num_warps * 32 as thread block size
if (num_warps <= MAX_WARP_BLOCK) {
return num_warps;
}
// (4) otherwise, find the largest divisor `best` of num_warps that is smaller
// than 32 and use `best` * 32 as the block size
int best = -1;
for (int x = MAX_WARP_BLOCK; x >= 4; x--) {
if (num_warps % x == 0) {
best = x;
break;
}
}
return best;
}

void TileBroadcastTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
context_ = context;
can_apply_ = false;
applied_layout_ = BroadcastLayout::Invalid;
if (!FLAGS_cinn_enable_tile_broadcast) {
return;
}
Expand All @@ -290,27 +336,47 @@ void TileBroadcastTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
if (broadcast_axis_.empty()) {
return;
}
// 3. It is an NCHW broadcast. We check this by checking that he last axis is
// a broadcast axis, and all the 3 groups of axis exist.
if (!is_broadcast_axis_.back()) {
return;
}
if (high_broadcast_axis_.empty() || preserved_axis_.empty() ||
low_broadcast_axis_.empty()) {
return;
}
InitBroadcastSizeInfo();
// 4. The low_broadcast_size should be a multiple of 32 (the CUDA warp size).
// Otherwise, memory access will not be fully coalesced, leading to
// performance degradation.
// TODO(liangshuhao): we may allow aligning to 16 if further optimizations
// can compensate for the cost of non-coalesced access.
if (low_broadcast_size_ % 32 != 0) {
return;
}

if (is_broadcast_axis_.back()) {
// 3. It is an NCHW broadcast. We check this by checking that he last axis
// is a broadcast axis, and all the 3 groups of axis exist.
if (high_broadcast_axis_.empty() || low_broadcast_axis_.empty() ||
preserved_axis_.empty()) {
return;
}
InitBroadcastSizeInfo();
// 4. The low_broadcast_size should be a multiple of 32 (the CUDA warp
// size).
// Otherwise, memory access will not be fully coalesced, leading to
// performance degradation.
// TODO(liangshuhao): we may allow aligning to 16 if further optimizations
// can compensate for the cost of non-coalesced access.
if (low_broadcast_size_ % 32 != 0) {
return;
}
applied_layout_ = BroadcastLayout::NCHWLayout;
VLOG(4) << "TileBroadcastTactic::Init: matched NCHWLayout\n";
} else {
// 3. For NHWC layout, we need valid preserved axis and broadcast axis
// also, to be more stable, we only handles NHW(C) broadcast
if (preserved_axis_.empty() || broadcast_axis_.empty() ||
preserved_axis_.size() > 1) {
return;
}
InitBroadcastSizeInfo();
// 4. compatible channel size is the multiple of 32
if (preserved_size_ % 32 != 0) {
return;
}
int num_warps = CalcNumWarps(preserved_size_ >> 5);
// 5. check whether NHWC layout can be successfully applied
if (num_warps < 0) {
return;
}
applied_layout_ = BroadcastLayout::NHWCLayout;
VLOG(4) << "TileBroadcastTactic::Init: matched NHWCLayout\n";
}
// Now we can apply this tactic
can_apply_ = true;
ir::Expr module_root = sch->GetModule().GetExprs().front();
ir::Expr root_block = ir::analyzer::GetRootSBlock(module_root);
auto* root_node = root_block.As<ir::ScheduleBlockRealize>()
Expand Down Expand Up @@ -363,22 +429,15 @@ void TileBroadcastTactic::InitBroadcastSizeInfo() {
for (int axis : low_broadcast_axis_) {
low_broadcast_size_ = MulDimSize(low_broadcast_size_, loop_ranges[axis]);
}
}

void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (ScheduleBlockEnableVectorize(context_->config, block_id)) {
ApplyVectorize(sch, block_id);
return;
preserved_size_ = 1;
for (int axis : preserved_axis_) {
preserved_size_ = MulDimSize(preserved_size_, loop_ranges[axis]);
}
}

if (!can_apply_) return;

// Cluster and fuse axis of the same type to get exactly 3 loops.
// [B, P, B, P, ..., B, B] => [B, P, B]
FuseAxisGroups(sch, block_id);

// Do tiling.
std::vector<std::string> TileBroadcastTactic::TileNCHW(
ir::IRSchedule* sch, const std::string& block_id, int block_size) {
// To achieve best performance, we apply different tiling templates based on
// low_broadcast_size. The key is which axis to allocate inner loop:
// 1. For small size:
Expand All @@ -390,18 +449,93 @@ void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
// 3. For large size:
// [B, P, B>2048]
// => [blockX', blockY, (blockX, loop, threadX)].
std::vector<std::string> axis_bind;
VLOG(4) << "TileBroadcastTactic using original NCHW layout\n";
if (low_broadcast_size_ <= 256) {
sch->Split(block_id, 0, {-1, 4});
axis_bind = {"blockIdx.y", "", "blockIdx.x", "threadIdx.x"};
return {"blockIdx.y", "", "blockIdx.x", "threadIdx.x"};
} else if (low_broadcast_size_ <= 2048) {
sch->Split(block_id, 2, {-1, 256});
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
sch->Split(block_id, 2, {-1, block_size});
return {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
} else {
sch->Reorder(block_id, {1, 0});
sch->Fuse(block_id, {1, 2});
sch->Split(block_id, 1, {-1, 4, 256});
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
sch->Split(block_id, 1, {-1, 4, block_size});
return {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
}
}

std::vector<std::string> TileBroadcastTactic::TileNHWC(
ir::IRSchedule* sch, const std::string& block_id, int block_size) {
// NHWC layout will have 2 fused loops, so we start with (blockIdx.x,
// threadIdx.x)
VLOG(4) << "TileBroadcastTactic using NHWC layout, block size: " << block_size
<< "\n";
if (broadcast_size_ <= 64) {
/**
* if the broadcast size is smaller than 64
* this means we need more blocks to increase the occupancy
* so no thread coarsening anyway
*/
sch->Split(block_id, 1, {-1, block_size});
sch->Fuse(block_id, {0, 1});
return {"blockIdx.x", "threadIdx.x"};
} else {
if (preserved_size_ == block_size) {
// block size is enough to cover
sch->Split(block_id, 1, {-1, block_size});
sch->Fuse(block_id, {0, 1});
sch->Split(block_id, 0, {-1, 4});
return {"blockIdx.x", "", "threadIdx.x"};
} else if (preserved_size_ < block_size) {
// block size is larger (deliberately, to have enough threads)
// than preserved size
sch->Fuse(block_id, {0, 1}); // single fused loop
sch->Split(
block_id, 0, {-1, 4, block_size / preserved_size_, preserved_size_});
return {"blockIdx.x", "", "threadIdx.y", "threadIdx.x"};
} else {
/**
* block size is not enough to cover the preserved size
* make the load index invariant to inner loop
* (-1, 4, p_size / block_size, block_size)
*/
sch->Split(block_id, 1, {-1, block_size});
sch->Fuse(block_id, {0, 1});
sch->Split(block_id, 0, {-1, 4, preserved_size_ / block_size});
return {"blockIdx.x", "", "blockIdx.y", "threadIdx.x"};
}
}
}

void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (ScheduleBlockEnableVectorize(context_->config, block_id)) {
ApplyVectorize(sch, block_id);
return;
}

if (applied_layout_ == BroadcastLayout::Invalid) return;
int block_size = 256;
// check the number of warps here, if not a applicable
// preserved_size, func will return later
if (applied_layout_ == BroadcastLayout::NHWCLayout) {
block_size = CalcNumWarps(preserved_size_ >> 5);
if (block_size == -1) {
applied_layout_ = BroadcastLayout::Invalid;
}
block_size = std::clamp(block_size << 5, 128, 1024);
}

// Cluster and fuse axis of the same type to get exactly 3 loops.
// [B, P, B, P, ..., B, B] => [B, P, B]
FuseAxisGroups(sch, block_id);

// Do tiling.
std::vector<std::string> axis_bind;
if (applied_layout_ == BroadcastLayout::NCHWLayout) {
axis_bind = TileNCHW(sch, block_id, block_size);
} else {
axis_bind = TileNHWC(sch, block_id, block_size);
}

// Do binding.
Expand All @@ -426,23 +560,32 @@ void TileBroadcastTactic::FuseAxisGroups(ir::IRSchedule* sch,
const std::string& block_id) {
// Reorder high-dim axis to cluster axis of the same type.
// [B, P, B, P, ..., B, B] => [B, B, ..., P, P, ..., B, B]
std::vector<int> high_axis_perm = high_broadcast_axis_;
high_axis_perm.insert(
high_axis_perm.end(), preserved_axis_.begin(), preserved_axis_.end());
sch->Reorder(block_id, high_axis_perm);

// Fuse continuous axis of the same type.
// [B, B, ..., P, P, ..., B, B] => [B, P, B]
// [B, B, ..., P, P, ..., B, B] => [B, P, B] (for NCHW layout)
// [B, B, P, B, ..., P] => [B, P] (for NHWC layout)
if (applied_layout_ == BroadcastLayout::Invalid) return;
const auto FuseRange = [&](int start, int count) {
if (count > 1) {
std::vector<int> loops_index(count);
std::iota(loops_index.begin(), loops_index.end(), start);
sch->Fuse(block_id, loops_index);
}
};
int high_axis_num = high_broadcast_axis_.size();
std::vector<int> axis_perm;
int high_axis_num = 0;
int low_axis_num = 0;
int mid_axis_num = preserved_axis_.size();
int low_axis_num = low_broadcast_axis_.size();
if (applied_layout_ == BroadcastLayout::NCHWLayout) {
axis_perm = high_broadcast_axis_;
high_axis_num = high_broadcast_axis_.size();
low_axis_num = low_broadcast_axis_.size();
} else {
axis_perm = broadcast_axis_;
high_axis_num = broadcast_axis_.size();
}
axis_perm.insert(
axis_perm.end(), preserved_axis_.begin(), preserved_axis_.end());
sch->Reorder(block_id, axis_perm);

FuseRange(high_axis_num + mid_axis_num, low_axis_num);
FuseRange(high_axis_num, mid_axis_num);
Expand All @@ -451,7 +594,7 @@ void TileBroadcastTactic::FuseAxisGroups(ir::IRSchedule* sch,

void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
const std::string& block_id) {
if (!can_apply_) return;
if (applied_layout_ == BroadcastLayout::Invalid) return;
const auto vectorize_factor =
static_cast<int>(context_->config.tile_config.vectorize_factor);

Expand Down