Skip to content

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jul 18, 2025

📌 Description

Simplify and unify the interface for trtllm-gen decode/prefill/mla kernels, and add support for shared-kv (in MLA, #1273).

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on refactoring the integration of TRTLLM-GEN attention kernels within the FlashInfer library. The primary goal is to streamline the interface between Python and C++/CUDA code by transitioning from at::Tensor objects to raw pointers and explicit dimension parameters in the low-level kernel launchers. This change aims to improve performance by reducing PyTorch tensor overheads and enables more direct control over kernel parameters, such as the Streaming Multiprocessor count, for better hardware utilization. Additionally, the C++ code has been cleaned up with improved data type dispatching and parameter structure initialization.

Highlights

  • C++ Kernel Interface Refactoring: Refactored C++ kernel launcher signatures (trtllm_paged_attention_decode_launcher, trtllm_paged_attention_context_launcher, trtllm_paged_attention_mla_launcher) to accept raw pointers and explicit dimension/count parameters (e.g., batch_size, head_dim, sm_count) instead of at::Tensor objects. This aims to reduce PyTorch tensor overhead and provide more direct control over kernel execution.
  • Data Type Dispatching Improvement: Replaced verbose if-else if blocks for data type dispatching in C++ attention functions (trtllm_paged_attention_decode, trtllm_paged_attention_context) with a cleaner DISPATCH_PYTORCH_DTYPE_TO_CTYPE macro, improving code readability and maintainability.
  • Hardware-Aware Optimization: Introduced sm_count (Streaming Multiprocessor count) as a direct parameter to CUDA kernel launchers, allowing the kernels to potentially make more informed decisions based on the target GPU's architecture.
  • Parameter Structure Initialization: Added a default constructor to TllmGenFmhaRunnerParams that performs zero-initialization using memset, eliminating the need for explicit memset calls in the kernel launcher functions.
  • Python Frontend Updates: Updated Python frontend calls in flashinfer/decode.py, flashinfer/prefill.py, and flashinfer/jit/attention/pytorch.py to retrieve the device's sm_count using a new utility function and pass it down to the refactored C++ kernels.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the TensorRT-LLM attention kernel integration, primarily by decoupling the CUDA/C++ code from PyTorch tensors and improving parameter handling. The changes are a good step towards better modularity and clarity.

I've identified a critical bug related to incorrect batch_size handling that needs to be addressed, along with a few other medium to high severity issues concerning potential regressions and code correctness. Please see the detailed comments for specifics.

using DTypeKV = DTypeQ;
using DTypeO = DTypeQ;

int batch_size_ = query.size(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The local variable batch_size_ is initialized with query.size(0), which corresponds to the total number of query tokens (sum_seq_q), not the batch size. This variable shadows the batch_size function parameter, which holds the correct number of sequences. This is a critical bug.

Please remove this line and use the batch_size function parameter in the call to trtllm_paged_attention_context_launcher on line 260.

auto device = query.device();
const auto stream = at::cuda::getCurrentCUDAStream(device.index());

trtllm_paged_attention_context_launcher<DTypeQ, DTypeKV, DTypeO>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The batch_size_ variable used here is incorrect as it holds the total number of query tokens instead of the batch size. Please use the batch_size function parameter instead.

After this change, please also review the initialization of runner_params.mMaxSeqLenQ inside trtllm_paged_attention_context_launcher. It is currently set to batch_size, which would become the number of sequences after the fix above. This is likely incorrect, as it should be the maximum query sequence length in the batch.

        batch_size, max_seq_len, num_qo_heads, num_kv_heads, head_dim, page_size,

Comment on lines 119 to 122
// Fix: key_value_cache is a pointer, not a tensor, so cannot call .size(0)
// runner_params.mNumPagesInMemPool = key_value_cache.size(0) * 2;
// Instead, set to 0 or a reasonable default if not available
runner_params.mNumPagesInMemPool = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parameter runner_params.mNumPagesInMemPool is hardcoded to 0. The previous implementation calculated this value based on the total device memory. Setting it to 0 could cause issues if the underlying kernel relies on this value for resource allocation or performance tuning.

Since key_value_cache is now a raw pointer, its size is not available here. Consider passing the total number of pages in the cache pool as an argument from the Python side to correctly calculate this value. For example, in csrc/trtllm_mla_kernel_launcher.cu, this was calculated as key_value_cache.size(0) * 2.

Comment on lines 227 to 229
// Fix: key_value_cache is a pointer, not a tensor, so cannot call .size(0)
// runner_params.mNumPagesInMemPool = key_value_cache.size(0) * 2;
runner_params.mNumPagesInMemPool = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the decode launcher, runner_params.mNumPagesInMemPool is hardcoded to 0. This could be problematic. Please consider passing the total number of pages from the caller to compute a more meaningful value for this parameter.

Comment on lines 32 to 35
int64_t qk_nope_head_dim, int64_t kv_lora_rank, int64_t qk_rope_head_dim, double bmm1_scale,
double bmm2_scale, std::optional<at::Tensor> bmm1_scale_tensor,
std::optional<at::Tensor> bmm2_scale_tensor, std::optional<int64_t> max_attention_window_size,
std::optional<int64_t> cyclic_attention_window_size) {
int const num_seqs = query.size(0);
int const batch_size = num_seqs;
int const acc_q_len = query.size(1);
std::optional<int64_t> cyclic_attention_window_size, int64_t sm_count) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function signature of trtllm_paged_attention_mla_launcher still uses at::Tensor arguments, while trtllm_paged_attention_decode_launcher and trtllm_paged_attention_context_launcher in trtllm_fmha_kernel_launcher.cu have been refactored to use raw pointers.

For consistency and to decouple the CUDA kernels from PyTorch's at::Tensor, consider applying the same refactoring to this function as well.

Comment on lines 291 to 294
TllmGenFmhaRunnerParams() {
// NOTE(Zihao): all fields are POD types, so we can use memset to initialize them to zero
memset(this, 0, sizeof(TllmGenFmhaRunnerParams));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding a default constructor to memset the struct is a good improvement for ensuring all members are zero-initialized. This follows RAII principles and makes the code safer and cleaner.

@yzh119 yzh119 force-pushed the refactor-trtllm-gen branch from bbf2db0 to b759e9e Compare July 18, 2025 18:03
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, int64_t window_left, int64_t sum_seq_q, int64_t sm_count,
cudaStream_t stream, int* cum_seq_lens_q = nullptr, int* cum_seq_lens_kv = nullptr) {
if (num_qo_heads % num_kv_heads != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are other restrictions on group size (i.e., num_qo_heads // num_kv_heads) we should throw that error as well

runner_params.outputScale = bmm2_scale;
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
runner_params.mChunkedAttentionSize = INT_MAX;
runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is added by one because flashinfer assumes that the sliding window attention should consider the extra token during masking, right ? probably we can add a comment here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1;
// Add one to include the extra token during masking.
runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1;

use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
runner_params.mMultiCtasKvMode = use_multi_block;

size_t num_semaphores = round_up(batch_size * num_qo_heads, 8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the rounding for 16B alignment ? then what about the workspace_buffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size_t num_semaphores = round_up(batch_size * num_qo_heads, 8);
// Round up num_semaphores to a mulitple of 8 since `multiCtasKvScratchPtr` requires 16B alignment.
size_t num_semaphores = round_up(batch_size * num_qo_heads, 8);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PerkzZheng workspace_buffer is allocated by users. We should ask users to make sure workspace_buffer is 16B aligned. I have added comments for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks.

Context,
ForGen,
};

Copy link
Contributor

@nvpohanh nvpohanh Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (P1): add documentation:

//! \brief Helper function to launch a trtllm paged attention kernel.
//! \note This function should not be called directly from another file. Use `trtllm_paged_attention_decode` and `trtllm_paged_attention_context` instead.
//!
//! \param out Device pointer to the output tensor.
//! \param query Device pointer to the input query tensor.
//! \param key_cache Device pointer to the input paged key cache tensor. The strides can be set with \p kv_stride_0 \p kv_stride_1 and \p kv_stride_2 .
//! \param value_cache Device pointer to the input paged value cache tensor. The strides can be set with \p kv_stride_0 \p kv_stride_1 and \p kv_stride_2 .
//! \param workspace_buffer Device pointer to the workspace. Must be at least 16-byte aligned. Recommended to allocate at least 128MB for workspace.
//! \param block_tables Device pointer to the block tables. The table shape is [batch_size, max_num_blocks_per_seq].
//! \param seq_lens Device pointer to the sequeunce lengths. The shape is [batch_size].
//! \param batch_size Batch size, i.e. the number of sequences in the batch.
//! \param max_q_len Maximum number of query tokens per sequence in the batch.
//! \param max_kv_len Maximum number of key/value tokens per sequence in the batch.
//! \param num_pages Maximum number of pages of the kv-cache.
//! \param num_qo_heads Number of query heads.
//! \param num_kv_heads Number of key/value heads.
//! \param head_dim_qk Head dimension of query/key.
//! \param head_dim_vo Head dimension of value/output.
//! \param page_size Number of tokens per page.
//! \param kv_stride_0 Stride of the "page_size" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param kv_stride_1 Stride of the "num_kv_heads" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param kv_stride_2 Stride of the "2" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param max_num_blocks_per_seq Maximum number of blocks that can be allocated for a sequence.
//! \param bmm1_scale The scaling factor applied between BMM1 and Softmax, not including the LOG2E factor.
//! \param bmm2_scale The scaling factor applied after BMM2, not including the scaling factor for Softmax output.
//! \param window_left The window size for sliding window attention. Set to -1 to disable sliding window attention.
//! \param sum_seq_q Total number of query tokens within the batch.
//! \param sm_count The number of SMs on the GPU.
//! \param stream The cuda stream to launch the kernel on.
//! \param cum_seq_lens_q Device pointer to the tensor storing the accumulated sequence lengths of the query tokens in the batch. Not used in ForGen mode.
//! \param cum_seq_lens_kv Device pointer to the tensor storing the accumulated sequence lengths of the key/value tokens in the batch. Not used in ForGen mode.

@yzh119 yzh119 changed the title [WIP] refactor: refactor trtllm-gen attention kernel integration code refactor: refactor trtllm-gen attention kernel integration code Jul 22, 2025
@yzh119 yzh119 marked this pull request as ready for review July 22, 2025 06:06
} else {
auto runner = std::make_shared<TllmGenFmhaRunner>(q_data_type, kv_data_type, o_data_type);
cache.emplace(key, runner);
return runner;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since std::unordered_map never invalidates references and cache outlives everything, we can do

static std::unordered_map<Key, TllmGenFmhaRunner, KeyHash> cache;

and return TllmGenFmhaRunner& instead

      auto runner = TllmGenFmhaRunner(q_data_type, kv_data_type, o_data_type);
      auto [it, ok] = cache.emplace(key, std::move(runner));
      return it->second;

@yzh119 yzh119 merged commit 74f5dcc into flashinfer-ai:main Jul 22, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants