Skip to content

Commit 1e64f15

Browse files
authored
improve streamk load balance (PaddlePaddle#743)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
1 parent 78b30d3 commit 1e64f15

File tree

2 files changed

+33
-33
lines changed

2 files changed

+33
-33
lines changed

include/cutlass/gemm/device/gemm_universal_base.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class GemmUniversalBase {
107107
/// Kernel SM occupancy (in thread blocks)
108108
thread_local static int sm_occupancy_;
109109

110+
/// Kernel dynamic shared memory allocation requirement
111+
thread_local static int smem_size_;
110112

111113
/// Initialize static thread-local members for the thread's current device,
112114
/// if necessary.
@@ -138,15 +140,15 @@ class GemmUniversalBase {
138140
}
139141

140142
// Update the kernel function's shared memory configuration for the current device
141-
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
142-
if (smem_size >= (48 << 10))
143-
{
144-
// Requires more than 48KB: configure for extended, dynamic shared memory
143+
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));
145144

145+
// If requires more than 48KB: configure for extended, dynamic shared memory
146+
if (smem_size_ >= (48 << 10))
147+
{
146148
cudart_result = cudaFuncSetAttribute(
147149
Kernel2<GemmKernel>,
148150
cudaFuncAttributeMaxDynamicSharedMemorySize,
149-
smem_size);
151+
smem_size_);
150152
if (cudart_result != cudaSuccess) {
151153
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
152154
return Status::kErrorInternal;
@@ -166,7 +168,7 @@ class GemmUniversalBase {
166168
&sm_occupancy_,
167169
Kernel2<GemmKernel>,
168170
GemmKernel::kThreadCount,
169-
int(sizeof(typename GemmKernel::SharedStorage)),
171+
smem_size_,
170172
cudaOccupancyDisableCachingOverride);
171173
if (cudart_result != cudaSuccess) {
172174
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
@@ -179,7 +181,9 @@ class GemmUniversalBase {
179181
CUTLASS_TRACE_HOST(" "
180182
"device_ordinal: (" << device_ordinal_ << "), "
181183
"device_sms: (" << device_sms_ << "), "
182-
"sm_occupancy: (" << sm_occupancy_ << ")");
184+
"sm_occupancy: (" << sm_occupancy_ << ") "
185+
"smem_size: (" << smem_size_ << ") "
186+
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
183187

184188
return Status::kSuccess;
185189
}
@@ -335,17 +339,16 @@ class GemmUniversalBase {
335339
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
336340

337341
// Configure grid and block dimensions
338-
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
339342
dim3 block(GemmKernel::kThreadCount, 1, 1);
340343
dim3 grid = params_.get_grid_dims();
341344

342345
// Launch kernel
343346
CUTLASS_TRACE_HOST(" "
344347
"grid: (" << grid << "), "
345348
"block: (" << block << "), "
346-
"SMEM: (" << smem_size << ")");
349+
"SMEM: (" << smem_size_ << ")");
347350

348-
Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
351+
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
349352

350353
// Query for errors
351354
cudaError_t result = cudaGetLastError();
@@ -398,6 +401,11 @@ thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
398401
template <typename GemmKernel_>
399402
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
400403

404+
/// Kernel dynamic shared memory allocation requirement
405+
template <typename GemmKernel_>
406+
thread_local int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;
407+
408+
401409

402410
/////////////////////////////////////////////////////////////////////////////////////////////////
403411

include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK {
158158
FastDivmod sk_iters_per_big_block;
159159
FastDivmod sk_iters_per_region;
160160
FastDivmod sk_blocks_per_region;
161-
FastDivmod sm_occupancy;
162-
163161
} div_mod;
164162

165163

@@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK {
188186
", dp_blocks: " << dp_blocks <<
189187
", sk_blocks_per_region: " << sk_blocks_per_region <<
190188
", sk_regions: " << sk_regions <<
189+
", sk_waves: " << sk_waves <<
191190
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
192191
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
193192
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
@@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK {
200199
", sm_occupancy: " << sm_occupancy <<
201200
", avail_sms: " << avail_sms <<
202201
", cohort_raster: " << cohort_raster <<
202+
", num_blocks: " << get_num_blocks() <<
203203
"\n\n";
204204
#endif
205205
}
@@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK {
316316

317317
// We're at (or greater) than GPU occupancy
318318

319-
if (full_waves % sm_occupancy == sm_occupancy - 1)
319+
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
320320
{
321-
// Form the SK wave from the partial wave to get us to full GPU occupancy
321+
// If occupancy is more than one CTA per SM, form the SK wave from the partial
322+
// wave to get us to full GPU occupancy
322323
int max_sk_occupancy = 1;
323324

324325
dp_tiles = full_wave_tiles;
@@ -533,15 +534,13 @@ struct ThreadblockSwizzleStreamK {
533534
dp_first_wave_tiles += waveset_excess;
534535
dp_blocks -= (waveset_excess * avail_sms);
535536
}
536-
537537
}
538538

539539
// Setup fast-div/mod for device-side usage
540540
div_mod.tiled_shape_m = FastDivmod(tiled_shape.m());
541541
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
542542
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
543543
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
544-
div_mod.sm_occupancy = FastDivmod(sm_occupancy);
545544
}
546545

547546

@@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK {
602601
/// Obtains number of threadblocks per GEMM
603602
int get_num_blocks() const
604603
{
605-
// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms;
606-
// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks;
607-
608-
609604
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
610605

611-
if (work_blocks < avail_sms)
606+
if (work_blocks <= avail_sms * 2)
612607
{
613608
return work_blocks;
614609
}
615610

616-
int gpu_occupancy = sm_occupancy * avail_sms;
617-
int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy;
618-
return gpu_wavesets * gpu_occupancy;
619-
611+
return fast_max(work_blocks, avail_sms * 4);
620612
}
621613

622614

@@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK {
686678
CUTLASS_DEVICE
687679
int get_block_idx() const
688680
{
689-
int block_idx = RematerializeBlockIdxX();
681+
// Remap the block indices for the first two waves of thread blocks if
682+
// we have multi-occupancy and the grid constitutes four or more waves
690683

691-
int gpu_occupancy = avail_sms * sm_occupancy;
684+
int block_idx = RematerializeBlockIdxX();
692685
int num_blocks = device_num_blocks();
693-
int dest_sm, dest_wave;
694-
695-
div_mod.sm_occupancy(dest_sm, dest_wave, block_idx);
696-
686+
int dest_sm = block_idx / 2;
687+
int dest_wave = block_idx % 2;
697688
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
698689

699-
// remapping the first gpu_occupancy blocks
700-
if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy))
690+
if ((sm_occupancy > 1) &&
691+
(num_blocks >= avail_sms * 4) &&
692+
(block_idx < avail_sms * 2))
701693
{
702694
block_idx = remapped_block_idx;
703695
}

0 commit comments

Comments
 (0)