Skip to content

Conversation

Espere-1119-Song
Copy link

@Espere-1119-Song Espere-1119-Song commented Aug 14, 2025

Currently, when generate new tokens in decoding stage, selection branch of NSA will create an output with the same length of k/v length, not compatible with compression branch and sliding window branch.

Summary by CodeRabbit

  • New Features
    • None
  • Bug Fixes
    • Corrected sequence-length handling in parallel attention to align outputs with the query length, improving shape consistency and stability.
  • Documentation
    • Marked legacy training as archived with guidance to the new repo.
    • Added tokenizer memory-leak notice and recommended minimum version.
    • Included steps for fetching SlimPajama data via Git LFS.
    • Fixed typos and improved formatting for readability.
  • Chores/Style
    • Cleaned up import order/formatting and minor README/EOF formatting; no behavioral changes.

Copy link
Contributor

coderabbitai bot commented Aug 14, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

The update adjusts sequence dimension handling in fla/ops/nsa/parallel.py to derive T_q from q.shape and align output/lse shapes and Triton grid with T_q, while keeping kernel arg T. Other changes are import reordering/formatting, README edits, and trailing-newline additions in legacy configs.

Changes

Cohort / File(s) Summary of changes
NSA forward path shape/grid update
fla/ops/nsa/parallel.py
Derive T_q from q.shape; switch grid and output/lse shapes to T_q; keep kernel parameter T; no public API changes.
Evaluation harness import order
evals/harness.py
Reordered an import to occur after HFLM import; no logic changes.
Legacy training import formatting/order
legacy/training/flame/logging.py, legacy/training/flame/parser.py, legacy/training/run.py
Reformatted and reordered imports; no functional changes.
Legacy training README update
legacy/training/README.md
Added archival notice, tokenizer note, SlimPajama LFS steps, footnotes, typo fix, and formatting adjustments.
Simple GLA README formatting
fla/ops/simple_gla/README.md
Presentation-only reflow; no content changes.
Legacy config EOF newlines
legacy/training/configs/gla_1B.json, .../gla_340M.json, .../gla_7B.json, .../transformer_340M.json
Added trailing newlines; no config content changes.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant parallel_nsa_fwd
  participant TritonKernel

  Caller->>parallel_nsa_fwd: q, k, v, ...
  parallel_nsa_fwd->>parallel_nsa_fwd: Extract T_q from q.shape
  parallel_nsa_fwd->>parallel_nsa_fwd: Allocate o(B,T_q,HQ,V), lse(B,T_q,HQ)
  parallel_nsa_fwd->>TritonKernel: launch(grid=T_q, ..., T=key_len)
  TritonKernel-->>parallel_nsa_fwd: write o, lse
  parallel_nsa_fwd-->>Caller: o, lse
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

I hop through tensors, T to T_q I see,
Grids realigned, like rows of clover for me.
Kernels hum softly, keys keep their way,
Readmes get tidied—neat stacks of hay.
With a twitch of my nose, I approve this PR—yay! 🐰✨

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Espere-1119-Song, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an output shape inconsistency in the Non-Standard Attention (NSA) selection branch during the decoding stage. Previously, the output length was tied to the key/value length, which caused incompatibility with other NSA branches like compression and sliding window. This change ensures the output shape correctly reflects the query sequence length, making it compatible for generating new tokens.

Highlights

  • Output Shape Correction: The output shape of the NSA selection branch in parallel_nsa_fwd is now correctly aligned with the query sequence length (T_q) instead of the key/value length (T).
  • Triton Kernel Grid Adjustment: The execution grid for the parallel_nsa_fwd_kernel in Triton has been updated to use the query sequence length (T_q) for its temporal dimension, ensuring proper parallelization for the corrected output shape.
  • Tensor Initialization Update: Output tensors (o and lse) are now initialized with T_q as their sequence dimension, reflecting the intended output shape for new token generation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adjusts the output shape of parallel_nsa_fwd for decoding scenarios where query and key/value sequences have different lengths. However, a critical issue remains in the parallel_nsa_fwd_kernel. The kernel calculates the memory offset for the query tensor q using the sequence length T from the key/value tensors. This will result in out-of-bounds memory access when q and k/v have different sequence lengths, as is common in decoding. To fully support decoding, the kernel must be updated to handle the query sequence length (T_q) and key/value sequence length (T) independently when calculating memory offsets.

):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
_, T_q, HQ, _ = q.shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using _ to ignore dimensions can be concise, but it's less explicit and misses an opportunity to validate tensor shapes. For better robustness and readability, it's recommended to explicitly unpack all dimensions and add assertions to ensure that the batch size and key/value dimensions of q and k are compatible.

