From cfb5bc7dc514fc2bfb107f52a2d33196e203f685 Mon Sep 17 00:00:00 2001 From: chang-wenbin Date: Tue, 6 May 2025 12:42:16 +0800 Subject: [PATCH] update moe ffn tune config --- .../fusion/cutlass/cutlass_extensions/ft_gemm_configs.h | 4 +++- .../fusion/cutlass/cutlass_kernels/cutlass_heuristic.h | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h index 27393f6f53402c..8bac846ce25175 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h @@ -48,16 +48,18 @@ enum class CutlassTileConfig { // Warp configs for M=16 CtaShape16x128x64_WarpShape16x32x64, CtaShape16x256x64_WarpShape16x64x64, - + CtaShape16x256x64_WarpShape64x16x128, // Warp configs for M=32 CtaShape32x128x64_WarpShape32x32x64, // Warp configs for M=64 + CtaShape64x64x64_WarpShape32x32x64, CtaShape64x128x64_WarpShape32x64x64, CtaShape64x128x64_WarpShape64x32x64, CtaShape64x128x64_WarpShape64x64x64, // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, CtaShape128x128x64_WarpShape64x32x64, CtaShape128x128x64_WarpShape64x64x64, CtaShape128x128x64_WarpShape128x32x64, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 4d007891c54c0c..332367063cae79 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -60,9 +60,14 @@ static std::vector get_candidate_tiles( }; std::vector quant_B_configs_sm80{ CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x64x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, }; if (is_moe) {