Skip to content

Commit 8cd1800

Browse files
authored
feat: Add weight layout option for trtllm-gen fused moe (#1297)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Expose weight layout for BlockMajorK usage ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent b152d41 commit 8cd1800

File tree

9 files changed

+9108
-1195
lines changed

9 files changed

+9108
-1195
lines changed

β€Žcsrc/trtllm_batched_gemm_runner.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
9898
(!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct &&
9999
options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch &&
100100
tileSize == mOptions.tileSize &&
101-
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA) {
101+
options.mUseShuffledMatrixA == mOptions.useShuffledMatrixA &&
102+
options.mLayoutA == mOptions.weightLayout) {
102103
if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) {
103104
mPassingConfigIndices.push_back(i);
104105
}

β€Žcsrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -482,12 +482,26 @@ at::Tensor trtllm_fp8_block_scale_moe_launcher(
482482
"hidden_states_scale dim1 must match num_tokens.");
483483
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn,
484484
"gemm1_weights must be fp8.");
485-
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
486-
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
487-
TORCH_CHECK(intermediate_size == gemm1_weights.sizes()[1] / 2,
488-
"intermediate_size has incorrect shape.");
489-
TORCH_CHECK(gemm1_weights.sizes()[2] == hidden_states.sizes()[1],
490-
"the third dimension of weights must be equal to hidden_size.");
485+
486+
TORCH_CHECK(gemm1_weights.dim() == 3 || gemm1_weights.dim() == 4,
487+
"gemm1_weights must be 3D or 4D.");
488+
{
489+
int64_t Mn = 0, K = 0;
490+
if (gemm1_weights.dim() == 3) {
491+
// MajorK [num_experts, M, K]
492+
Mn = gemm1_weights.sizes()[1];
493+
K = gemm1_weights.sizes()[2];
494+
} else if (gemm1_weights.dim() == 4) {
495+
// BlockMajorK [num_experts, K/block_k, M, block_k]
496+
Mn = gemm1_weights.sizes()[2];
497+
int64_t block_k = gemm1_weights.sizes()[3];
498+
K = gemm1_weights.sizes()[1] * block_k;
499+
}
500+
TORCH_CHECK(Mn % 2 == 0, "the second dimension of weights must be even.");
501+
TORCH_CHECK(intermediate_size == Mn / 2, "intermediate_size has incorrect shape.");
502+
TORCH_CHECK(K == hidden_states.sizes()[1],
503+
"the third dimension of weights must be equal to hidden_size.");
504+
}
491505
TORCH_CHECK(gemm1_weights_scale.scalar_type() == at::ScalarType::Float,
492506
"gemm1_weights_scale must be float.");
493507
TORCH_CHECK(gemm1_weights_scale.dim() == 3, "gemm1_weights_scale must be 3D.");
@@ -502,9 +516,22 @@ at::Tensor trtllm_fp8_block_scale_moe_launcher(
502516
"gemm1_weights_scale has incorrect shape.");
503517
TORCH_CHECK(gemm2_weights.scalar_type() == at::ScalarType::Float8_e4m3fn,
504518
"gemm2_weights must be fp8.");
505-
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
506-
TORCH_CHECK(gemm2_weights.sizes()[2] == intermediate_size,
507-
"the third dimension of weights must be equal to intermediate_size.");
519+
520+
TORCH_CHECK(gemm2_weights.dim() == 3 || gemm2_weights.dim() == 4,
521+
"gemm2_weights must be 3D or 4D.");
522+
{
523+
int64_t K = 0;
524+
if (gemm2_weights.dim() == 3) {
525+
// MajorK [num_experts, M, K]
526+
K = gemm2_weights.sizes()[2];
527+
} else if (gemm2_weights.dim() == 4) {
528+
// BlockMajorK [num_experts, K/block_k, M, block_k]
529+
int64_t block_k = gemm2_weights.sizes()[3];
530+
K = gemm2_weights.sizes()[1] * block_k;
531+
}
532+
TORCH_CHECK(K == intermediate_size,
533+
"the third dimension of weights must be equal to intermediate_size.");
534+
}
508535
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float,
509536
"gemm2_weights_scale must be float.");
510537
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
@@ -568,7 +595,8 @@ at::Tensor trtllm_fp8_block_scale_moe(
568595
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, int64_t num_experts,
569596
int64_t top_k, int64_t n_group, int64_t topk_group, int64_t intermediate_size,
570597
int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor,
571-
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight) {
598+
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
599+
int64_t weight_layout) {
572600
auto dtype = hidden_states.dtype();
573601
if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 ||
574602
dtype == at::ScalarType::Float8_e4m3fn) {
@@ -578,9 +606,13 @@ at::Tensor trtllm_fp8_block_scale_moe(
578606
batchedGemm::trtllm::gen::Dtype::E4m3}; // FP8 runner so hard-coded
579607
bool mUseDeepSeekFp8{true}; // Always true for BlockScaleMoe
580608

609+
TORCH_CHECK(0 <= weight_layout && weight_layout <= 2,
610+
"the value of weight_layout is not recognized");
611+
581612
// Properly initialize the runner using make_unique like in the original code
582-
auto mRunner = std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim,
583-
use_shuffled_weight);
613+
auto mRunner = std::make_unique<RunnerType>(
614+
mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim, use_shuffled_weight,
615+
static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
584616

585617
// Always use fallback config (equivalent to moeConfigIndex == -1 case from original code)
586618
auto const num_tokens = hidden_states.sizes()[0];
@@ -929,7 +961,8 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
929961

930962
// Properly initialize the runner using make_unique like in the original code
931963
auto mRunner = std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim,
932-
/*useShuffledMatrixA*/ true);
964+
/*useShuffledMatrixA*/ true,
965+
batchedGemm::gemm::MatrixLayout::MajorK);
933966