Suggested change
_, T_q, HQ, _ = q.shape
B_q, T_q, HQ, K_q = q.shape
assert B == B_q, f"q and k must have the same batch size, but got {B_q} and {B}"
assert K == K_q, f"q and k must have the same key dimension, but got {K_q} and {K}"

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🔭 Outside diff range comments (3)
evals/harness.py (1)

14-18: init should return None, not the class type

Annotating init with the class type is incorrect and may break type checks (e.g., mypy). init must return None.

Apply this diff:

-    def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper:
+    def __init__(self, **kwargs) -> None:
legacy/training/README.md (2)

77-77: Fix trailing spaces after backslash; they break line continuation in bash.

In bash, the backslash must be the last character on the line. Trailing spaces after it will prevent continuation and break the command.

-  gpus=8 \  
+  gpus=8 \

126-126: Fix trailing spaces after backslash in the resume example as well.

Same issue as above; this will break the multi-line command when copy-pasted.

-  gpus=8 \  
+  gpus=8 \
♻️ Duplicate comments (1)
fla/ops/nsa/parallel.py (1)

545-547: Don’t discard batch dim; assert compatibility and guard decode edge-case

Using _ for the batch dim forfeits a cheap correctness check. Also, when decoding (T_q != T) without cu_seqlens, indexing assumes B==1; otherwise, offsets into q/lse/block_indices will be wrong.

Apply this minimal guard and explicit unpacking:

-    _, T_q, HQ, _ = q.shape
-    G = HQ // H
-    BS = block_size
+    B_q, T_q, HQ, _ = q.shape
+    assert B_q == B, "q and k must have the same batch size"
+    # Guard: when decoding (T_q != T) without cu_seqlens, only B==1 is currently supported
+    if cu_seqlens is None and T_q != T:
+        assert B == 1, "B>1 with T_q != T is not supported in parallel_nsa_fwd"
+    G = HQ // H
+    BS = block_size

Please verify that decode is only exercised with B==1 when T_q != T in your pipelines.

🧹 Nitpick comments (5)
evals/harness.py (1)

9-10: Make the side-effect import rationale explicit

The reorder looks fine. To prevent future confusion or auto-removal by linters, be explicit about the side-effect nature of this import and narrow the noqa code.

Apply this diff:

-import fla  # noqa
+import fla  # noqa: F401  # side-effect: register FLA kernels and model components
legacy/training/README.md (4)

10-10: Remove stray blockquote marker to avoid an empty box in the IMPORTANT note.

The single “>” line renders as an empty paragraph inside the admonition.

->

32-35: Good call-out on tokenizers memory leak; consider adding an explicit install command for quick copy-paste.

You already mention the constraint; the install line in Setup (previous comment) will make this actionable for users.


144-146: Hyphenate “fine-tune” for consistency with common usage.

-Below, we provide an example of how to finetune Mistral-7B to GLA.
+Below, we provide an example of how to fine-tune Mistral-7B to GLA.

147-147: Typo: “mathced” → “matched”.

-1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B:
+1. Initialize a brand-new GLA-7B model from the config and copy the matched pretrained weights from Mistral-7B:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8a5496c and 1b3483c.

📒 Files selected for processing (11)
  • evals/harness.py (1 hunks)
  • fla/ops/nsa/parallel.py (2 hunks)
  • fla/ops/simple_gla/README.md (1 hunks)
  • legacy/training/README.md (6 hunks)
  • legacy/training/configs/gla_1B.json (1 hunks)
  • legacy/training/configs/gla_340M.json (1 hunks)
  • legacy/training/configs/gla_7B.json (1 hunks)
  • legacy/training/configs/transformer_340M.json (1 hunks)
  • legacy/training/flame/logging.py (1 hunks)
  • legacy/training/flame/parser.py (1 hunks)
  • legacy/training/run.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (14)
legacy/training/configs/gla_340M.json (1)

24-24: No-op change: trailing newline added

No semantic change to the JSON. Adding a trailing newline is POSIX-friendly and helps some tooling.

legacy/training/configs/gla_7B.json (1)

28-28: No-op change: trailing newline added

No semantic change to the JSON. Consistent end-of-file newlines are good hygiene.

legacy/training/configs/transformer_340M.json (1)

18-18: No-op change: trailing newline added

