Skip to content

Commit 6ea3df1

Browse files
committed
Fix the loop index in Split/Fuse is not less than total loop's number.
1 parent c45a0b3 commit 6ea3df1

File tree

1 file changed

+8
-57
lines changed

1 file changed

+8
-57
lines changed

paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,6 @@ class TileBroadcastTactic final : public ScheduleTactic {
100100
const std::string& block_id,
101101
int block_size,
102102
int vetorize_factor);
103-
std::vector<std::string> TileVectorizeNHWC(ir::IRSchedule* sch,
104-
const std::string& block_id,
105-
int block_size,
106-
int vetorize_factor);
107103

108104
private:
109105
ScheduleContext* context_;
@@ -549,60 +545,15 @@ std::vector<std::string> TileBroadcastTactic::TileVectorizeNCHW(
549545
}
550546
}
551547

552-
std::vector<std::string> TileBroadcastTactic::TileVectorizeNHWC(
553-
ir::IRSchedule* sch,
554-
const std::string& block_id,
555-
int block_size,
556-
int vectorize_factor) {
557-
// NHWC layout will have 2 fused loops, so we start with (blockIdx.x,
558-
// threadIdx.x)
559-
VLOG(4) << "TileBroadcastTactic using NHWC layout, block size = "
560-
<< block_size << ", broadcast_size_ = " << broadcast_size_
561-
<< ", preserved_size_ = " << preserved_size_;
562-
int vectorize_p_size = preserved_size_ / vectorize_factor;
563-
if (broadcast_size_ <= 64) {
564-
/**
565-
* if the broadcast size is smaller than 64
566-
* this means we need more blocks to increase the occupancy
567-
* so no thread coarsening anyway
568-
*/
569-
sch->Split(block_id, 1, {-1, block_size, vectorize_factor});
570-
sch->Fuse(block_id, {0, 1});
571-
return {"blockIdx.x", "threadIdx.x", ""};
572-
} else {
573-
if (vectorize_p_size == block_size) {
574-
sch->Split(block_id, 1, {-1, vectorize_factor});
575-
return {"blockIdx.x", "threadIdx.x", ""};
576-
} else if (vectorize_p_size < block_size) {
577-
// block size is larger (deliberately, to have enough threads)
578-
// than preserved size
579-
sch->Split(block_id, 1, {-1, vectorize_factor});
580-
sch->Fuse(block_id, {0, 1});
581-
sch->Split(
582-
block_id, 0, {-1, block_size / vectorize_p_size, vectorize_p_size});
583-
return {"blockIdx.x", "threadIdx.y", "threadIdx.x", ""};
584-
} else {
585-
/**
586-
* block size is not enough to cover the preserved size
587-
* make the load index invariant to inner loop
588-
* (-1, v_p_size / block_size, block_size, vectorize_factor)
589-
*/
590-
block_size = 128;
591-
sch->Split(block_id, 1, {-1, block_size, vectorize_factor});
592-
sch->Fuse(block_id, {0, 1});
593-
sch->Split(block_id, 0, {-1, vectorize_p_size / block_size});
594-
return {"blockIdx.x", "blockIdx.y", "threadIdx.x", ""};
595-
}
596-
}
597-
}
598-
599548
void TileBroadcastTactic::Apply(ir::IRSchedule* sch,
600549
const std::string& block_id) {
601-
if (applied_layout_ == BroadcastLayout::NCHWLayout) {
602-
if (ScheduleBlockEnableVectorize(context_->config, block_id)) {
603-
ApplyVectorize(sch, block_id);
604-
return;
605-
}
550+
if (applied_layout_ == BroadcastLayout::NCHWLayout &&
551+
ScheduleBlockEnableVectorize(context_->config, block_id) {
552+
// TODO(baoqiwen): Due to register overflow issues, NHWC currently has
553+
// performance problems. The current vectorization only supports NCHW, and
554+
// future support for NHWC is needed.
555+
ApplyVectorize(sch, block_id);
556+
return;
606557
}
607558

608559
if (applied_layout_ == BroadcastLayout::Invalid) return;
@@ -714,7 +665,7 @@ void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
714665
axis_bind = TileVectorizeNCHW(sch, block_id, block_size, vectorize_factor);
715666
} else {
716667
// [B, P] (for NHWC layout)
717-
axis_bind = TileVectorizeNHWC(sch, block_id, block_size, vectorize_factor);
668+
// TODO(baoqiwen): support TileVectorizeNHWC
718669
}
719670

720671
// set vectorize schedule primitives

0 commit comments

Comments
 (0)