@@ -100,10 +100,6 @@ class TileBroadcastTactic final : public ScheduleTactic {
100
100
const std::string& block_id,
101
101
int block_size,
102
102
int vetorize_factor);
103
- std::vector<std::string> TileVectorizeNHWC (ir::IRSchedule* sch,
104
- const std::string& block_id,
105
- int block_size,
106
- int vetorize_factor);
107
103
108
104
private:
109
105
ScheduleContext* context_;
@@ -549,60 +545,15 @@ std::vector<std::string> TileBroadcastTactic::TileVectorizeNCHW(
549
545
}
550
546
}
551
547
552
- std::vector<std::string> TileBroadcastTactic::TileVectorizeNHWC (
553
- ir::IRSchedule* sch,
554
- const std::string& block_id,
555
- int block_size,
556
- int vectorize_factor) {
557
- // NHWC layout will have 2 fused loops, so we start with (blockIdx.x,
558
- // threadIdx.x)
559
- VLOG (4 ) << " TileBroadcastTactic using NHWC layout, block size = "
560
- << block_size << " , broadcast_size_ = " << broadcast_size_
561
- << " , preserved_size_ = " << preserved_size_;
562
- int vectorize_p_size = preserved_size_ / vectorize_factor;
563
- if (broadcast_size_ <= 64 ) {
564
- /* *
565
- * if the broadcast size is smaller than 64
566
- * this means we need more blocks to increase the occupancy
567
- * so no thread coarsening anyway
568
- */
569
- sch->Split (block_id, 1 , {-1 , block_size, vectorize_factor});
570
- sch->Fuse (block_id, {0 , 1 });
571
- return {" blockIdx.x" , " threadIdx.x" , " " };
572
- } else {
573
- if (vectorize_p_size == block_size) {
574
- sch->Split (block_id, 1 , {-1 , vectorize_factor});
575
- return {" blockIdx.x" , " threadIdx.x" , " " };
576
- } else if (vectorize_p_size < block_size) {
577
- // block size is larger (deliberately, to have enough threads)
578
- // than preserved size
579
- sch->Split (block_id, 1 , {-1 , vectorize_factor});
580
- sch->Fuse (block_id, {0 , 1 });
581
- sch->Split (
582
- block_id, 0 , {-1 , block_size / vectorize_p_size, vectorize_p_size});
583
- return {" blockIdx.x" , " threadIdx.y" , " threadIdx.x" , " " };
584
- } else {
585
- /* *
586
- * block size is not enough to cover the preserved size
587
- * make the load index invariant to inner loop
588
- * (-1, v_p_size / block_size, block_size, vectorize_factor)
589
- */
590
- block_size = 128 ;
591
- sch->Split (block_id, 1 , {-1 , block_size, vectorize_factor});
592
- sch->Fuse (block_id, {0 , 1 });
593
- sch->Split (block_id, 0 , {-1 , vectorize_p_size / block_size});
594
- return {" blockIdx.x" , " blockIdx.y" , " threadIdx.x" , " " };
595
- }
596
- }
597
- }
598
-
599
548
void TileBroadcastTactic::Apply (ir::IRSchedule* sch,
600
549
const std::string& block_id) {
601
- if (applied_layout_ == BroadcastLayout::NCHWLayout) {
602
- if (ScheduleBlockEnableVectorize (context_->config , block_id)) {
603
- ApplyVectorize (sch, block_id);
604
- return ;
605
- }
550
+ if (applied_layout_ == BroadcastLayout::NCHWLayout &&
551
+ ScheduleBlockEnableVectorize (context_->config , block_id) {
552
+ // TODO(baoqiwen): Due to register overflow issues, NHWC currently has
553
+ // performance problems. The current vectorization only supports NCHW, and
554
+ // future support for NHWC is needed.
555
+ ApplyVectorize (sch, block_id);
556
+ return ;
606
557
}
607
558
608
559
if (applied_layout_ == BroadcastLayout::Invalid) return ;
@@ -714,7 +665,7 @@ void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
714
665
axis_bind = TileVectorizeNCHW (sch, block_id, block_size, vectorize_factor);
715
666
} else {
716
667
// [B, P] (for NHWC layout)
717
- axis_bind = TileVectorizeNHWC (sch, block_id, block_size, vectorize_factor);
668
+ // TODO(baoqiwen): support TileVectorizeNHWC
718
669
}
719
670
720
671
// set vectorize schedule primitives
0 commit comments