No functional impact. Thanks for keeping config files consistent.

legacy/training/flame/logging.py (1)

9-9: LGTM on import formatting

Consolidating the transformers callback imports into a single line maintains semantics and reduces noise.

legacy/training/flame/parser.py (1)

10-10: Import reordering only — LGTM

Moving the transformers imports after flame.logging is fine and does not affect runtime behavior here.

legacy/training/configs/gla_1B.json (1)

25-25: EOF newline — LGTM

Adding the trailing newline is a benign formatting improvement; no content changes.

legacy/training/run.py (1)

7-9: Import consolidation/reordering — LGTM

Collapsing the transformers imports and moving import fla below them is fine. No functional impact expected.

fla/ops/simple_gla/README.md (1)

3-7: Formatting/readability tweaks — LGTM

The reflow improves readability; content remains unchanged.

legacy/training/README.md (6)

17-17: LGTM: paragraph spacing improves readability.


38-40: LGTM: clearer preprocessing intro.


106-106: LGTM: global_batch_size calculation and example are correct.


109-111: LGTM: clear explanation of tokens-per-step and totals.


114-116: LGTM: resuming guidance is helpful.


174-177: LGTM: good guidance about single-node inefficiency and multi-node references.

Comment on lines +558 to 561
grid = (T_q, NV, B * H)
o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Critical: grid/output use T_q but kernel indexing still uses T (KV length) → potential misaddressing/OOB for B>1

You correctly switched grid/o/lse to T_q, but parallel_nsa_fwd_kernel still receives only T and derives bos/eos and all pointer offsets from it. In the non-varlen path, this makes q/lse/block_indices indexing depend on KV length, which becomes incorrect when T_q != T (e.g., decoding) and is unsafe for B>1 (wrong row strides, possible OOB).

Actionable fix: pass T_q to the kernel and decouple q- vs k/v-based offsets inside the kernel.

  • Forward call site change (add T_q):
# around Lines 562-583: pass T_k and T_q distinctly
parallel_nsa_fwd_kernel[grid](
    q=q, k=k, v=v, o=o, lse=lse, scale=scale,
    block_indices=block_indices, block_counts=block_counts,
    cu_seqlens=cu_seqlens, token_indices=token_indices,
    T=T,              # keys length (unchanged)
    T_q=T_q,          # NEW: queries length
    H=H, HQ=HQ, G=G, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV,
)
  • Kernel signature and internal offsets (decouple bos for q vs k/v):
@triton.jit
def parallel_nsa_fwd_kernel(
    q, k, v, o, lse, scale, block_indices, block_counts,
    cu_seqlens, token_indices,
    T,               # keys length
    T_q,             # NEW: queries length
    H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
    S: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
    IS_VARLEN: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr
):
    ...
    if IS_VARLEN:
        i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
        bos_k = tl.load(cu_seqlens + i_n).to(tl.int32)
        bos_q = bos_k  # varlen paths have matching lengths
    else:
        bos_k = i_b * T
        bos_q = i_b * T_q

    # Offsets using bos_k for k/v, bos_q for q/o/lse/block_indices
    k += (bos_k * H + i_h) * K
    v += (bos_k * H + i_h) * V
    block_indices += (bos_q + i_t) * H * S + i_h * S
    p_q  = tl.make_block_ptr(q  + (bos_q + i_t) * HQ*K, ...)
    p_o  = tl.make_block_ptr(o  + (bos_q + i_t) * HQ*V, ...)
    p_lse = lse + (bos_q + i_t) * HQ + i_h * G + tl.arange(0, G)
    if USE_BLOCK_COUNTS:
        NS = tl.load(block_counts + (bos_q + i_t) * H + i_h)
    else:
        NS = S
    ...
    # Keep p_k/p_v constructed with (K, T) / (T, V) layouts (T is key length)

Strongly recommend mirroring this decoupling in the backward kernels as well:

  • parallel_nsa_bwd_kernel_dq: add T_q; use bos_q for q/do/lse/delta/dq and bos_k for k/v; change launch grid to (T_q, NV, B*H).
  • parallel_nsa_bwd_kernel_dkv: add T_q if needed for q-side offsets used inside the per-token loop; ensure block_mask and its T dimension are consistent with the iteration bound.

Short-term mitigation (until kernels are updated): keep the guard suggested above to prevent B>1 decode runs when T_q != T.

Also consider updating the docstring of parallel_nsa to clarify that output seqlen matches q (T_q).

