@@ -96,6 +96,14 @@ class TileBroadcastTactic final : public ScheduleTactic {
96
96
std::vector<std::string> TileNHWC (ir::IRSchedule* sch,
97
97
const std::string& block_id,
98
98
int block_size);
99
+ std::vector<std::string> TileVectorizeNCHW (ir::IRSchedule* sch,
100
+ const std::string& block_id,
101
+ int block_size,
102
+ int vetorize_factor);
103
+ std::vector<std::string> TileVectorizeNHWC (ir::IRSchedule* sch,
104
+ const std::string& block_id,
105
+ int block_size,
106
+ int vetorize_factor);
99
107
100
108
private:
101
109
ScheduleContext* context_;
@@ -507,11 +515,94 @@ std::vector<std::string> TileBroadcastTactic::TileNHWC(
507
515
}
508
516
}
509
517
518
+ std::vector<std::string> TileBroadcastTactic::TileVectorizeNCHW (
519
+ ir::IRSchedule* sch,
520
+ const std::string& block_id,
521
+ int block_size,
522
+ int vectorize_factor) {
523
+ /* *
524
+ * 1. For small size:
525
+ * [B, P, B<=256]
526
+ * => [blockY, blockX, (threadX, loop)].
527
+ * 2. For medium size:
528
+ * [B, P, 256<B<=2048],
529
+ * => [blockY, blockX, (threadX, loop)].
530
+ * 3. For large size:
531
+ * [B, P, B>2048]
532
+ * => [blockX', blockY, (blockX, threadX, loop)].
533
+ */
534
+ VLOG (4 ) << " TileBroadcastTactic using original NCHW layout, "
535
+ " low_broadcast_size_ = "
536
+ << low_broadcast_size_;
537
+ if (low_broadcast_size_ <= 256 ) {
538
+ sch->Split (block_id, 2 , {-1 , vectorize_factor});
539
+ return {" blockIdx.y" , " blockIdx.x" , " threadIdx.x" , " " };
540
+ } else if (low_broadcast_size_ <= 2048 ) {
541
+ sch->Split (block_id, 2 , {-1 , block_size, vectorize_factor});
542
+ sch->Fuse (block_id, {1 , 2 });
543
+ return {" blockIdx.y" , " blockIdx.x" , " threadIdx.x" , " " };
544
+ } else {
545
+ sch->Reorder (block_id, {1 , 0 });
546
+ sch->Fuse (block_id, {1 , 2 });
547
+ sch->Split (block_id, 1 , {-1 , block_size, vectorize_factor});
548
+ return {" blockIdx.y" , " blockIdx.x" , " threadIdx.x" , " " };
549
+ }
550
+ }
551
+
552
+ std::vector<std::string> TileBroadcastTactic::TileVectorizeNHWC (
553
+ ir::IRSchedule* sch,
554
+ const std::string& block_id,
555
+ int block_size,
556
+ int vectorize_factor) {
557
+ // NHWC layout will have 2 fused loops, so we start with (blockIdx.x,
558
+ // threadIdx.x)
559
+ VLOG (4 ) << " TileBroadcastTactic using NHWC layout, block size = "
560
+ << block_size << " , broadcast_size_ = " << broadcast_size_
561
+ << " , preserved_size_ = " << preserved_size_;
562
+ int vectorize_p_size = preserved_size_ / vectorize_factor;
563
+ if (broadcast_size_ <= 64 ) {
564
+ /* *
565
+ * if the broadcast size is smaller than 64
566
+ * this means we need more blocks to increase the occupancy
567
+ * so no thread coarsening anyway
568
+ */
569
+ sch->Split (block_id, 1 , {-1 , block_size, vectorize_factor});
570
+ sch->Fuse (block_id, {0 , 1 });
571
+ return {" blockIdx.x" , " threadIdx.x" , " " };
572
+ } else {
573
+ if (vectorize_p_size == block_size) {
574
+ sch->Split (block_id, 1 , {-1 , vectorize_factor});
575
+ return {" blockIdx.x" , " threadIdx.x" , " " };
576
+ } else if (vectorize_p_size < block_size) {
577
+ // block size is larger (deliberately, to have enough threads)
578
+ // than preserved size
579
+ sch->Split (block_id, 1 , {-1 , vectorize_factor});
580
+ sch->Fuse (block_id, {0 , 1 });
581
+ sch->Split (
582
+ block_id, 0 , {-1 , block_size / vectorize_p_size, vectorize_p_size});
583
+ return {" blockIdx.x" , " threadIdx.y" , " threadIdx.x" , " " };
584
+ } else {
585
+ /* *
586
+ * block size is not enough to cover the preserved size
587
+ * make the load index invariant to inner loop
588
+ * (-1, v_p_size / block_size, block_size, vectorize_factor)
589
+ */
590
+ block_size = 128 ;
591
+ sch->Split (block_id, 1 , {-1 , block_size, vectorize_factor});
592
+ sch->Fuse (block_id, {0 , 1 });
593
+ sch->Split (block_id, 0 , {-1 , vectorize_p_size / block_size});
594
+ return {" blockIdx.x" , " blockIdx.y" , " threadIdx.x" , " " };
595
+ }
596
+ }
597
+ }
598
+
510
599
void TileBroadcastTactic::Apply (ir::IRSchedule* sch,
511
600
const std::string& block_id) {
512
- if (ScheduleBlockEnableVectorize (context_->config , block_id)) {
513
- ApplyVectorize (sch, block_id);
514
- return ;
601
+ if (applied_layout_ == BroadcastLayout::NCHWLayout) {
602
+ if (ScheduleBlockEnableVectorize (context_->config , block_id)) {
603
+ ApplyVectorize (sch, block_id);
604
+ return ;
605
+ }
515
606
}
516
607
517
608
if (applied_layout_ == BroadcastLayout::Invalid) return ;
@@ -598,6 +689,17 @@ void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
598
689
const auto vectorize_factor =
599
690
static_cast <int >(context_->config .tile_config .vectorize_factor );
600
691
692
+ int block_size = 256 ;
693
+ // check the number of warps here, if not a applicable
694
+ // preserved_size, func will return later
695
+ if (applied_layout_ == BroadcastLayout::NHWCLayout) {
696
+ block_size = CalcNumWarps (preserved_size_ >> 5 );
697
+ if (block_size == -1 ) {
698
+ applied_layout_ = BroadcastLayout::Invalid;
699
+ }
700
+ block_size = std::clamp (block_size << 5 , 128 , 1024 );
701
+ }
702
+
601
703
FuseAxisGroups (sch, block_id);
602
704
603
705
const auto ApplyVectorization = [&](const std::string& block_id, int factor) {
@@ -606,29 +708,13 @@ void TileBroadcastTactic::ApplyVectorize(ir::IRSchedule* sch,
606
708
sch->Vectorize (loops[vectorize_axis], factor);
607
709
};
608
710
609
- // Do tiling with vectorize.
610
- // 1. For small size:
611
- // [B, P, B<=256]
612
- // => [(blockY, loop), blockX, threadX, vectorize].
613
- // 2. For medium size:
614
- // [B, P, 256<B<=2048],
615
- // => [blockY, blockX, (loop, threadX, vectorize)].
616
- // 3. For large size:
617
- // [B, P, B>2048]
618
- // => [blockX', blockY, (blockX, loop, threadX, vectorize)].
619
711
std::vector<std::string> axis_bind;
620
- if (low_broadcast_size_ <= 256 ) {
621
- sch->Split (block_id, 0 , {-1 , 4 });
622
- sch->Split (block_id, 3 , {-1 , vectorize_factor});
623
- axis_bind = {" blockIdx.y" , " " , " blockIdx.x" , " threadIdx.x" , " " };
624
- } else if (low_broadcast_size_ <= 2048 ) {
625
- sch->Split (block_id, 2 , {-1 , 256 , vectorize_factor});
626
- axis_bind = {" blockIdx.y" , " blockIdx.x" , " " , " threadIdx.x" , " " };
712
+ if (applied_layout_ == BroadcastLayout::NCHWLayout) {
713
+ // [B, P, B] (for NCHW layout)
714
+ axis_bind = TileVectorizeNCHW (sch, block_id, block_size, vectorize_factor);
627
715
} else {
628
- sch->Reorder (block_id, {1 , 0 });
629
- sch->Fuse (block_id, {1 , 2 });
630
- sch->Split (block_id, 1 , {-1 , 256 , vectorize_factor});
631
- axis_bind = {" blockIdx.y" , " blockIdx.x" , " threadIdx.x" , " " };
716
+ // [B, P] (for NHWC layout)
717
+ axis_bind = TileVectorizeNHWC (sch, block_id, block_size, vectorize_factor);
632
718
}
633
719
634
720
// set vectorize schedule primitives
0 commit comments