@@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK {
158
158
FastDivmod sk_iters_per_big_block;
159
159
FastDivmod sk_iters_per_region;
160
160
FastDivmod sk_blocks_per_region;
161
- FastDivmod sm_occupancy;
162
-
163
161
} div_mod;
164
162
165
163
@@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK {
188
186
" , dp_blocks: " << dp_blocks <<
189
187
" , sk_blocks_per_region: " << sk_blocks_per_region <<
190
188
" , sk_regions: " << sk_regions <<
189
+ " , sk_waves: " << sk_waves <<
191
190
" , sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
192
191
" , sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
193
192
" , dp_first_wave_tiles: " << dp_first_wave_tiles <<
@@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK {
200
199
" , sm_occupancy: " << sm_occupancy <<
201
200
" , avail_sms: " << avail_sms <<
202
201
" , cohort_raster: " << cohort_raster <<
202
+ " , num_blocks: " << get_num_blocks () <<
203
203
" \n\n " ;
204
204
#endif
205
205
}
@@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK {
316
316
317
317
// We're at (or greater) than GPU occupancy
318
318
319
- if (full_waves % sm_occupancy == sm_occupancy - 1 )
319
+ if ((sm_occupancy > 1 ) && ( full_waves % sm_occupancy == sm_occupancy - 1 ) )
320
320
{
321
- // Form the SK wave from the partial wave to get us to full GPU occupancy
321
+ // If occupancy is more than one CTA per SM, form the SK wave from the partial
322
+ // wave to get us to full GPU occupancy
322
323
int max_sk_occupancy = 1 ;
323
324
324
325
dp_tiles = full_wave_tiles;
@@ -533,15 +534,13 @@ struct ThreadblockSwizzleStreamK {
533
534
dp_first_wave_tiles += waveset_excess;
534
535
dp_blocks -= (waveset_excess * avail_sms);
535
536
}
536
-
537
537
}
538
538
539
539
// Setup fast-div/mod for device-side usage
540
540
div_mod.tiled_shape_m = FastDivmod (tiled_shape.m ());
541
541
div_mod.tiled_shape_n = FastDivmod (tiled_shape.n ());
542
542
div_mod.tiled_cohort_shape_n = FastDivmod (tiled_cohort_shape.n ());
543
543
div_mod.iters_per_tile = FastDivmod (iters_per_tile);
544
- div_mod.sm_occupancy = FastDivmod (sm_occupancy);
545
544
}
546
545
547
546
@@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK {
602
601
// / Obtains number of threadblocks per GEMM
603
602
int get_num_blocks () const
604
603
{
605
- // int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms;
606
- // return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks;
607
-
608
-
609
604
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
610
605
611
- if (work_blocks < avail_sms)
606
+ if (work_blocks <= avail_sms * 2 )
612
607
{
613
608
return work_blocks;
614
609
}
615
610
616
- int gpu_occupancy = sm_occupancy * avail_sms;
617
- int gpu_wavesets = (work_blocks + gpu_occupancy - 1 ) / gpu_occupancy;
618
- return gpu_wavesets * gpu_occupancy;
619
-
611
+ return fast_max (work_blocks, avail_sms * 4 );
620
612
}
621
613
622
614
@@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK {
686
678
CUTLASS_DEVICE
687
679
int get_block_idx () const
688
680
{
689
- int block_idx = RematerializeBlockIdxX ();
681
+ // Remap the block indices for the first two waves of thread blocks if
682
+ // we have multi-occupancy and the grid constitutes four or more waves
690
683
691
- int gpu_occupancy = avail_sms * sm_occupancy ;
684
+ int block_idx = RematerializeBlockIdxX () ;
692
685
int num_blocks = device_num_blocks ();
693
- int dest_sm, dest_wave;
694
-
695
- div_mod.sm_occupancy (dest_sm, dest_wave, block_idx);
696
-
686
+ int dest_sm = block_idx / 2 ;
687
+ int dest_wave = block_idx % 2 ;
697
688
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
698
689
699
- // remapping the first gpu_occupancy blocks
700
- if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy))
690
+ if ((sm_occupancy > 1 ) &&
691
+ (num_blocks >= avail_sms * 4 ) &&
692
+ (block_idx < avail_sms * 2 ))
701
693
{
702
694
block_idx = remapped_block_idx;
703
695
}
0 commit comments