I can draft the full patch set for fwd/bwd kernels if you’d like.

To validate, please run a decode pass with B=2, T_q << T (e.g., T_q=1, T=128) and confirm no memory errors and correct o.shape == (B, T_q, HQ, V).


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "---- repo root ----"
pwd
echo

echo "---- show file existence and size ----"
ls -l fla/ops/nsa/parallel.py || true
echo

echo "---- print file with line numbers (1..800) ----"
nl -ba fla/ops/nsa/parallel.py | sed -n '1,800p'
echo

echo "---- search for kernel call sites and definitions ----"
rg "parallel_nsa_fwd_kernel" -n || true
rg "parallel_nsa_bwd_kernel" -n || true
rg "T_q" -n || true
rg "bos_q" -n || true
rg "bos_k" -n || true

Length of output: 550


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "---- repo root ----"
pwd
echo

echo "---- file info ----"
ls -l fla/ops/nsa/parallel.py || true
echo

echo "---- print lines 1..220 ----"
sed -n '1,220p' fla/ops/nsa/parallel.py || true
echo

echo "---- print lines 221..480 ----"
sed -n '221,480p' fla/ops/nsa/parallel.py || true
echo

echo "---- print lines 481..760 ----"
sed -n '481,760p' fla/ops/nsa/parallel.py || true
echo

echo "---- search for kernel defs/calls and T_q usage ----"
rg -n "parallel_nsa|parallel_nsa_fwd|parallel_nsa_fwd_kernel|parallel_nsa_bwd|parallel_nsa_bwd_kernel|T_q|make_block_ptr|triton.jit" fla/ops/nsa/parallel.py || true

Length of output: 27435


Critical: q-indexing uses T (KV length) while kernel/grid use T_q — will misaddress / OOB when T_q != T (e.g. decoding with B>1)

Short verification: I inspected fla/ops/nsa/parallel.py — forward allocates o/lse with T_q and launches the fwd kernel with grid=(T_q, NV, B*H) (lines ~533–562), but parallel_nsa_fwd_kernel (def at ~173) computes bos = i_b * T in the non-varlen path and uses (bos + i_t) to index q/o/lse/block_indices. The same pattern occurs in the backward kernels: parallel_nsa_bwd_kernel_dq (def ~299) and parallel_nsa_bwd_kernel_dkv (def ~407) derive offsets from T but callers / buffers use T_q. This will produce wrong row strides and can be OOB when T_q != T and B>1.

Affected locations (file: fla/ops/nsa/parallel.py)

  • parallel_nsa_fwd (caller): grid, o, lse allocation ~lines 533–562 (uses T_q)
  • parallel_nsa_fwd_kernel (kernel): def ~173 — computes bos using T and indexes q/o/lse with (bos + i_t)
  • parallel_nsa_bwd (caller): dq allocation and kernel launch ~lines 614–646 (dq uses q.shape / T_q, launch uses T)
  • parallel_nsa_bwd_kernel_dq (kernel): def ~299 — uses all = B * T and bos = i_b * T for q/dq offsets
  • parallel_nsa_bwd_kernel_dkv (kernel): def ~407 — loop bounds and q pointers use T/bos

Actionable fix (minimal, safe design)

  • Pass queries length (T_q) explicitly to all kernels that index/iterate over queries.
  • Inside kernels decouple query vs key/value base offsets:
    • compute bos_k = i_b * T_k and bos_q = i_b * T_q (for varlen keep bos_k = bos_q = bos).
    • Use bos_k for k/v offsets; use bos_q for q/o/lse/block_indices/block_counts offsets and for any loop bounds over query tokens.
  • Update dq allocation / launch:
    • allocate dq as now (NV, q.shape) but launch dq kernel with grid = (T_q, NV, BH) and compute dq pointer arithmetic using all_q = B * T_q (not B * T_k).
  • Mirror the same decoupling in parallel_nsa_bwd_kernel_dq and parallel_nsa_bwd_kernel_dkv (loop bounds, block_mask indexing, and any "all" computations must use T_q for q-side, T_k for k/v-side).

Example (concise snippets to apply)

  • Forward call site: pass T_q
parallel_nsa_fwd_kernel[grid](
    q=q, k=k, v=v, o=o, lse=lse, scale=scale,
    block_indices=block_indices, block_counts=block_counts,
    cu_seqlens=cu_seqlens, token_indices=token_indices,
    T=T,              # keys length (unchanged)
    T_q=T_q,          # NEW: queries length
    H=H, HQ=HQ, G=G, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV,
)
  • Kernel signature + internal decoupling (illustrative)
