Skip to content

Commit c45a0b3

Browse files
committed
Fix the loop index in Split/Fuse is not less than total loop's number.
1 parent 6d4cb7e commit c45a0b3

File tree

1 file changed

+110
-24
lines changed

1 file changed

+110
-24
lines changed

paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc

Lines changed: 110 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ class TileBroadcastTactic final : public ScheduleTactic {
9696
std::vector<std::string> TileNHWC(ir::IRSchedule* sch,
9797
const std::string& block_id,
9898
int block_size);
99+
std::vector<std::string> TileVectorizeNCHW(ir::IRSchedule* sch,
100+
const std::string& block_id,
101+
int block_size,
102+
int vetorize_factor);
103+
std::vector<std::string> TileVectorizeNHWC(ir::IRSchedule* sch,
104+
const std::string& block_id,
105+
int block_size,
106+
int vetorize_factor);
99107

100108
private:
101109
ScheduleContext* context_;
@@ -507,11 +515,94 @@ std::vector<std::string> TileBroadcastTactic::TileNHWC(
507515
}
508516
}
509517

518+
std::vector<std::string> TileBroadcastTactic::TileVectorizeNCHW(
519+
ir::IRSchedule* sch,
520+
const std::string& block_id,
521+
int block_size,
522+
int vectorize_factor) {
523+
/**
524+
* 1. For small size:
525+
* [B, P, B<=256]
526+
* => [blockY, blockX, (threadX, loop)].
527+
* 2. For medium size:
528+
* [B, P, 256<B<=2048],
529+
* => [blockY, blockX, (threadX, loop)].
530+
* 3. For large size:
531+
* [B, P, B>2048]
532+
* => [blockX', blockY, (blockX, threadX, loop)].
533+
*/
534+
VLOG(4) << "TileBroadcastTactic using original NCHW layout, "
535+
"low_broadcast_size_ = "
536+
<< low_broadcast_size_;
537+
if (low_broadcast_size_ <= 256) {
538+
sch->Split(block_id, 2, {-1, vectorize_factor});
539+
return {"blockIdx.y", "blockIdx.x", "threadIdx.x", ""};
540+
} else if (low_broadcast_size_ <= 2048) {
541+
sch->Split(block_id, 2, {-1, block_size, vectorize_factor});
542+
sch->Fuse(block_id, {1, 2});
543+
return {"blockIdx.y", "blockIdx.x", "threadIdx.x", ""};
544+
} else {
545+
sch->Reorder(block_id, {1, 0});
546+
sch->Fuse(block_id, {1, 2});
547+
sch->Split(block_id, 1, {-1, block_size, vectorize_factor});
548+
return {"blockIdx.y", "blockIdx.x", "threadIdx.x", ""};
549+
}
550+
}
551+
552+
std::vector<std::string> TileBroadcastTactic::TileVectorizeNHWC(
553+
ir::IRSchedule* sch,
554+
const std::string& block_id,
555+
int block_size,
556+
int vectorize_factor) {
557+
// NHWC layout will have 2 fused loops, so we start with (blockIdx.x,
558+
// threadIdx.x)
559+
VLOG(4) << "TileBroadcastTactic using NHWC layout, block size = "
560+
<< block_size << ", broadcast_size_ = " << broadcast_size_
561+
<< ", preserved_size_ = " << preserved_size_;
562+
int vectorize_p_size = preserved_size_ / vectorize_factor;
563+
if (broadcast_size_ <= 64) {
564+
/**
565+
* if the broadcast size is smaller than 64
566+
* this means we need more blocks to increase the occupancy
567+
* so no thread coarsening anyway
568+
*/
569+
sch->Split(block_id, 1, {-1, block_size, vectorize_factor});
570+
sch->Fuse(block_id, {0, 1});
571+
return {"blockIdx.x", "threadIdx.x", ""};
572+
} else {
573+
if (vectorize_p_size == block_size) {
574+
sch->Split(block_id, 1, {-1, vectorize_factor});
575+
return {"blockIdx.x", "threadIdx.x", ""};
576+
} else if (vectorize_p_size < block_size) {
577+
// block size is larger (deliberately, to have enough threads)
578+
// than preserved size
579+
sch->Split(block_id, 1, {-1, vectorize_factor});
580+
sch->Fuse(block_id, {0, 1});
581+
sch->Split(
582+
block_id, 0, {-1, block_size / vectorize_p_size, vectorize_p_size});
583+
return {"blockIdx.x", "threadIdx.y", "threadIdx.x", ""};
584+
} else {
585+
/**
586+
* block size is not enough to cover the preserved size
587+
* make the load index invariant to inner loop
588+
* (-1, v_p_size / block_size, block_size, vectorize_factor)
589+
*/
590+
block_size = 128;
591+
sch->Split(block_id, 1, {-1, block_size, vectorize_factor});
592+
sch->Fuse(block_id, {0, 1});
593+
sch->Split(block_id, 0, {-1, vectorize_p_size / block_size});
594+
return {"blockIdx.x", "blockIdx.y", "threadIdx.x", ""};
595+
}
596+
}
597+
}
598+
510599
void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
511600
const std::string& block_id) {
512-
if (ScheduleBlockEnableVectorize(context_->config, block_id)) {
513-
ApplyVectorize(sch, block_id);
514-
return;
601+
if (applied_layout_ == BroadcastLayout::NCHWLayout) {
602+
if (ScheduleBlockEnableVectorize(context_->config, block_id)) {
603+
ApplyVectorize(sch, block_id);
604+
return;
605+
}
515606
}
516607

517608
if (applied_layout_ == BroadcastLayout::Invalid) return;
@@ -598,6 +689,17 @@ void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
598689
const auto vectorize_factor =
599690
static_cast<int>(context_->config.tile_config.vectorize_factor);
600691

692+
int block_size = 256;
693+
// check the number of warps here, if not a applicable
694+
// preserved_size, func will return later
695+
if (applied_layout_ == BroadcastLayout::NHWCLayout) {
696+
block_size = CalcNumWarps(preserved_size_ >> 5);
697+
if (block_size == -1) {
698+
applied_layout_ = BroadcastLayout::Invalid;
699+
}
700+
block_size = std::clamp(block_size << 5, 128, 1024);
701+
}
702+
601703
FuseAxisGroups(sch, block_id);
602704

603705
const auto ApplyVectorization = [&](const std::string& block_id, int factor) {
@@ -606,29 +708,13 @@ void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
606708
sch->Vectorize(loops[vectorize_axis], factor);
607709
};
608710

609-
// Do tiling with vectorize.
610-
// 1. For small size:
611-
// [B, P, B<=256]
612-
// => [(blockY, loop), blockX, threadX, vectorize].
613-
// 2. For medium size:
614-
// [B, P, 256<B<=2048],
615-
// => [blockY, blockX, (loop, threadX, vectorize)].
616-
// 3. For large size:
617-
// [B, P, B>2048]
618-
// => [blockX', blockY, (blockX, loop, threadX, vectorize)].
619711
std::vector<std::string> axis_bind;
620-
if (low_broadcast_size_ <= 256) {
621-
sch->Split(block_id, 0, {-1, 4});
622-
sch->Split(block_id, 3, {-1, vectorize_factor});
623-
axis_bind = {"blockIdx.y", "", "blockIdx.x", "threadIdx.x", ""};
624-
} else if (low_broadcast_size_ <= 2048) {
625-
sch->Split(block_id, 2, {-1, 256, vectorize_factor});
626-
axis_bind = {"blockIdx.y", "blockIdx.x", "", "threadIdx.x", ""};
712+
if (applied_layout_ == BroadcastLayout::NCHWLayout) {
713+
// [B, P, B] (for NCHW layout)
714+
axis_bind = TileVectorizeNCHW(sch, block_id, block_size, vectorize_factor);
627715
} else {
628-
sch->Reorder(block_id, {1, 0});
629-
sch->Fuse(block_id, {1, 2});
630-
sch->Split(block_id, 1, {-1, 256, vectorize_factor});
631-
axis_bind = {"blockIdx.y", "blockIdx.x", "threadIdx.x", ""};
716+
// [B, P] (for NHWC layout)
717+
axis_bind = TileVectorizeNHWC(sch, block_id, block_size, vectorize_factor);
632718
}
633719

634720
// set vectorize schedule primitives

0 commit comments

Comments
 (0)