diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc index 54a4faef8e1d1b..ff9f4d39c7b7d2 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc @@ -90,10 +90,19 @@ class TileBroadcastTactic final : public ScheduleTactic { void InitBroadcastAxisInfo(ir::IRSchedule* sch); void InitBroadcastSizeInfo(); void FuseAxisGroups(ir::IRSchedule* sch, const std::string& block_id); + std::vector TileNCHW(ir::IRSchedule* sch, + const std::string& block_id, + int block_size); + std::vector TileNHWC(ir::IRSchedule* sch, + const std::string& block_id, + int block_size); private: ScheduleContext* context_; - bool can_apply_; + + enum class BroadcastLayout { Invalid = 0x0, NCHWLayout, NHWCLayout }; + + BroadcastLayout applied_layout_; // applied tactic // list of broadcast axis in ascending order std::vector broadcast_axis_; @@ -107,6 +116,7 @@ class TileBroadcastTactic final : public ScheduleTactic { // ^ ^ ^ ^ // | | low_broadcast_axis // preserved_axis + std::vector high_broadcast_axis_; std::vector preserved_axis_; std::vector low_broadcast_axis_; @@ -115,6 +125,8 @@ class TileBroadcastTactic final : public ScheduleTactic { int64_t broadcast_size_; // product of the low broadcast axis's dim sizes int64_t low_broadcast_size_; + // product of the preserved axis's dim sizes + int64_t preserved_size_; }; std::unordered_set CollectIterVars( @@ -271,9 +283,43 @@ bool ScheduleBlockEnableVectorize(const ScheduleConfig& config, return true; } +static int CalcNumWarps(int64_t num_warps) { + // NHWC layout: calculate number of warps per block + constexpr int MAX_WARP_BLOCK = 32; + // the largest preserved size is 1024, for size bigger than 1024 + // TODO(heqianyue): the code should be revised to be a DP version + if (num_warps > 1024) { + return -1; + } + // several rules to decide the thread block size + // (1) the preserved size (channel) is too small + // we need more threads in a block + if (num_warps <= 3) { + return std::max(4l, num_warps * 2); + } + // (2) if the num_warps is power of 2, use thread block size 256 + if ((num_warps & (num_warps - 1)) == 0) { + return 8; + } + // (3) if num_warps <= 32, use the num_warps * 32 as thread block size + if (num_warps <= MAX_WARP_BLOCK) { + return num_warps; + } + // (4) otherwise, find the largest divisor `best` of num_warps that is smaller + // than 32 and use `best` * 32 as the block size + int best = -1; + for (int x = MAX_WARP_BLOCK; x >= 4; x--) { + if (num_warps % x == 0) { + best = x; + break; + } + } + return best; +} + void TileBroadcastTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) { context_ = context; - can_apply_ = false; + applied_layout_ = BroadcastLayout::Invalid; if (!FLAGS_cinn_enable_tile_broadcast) { return; } @@ -290,27 +336,47 @@ void TileBroadcastTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) { if (broadcast_axis_.empty()) { return; } - // 3. It is an NCHW broadcast. We check this by checking that he last axis is - // a broadcast axis, and all the 3 groups of axis exist. - if (!is_broadcast_axis_.back()) { - return; - } - if (high_broadcast_axis_.empty() || preserved_axis_.empty() || - low_broadcast_axis_.empty()) { - return; - } - InitBroadcastSizeInfo(); - // 4. The low_broadcast_size should be a multiple of 32 (the CUDA warp size). - // Otherwise, memory access will not be fully coalesced, leading to - // performance degradation. - // TODO(liangshuhao): we may allow aligning to 16 if further optimizations - // can compensate for the cost of non-coalesced access. - if (low_broadcast_size_ % 32 != 0) { - return; - } + if (is_broadcast_axis_.back()) { + // 3. It is an NCHW broadcast. We check this by checking that he last axis + // is a broadcast axis, and all the 3 groups of axis exist. + if (high_broadcast_axis_.empty() || low_broadcast_axis_.empty() || + preserved_axis_.empty()) { + return; + } + InitBroadcastSizeInfo(); + // 4. The low_broadcast_size should be a multiple of 32 (the CUDA warp + // size). + // Otherwise, memory access will not be fully coalesced, leading to + // performance degradation. + // TODO(liangshuhao): we may allow aligning to 16 if further optimizations + // can compensate for the cost of non-coalesced access. + if (low_broadcast_size_ % 32 != 0) { + return; + } + applied_layout_ = BroadcastLayout::NCHWLayout; + VLOG(4) << "TileBroadcastTactic::Init: matched NCHWLayout\n"; + } else { + // 3. For NHWC layout, we need valid preserved axis and broadcast axis + // also, to be more stable, we only handles NHW(C) broadcast + if (preserved_axis_.empty() || broadcast_axis_.empty() || + preserved_axis_.size() > 1) { + return; + } + InitBroadcastSizeInfo(); + // 4. compatible channel size is the multiple of 32 + if (preserved_size_ % 32 != 0) { + return; + } + int num_warps = CalcNumWarps(preserved_size_ >> 5); + // 5. check whether NHWC layout can be successfully applied + if (num_warps < 0) { + return; + } + applied_layout_ = BroadcastLayout::NHWCLayout; + VLOG(4) << "TileBroadcastTactic::Init: matched NHWCLayout\n"; + } // Now we can apply this tactic - can_apply_ = true; ir::Expr module_root = sch->GetModule().GetExprs().front(); ir::Expr root_block = ir::analyzer::GetRootSBlock(module_root); auto* root_node = root_block.As() @@ -363,22 +429,15 @@ void TileBroadcastTactic::InitBroadcastSizeInfo() { for (int axis : low_broadcast_axis_) { low_broadcast_size_ = MulDimSize(low_broadcast_size_, loop_ranges[axis]); } -} -void TileBroadcastTactic::Apply(ir::IRSchedule* sch, - const std::string& block_id) { - if (ScheduleBlockEnableVectorize(context_->config, block_id)) { - ApplyVectorize(sch, block_id); - return; + preserved_size_ = 1; + for (int axis : preserved_axis_) { + preserved_size_ = MulDimSize(preserved_size_, loop_ranges[axis]); } +} - if (!can_apply_) return; - - // Cluster and fuse axis of the same type to get exactly 3 loops. - // [B, P, B, P, ..., B, B] => [B, P, B] - FuseAxisGroups(sch, block_id); - - // Do tiling. +std::vector TileBroadcastTactic::TileNCHW( + ir::IRSchedule* sch, const std::string& block_id, int block_size) { // To achieve best performance, we apply different tiling templates based on // low_broadcast_size. The key is which axis to allocate inner loop: // 1. For small size: @@ -390,18 +449,93 @@ void TileBroadcastTactic::Apply(ir::IRSchedule* sch, // 3. For large size: // [B, P, B>2048] // => [blockX', blockY, (blockX, loop, threadX)]. - std::vector axis_bind; + VLOG(4) << "TileBroadcastTactic using original NCHW layout\n"; if (low_broadcast_size_ <= 256) { sch->Split(block_id, 0, {-1, 4}); - axis_bind = {"blockIdx.y", "", "blockIdx.x", "threadIdx.x"}; + return {"blockIdx.y", "", "blockIdx.x", "threadIdx.x"}; } else if (low_broadcast_size_ <= 2048) { - sch->Split(block_id, 2, {-1, 256}); - axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"}; + sch->Split(block_id, 2, {-1, block_size}); + return {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"}; } else { sch->Reorder(block_id, {1, 0}); sch->Fuse(block_id, {1, 2}); - sch->Split(block_id, 1, {-1, 4, 256}); - axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"}; + sch->Split(block_id, 1, {-1, 4, block_size}); + return {"blockIdx.y", "blockIdx.x", "", "threadIdx.x"}; + } +} + +std::vector TileBroadcastTactic::TileNHWC( + ir::IRSchedule* sch, const std::string& block_id, int block_size) { + // NHWC layout will have 2 fused loops, so we start with (blockIdx.x, + // threadIdx.x) + VLOG(4) << "TileBroadcastTactic using NHWC layout, block size: " << block_size + << "\n"; + if (broadcast_size_ <= 64) { + /** + * if the broadcast size is smaller than 64 + * this means we need more blocks to increase the occupancy + * so no thread coarsening anyway + */ + sch->Split(block_id, 1, {-1, block_size}); + sch->Fuse(block_id, {0, 1}); + return {"blockIdx.x", "threadIdx.x"}; + } else { + if (preserved_size_ == block_size) { + // block size is enough to cover + sch->Split(block_id, 1, {-1, block_size}); + sch->Fuse(block_id, {0, 1}); + sch->Split(block_id, 0, {-1, 4}); + return {"blockIdx.x", "", "threadIdx.x"}; + } else if (preserved_size_ < block_size) { + // block size is larger (deliberately, to have enough threads) + // than preserved size + sch->Fuse(block_id, {0, 1}); // single fused loop + sch->Split( + block_id, 0, {-1, 4, block_size / preserved_size_, preserved_size_}); + return {"blockIdx.x", "", "threadIdx.y", "threadIdx.x"}; + } else { + /** + * block size is not enough to cover the preserved size + * make the load index invariant to inner loop + * (-1, 4, p_size / block_size, block_size) + */ + sch->Split(block_id, 1, {-1, block_size}); + sch->Fuse(block_id, {0, 1}); + sch->Split(block_id, 0, {-1, 4, preserved_size_ / block_size}); + return {"blockIdx.x", "", "blockIdx.y", "threadIdx.x"}; + } + } +} + +void TileBroadcastTactic::Apply(ir::IRSchedule* sch, + const std::string& block_id) { + if (ScheduleBlockEnableVectorize(context_->config, block_id)) { + ApplyVectorize(sch, block_id); + return; + } + + if (applied_layout_ == BroadcastLayout::Invalid) return; + int block_size = 256; + // check the number of warps here, if not a applicable + // preserved_size, func will return later + if (applied_layout_ == BroadcastLayout::NHWCLayout) { + block_size = CalcNumWarps(preserved_size_ >> 5); + if (block_size == -1) { + applied_layout_ = BroadcastLayout::Invalid; + } + block_size = std::clamp(block_size << 5, 128, 1024); + } + + // Cluster and fuse axis of the same type to get exactly 3 loops. + // [B, P, B, P, ..., B, B] => [B, P, B] + FuseAxisGroups(sch, block_id); + + // Do tiling. + std::vector axis_bind; + if (applied_layout_ == BroadcastLayout::NCHWLayout) { + axis_bind = TileNCHW(sch, block_id, block_size); + } else { + axis_bind = TileNHWC(sch, block_id, block_size); } // Do binding. @@ -426,13 +560,10 @@ void TileBroadcastTactic::FuseAxisGroups(ir::IRSchedule* sch, const std::string& block_id) { // Reorder high-dim axis to cluster axis of the same type. // [B, P, B, P, ..., B, B] => [B, B, ..., P, P, ..., B, B] - std::vector high_axis_perm = high_broadcast_axis_; - high_axis_perm.insert( - high_axis_perm.end(), preserved_axis_.begin(), preserved_axis_.end()); - sch->Reorder(block_id, high_axis_perm); - // Fuse continuous axis of the same type. - // [B, B, ..., P, P, ..., B, B] => [B, P, B] + // [B, B, ..., P, P, ..., B, B] => [B, P, B] (for NCHW layout) + // [B, B, P, B, ..., P] => [B, P] (for NHWC layout) + if (applied_layout_ == BroadcastLayout::Invalid) return; const auto FuseRange = [&](int start, int count) { if (count > 1) { std::vector loops_index(count); @@ -440,9 +571,21 @@ void TileBroadcastTactic::FuseAxisGroups(ir::IRSchedule* sch, sch->Fuse(block_id, loops_index); } }; - int high_axis_num = high_broadcast_axis_.size(); + std::vector axis_perm; + int high_axis_num = 0; + int low_axis_num = 0; int mid_axis_num = preserved_axis_.size(); - int low_axis_num = low_broadcast_axis_.size(); + if (applied_layout_ == BroadcastLayout::NCHWLayout) { + axis_perm = high_broadcast_axis_; + high_axis_num = high_broadcast_axis_.size(); + low_axis_num = low_broadcast_axis_.size(); + } else { + axis_perm = broadcast_axis_; + high_axis_num = broadcast_axis_.size(); + } + axis_perm.insert( + axis_perm.end(), preserved_axis_.begin(), preserved_axis_.end()); + sch->Reorder(block_id, axis_perm); FuseRange(high_axis_num + mid_axis_num, low_axis_num); FuseRange(high_axis_num, mid_axis_num); @@ -451,7 +594,7 @@ void TileBroadcastTactic::FuseAxisGroups(ir::IRSchedule* sch, void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch, const std::string& block_id) { - if (!can_apply_) return; + if (applied_layout_ == BroadcastLayout::Invalid) return; const auto vectorize_factor = static_cast(context_->config.tile_config.vectorize_factor);