File tree 2 files changed +8
-1
lines changed
paddle/phi/kernels/fusion/cutlass
2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -48,16 +48,18 @@ enum class CutlassTileConfig {
48
48
// Warp configs for M=16
49
49
CtaShape16x128x64_WarpShape16x32x64,
50
50
CtaShape16x256x64_WarpShape16x64x64,
51
-
51
+ CtaShape16x256x64_WarpShape64x16x128,
52
52
// Warp configs for M=32
53
53
CtaShape32x128x64_WarpShape32x32x64,
54
54
55
55
// Warp configs for M=64
56
+ CtaShape64x64x64_WarpShape32x32x64,
56
57
CtaShape64x128x64_WarpShape32x64x64,
57
58
CtaShape64x128x64_WarpShape64x32x64,
58
59
CtaShape64x128x64_WarpShape64x64x64,
59
60
60
61
// Warp configs for M=128
62
+ CtaShape128x64x64_WarpShape64x32x64,
61
63
CtaShape128x128x64_WarpShape64x32x64,
62
64
CtaShape128x128x64_WarpShape64x64x64,
63
65
CtaShape128x128x64_WarpShape128x32x64,
Original file line number Diff line number Diff line change @@ -60,9 +60,14 @@ static std::vector<CutlassTileConfig> get_candidate_tiles(
60
60
};
61
61
std::vector<CutlassTileConfig> quant_B_configs_sm80{
62
62
CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64,
63
+ CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64,
63
64
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
65
+ CutlassTileConfig::CtaShape64x64x64_WarpShape32x32x64,
66
+ CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
64
67
CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64,
68
+ CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
65
69
CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64,
70
+ CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
66
71
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
67
72
};
68
73
if (is_moe) {
You can’t perform that action at this time.
0 commit comments