Skip to content

Commit 8dd012b

Browse files
ZSL98houqi
authored andcommitted
[doc] update tuning_guide
1 parent 00fe2a6 commit 8dd012b

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

docs/tuning_guide.md

+34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ You may need to tune a kernel before it achieving the best performance. In this
44

55
---
66

7+
### MoE layer 0 (AllGather + Scatter + GroupGEMM)
78
To enable tuning for the test demo, you only need to set the `--tune` flag.
89

910
```bash
@@ -37,3 +38,36 @@ static int config_ag_scatter_sm90 = []() {
3738
```
3839
3940
The search space for tuning is defined in `src/generator`. For the MoE layer0's kernel, the search space is defined in `src/generator/gen_moe_ag_scatter.cc`. For example, the search space for GEMM tile size is defined as `cute::make_tuple(Shape<Auto, _256, Auto>{}, Shape<Auto, _128, Auto>{})` in #L88. Modify these codes and compile Flux again if you want enlarge the search space.
41+
42+
### MoE layer 1 (GroupGEMM + Gather + Topk-reduce + ReduceScatter)
43+
44+
Tune the MoE layer1 kernel as follows:
45+
```bash
46+
./launch.sh test/python/moe_gather_rs/test_moe_gather_rs.py --tune
47+
```
48+
Then the profiling result is as follows:
49+
50+
```c++
51+
====== Profiling Results =======
52+
GemmMeta(dtype=GemmDTypeConfig(a=BF16,b=BF16,c=BF16,d=BF16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=GatherRS,gemm_layout=RCC,impl=GemmGroupedV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=GatherRSMeta(topk=4))
53+
RuntimeConfig(m=8192,n=1024,k=1024,comm_spec=None)
54+
* TopK=1 (1.87 ms): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,1,1),kernel_schedule=Cooperative),comm_spec=GatherRSHParams(gather_rs_ctas=26,n_dim=8192),tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
55+
* TopK=2 (1.88 ms): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,1,1),kernel_schedule=Cooperative),comm_spec=GatherRSHParams(gather_rs_ctas=28,n_dim=8192),tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
56+
* TopK=3 (1.91 ms): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,1,1),kernel_schedule=Cooperative),comm_spec=GatherRSHParams(gather_rs_ctas=30,n_dim=8192),tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
57+
58+
====== Generated Config Code =======
59+
// clang-format off
60+
#include "flux/op_registry.h"
61+
namespace bytedance::flux {
62+
using namespace cute;
63+
64+
static int config_gather_rs_sm90 = []() {
65+
auto &inst = TuningConfigRegistry::instance();
66+
inst.add(make_gemm_meta(make_gemm_dtype_config(_BF16{}(),_BF16{}(),_BF16{}(),_BF16{}(),_FP32{}(),_FP32{}()),_Sm90{}(),_GatherRS{}(),_RCC{}(),_GemmGroupedV3{}(),make_gemm_v3_meta(false,false),make_gather_rs_meta(4)),make_runtime_config(8192,1024,1024,None{}),make_gemm_hparams(make_gemm_v3_hparams(cute::make_tuple(1l,1l,1l),_Cooperative{}()),make_gather_rs_hparams(26,8192),cute::make_tuple(128l,256l,64l),_GemmDefault{}(),0,_RasterHeuristic{}()));
67+
return 0;
68+
}();
69+
}
70+
// clang-format on
71+
```
72+
73+
You can find the configuration space defined in `src/generator/gen_moe_gather_rs.cc`. You may notice that there are only three kernels been profiled in the case above. This is because there are only three qualified kernels in the search space for the configuration in the test demo, as defined in #L90-92 of `src/generator/gen_moe_gather_rs.cc`. The first value in `make_gather_rs_hparams` refers to the number of thread blocks specialized for communication and the second value refers to the size of the hidden dimension. You must make sure at least one hparams is registered here for the shape of the MoE layer1 you want.

0 commit comments

Comments
 (0)