Skip to content

Commit 85af92c

Browse files
Tune the AG performance for the llama-8b (#21)
1 parent 322710d commit 85af92c

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/all_gather/gemm_v2_ag_kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ struct GemmV2AGKernel_Space : OpSpaceBase<GemmV2AGKernel_Space> {
182182
make_gemm_v2_hparams(Shape<_64, _64, _32>{}, Shape<_16, _8, _16>{}, _StreamkDP{})),
183183
cute::make_tuple(Auto{}),
184184
cute::make_tuple(
185-
Shape<_128, _128, _64>{},
186185
Shape<_128, _128, _32>{},
186+
Shape<_128, _128, _64>{},
187187
Shape<_64, _128, _32>{},
188188
Shape<_64, _128, _64>{},
189189
Shape<_64, _256, _32>{},

src/all_gather/tuning_config/config_ag_gemm_kernel_sm80_tp4_nnodes1.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ static int config_ag_gemm_kernel_sm80_tp4_nnodes1 = []() {
8888
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_FP16{}(),_FP16{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}()),make_runtime_config(8192,12288,12288,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(256l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
8989
/// NVLink
9090
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}()),make_runtime_config(8192,12288,12288,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,256l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
91+
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),3,_RasterAlongM{}()));
92+
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),3,_RasterAlongM{}()));
93+
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),3,_RasterAlongM{}()));
94+
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,1792,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(64l,128l,64l),_GemmStreamK{}(),4,_RasterAlongM{}()));
95+
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
96+
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RRR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
97+
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_Void{}(),_BF16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkDP{}()),None{},cute::make_tuple(128l,128l,32l),_GemmStreamK{}(),3,_RasterAlongM{}()));
98+
inst.add(make_gemm_meta(make_gemm_dtype_config(_FP16{}(),_FP16{}(),_Void{}(),_FP16{}(),_FP32{}()),_Sm80{}(),_AGKernel{}(),_RCR{}(),_GemmV2{}(),make_gemm_v2_meta(false),None{}),make_runtime_config(2048,7168,4096,make_all_gather_runtime_config(4,1,0)),make_gemm_hparams(make_gemm_v2_hparams(cute::make_tuple(64l,64l,32l),cute::make_tuple(16l,8l,16l),_StreamkSK{}()),None{},cute::make_tuple(256l,128l,32l),_GemmStreamK{}(),3,_RasterAlongN{}()));
99+
91100
return 0;
92101
}();
93102
}

0 commit comments

Comments
 (0)