Skip to content

[CINN] TileBroadcastTactic NHWC layout broadcast #71464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 208 additions & 60 deletions paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,22 @@ class TileBroadcastTactic final : public ScheduleTactic {
void InitBroadcastAxisInfo(ir::IRSchedule* sch);
void InitBroadcastSizeInfo();
void FuseAxisGroups(ir::IRSchedule* sch, const std::string& block_id);
// NHWC layout: calculate number of warps per block
static int CalcNumWarps(int64_t preserved_size);

private:
ScheduleContext* context_;
bool can_apply_;

// list of broadcast axis in ascending order
enum class TacticExtension : uint8_t {
Invalid = 0x0,
NCHWExt,
NHWCExt,
NumSupportedExt
};

TacticExtension applied_ext_; // applied tactic extension

// list of broadcast axis in ascending order, NHWC layout will use this list
std::vector<int> broadcast_axis_;
// one-hot representation of broadcast_axis
std::vector<bool> is_broadcast_axis_;
Expand All @@ -107,6 +117,7 @@ class TileBroadcastTactic final : public ScheduleTactic {
// ^ ^ ^ ^
// | | low_broadcast_axis
// preserved_axis

std::vector<int> high_broadcast_axis_;
std::vector<int> preserved_axis_;
std::vector<int> low_broadcast_axis_;
Expand All @@ -115,6 +126,8 @@ class TileBroadcastTactic final : public ScheduleTactic {
int64_t broadcast_size_;
// product of the low broadcast axis's dim sizes
int64_t low_broadcast_size_;
// product of the preserved axis's dim sizes
int64_t preserved_size_;
};

std::unordered_set<ir::Var> CollectIterVars(
Expand Down Expand Up @@ -273,7 +286,7 @@ bool ScheduleBlockEnableVectorize(const ScheduleConfig& config,

void TileBroadcastTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
context_ = context;
can_apply_ = false;
applied_ext_ = TacticExtension::Invalid;
if (!FLAGS_cinn_enable_tile_broadcast) {
return;
}
Expand All @@ -290,27 +303,48 @@ void TileBroadcastTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
if (broadcast_axis_.empty()) {
return;
}
// 3. It is an NCHW broadcast. We check this by checking that he last axis is
// a broadcast axis, and all the 3 groups of axis exist.
if (!is_broadcast_axis_.back()) {
return;
}
if (high_broadcast_axis_.empty() || preserved_axis_.empty() ||
low_broadcast_axis_.empty()) {
return;
}
InitBroadcastSizeInfo();
// 4. The low_broadcast_size should be a multiple of 32 (the CUDA warp size).
// Otherwise, memory access will not be fully coalesced, leading to
// performance degradation.
// TODO(liangshuhao): we may allow aligning to 16 if further optimizations
// can compensate for the cost of non-coalesced access.
if (low_broadcast_size_ % 32 != 0) {
return;
}

if (is_broadcast_axis_.back()) {
// 3. It is an NCHW broadcast. We check this by checking that he last axis
// is
// a broadcast axis, and all the 3 groups of axis exist.
if (high_broadcast_axis_.empty() || low_broadcast_axis_.empty() ||
preserved_axis_.empty()) {
return;
}
InitBroadcastSizeInfo();
// 4. The low_broadcast_size should be a multiple of 32 (the CUDA warp
// size).
// Otherwise, memory access will not be fully coalesced, leading to
// performance degradation.
// TODO(liangshuhao): we may allow aligning to 16 if further optimizations
// can compensate for the cost of non-coalesced access.
if (low_broadcast_size_ % 32 != 0) {
return;
}
applied_ext_ = TacticExtension::NCHWExt;
VLOG(4) << "TileBroadcastTactic::Init: select extension: NCHWExt\n";
} else {
// 3. For NHWC layout, we need valid preserved axis and broadcast axis
// also, to be more stable, we only handles NHW(C) broadcast
if (preserved_axis_.empty() || broadcast_axis_.empty() ||
preserved_axis_.size() > 1) {
return;
}
InitBroadcastSizeInfo();
// 4. compatible size is the multiple of 32
if (broadcast_size_ % 32 != 0 || preserved_size_ % 32 != 0) {
return;
}
int num_warps = TileBroadcastTactic::CalcNumWarps(preserved_size_ >> 5);
// 5. check whether NHWC extension can be successfully applied
if (num_warps < 0) {
return;
}
applied_ext_ = TacticExtension::NHWCExt;
VLOG(4) << "TileBroadcastTactic::Init: select extension: NHWCExt\n";
}
// Now we can apply this tactic
can_apply_ = true;
ir::Expr module_root = sch->GetModule().GetExprs().front();
ir::Expr root_block = ir::analyzer::GetRootSBlock(module_root);
auto* root_node = root_block.As<ir::ScheduleBlockRealize>()
Expand Down Expand Up @@ -363,6 +397,48 @@ void TileBroadcastTactic::InitBroadcastSizeInfo() {
for (int axis : low_broadcast_axis_) {
low_broadcast_size_ = MulDimSize(low_broadcast_size_, loop_ranges[axis]);
}

preserved_size_ = 1;
for (int axis : preserved_axis_) {
preserved_size_ = MulDimSize(preserved_size_, loop_ranges[axis]);
}
}

int TileBroadcastTactic::CalcNumWarps(int64_t num_warps) {
constexpr int MAX_WARP_BLOCK = 32;
// the largest preserved size is 1024, for size bigger than 1024
// TODO(heqianyue): the code should be revised to be a DP version, see (6)
if (num_warps > 1024) {
return -1;
}
// several rules to decide the thread block size
// (1) the preserved size (channel) is too small
// we need more threads in a block
if (num_warps <= 3) {
return std::max(4l, num_warps * 2);
}
// (2) if the num_warps is power of 2, use thread block size 256
if ((num_warps & (num_warps - 1)) == 0) {
return 8;
}
// (3) if num_warps is smaller than 32, use the num_warps * 32 as thread block
// size
if (num_warps <= MAX_WARP_BLOCK) {
return num_warps;
}
// (4) otherwise, find the largest divisor `best` of num_warps that is smaller
// than 32 and use `best` * 32 as the block size
int best = -1;
for (int x = MAX_WARP_BLOCK; x >= 4; x--) {
if (num_warps % x == 0) {
best = x;
break;
}
}
// (6) actually, the full problem is a variant to the backpack problem, DP
// should be used here But since the problem is not large enough, we can use a
// simple greedy algorithm. TODO(heqianyue): use DP to solve bigger psize
return best;
}

void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
Expand All @@ -372,36 +448,91 @@ void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
return;
}

if (!can_apply_) return;
int block_size = 256;
// check the number of warps here, if not a applicable
// preserved_size, func will return later
if (applied_ext_ == TacticExtension::NHWCExt) {
block_size = TileBroadcastTactic::CalcNumWarps(preserved_size_ >> 5);
if (block_size == -1) {
applied_ext_ = TacticExtension::Invalid;
}
block_size = std::clamp(block_size << 5, 128, 1024);
}
if (applied_ext_ == TacticExtension::Invalid) return;

// Cluster and fuse axis of the same type to get exactly 3 loops.
// [B, P, B, P, ..., B, B] => [B, P, B]
FuseAxisGroups(sch, block_id);

// Do tiling.
// To achieve best performance, we apply different tiling templates based on
// low_broadcast_size. The key is which axis to allocate inner loop:
// 1. For small size:
// [B, P, B<=256]
// => [(blockY, loop), blockX, threadX].
// 2. For medium size:
// [B, P, 256<B<=2048],
// => [blockY, blockX, (loop, threadX)].
// 3. For large size:
// [B, P, B>2048]
// => [blockX', blockY, (blockX, loop, threadX)].
std::vector<std::string> axis_bind;
if (low_broadcast_size_ <= 256) {
sch->Split(block_id, 0, {-1, 4});
axis_bind = {"blockIdx.y", "", "blockIdx.x", "threadIdx.x"};
} else if (low_broadcast_size_ <= 2048) {
sch->Split(block_id, 2, {-1, 256});
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
} else {
sch->Reorder(block_id, {1, 0});
sch->Fuse(block_id, {1, 2});
sch->Split(block_id, 1, {-1, 4, 256});
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
if (applied_ext_ == TacticExtension::NCHWExt) {
// Do tiling.
// To achieve best performance, we apply different tiling templates based on
// low_broadcast_size. The key is which axis to allocate inner loop:
// 1. For small size:
// [B, P, B<=256]
// => [(blockY, loop), blockX, threadX].
// 2. For medium size:
// [B, P, 256<B<=2048],
// => [blockY, blockX, (loop, threadX)].
// 3. For large size:
// [B, P, B>2048]
// => [blockX', blockY, (blockX, loop, threadX)].
if (low_broadcast_size_ <= 256) {
sch->Split(block_id, 0, {-1, 4});
axis_bind = {"blockIdx.y", "", "blockIdx.x", "threadIdx.x"};
} else if (low_broadcast_size_ <= 2048) {
sch->Split(block_id, 2, {-1, block_size});
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
} else {
sch->Reorder(block_id, {1, 0});
sch->Fuse(block_id, {1, 2});
sch->Split(block_id, 1, {-1, 4, block_size});
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"};
}
VLOG(4) << "TileBroadcastTactic using original NCHW layout extension\n";
} else if (applied_ext_ == TacticExtension::NHWCExt) {
// NHWC layout will have 2 fused loops, so we start with (blockIdx.x,
// threadIdx.x)
if (broadcast_size_ <= 64) {
/**
* if the broadcast size is smaller than 64
* this means we need more blocks to increase the occupancy
* so no thread coarsening anyway
*/
sch->Split(block_id, 1, {-1, block_size});
sch->Fuse(block_id, {0, 1});
axis_bind = {"blockIdx.x", "threadIdx.x"};
} else {
if (preserved_size_ == block_size) {
// block size is enough to cover
sch->Split(block_id, 1, {-1, block_size});
sch->Fuse(block_id, {0, 1});
sch->Split(block_id, 0, {-1, 4});
axis_bind = {"blockIdx.x", "", "threadIdx.x"};
} else if (preserved_size_ < block_size) {
// block size is larger (deliberately, to have enough threads)
// than preserved size
sch->Fuse(block_id, {0, 1}); // single fused loop
sch->Split(block_id,
0,
{-1, 4, block_size / preserved_size_, preserved_size_});
axis_bind = {"blockIdx.x", "", "threadIdx.y", "threadIdx.x"};
} else {
/**
* block size is not enough to cover the preserved size
* make the load index invariant to inner loop
* (-1, 4, p_size / block_size, block_size)
*/
sch->Split(block_id, 1, {-1, block_size});
sch->Fuse(block_id, {0, 1});
sch->Split(block_id, 0, {-1, 4, preserved_size_ / block_size});
axis_bind = {"blockIdx.x", "", "blockIdx.y", "threadIdx.x"};
}
}

VLOG(4) << "TileBroadcastTactic using NHWC layout extension, block size: "
<< block_size << "\n";
}

// Do binding.
Expand All @@ -426,32 +557,49 @@ void TileBroadcastTactic::FuseAxisGroups(ir::IRSchedule* sch,
const std::string& block_id) {
// Reorder high-dim axis to cluster axis of the same type.
// [B, P, B, P, ..., B, B] => [B, B, ..., P, P, ..., B, B]
std::vector<int> high_axis_perm = high_broadcast_axis_;
high_axis_perm.insert(
high_axis_perm.end(), preserved_axis_.begin(), preserved_axis_.end());
sch->Reorder(block_id, high_axis_perm);

// Fuse continuous axis of the same type.
// [B, B, ..., P, P, ..., B, B] => [B, P, B]
// [B, B, ..., P, P, ..., B, B] => [B, P, B] (for NCHW layout)
// [B, B, P, B, ..., P] => [B, P] (for NHWC layout)
const auto FuseRange = [&](int start, int count) {
if (count > 1) {
std::vector<int> loops_index(count);
std::iota(loops_index.begin(), loops_index.end(), start);
sch->Fuse(block_id, loops_index);
}
};
int high_axis_num = high_broadcast_axis_.size();
int mid_axis_num = preserved_axis_.size();
int low_axis_num = low_broadcast_axis_.size();

FuseRange(high_axis_num + mid_axis_num, low_axis_num);
FuseRange(high_axis_num, mid_axis_num);
FuseRange(0, high_axis_num);
if (applied_ext_ == TacticExtension::NCHWExt) {
std::vector<int> high_axis_perm = high_broadcast_axis_;
high_axis_perm.insert(
high_axis_perm.end(), preserved_axis_.begin(), preserved_axis_.end());
sch->Reorder(block_id, high_axis_perm);

int high_axis_num = high_broadcast_axis_.size();
int mid_axis_num = preserved_axis_.size();
int low_axis_num = low_broadcast_axis_.size();

FuseRange(high_axis_num + mid_axis_num, low_axis_num);
FuseRange(high_axis_num, mid_axis_num);
FuseRange(0, high_axis_num);
} else if (applied_ext_ == TacticExtension::NHWCExt) {
std::vector<int> broadcast_axis_perm = broadcast_axis_;
broadcast_axis_perm.insert(broadcast_axis_perm.end(),
preserved_axis_.begin(),
preserved_axis_.end());
sch->Reorder(block_id, broadcast_axis_perm);
int broadcast_num = broadcast_axis_.size();
int preserved_num = preserved_axis_.size();
FuseRange(broadcast_num, preserved_num);
FuseRange(0, broadcast_num);
} else {
std::cerr << "Unknown tactic extension: " << static_cast<int>(applied_ext_)
<< std::endl;
throw std::runtime_error("Unsupported tactic extension");
}
}

void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
const std::string& block_id) {
if (!can_apply_) return;
if (applied_ext_ == TacticExtension::Invalid) return;
const auto vectorize_factor =
static_cast<int>(context_->config.tile_config.vectorize_factor);

Expand Down
Loading