Skip to content

Commit 5502828

Browse files
authored
Bug fix: fix duplicate launch in POD (#1267)
<!-- .github/pull_request_template.md --> ## 📌 Description Mistakenly added a duplicate kernel launch last time (actually by cursor, but should've checked more closely😂) cc @yzh119 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7253d74 commit 5502828

File tree

1 file changed

+2
-4
lines changed
  • include/flashinfer/attention

1 file changed

+2
-4
lines changed

include/flashinfer/attention/pod.cuh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,6 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
438438
FLASHINFER_CUDA_CALL(
439439
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
440440
}
441-
FLASHINFER_CUDA_CALL(
442-
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
443441

444442
// Post-kernel stuff for split-kv prefill
445443
if (!(num_chunks <= 1 || tmp_p == nullptr)) {
@@ -457,11 +455,11 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
457455
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
458456
tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d,
459457
decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads,
460-
HEAD_DIM_VO, stream));
458+
HEAD_DIM_VO, enable_pdl, stream));
461459
} else {
462460
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
463461
tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows,
464-
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream));
462+
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream));
465463
}
466464
}
467465
}

0 commit comments

Comments
 (0)