934967
auto const num_tokens = hidden_states.sizes()[0];
935968

β€Žcsrc/trtllm_fused_moe_runner.cu

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
174174

175175
namespace PermuteGemm1 {
176176

177-
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(btg::Dtype dtypeElt,
178-
int32_t tileTokensDim,
179-
bool useDeepSeekFp8,
180-
bool useShuffledMatrixA) {
177+
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(
178+
btg::Dtype dtypeElt, int32_t tileTokensDim, bool useDeepSeekFp8, bool useShuffledMatrixA,
179+
batchedGemm::gemm::MatrixLayout weightLayout) {
181180
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {
182181
.eltType = dtypeElt,
183182
.outputType = dtypeElt,
@@ -188,15 +187,17 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(btg::Dtype d
188187
.transposeMmaOutput = true,
189188
.tileSize = tileTokensDim,
190189
.epilogueTileM = useDeepSeekFp8 ? 64 : 128,
191-
.useShuffledMatrixA = useShuffledMatrixA};
190+
.useShuffledMatrixA = useShuffledMatrixA,
191+
.weightLayout = weightLayout};
192192
return options;
193193
}
194194

195-
Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrixA)
195+
Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrixA,
196+
batchedGemm::gemm::MatrixLayout weightLayout)
196197
: mDtypeElt(dtypeElt),
197198
mTileTokensDim(tileTokensDim),
198-
mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner(
199-
getOptions(mDtypeElt, mTileTokensDim, useDeepSeekFp8, useShuffledMatrixA))) {}
199+
mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner(getOptions(
200+
mDtypeElt, mTileTokensDim, useDeepSeekFp8, useShuffledMatrixA, weightLayout))) {}
200201

201202
void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale,
202203
void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar,
@@ -253,11 +254,9 @@ std::vector<int64_t> Runner::getPassingConfigIndices() const {
253254
} // namespace PermuteGemm1
254255

255256
namespace Gemm2 {
256-
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(btg::Dtype dtypeElt,
257-
btg::Dtype dtypeOut,
258-
int32_t tileTokensDim,
259-
bool useDeepSeekFp8,
260-
bool useShuffledMatrixA) {
257+
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(
258+
btg::Dtype dtypeElt, btg::Dtype dtypeOut, int32_t tileTokensDim, bool useDeepSeekFp8,
259+
bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) {
261260
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {
262261
.eltType = dtypeElt,
263262
.outputType = dtypeOut,
@@ -268,17 +267,19 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions(btg::Dtype d
268267
.transposeMmaOutput = true,
269268
.tileSize = tileTokensDim,
270269
.epilogueTileM = useDeepSeekFp8 ? 64 : 128,
271-
.useShuffledMatrixA = useShuffledMatrixA};
270+
.useShuffledMatrixA = useShuffledMatrixA,
271+
.weightLayout = weightLayout};
272272
return options;
273273
}
274274

275275
Runner::Runner(btg::Dtype dtypeElt, btg::Dtype outputDtype, bool useDeepSeekFp8, int tileTokensDim,
276-
bool useShuffledMatrixA)
276+
bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout)
277277
: mDtypeElt(dtypeElt),
278278
mOutputDtype(outputDtype),
279279
mTileTokensDim(tileTokensDim),
280-
mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner(getOptions(
281-
mDtypeElt, mOutputDtype, mTileTokensDim, useDeepSeekFp8, useShuffledMatrixA))) {}
280+
mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner(
281+
getOptions(mDtypeElt, mOutputDtype, mTileTokensDim, useDeepSeekFp8, useShuffledMatrixA,
282+
weightLayout))) {}
282283

283284
void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weights,
284285
void* weightsScale, float* outputScalesScalar, void* output, void* outputScale,
@@ -336,11 +337,11 @@ std::vector<int64_t> Runner::getPassingConfigIndices() const {
336337

337338
namespace MoE {
338339
Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim,
339-
bool useShuffledMatrixA)
340-
: mPermuteGemm1(
341-
PermuteGemm1::Runner(dtypeElt, useDeepSeekFp8, tileTokensDim, useShuffledMatrixA)),
340+
bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout)
341+
: mPermuteGemm1(PermuteGemm1::Runner(dtypeElt, useDeepSeekFp8, tileTokensDim,
342+
useShuffledMatrixA, weightLayout)),
342343
mGemm2(Gemm2::Runner(dtypeElt, btg::Dtype::Bfloat16, useDeepSeekFp8, tileTokensDim,
343-
useShuffledMatrixA)) {
344+
useShuffledMatrixA, weightLayout)) {
344345
auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices();
345346
auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices();
346347

β€Žflashinfer/fused_moe.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ class RoutingMethodType(IntEnum):
5959
Unspecified = 5
6060

6161

62+
# See MatrixLayout from include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
63+
class WeightLayout(IntEnum):
64+
# K-major layout (default). [Mn, K]
65+
MajorK = 0
66+
# M-major for A and N-major for B. [K, Mn]
67+
MajorMn = 1
68+
# Layout is blocked along the K dimension. [K / blockK, Mn, blockK]
69+
# where blockK is fixed at 128B
70+
BlockMajorK = 2
71+
72+
6273
def get_reorder_rows_for_gated_act_gemm_row_indices(x) -> torch.Tensor:
6374
"""
6475
Reorders rows in the gemm/MOE_gemm weight matrix for min-latency
@@ -224,6 +235,12 @@ def shuffle_matrix_sf_a(
224235
return nvfp4_block_scale_interleave(w_shuffled)
225236

226237

238+
def convert_to_block_layout(input_tensor: torch.Tensor, blockK: int) -> torch.Tensor:
239+
M, K = input_tensor.shape
240+
assert K % blockK == 0, "K must be divisible by blockK"
241+
return input_tensor.view(M, K // blockK, blockK).permute(1, 0, 2).contiguous()
242+
243+
227244
def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
228245
return gen_jit_spec(
229246
"fused_moe_sm100",
@@ -884,6 +901,7 @@ def trtllm_fp8_block_scale_moe_op(
884901
tile_tokens_dim: int,
885902
routing_method_type: int,
886903
use_shuffled_weight: bool = False,
904+
weight_layout: int = 0,
887905
) -> torch.Tensor:
888906

889907
# Call the C++ function for block scale MoE
@@ -907,6 +925,7 @@ def trtllm_fp8_block_scale_moe_op(
907925
tile_tokens_dim,
908926
routing_method_type,
909927
use_shuffled_weight,
928+
weight_layout,
910929
)
911930

912931
return output
@@ -932,6 +951,7 @@ def _fake_trtllm_fp8_block_scale_moe(
932951
tile_tokens_dim: int = 8,
933952
routing_method_type: int = 0,
934953
use_shuffled_weight: bool = False,
954+
weight_layout: int = 0,
935955
):
936956
seq_len = hidden_states.shape[0]
937957
hidden_size = hidden_states.shape[1]
@@ -1121,7 +1141,8 @@ def trtllm_fp8_block_scale_moe(
11211141
routed_scaling_factor: float,
11221142
tile_tokens_dim: int = 8,
11231143
routing_method_type: int = 0,
1124-
use_shuffled_weight: bool = True,
1144+
use_shuffled_weight: bool = False,
1145+
weight_layout: int = 0,
11251146
) -> torch.Tensor:
11261147
"""FP8 block scale MoE operation.
11271148
@@ -1168,6 +1189,7 @@ def trtllm_fp8_block_scale_moe(
11681189
tile_tokens_dim,
11691190
routing_method_type,
11701191
use_shuffled_weight,
1192+
weight_layout,
11711193
)
11721194

11731195

β€Žinclude/flashinfer/trtllm/batched_gemm/KernelRunner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <vector>
2424

2525
// #include "flashinfer/trtllm/common/Dtype.h"
26+
#include "trtllmGen_bmm_export/Enums.h"
2627
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2728

2829
namespace tensorrt_llm {
@@ -39,6 +40,7 @@ struct TrtllmGenBatchedGemmRunnerOptions {
3940
int32_t tileSize{8};
4041
int32_t epilogueTileM{128};
4142
bool useShuffledMatrixA{false};
43+
batchedGemm::gemm::MatrixLayout weightLayout{batchedGemm::gemm::MatrixLayout::MajorK};
4244
};
4345

4446
class TrtllmGenBatchedGemmRunner {

β€Žinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
645645

646646
auto fiModuleLoadData = [&](CUmodule* module) {
647647
const std::string sha256 = config.mHash ? config.mHash : "";
648-
const std::string pipeline_hash = "39b7e49bfedde88ea29bfdc2547cbba659f2b236";
648+
const std::string pipeline_hash = "991e7438224199de85ef08a2730ce18c12b4e0aa";
649649
const std::string cubin_path = pipeline_hash + "/" + std::string("batched_gemm-") +
650650
TLLM_GEN_COMMIT + "-" + TLLM_GEN_BATCHED_GEMM_CONFIG_HASH + "/";
651651
std::string fname_cubin = config.mFunctionName;

0 commit comments

Comments
Β (0)