diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h index 27393f6f53402c..8bac846ce25175 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h @@ -48,16 +48,18 @@ enum class CutlassTileConfig { // Warp configs for M=16 CtaShape16x128x64_WarpShape16x32x64, CtaShape16x256x64_WarpShape16x64x64, - + CtaShape16x256x64_WarpShape64x16x128, // Warp configs for M=32 CtaShape32x128x64_WarpShape32x32x64, // Warp configs for M=64 + CtaShape64x64x64_WarpShape32x32x64, CtaShape64x128x64_WarpShape32x64x64, CtaShape64x128x64_WarpShape64x32x64, CtaShape64x128x64_WarpShape64x64x64, // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, CtaShape128x128x64_WarpShape64x32x64, CtaShape128x128x64_WarpShape64x64x64, CtaShape128x128x64_WarpShape128x32x64, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 4d007891c54c0c..332367063cae79 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -60,9 +60,14 @@ static std::vector get_candidate_tiles( }; std::vector quant_B_configs_sm80{ CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x64x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, }; if (is_moe) {