def parallel_nsa_fwd_kernel(..., T, T_q, H:tl.constexpr, HQ:tl.constexpr, ...):
    ...
    if IS_VARLEN:
        i_n, i_t = ...
        bos, eos = ...
        bos_k = bos_q = bos
        # T set from eos-bos for varlen
    else:
        bos_k = i_b * T         # keys base
        bos_q = i_b * T_q       # queries base

    k += (bos_k * H + i_h) * K
    v += (bos_k * H + i_h) * V
    block_indices += (bos_q + i_t) * H*S + i_h * S

    p_q = tl.make_block_ptr(q + (bos_q + i_t) * HQ*K, ...)
    p_o = tl.make_block_ptr(o + (bos_q + i_t) * HQ*V, ...)
    p_lse = lse + (bos_q + i_t) * HQ + i_h * G + tl.arange(0, G)
    if USE_BLOCK_COUNTS:
        NS = tl.load(block_counts + (bos_q + i_t) * H + i_h)
    else:
        NS = S
    ...

Required backward changes (summary)

  • parallel_nsa_bwd: allocate dq as before, but launch parallel_nsa_bwd_kernel_dq with grid = (T_q, NV, B*H) and pass T_q.
  • parallel_nsa_bwd_kernel_dq: add T_q param, use all_q = B * T_q for dq pointer arithmetic and bos_q for q/do/lse/dq offsets; use bos_k for k/v offsets.
  • parallel_nsa_bwd_kernel_dkv: add T_q (if q-side offsets/loop bounds used), use bos_q for q pointers and loop bounds over queries; use bos_k / all_k for dk/dv offsets.
  • Ensure parallel_nsa_block_mask / NS computation and any other T-derived sizes use the same T_q where appropriate (block_indices shape must match queries length).

Short-term mitigation

  • Until kernels are updated, add a guard preventing runs with B>1 when T_q != T (decoding case) to avoid OOB.

Docs/tests

  • Update parallel_nsa docstring to state output seqlen == queries length T_q.
  • Add a small unit test: B=2, T_q=1, T=128 decode case to validate shapes and no OOB.

I found definitive evidence in the file that the current implementation mixes T and T_q; this needs code changes as described above. I recommend applying the kernel-signature + offset-decoupling changes and updating the backward launches (grid dimensions and "all" computations) so q-indexing always uses T_q.

Comment on lines 26 to 30
```bash
git clone https://github.com/sustcsonglin/flash-linear-attention.git
pip install .
pip install .
pip install accelerate
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing cd step before pip install . (and consider installing tokenizers explicitly).

Without changing directory, pip install . will not install the cloned repo. Also, since you call out a minimum tokenizer version below, it’s helpful to show the explicit install here.

 ```bash
 git clone https://github.com/sustcsonglin/flash-linear-attention.git
+cd flash-linear-attention
 pip install .
 pip install accelerate
+pip install 'tokenizers>=0.20.4'
+# If you plan to use DeepSpeed:
+# pip install 'accelerate[deepspeed]' deepspeed

<details>
<summary>🤖 Prompt for AI Agents</summary>

In legacy/training/README.md around lines 26 to 30, the instructions run pip
install . immediately after git clone which fails because the working directory
wasn't changed; update the steps to cd into the cloned flash-linear-attention
directory before running pip install ., and add an explicit installation of a
compatible tokenizers package (e.g. tokenizers>=0.20.4); optionally note
DeepSpeed install extras (accelerate[deepspeed] and deepspeed) as an additional
step for users planning to use DeepSpeed.


</details>

<!-- fingerprinting:phantom:triton:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->

@zhiyuan1i
Copy link
Collaborator

@Espere-1119-Song Thank you for your thoughtful feedback.

Could you please avoid modifying any files that aren't directly related to this PR? I noticed a few changes that appear unrelated.
Additionally, could you include tests to validate your modifications? I didn't see any decoding-related tests in tests/ops/test_nsa.py.

@Espere-1119-Song
Copy link
Author

I am sorry for adding unrelated edits. I will remove them and keep this PR focused. I ran into this while applying NSA to Qwen2.5 LLM. Training works, but during inference when decoding a new token I hit a shape mismatch because the selection branch returns an attention output with length equal to kv_len instead of q_len. I also missed decoding related tests in tests/ops/test_nsa.py, and I will add them along with checks to validate the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants