-
Notifications
You must be signed in to change notification settings - Fork 484
refactor: refactor trtllm-gen attention kernel integration code #1289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on refactoring the integration of TRTLLM-GEN attention kernels within the FlashInfer library. The primary goal is to streamline the interface between Python and C++/CUDA code by transitioning from at::Tensor
objects to raw pointers and explicit dimension parameters in the low-level kernel launchers. This change aims to improve performance by reducing PyTorch tensor overheads and enables more direct control over kernel parameters, such as the Streaming Multiprocessor count, for better hardware utilization. Additionally, the C++ code has been cleaned up with improved data type dispatching and parameter structure initialization.
Highlights
- C++ Kernel Interface Refactoring: Refactored C++ kernel launcher signatures (
trtllm_paged_attention_decode_launcher
,trtllm_paged_attention_context_launcher
,trtllm_paged_attention_mla_launcher
) to accept raw pointers and explicit dimension/count parameters (e.g.,batch_size
,head_dim
,sm_count
) instead ofat::Tensor
objects. This aims to reduce PyTorch tensor overhead and provide more direct control over kernel execution. - Data Type Dispatching Improvement: Replaced verbose
if-else if
blocks for data type dispatching in C++ attention functions (trtllm_paged_attention_decode
,trtllm_paged_attention_context
) with a cleanerDISPATCH_PYTORCH_DTYPE_TO_CTYPE
macro, improving code readability and maintainability. - Hardware-Aware Optimization: Introduced
sm_count
(Streaming Multiprocessor count) as a direct parameter to CUDA kernel launchers, allowing the kernels to potentially make more informed decisions based on the target GPU's architecture. - Parameter Structure Initialization: Added a default constructor to
TllmGenFmhaRunnerParams
that performs zero-initialization usingmemset
, eliminating the need for explicitmemset
calls in the kernel launcher functions. - Python Frontend Updates: Updated Python frontend calls in
flashinfer/decode.py
,flashinfer/prefill.py
, andflashinfer/jit/attention/pytorch.py
to retrieve the device'ssm_count
using a new utility function and pass it down to the refactored C++ kernels.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the TensorRT-LLM attention kernel integration, primarily by decoupling the CUDA/C++ code from PyTorch tensors and improving parameter handling. The changes are a good step towards better modularity and clarity.
I've identified a critical bug related to incorrect batch_size
handling that needs to be addressed, along with a few other medium to high severity issues concerning potential regressions and code correctness. Please see the detailed comments for specifics.
csrc/trtllm_fmha_kernel_launcher.cu
Outdated
using DTypeKV = DTypeQ; | ||
using DTypeO = DTypeQ; | ||
|
||
int batch_size_ = query.size(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local variable batch_size_
is initialized with query.size(0)
, which corresponds to the total number of query tokens (sum_seq_q
), not the batch size. This variable shadows the batch_size
function parameter, which holds the correct number of sequences. This is a critical bug.
Please remove this line and use the batch_size
function parameter in the call to trtllm_paged_attention_context_launcher
on line 260.
csrc/trtllm_fmha_kernel_launcher.cu
Outdated
auto device = query.device(); | ||
const auto stream = at::cuda::getCurrentCUDAStream(device.index()); | ||
|
||
trtllm_paged_attention_context_launcher<DTypeQ, DTypeKV, DTypeO>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch_size_
variable used here is incorrect as it holds the total number of query tokens instead of the batch size. Please use the batch_size
function parameter instead.
After this change, please also review the initialization of runner_params.mMaxSeqLenQ
inside trtllm_paged_attention_context_launcher
. It is currently set to batch_size
, which would become the number of sequences after the fix above. This is likely incorrect, as it should be the maximum query sequence length in the batch.
batch_size, max_seq_len, num_qo_heads, num_kv_heads, head_dim, page_size,
csrc/trtllm_fmha_kernel_launcher.cu
Outdated
// Fix: key_value_cache is a pointer, not a tensor, so cannot call .size(0) | ||
// runner_params.mNumPagesInMemPool = key_value_cache.size(0) * 2; | ||
// Instead, set to 0 or a reasonable default if not available | ||
runner_params.mNumPagesInMemPool = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter runner_params.mNumPagesInMemPool
is hardcoded to 0
. The previous implementation calculated this value based on the total device memory. Setting it to 0
could cause issues if the underlying kernel relies on this value for resource allocation or performance tuning.
Since key_value_cache
is now a raw pointer, its size is not available here. Consider passing the total number of pages in the cache pool as an argument from the Python side to correctly calculate this value. For example, in csrc/trtllm_mla_kernel_launcher.cu
, this was calculated as key_value_cache.size(0) * 2
.
csrc/trtllm_fmha_kernel_launcher.cu
Outdated
// Fix: key_value_cache is a pointer, not a tensor, so cannot call .size(0) | ||
// runner_params.mNumPagesInMemPool = key_value_cache.size(0) * 2; | ||
runner_params.mNumPagesInMemPool = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
csrc/trtllm_mla_kernel_launcher.cu
Outdated
int64_t qk_nope_head_dim, int64_t kv_lora_rank, int64_t qk_rope_head_dim, double bmm1_scale, | ||
double bmm2_scale, std::optional<at::Tensor> bmm1_scale_tensor, | ||
std::optional<at::Tensor> bmm2_scale_tensor, std::optional<int64_t> max_attention_window_size, | ||
std::optional<int64_t> cyclic_attention_window_size) { | ||
int const num_seqs = query.size(0); | ||
int const batch_size = num_seqs; | ||
int const acc_q_len = query.size(1); | ||
std::optional<int64_t> cyclic_attention_window_size, int64_t sm_count) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function signature of trtllm_paged_attention_mla_launcher
still uses at::Tensor
arguments, while trtllm_paged_attention_decode_launcher
and trtllm_paged_attention_context_launcher
in trtllm_fmha_kernel_launcher.cu
have been refactored to use raw pointers.
For consistency and to decouple the CUDA kernels from PyTorch's at::Tensor
, consider applying the same refactoring to this function as well.
TllmGenFmhaRunnerParams() { | ||
// NOTE(Zihao): all fields are POD types, so we can use memset to initialize them to zero | ||
memset(this, 0, sizeof(TllmGenFmhaRunnerParams)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bbf2db0
to
b759e9e
Compare
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t max_num_blocks_per_seq, | ||
double bmm1_scale, double bmm2_scale, int64_t window_left, int64_t sum_seq_q, int64_t sm_count, | ||
cudaStream_t stream, int* cum_seq_lens_q = nullptr, int* cum_seq_lens_kv = nullptr) { | ||
if (num_qo_heads % num_kv_heads != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are other restrictions on group size (i.e., num_qo_heads // num_kv_heads) we should throw that error as well
…v into refactor-trtllm-gen
csrc/trtllm_fmha_kernel_launcher.cu
Outdated
runner_params.outputScale = bmm2_scale; | ||
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; | ||
runner_params.mChunkedAttentionSize = INT_MAX; | ||
runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is added by one because flashinfer assumes that the sliding window attention should consider the extra token during masking, right ? probably we can add a comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1; | |
// Add one to include the extra token during masking. | |
runner_params.mAttentionWindowSize = window_left == -1 ? INT_MAX : window_left + 1; |
csrc/trtllm_fmha_kernel_launcher.cu
Outdated
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; | ||
runner_params.mMultiCtasKvMode = use_multi_block; | ||
|
||
size_t num_semaphores = round_up(batch_size * num_qo_heads, 8); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the rounding for 16B alignment ? then what about the workspace_buffer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t num_semaphores = round_up(batch_size * num_qo_heads, 8); | |
// Round up num_semaphores to a mulitple of 8 since `multiCtasKvScratchPtr` requires 16B alignment. | |
size_t num_semaphores = round_up(batch_size * num_qo_heads, 8); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PerkzZheng workspace_buffer
is allocated by users. We should ask users to make sure workspace_buffer
is 16B aligned. I have added comments for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks.
Context, | ||
ForGen, | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (P1): add documentation:
//! \brief Helper function to launch a trtllm paged attention kernel.
//! \note This function should not be called directly from another file. Use `trtllm_paged_attention_decode` and `trtllm_paged_attention_context` instead.
//!
//! \param out Device pointer to the output tensor.
//! \param query Device pointer to the input query tensor.
//! \param key_cache Device pointer to the input paged key cache tensor. The strides can be set with \p kv_stride_0 \p kv_stride_1 and \p kv_stride_2 .
//! \param value_cache Device pointer to the input paged value cache tensor. The strides can be set with \p kv_stride_0 \p kv_stride_1 and \p kv_stride_2 .
//! \param workspace_buffer Device pointer to the workspace. Must be at least 16-byte aligned. Recommended to allocate at least 128MB for workspace.
//! \param block_tables Device pointer to the block tables. The table shape is [batch_size, max_num_blocks_per_seq].
//! \param seq_lens Device pointer to the sequeunce lengths. The shape is [batch_size].
//! \param batch_size Batch size, i.e. the number of sequences in the batch.
//! \param max_q_len Maximum number of query tokens per sequence in the batch.
//! \param max_kv_len Maximum number of key/value tokens per sequence in the batch.
//! \param num_pages Maximum number of pages of the kv-cache.
//! \param num_qo_heads Number of query heads.
//! \param num_kv_heads Number of key/value heads.
//! \param head_dim_qk Head dimension of query/key.
//! \param head_dim_vo Head dimension of value/output.
//! \param page_size Number of tokens per page.
//! \param kv_stride_0 Stride of the "page_size" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param kv_stride_1 Stride of the "num_kv_heads" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param kv_stride_2 Stride of the "2" dimension of kv-cache with shape [num_pages, 2, num_kv_heads, page_size, head_dim].
//! \param max_num_blocks_per_seq Maximum number of blocks that can be allocated for a sequence.
//! \param bmm1_scale The scaling factor applied between BMM1 and Softmax, not including the LOG2E factor.
//! \param bmm2_scale The scaling factor applied after BMM2, not including the scaling factor for Softmax output.
//! \param window_left The window size for sliding window attention. Set to -1 to disable sliding window attention.
//! \param sum_seq_q Total number of query tokens within the batch.
//! \param sm_count The number of SMs on the GPU.
//! \param stream The cuda stream to launch the kernel on.
//! \param cum_seq_lens_q Device pointer to the tensor storing the accumulated sequence lengths of the query tokens in the batch. Not used in ForGen mode.
//! \param cum_seq_lens_kv Device pointer to the tensor storing the accumulated sequence lengths of the key/value tokens in the batch. Not used in ForGen mode.
…v into refactor-trtllm-gen
} else { | ||
auto runner = std::make_shared<TllmGenFmhaRunner>(q_data_type, kv_data_type, o_data_type); | ||
cache.emplace(key, runner); | ||
return runner; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since std::unordered_map
never invalidates references and cache
outlives everything, we can do
static std::unordered_map<Key, TllmGenFmhaRunner, KeyHash> cache;
and return TllmGenFmhaRunner&
instead
auto runner = TllmGenFmhaRunner(q_data_type, kv_data_type, o_data_type);
auto [it, ok] = cache.emplace(key, std::move(runner));
return it->second;
📌 Description
Simplify and unify the interface for trtllm-gen decode/prefill/mla kernels, and add support for shared-kv (in MLA, #1273).
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes