Skip to content

Conversation

mutiann
Copy link

@mutiann mutiann commented Aug 22, 2025

I'm trying to update the implementation of NSA including the kernels to adapt to the cached inference scenario when Tq != Tkv, so that hopefully #417 can be resolved. Respective tests on decoding (or chunked prefilling) are also added.

At the moment for cached generation with HF the outputs look fine, but the varlen mode or attention mask is not fixed yet; hopefully I can get it done in a few days. and the varlen mode is also fixed now.

Summary by CodeRabbit

  • New Features

    • Full support for distinct per-sample query vs key/value lengths and richer variable-length attention.
    • Modular NSA modes: selection, compression, top‑k, sliding‑window with gating; optional Flash Attention fast path with safe fallback.
    • Improved attention‑mask handling with padded/unpadded flows and outputs shaped to query length.
  • Refactor

    • Public APIs revised to accept per‑side sequence info; block‑count usage made optional; cache update/offset handling streamlined.
  • Tests

    • Expanded deterministic tests covering all NSA variants, var‑len cases, and gradients.

Copy link
Contributor

coderabbitai bot commented Aug 22, 2025

Walkthrough

Adds dual-length (TQ/TK) and varlen support across NSA: layers now unpad/pad masked inputs (tuple-aware), ops (parallel/compression/naive) accept separate cu_seqlens_q/cu_seqlens_k and token_indices_q, caching offsets use q_len, and tests expanded for new APIs and varlen scenarios.

Changes

Cohort / File(s) Summary
Layers: NSA mask-aware I/O
fla/layers/nsa.py
Forward treats hidden dim as q_len; branches on attention_mask to unpad/pad (q,k,v,g) via unpad_input/pad_input when masked, computes gates per-branch, runs NSA on unpadded or original tensors, reshapes output to (B, q_len, -1), and updates cache offset semantics to use q_len/mask-aware seqlen.
Utils: unpad_input tuple support
fla/layers/utils.py
unpad_input signature accepts Union[torch.Tensor, Tuple[torch.Tensor, ...]]; normalizes input to tuple, performs per-element unpad/keepdim handling, returns same container type as input (tensor or tuple), and updates docstrings.
NSA ops: compression dual-length
fla/ops/nsa/compression.py
Adds TQ/TK, cu_seqlens_q/cu_seqlens_k, token_indices_q; kernels/wrappers operate over query length TQ, adjust NC/causal math and indexing for Q-aware blocks, extend autograd ctx to store Q/K metadata, and public API signatures include TK and varlen inputs.
NSA ops: naive modularization + varlen
fla/ops/nsa/naive.py
Introduces naive_nsa_sel, naive_nsa_cmp, naive_nsa_topk; refactors naive_nsa to combine compression/selection/sliding via gates; supports tuple cu_seqlens (Q/K), optional FlashAttention fallback, and returns shapes using TQ semantics.
NSA ops: parallel dual-length + APIs
fla/ops/nsa/parallel.py
Kernels and wrappers accept cu_seqlens_q/cu_seqlens_k, token_indices_q, and TQ/TK; top-k and fwd kernels reindexed for Q-aware traversal; public signatures for parallel_nsa_topk, parallel_nsa_fwd, and parallel_nsa updated (block_counts default→None).
Tests: expanded NSA coverage
tests/ops/test_nsa.py
Expands test matrix for parallel/naive/compression/topk/fwd and varlen flows, adds helpers (prepare_token_indices, prepare_chunk_offsets, block builders), enables deterministic FP32 Triton flag, and validates gradients/outputs across new APIs.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant NSA_Layer as fla.layers.nsa
  participant Utils as fla.layers.utils
  participant Ops as fla.ops.nsa

  Caller->>NSA_Layer: forward(hidden_states, attention_mask?, past_kv?)
  alt attention_mask present
    NSA_Layer->>Utils: unpad_input((q,k,v,g), states, attention_mask, q_len)
    Utils-->>NSA_Layer: q_u, states_u, indices_q, cu_seqlens_q, cu_seqlens_k
    NSA_Layer->>Ops: parallel/naive/compression(q_u,k_u,v_u, g_parts, cu_seqlens_q/k, token_indices_q, TQ, TK)
    Ops-->>NSA_Layer: o_u
    NSA_Layer->>Utils: pad_input(o_u, indices_q, q_len)
    Utils-->>NSA_Layer: o
  else no mask
    NSA_Layer->>Ops: parallel/naive/compression(q,k,v, g_parts, cu_seqlens_q/k?, token_indices_q?, TQ, TK)
    Ops-->>NSA_Layer: o
  end
  NSA_Layer->>Caller: o, updated_past_kv
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

I hop through TQ, I hop through TK,
I unpad the masks so attention may play.
Blocks hum softly, kernels align,
Q and K stitched in orderly line.
Rabbit gates open—focus finds its way. 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@mutiann mutiann marked this pull request as ready for review September 1, 2025 01:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
fla/layers/nsa.py (2)

93-103: Bug: max() on a Tensor and mixed scalar/vector seqlen_offset.

When attention_mask is provided with cache, seqlen_offset becomes a tensor; using Python max() will error on CUDA tensors and is semantically wrong. Also pass the vector offsets to rotary.

-        seqlen_offset, max_seqlen = 0, q_len
+        seqlen_offset, max_seqlen = 0, q_len
@@
-            if attention_mask is not None:
-                # to deliminate the offsets of padding tokens
-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q.shape[1] + max(seqlen_offset)
+            if attention_mask is not None:
+                # account for left-padding: per-sequence offset = cached_len + (valid_len - seq_len)
+                lens = prepare_lens_from_mask(attention_mask)
+                seqlen_offset = seqlen_offset + (lens - attention_mask.shape[-1])
+                max_seqlen = q.shape[1] + (seqlen_offset.max().item() if torch.is_tensor(seqlen_offset) else int(seqlen_offset))
@@
-        q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
+        q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)

154-157: attentions may be undefined when output_attentions=True.

Define it upfront (or raise NotImplementedError) to avoid UnboundLocalError.

-        if not output_attentions:
-            attentions = None
+        attentions = None  # TODO: expose block indices / masks if attention outputs are required
+        if output_attentions:
+            # Optionally: return block_indices or masks as a proxy; for now, keep None
+            pass
fla/ops/nsa/parallel.py (1)

795-859: Remove torch.compile from autograd.Function class (invalid target).

Decorating the class will break at runtime; compile the forward wrapper instead if needed.

-@torch.compile
 class ParallelNSAFunction(torch.autograd.Function):

If you want compilation, wrap parallel_nsa_fwd with torch.compile or set torch._dynamo.optimize on call sites.

🧹 Nitpick comments (15)
fla/layers/utils.py (2)

102-107: Fix type hints for tuples (current annotations imply single-element tuples).

Use variadic tuples to reflect that q/states can contain multiple tensors.

-def unpad_input(
-    q: Union[torch.Tensor, Tuple[torch.Tensor]],
-    states: Tuple[torch.Tensor],
+def unpad_input(
+    q: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
+    states: Tuple[torch.Tensor, ...],

136-143: Docstring types are inaccurate for cu_seqlens and max_seqlen tuples.

They’re torch.LongTensor and ints respectively, and cu_seqlens is a pair when Tq != Tk. Update for clarity.

-        (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
+        (cu_seqlens_q, cu_seqlens_k) (`Tuple[torch.LongTensor, torch.LongTensor]`):
@@
-        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
+        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int, int]`):
fla/ops/nsa/naive.py (2)

14-22: Add stacklevel to warnings for actionable call sites.

Helps users locate import-site issues quickly.

-    warnings.warn(
+    warnings.warn(
         "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
-        category=ImportWarning
+        category=ImportWarning,
+        stacklevel=2
     )

324-336: Use of runtime asserts.

Replace asserts with explicit exceptions for library code paths that may run in production.

-    assert block_counts is not None, "block counts must be provided for selection"
+    if block_counts is None:
+        raise ValueError("block counts must be provided for selection")
fla/ops/nsa/compression.py (1)

519-537: Backward path still uses a single cu_seqlens.

Training with TQ != TK would be incorrect; today that’s inference-only, but please guard or document the constraint.

-        ctx.cu_seqlens = cu_seqlens_q
+        # Only q cu_seqlens is used in bwd; training assumes TQ == TK
+        ctx.cu_seqlens = cu_seqlens_q

Consider raising if cu_seqlens_q and cu_seqlens_k differ during backward.

fla/ops/nsa/parallel.py (3)

19-26: Add stacklevel to FlashAttention import warning.

Improves debuggability.

-    warnings.warn(
+    warnings.warn(
         "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
-        category=ImportWarning
+        category=ImportWarning,
+        stacklevel=2
     )

971-976: Guard windowed path when FlashAttention is unavailable.

flash_attn_func may be None; raise clearer error or fallback to naive.

-            o_swa = flash_attn_func(
+            if flash_attn_func is None:
+                raise RuntimeError("window_size>0 requires FlashAttention; please install flash-attn")
+            o_swa = flash_attn_func(
                 q, k, v,
                 causal=True,
                 window_size=(window_size-1, 0)
             )

953-954: Add stacklevel to warnings.warn.

-            warnings.warn("`block_indices` computed from compression is overridden")
+            warnings.warn("`block_indices` computed from compression is overridden", stacklevel=2)
tests/ops/test_nsa.py (7)

21-33: Avoid sorting -1 paddings to the front in block indices.

Sorting ascending puts -1 (padding) first, causing unnecessary gathers of masked entries. Keep valid indices ahead of paddings for cheaper reference runs.

-    block_indices = block_indices.sort(-1)[0]
+    # Place valid indices before paddings (-1) to reduce useless work in reference paths
+    block_indices = block_indices.sort(-1, descending=True)[0]

430-431: Add stacklevel to warnings for accurate source locations.

-        warnings.warn(f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
-                      f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.")
+        warnings.warn(
+            f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
+            f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.",
+            stacklevel=2,
+        )

597-599: Use int32 and spread literal for cu_seqlens_q.

Unify types with other cu_seqlens (int32) and address RUF005.

-    cu_seqlens_q = torch.cumsum(torch.tensor([0] + q_lens), dim=0).to(device)
+    cu_seqlens_q = torch.tensor([0, *q_lens], device=device, dtype=torch.int32).cumsum(0)

696-697: *Make cu_seqlens_q int32 and use [0, q_lens].

-    cu_seqlens_q = torch.cumsum(torch.tensor([0] + q_lens), dim=0).to(device)
+    cu_seqlens_q = torch.tensor([0, *q_lens], device=device, dtype=torch.int32).cumsum(0)

823-824: Add stacklevel to warnings.warn.

-        warnings.warn(f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
-                      f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.")
+        warnings.warn(
+            f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
+            f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.",
+            stacklevel=2,
+        )

827-828: *Use int32 and [0, q_lens] for cu_seqlens_q.

-    cu_seqlens_q = torch.cumsum(torch.tensor([0] + q_lens), dim=0).to(device)
+    cu_seqlens_q = torch.tensor([0, *q_lens], device=device, dtype=torch.int32).cumsum(0)

932-933: Standardize cu_seqlens_q construction.

Equivalent to current code but keeps dtype/device explicit and resolves RUF005 consistently.

-    cu_seqlens_q = torch.cumsum(torch.tensor([0] + q_lens), dim=0).int().to(device)
+    cu_seqlens_q = torch.tensor([0, *q_lens], device=device, dtype=torch.int32).cumsum(0)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 77bff45 and ab84906.

📒 Files selected for processing (6)
  • fla/layers/nsa.py (3 hunks)
  • fla/layers/utils.py (5 hunks)
  • fla/ops/nsa/compression.py (14 hunks)
  • fla/ops/nsa/naive.py (2 hunks)
  • fla/ops/nsa/parallel.py (19 hunks)
  • tests/ops/test_nsa.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
fla/layers/nsa.py (2)
fla/layers/utils.py (2)
  • pad_input (185-206)
  • unpad_input (101-182)
fla/ops/nsa/parallel.py (1)
  • parallel_nsa (862-977)
tests/ops/test_nsa.py (6)
fla/ops/nsa/naive.py (4)
  • naive_nsa (269-386)
  • naive_nsa_sel (24-130)
  • naive_nsa_cmp (133-166)
  • naive_nsa_topk (169-266)
fla/ops/nsa/parallel.py (3)
  • parallel_nsa (862-977)
  • parallel_nsa_fwd (616-670)
  • parallel_nsa_topk (553-612)
fla/ops/nsa/compression.py (1)
  • parallel_nsa_compression (539-558)
fla/ops/utils/index.py (2)
  • prepare_token_indices (94-96)
  • prepare_chunk_offsets (109-113)
fla/utils.py (1)
  • assert_close (77-88)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
fla/ops/nsa/naive.py (2)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
fla/ops/utils/index.py (2)
  • prepare_token_indices (94-96)
  • prepare_chunk_offsets (109-113)
fla/ops/nsa/parallel.py (3)
fla/ops/utils/index.py (2)
  • prepare_token_indices (94-96)
  • prepare_chunk_offsets (109-113)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
fla/ops/nsa/compression.py (1)
  • parallel_nsa_compression (539-558)
fla/ops/nsa/compression.py (1)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
🪛 GitHub Actions: lint
fla/layers/nsa.py

[error] 1-1: isort: Import sorting issues detected. File was auto-fixed by the isort hook during pre-commit.

tests/ops/test_nsa.py

[error] 1-1: isort: Import sorting issues detected. File was auto-fixed by the isort hook during pre-commit.

fla/ops/nsa/naive.py

[error] 1-1: isort: Import sorting issues detected. File was auto-fixed by the isort hook during pre-commit.

fla/ops/nsa/parallel.py

[error] 1-1: isort: Import sorting issues detected. File was auto-fixed by the isort hook during pre-commit.

🪛 Ruff (0.12.2)
tests/ops/test_nsa.py

409-409: Use of assert detected

(S101)


430-430: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


446-446: Use of assert detected

(S101)


448-448: Use of assert detected

(S101)


597-597: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


696-696: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


796-796: Use of assert detected

(S101)


823-823: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


827-827: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


847-847: Use of assert detected

(S101)


849-849: Use of assert detected

(S101)


932-932: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)

fla/ops/nsa/naive.py

17-17: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


142-142: Unused function argument: b

(ARG001)


142-142: Unused function argument: h

(ARG001)


146-146: Unused function argument: b

(ARG001)


146-146: Unused function argument: h

(ARG001)


152-152: Unused function argument: b

(ARG001)


152-152: Unused function argument: h

(ARG001)


195-195: Use of assert detected

(S101)


204-204: Use of assert detected

(S101)


238-238: Use of assert detected

(S101)


241-241: Use of assert detected

(S101)


244-244: Avoid specifying long messages outside the exception class

(TRY003)


324-324: Use of assert detected

(S101)


328-328: Use of assert detected

(S101)


329-329: Use of assert detected

(S101)


351-351: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/ops/nsa/parallel.py

560-560: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


566-566: Use of assert detected

(S101)


567-567: Use of assert detected

(S101)


953-953: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🔇 Additional comments (7)
fla/layers/utils.py (1)

151-156: Edge-case: preserve dtype/device when wrapping q into a tuple.

Implementation is fine. Calling out that this cast/back behavior is correct and useful for callers passing a single tensor.

fla/layers/nsa.py (1)

119-136: Mask path looks correct and neatly re-pads outputs.

Good use of unpad_input/pad_input and passing tuple cu_seqlens into parallel_nsa.

fla/ops/nsa/naive.py (1)

133-166: Include per-sequence offsets in shifted_varlen_mask for variable-length attention
The comparison q_i >= (kv_i + 1) * block_size - 1 omits each sequence’s start offset (derived from cu_seqlens), so for padded or truncated final blocks (where TQ ≠ TKV) valid query–key pairs will be masked out. Incorporate a per-sequence Q_OFFSET from cu_seqlens into the mask to align token indices correctly.

fla/ops/nsa/compression.py (2)

54-66: Nice: proper per-Q/K varlen handling and TQ/TK separation in the kernel.

The Q_OFFSET and NC logic cleanly enforce causality across mismatched lengths.


350-355: Grid and output shapes updated to TQ—good alignment with API.

This matches the higher-level wrappers and avoids padding waste.

fla/ops/nsa/parallel.py (2)

552-563: API changes in parallel_nsa_topk look good; minor type/doc polish.

Docstrings and assertions align with TQ/TK separation.


922-941: Correct propagation of cu_seqlens_q/k and TK into compression/topk.

End-to-end TQ != TK path is wired properly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (8)
fla/layers/utils.py (1)

101-108: Fix tuple typing for states and enforce Long dtype for indices_q (avoids runtime index errors).

  • states: Tuple[torch.Tensor] denotes a 1‑tuple; use var-length Tuple[torch.Tensor, ...].
  • In the q_len == 1 path, indices_q is int32; pad_inputIndexPutFirstAxis uses advanced indexing (y[indices] = x) which requires LongTensor. This will crash on CUDA with int32.

Apply:

-def unpad_input(
-    q: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
-    states: Tuple[torch.Tensor],
+def unpad_input(
+    q: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
+    states: Tuple[torch.Tensor, ...],
     attention_mask: torch.Tensor,
     q_len: int,
     keepdim: bool = False,
 ):
@@
     elif q_len == 1:
         max_seqlen_in_batch_q = 1
         cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q[0].device)
-        indices_q = cu_seqlens_q[:-1]
+        indices_q = cu_seqlens_q[:-1].to(torch.long)
         q = tuple(q_.squeeze(1) for q_ in q)

Also applies to: 163-167

fla/ops/nsa/parallel.py (2)

795-802: Remove @torch.compile decorator from torch.autograd.Function class.

Decorating the class replaces it with a callable object lacking .apply, breaking ParallelNSAFunction.apply(...) at Line 955.

Apply:

-@torch.compile
 class ParallelNSAFunction(torch.autograd.Function):

916-923: Replace assert with explicit exceptions in public API.

assert can be stripped with -O and is flagged by linters. Raise ValueError instead.

-    assert block_counts is not None, "block counts must be provided for selection"
+    if block_counts is None:
+        raise ValueError("block counts must be provided for selection")
@@
-    if cu_seqlens is not None:
-        assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
-    assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
+    if cu_seqlens is not None and q.shape[0] != 1:
+        raise ValueError("batch size must be 1 when cu_seqlens are provided")
+    if q.shape[2] % (k.shape[2] * 16) != 0:
+        raise ValueError("Group size must be a multiple of 16 in NSA")
fla/layers/nsa.py (1)

98-102: Compute max_seqlen robustly when attention_mask is present.

max(seqlen_offset) on a Tensor routes through Python iteration and returns a Tensor; use Tensor ops and .item().

-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q.shape[1] + max(seqlen_offset)
+                offs = prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
+                seqlen_offset = seqlen_offset + offs
+                max_seqlen = q.shape[1] + int((seqlen_offset.max() if torch.is_tensor(seqlen_offset) else seqlen_offset))
fla/ops/nsa/compression.py (4)

124-159: BWD heuristic still keyed on cu_seqlens (old API).

Forward uses cu_seqlens_q/cu_seqlens_k; bwd heuristic should mirror that to avoid mis-detection in varlen.

-@triton.heuristics({
-    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
-})
+@triton.heuristics({
+    'IS_VARLEN': lambda args: args.get('cu_seqlens_q', None) is not None,
+})

134-159: dq kernel ignores TK != TQ (missing Q_OFFSET in mask and loop extent). Gradients wrong in cached decode.

NC is computed as (i_t+1)//BS; for TQ!=TK it must use i_t+Q_OFFSET. Also kernel receives only T and single cu_seqlens, so it cannot derive TK. This produces under-attention in dq for cached prefills.

Minimal fix: pass TK and either cu_seqlens_q/cu_seqlens_k or Q_OFFSET, and incorporate into NC.

-@triton.jit(do_not_specialize=['T'])
-def parallel_nsa_compression_bwd_kernel_dq(
+@triton.jit(do_not_specialize=['TQ'])
+def parallel_nsa_compression_bwd_kernel_dq(
     q, k, v, lse, delta, do, dq, scale,
-    cu_seqlens, token_indices, chunk_offsets, T,
+    cu_seqlens_q, token_indices_q, chunk_offsets, TQ,
+    TK: tl.constexpr,
@@
-    if IS_VARLEN:
-        i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
-        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
-        T = eos - bos
+    if IS_VARLEN:
+        i_n, i_t = tl.load(token_indices_q + i_t * 2).to(tl.int32), tl.load(token_indices_q + i_t * 2 + 1).to(tl.int32)
+        bos, eos = tl.load(cu_seqlens_q + i_n).to(tl.int32), tl.load(cu_seqlens_q + i_n + 1).to(tl.int32)
+        TQ = eos - bos
@@
-    all = B * T
+    all = B * TQ
@@
-    TC = tl.cdiv(T, BS)
-    NC = (i_t + 1) // BS
+    TC = tl.cdiv(TK, BS)
+    Q_OFFSET = TK - TQ
+    n_raw = (i_t + Q_OFFSET + 1) // BS
+    NC = tl.where(n_raw > 0, n_raw, 0)

I can wire the Python launcher changes too if you want.

Also applies to: 173-195


237-266: dv pointer base uses i_v “plane” on a tensor without NV leading dim — memory corruption when NV > 1.

dv is allocated without NV, but p_dv offsets by i_v * all*H. This will write past the buffer when V > 128.

-    dv,
+    dv,
@@
-    p_dv = tl.make_block_ptr(dv + (i_v * all*H + boc * H + i_h) * V, (TC, V), (H*V, 1),
+    p_dv = tl.make_block_ptr(dv + (boc * H + i_h) * V, (TC, V), (H*V, 1),
                               (i_c * BC, i_v * BV), (BC, BV), (1, 0))

Optionally mirror dk’s NV reduction pattern: materialize dv with NV, accumulate per-slice, then sum on host.

Also applies to: 280-284, 292-321


471-519: ctx only stores cu_seqlens_q; bwd needs TK/Q_OFFSET to be correct when TQ != TK.

Backprop currently cannot reconstruct TK and will compute wrong masks. Save TK (and cu_seqlens_k if needed) in ctx.

-        ctx.save_for_backward(q, k, v, o, lse)
+        ctx.save_for_backward(q, k, v, o, lse)
         # Use cu_seqlens of q in backward, as cu_seqlens for q & k are different only for inference
-        ctx.cu_seqlens = cu_seqlens_q
-        ctx.token_indices = token_indices_q
+        ctx.cu_seqlens_q = cu_seqlens_q
+        ctx.token_indices_q = token_indices_q
+        ctx.TK = TK
@@
-            cu_seqlens=ctx.cu_seqlens,
-            token_indices=ctx.token_indices
+            cu_seqlens=ctx.cu_seqlens_q,
+            token_indices=ctx.token_indices_q,
+            TK=ctx.TK,

Also applies to: 523-538

♻️ Duplicate comments (1)
fla/ops/nsa/naive.py (1)

188-194: Device mismatch: cu_q/cu_k created on CPU when cu_seqlens is None.

These are later combined with CUDA tensors.

-        cu_q = torch.cat([
-            torch.arange(0, B * Tq, Tq), torch.tensor([B * Tq])
-        ])
-        cu_k = torch.cat([
-            torch.arange(0, B * Tc, Tc), torch.tensor([B * Tc])
-        ])
+        cu_q = torch.cat([
+            torch.arange(0, B * Tq, Tq, device=device),
+            torch.tensor([B * Tq], device=device),
+        ])
+        cu_k = torch.cat([
+            torch.arange(0, B * Tc, Tc, device=device),
+            torch.tensor([B * Tc], device=device),
+        ])
🧹 Nitpick comments (11)
fla/ops/nsa/parallel.py (2)

943-955: Set stacklevel for override warning.

Improves traceability to the caller; aligns with B028.

-        else:
-            warnings.warn("`block_indices` computed from compression is overridden")
+        else:
+            warnings.warn("`block_indices` computed from compression is overridden", stacklevel=2)

886-891: Docstring typo.

“sliding attentionof” → “sliding attention of”.

-        g_swa (torch.Tensor):
-            Gate score for sliding attentionof shape `[B, TQ, HQ]`.
+        g_swa (torch.Tensor):
+            Gate score for sliding attention of shape `[B, TQ, HQ]`.
fla/ops/nsa/compression.py (2)

325-355: Guard BK/BV selection for Hopper vs. others is fine; assert NK==1 is hard-coded.

If K can be 320/512 in some configs, this assert will hard fail. Consider documenting this limitation in the public API and validating earlier.

Also applies to: 356-379


539-559: Type hint: avoid implicit Optional in Union signature.

PEP 484 prefers Union[X, Y] | None (or Optional[X]) over Union[None, X, Y].

-    cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None
+    cu_seqlens: Tuple[torch.LongTensor, torch.LongTensor] | torch.LongTensor | None = None
tests/ops/test_nsa.py (3)

431-431: warnings.warn without stacklevel obscures call site in test failures.

Add stacklevel=2.

-        warnings.warn(f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
-                      f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.")
+        warnings.warn(
+            f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
+            f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.",
+            stacklevel=2,
+        )

Also applies to: 824-824


598-599: *Tiny style: build cu_seqlens_q with [0, q_lens].

Cleaner and avoids Python list concatenation.

-    cu_seqlens_q = torch.cumsum(torch.tensor([0] + q_lens), dim=0).to(device)
+    cu_seqlens_q = torch.cumsum(torch.tensor([0, *q_lens]), dim=0).to(device)

(Apply similarly to other occurrences.)

Also applies to: 697-698, 828-829, 933-934


311-316: lse comparison special-case is embedded in tests; consider encoding in helper.

You repeatedly replace -inf with 0 for compressive paths. Factor into a small helper to keep assertions focused.

Also applies to: 690-692

fla/ops/nsa/naive.py (4)

14-22: Add stacklevel=2 to ImportWarning.

Improves traceback pointing to import site.

-    warnings.warn(
+    warnings.warn(
         "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
-        category=ImportWarning
+        category=ImportWarning,
+        stacklevel=2,
     )

110-114: naive_nsa_sel explicitly disallows TQ != TK. Align mask or keep assert with clear message.

Given the PR adds cached TQ!=TK paths, either:

  • implement the offset-aware causal mask here, or
  • keep the assert but clarify it’s only used as a reference when TQ == TK.

Also applies to: 123-129


246-266: Top-k selection: enforce allow mask and quota — good. Add early-exit when S==0.

Micro-opt to skip work when no blocks requested.

-        _, topi = torch.topk(scores, k=min(S, Tc), dim=-1)
+        if S == 0:
+            out = torch.full((Tq, Hkv, 0), -1, device=device, dtype=torch.long)
+        else:
+            _, topi = torch.topk(scores, k=min(S, Tc), dim=-1)

324-337: Group-size assert: message could be clearer.

Spell out “HQ must be a multiple of 16*H”.

-    assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
+    assert q.shape[2] % (k.shape[2] * 16) == 0, "HQ must be a multiple of 16 * H (GQA group size in NSA)"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ab84906 and 18017f6.

📒 Files selected for processing (6)
  • fla/layers/nsa.py (3 hunks)
  • fla/layers/utils.py (5 hunks)
  • fla/ops/nsa/compression.py (14 hunks)
  • fla/ops/nsa/naive.py (2 hunks)
  • fla/ops/nsa/parallel.py (20 hunks)
  • tests/ops/test_nsa.py (7 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
PR: fla-org/flash-linear-attention#544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.

Applied to files:

  • tests/ops/test_nsa.py
🧬 Code graph analysis (5)
tests/ops/test_nsa.py (6)
fla/ops/nsa/compression.py (1)
  • parallel_nsa_compression (540-559)
fla/ops/nsa/naive.py (4)
  • naive_nsa (269-386)
  • naive_nsa_cmp (133-166)
  • naive_nsa_sel (24-130)
  • naive_nsa_topk (169-266)
fla/ops/nsa/parallel.py (3)
  • parallel_nsa (862-978)
  • parallel_nsa_fwd (616-670)
  • parallel_nsa_topk (553-612)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
fla/utils.py (1)
  • assert_close (77-88)
fla/ops/nsa/parallel.py (3)
fla/ops/utils/index.py (2)
  • prepare_token_indices (94-96)
  • prepare_chunk_offsets (109-113)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
fla/ops/nsa/compression.py (1)
  • parallel_nsa_compression (540-559)
fla/layers/nsa.py (2)
fla/layers/utils.py (2)
  • pad_input (185-206)
  • unpad_input (101-182)
fla/ops/nsa/parallel.py (1)
  • parallel_nsa (862-978)
fla/ops/nsa/compression.py (1)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
fla/ops/nsa/naive.py (2)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
🪛 Ruff (0.12.2)
tests/ops/test_nsa.py

410-410: Use of assert detected

(S101)


431-431: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


447-447: Use of assert detected

(S101)


449-449: Use of assert detected

(S101)


598-598: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


697-697: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


797-797: Use of assert detected

(S101)


824-824: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


828-828: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


848-848: Use of assert detected

(S101)


850-850: Use of assert detected

(S101)


933-933: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)

fla/ops/nsa/parallel.py

560-560: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


566-566: Use of assert detected

(S101)


567-567: Use of assert detected

(S101)


954-954: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/ops/nsa/compression.py

546-546: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

fla/ops/nsa/naive.py

17-17: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


110-110: Use of assert detected

(S101)


142-142: Unused function argument: b

(ARG001)


142-142: Unused function argument: h

(ARG001)


146-146: Unused function argument: b

(ARG001)


146-146: Unused function argument: h

(ARG001)


152-152: Unused function argument: b

(ARG001)


152-152: Unused function argument: h

(ARG001)


195-195: Use of assert detected

(S101)


204-204: Use of assert detected

(S101)


238-238: Use of assert detected

(S101)


241-241: Use of assert detected

(S101)


244-244: Avoid specifying long messages outside the exception class

(TRY003)


324-324: Use of assert detected

(S101)


328-328: Use of assert detected

(S101)


329-329: Use of assert detected

(S101)


351-351: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🪛 GitHub Actions: lint
tests/ops/test_nsa.py

[error] 8-8: flake8: E402 module level import not at top of file.


[error] 10-10: flake8: E402 module level import not at top of file.


[error] 11-11: flake8: E402 module level import not at top of file.


[error] 12-12: flake8: E402 module level import not at top of file.


[error] 14-14: flake8: E402 module level import not at top of file.


[error] 15-15: flake8: E402 module level import not at top of file.


[error] 16-16: flake8: E402 module level import not at top of file.


[error] 17-17: flake8: E402 module level import not at top of file.


[error] 18-18: flake8: E402 module level import not at top of file.


[error] 19-19: flake8: E402 module level import not at top of file.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (4)
fla/ops/nsa/parallel.py (1)

234-236: Document and verify the Q_OFFSET = TK - TQ assumption.

Kernels assume query tokens are the last TQ tokens of the KV stream (both in LSE and forward paths). This is correct for decoding/chunked prefilling, but will be wrong if queries are not aligned to the tail.

Please confirm tests cover varlen cases where TQ != TK and Q is exactly the tail. If not guaranteed, guard with an explicit check or expose an offset parameter.

Also applies to: 282-296

fla/layers/nsa.py (3)

109-113: Good fix: cache offset now uses q_len.

This aligns KV cache updates with the actual number of appended queries.


119-136: Varlen path looks correct; ensure indices_q is Long.

pad_input relies on Long indices; with the proposed utils.unpad_input fix, this path remains safe.

After applying the utils.py fix, please run a decoding varlen test to confirm no dtype mismatch on CUDA.


151-153: Shape finalize is fine.

o reshaped back to [B, q_len, -1] before projection is consistent with dual-length flow.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
fla/layers/nsa.py (1)

98-103: Bug: using Python max() on a torch.Tensor; breaks with attention_mask (vector offsets).

seqlen_offset becomes a tensor when attention_mask is provided; Python max(seqlen_offset) fails. Compute lens once, use tensor.max(), and cast to int.

Apply:

-            if attention_mask is not None:
-                # to deliminate the offsets of padding tokens
-                seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1]
-                max_seqlen = q.shape[1] + max(seqlen_offset)
+            if attention_mask is not None:
+                # delineate offsets of padding tokens
+                lens = prepare_lens_from_mask(attention_mask)
+                seqlen_offset = seqlen_offset + lens - attention_mask.shape[-1]
+                max_seqlen = q.shape[1] + int(lens.max().item())
fla/ops/nsa/compression.py (1)

74-88: Clamp NC to non-negative
The raw computation (i_t + Q_OFFSET + 1) // BS can produce negative NC when TK < TQ, leading to invalid loop bounds and NaNs; clamp NC ≥ 0.

-    NC = (i_t + Q_OFFSET + 1) // BS
+    n_raw = (i_t + Q_OFFSET + 1) // BS
+    NC = tl.where(n_raw > 0, n_raw, 0)
fla/ops/nsa/parallel.py (3)

842-858: Fix backward: gradient tuple length mismatch (11 returned, 9 expected).

ParallelNSAFunction.backward must return exactly one grad per forward input (9 total). Returning 11 will raise at runtime.

Apply this diff:

-        return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
+        # grads for: q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens
+        return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None

316-320: Avoid divide-by-zero when no blocks contribute.

If no valid KV blocks pass the causal checks, b_acc can be 0, causing NaNs/Inf.

-    b_o = b_o / b_acc[:, None]
-    b_m += log(b_acc)
+    eps = 1e-20
+    b_o = b_o / tl.maximum(b_acc, eps)[:, None]
+    b_m += log(tl.maximum(b_acc, eps))

961-971: Use per-sequence max lengths for varlen FlashAttention
The current call passes max_seqlen_q=q.shape[1] and max_seqlen_k=k.shape[1] (total tokens), but FlashAttention expects the maximum individual sequence length from cu_seqlens. Update to:

         if cu_seqlens is not None:
-            o_swa = flash_attn_varlen_func(
-                q.squeeze(0), k.squeeze(0), v.squeeze(0),
-                cu_seqlens_q=cu_seqlens_q,
-                cu_seqlens_k=cu_seqlens_k,
-                max_seqlen_q=q.shape[1],
-                max_seqlen_k=k.shape[1],
-                causal=True,
-                window_size=(window_size-1, 0)
-            ).unsqueeze(0)
+            max_seqlen_q = prepare_lens(cu_seqlens_q).max().item()
+            max_seqlen_k = prepare_lens(cu_seqlens_k).max().item()
+            o_swa = flash_attn_varlen_func(
+                q.squeeze(0), k.squeeze(0), v.squeeze(0),
+                cu_seqlens_q=cu_seqlens_q,
+                cu_seqlens_k=cu_seqlens_k,
+                max_seqlen_q=max_seqlen_q,
+                max_seqlen_k=max_seqlen_k,
+                causal=True,
+                window_size=(window_size-1, 0)
+            ).unsqueeze(0)
🧹 Nitpick comments (12)
fla/layers/utils.py (1)

162-169: Clarify q_len handling and error message.

The NotImplementedError assumes only q_len == seq_len or q_len == 1. Consider including both values in the message for faster diagnosis and add a short comment explaining the decode assumption.

Apply:

-    else:
-        raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
+    else:
+        # Decode path assumes one query token per sequence; prefill assumes q_len == seq_len.
+        raise NotImplementedError(
+            f"Unsupported q_len={q_len}; expected q_len == seq_len ({seq_len}) or q_len == 1."
+        )
fla/ops/nsa/naive.py (3)

14-21: warnings.warn without stacklevel; clearer ImportWarning.

Add stacklevel=2 so warnings point to callers, and include the install hint once.

Apply:

-    warnings.warn(
-        "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
-        category=ImportWarning
-    )
+    warnings.warn(
+        "Flash Attention is not installed. Install with: pip install flash-attn --no-build-isolation",
+        category=ImportWarning,
+        stacklevel=2,
+    )

65-69: Do not raise DeprecationWarning as an exception.

Prefer a deprecation warning and continue (or remove the flag entirely).

Apply:

-    if head_first:
-        raise DeprecationWarning(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+    if head_first:
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version; use head_first=False.",
+            category=DeprecationWarning,
+            stacklevel=2,
+        )

123-129: Causal mask under-selects when TQ != TK (decode).

Masking uses i_i > i_q assuming aligned Q/K lengths. For cached decode (queries are the last TQ tokens), your own assertion forbids TQ != TK, but if you lift it later this mask must include an offset (TK - TQ). Add a TODO to keep naive path consistent with parallel_nsa.

Proposed (guarded by an offset):

-            attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(
-                torch.logical_or(i_i > i_q, i_i < 0), float('-inf')).softmax(0)
+            # If TQ != TK (cached decode), allow keys up to i_q + (TK - TQ).
+            causal_limit = i_q + max(Tk - Tq, 0)
+            attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(
+                torch.logical_or(i_i > causal_limit, i_i < 0), float('-inf')
+            ).softmax(0)
tests/ops/test_nsa.py (3)

11-19: Remove unused noqa flags; add CUDA-only guard for NSA tests.

Ruff flags RUF100 for unused E402 suppressions. Also, these tests require CUDA kernels—skip module on CPU.

Apply:

-import torch  # noqa: E402
-import triton  # noqa: E402
+import torch
+import triton
@@
-from fla.ops.nsa.compression import parallel_nsa_compression  # noqa: E402
-from fla.ops.nsa.naive import naive_nsa, naive_nsa_cmp, naive_nsa_sel, naive_nsa_topk  # noqa: E402
-from fla.ops.nsa.parallel import parallel_nsa, parallel_nsa_fwd, parallel_nsa_topk  # noqa: E402
-from fla.ops.utils import prepare_chunk_offsets, prepare_token_indices  # noqa: E402
-from fla.ops.utils.pooling import mean_pooling  # noqa: E402
-from fla.utils import assert_close, device  # noqa: E402
+from fla.ops.nsa.compression import parallel_nsa_compression
+from fla.ops.nsa.naive import naive_nsa, naive_nsa_cmp, naive_nsa_sel, naive_nsa_topk
+from fla.ops.nsa.parallel import parallel_nsa, parallel_nsa_fwd, parallel_nsa_topk
+from fla.ops.utils import prepare_chunk_offsets, prepare_token_indices
+from fla.ops.utils.pooling import mean_pooling
+from fla.utils import assert_close, device
+
+# Skip on CPU-only runners
+pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for NSA kernels")

431-433: warnings.warn without stacklevel in tests.

Set stacklevel=2 so failure messages point at the test call sites.

Apply:

-        warnings.warn(f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
-                      f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.")
+        warnings.warn(
+            f"Block indices mismatch: {len(indices)}/{block_indices.numel()} "
+            f"({len(indices) / free_block_indices_naive.numel():.2f}), seemingly due to numerical issues.",
+            stacklevel=2,
+        )

Also applies to: 824-825


598-599: Tiny style: build cu_seqlens_q lists idiomatically.

Use [0, *q_lens] instead of concatenation before torch.cumsum.

Apply:

-    cu_seqlens_q = torch.cumsum(torch.tensor([0] + q_lens), dim=0).to(device)
+    cu_seqlens_q = torch.cumsum(torch.tensor([0, *q_lens]), dim=0).to(device)

(Repeat similarly for other occurrences.)

Also applies to: 697-698, 828-829, 933-934

fla/ops/nsa/compression.py (1)

546-548: Type hint cleanup (RUF013).

Optional: prefer Union[X, Y] | None style or explicit Optional[...] consistently across the repo.

fla/ops/nsa/parallel.py (4)

560-568: Make parallel_nsa_topk robust when scale=None and clean up typing.

  • Provide a default scale like other paths.
  • Address implicit Optional typing (RUF013).
-def parallel_nsa_topk(
+def parallel_nsa_topk(
     q: torch.Tensor,
     k: torch.Tensor,
     TK: int,
-    lse: Optional[torch.Tensor],
+    lse: Optional[torch.Tensor],
     block_counts: Union[torch.LongTensor, int],
     block_size: int = 64,
-    scale: float = None,
-    cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None,
+    scale: float | None = None,
+    cu_seqlens: torch.LongTensor | Tuple[torch.LongTensor, torch.LongTensor] | None = None,
 ) -> torch.LongTensor:
@@
-    BK = max(triton.next_power_of_2(K), 16)
+    BK = max(triton.next_power_of_2(K), 16)
+    if scale is None:
+        scale = K ** -0.5

952-955: Clarify override warning and add stacklevel=2.

Message is inverted and static analysis flags missing stacklevel.

-        else:
-            warnings.warn("`block_indices` computed from compression is overridden")
+        else:
+            warnings.warn(
+                "Provided `block_indices` override internally computed indices; skipping NSA top-k.",
+                stacklevel=2
+            )

912-915: Docstring: fix output shape to [B, TQ, HQ, V].

Current text says [B, T, HQ, V].

-            Outputs of shape `[B, T, HQ, V]`.
+            Outputs of shape `[B, TQ, HQ, V]`.

673-678: Type hint: cu_seqlens can be None.

Signature should reflect actual usage below.

-def parallel_nsa_block_mask(
+def parallel_nsa_block_mask(
     block_indices: torch.LongTensor,
     block_counts: Union[torch.LongTensor, int],
-    cu_seqlens: torch.LongTensor,
+    cu_seqlens: torch.LongTensor | None,
     block_size: int,
 ):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 18017f6 and bb2e32e.

📒 Files selected for processing (6)
  • fla/layers/nsa.py (3 hunks)
  • fla/layers/utils.py (5 hunks)
  • fla/ops/nsa/compression.py (14 hunks)
  • fla/ops/nsa/naive.py (2 hunks)
  • fla/ops/nsa/parallel.py (20 hunks)
  • tests/ops/test_nsa.py (7 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
PR: fla-org/flash-linear-attention#544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.

Applied to files:

  • tests/ops/test_nsa.py
🧬 Code graph analysis (5)
fla/layers/nsa.py (2)
fla/layers/utils.py (2)
  • pad_input (185-206)
  • unpad_input (101-182)
fla/ops/nsa/parallel.py (1)
  • parallel_nsa (862-978)
tests/ops/test_nsa.py (6)
fla/ops/nsa/compression.py (2)
  • parallel_nsa_compression (540-559)
  • backward (523-537)
fla/ops/nsa/naive.py (4)
  • naive_nsa (269-386)
  • naive_nsa_cmp (133-166)
  • naive_nsa_sel (24-130)
  • naive_nsa_topk (169-266)
fla/ops/nsa/parallel.py (4)
  • parallel_nsa (862-978)
  • parallel_nsa_fwd (616-670)
  • parallel_nsa_topk (553-612)
  • backward (842-858)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
fla/ops/utils/pooling.py (2)
  • mean_pooling (190-207)
  • backward (179-187)
fla/utils.py (1)
  • assert_close (77-88)
fla/ops/nsa/parallel.py (3)
fla/ops/utils/index.py (2)
  • prepare_token_indices (94-96)
  • prepare_chunk_offsets (109-113)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
fla/ops/nsa/compression.py (1)
  • parallel_nsa_compression (540-559)
fla/ops/nsa/compression.py (1)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
fla/ops/nsa/naive.py (2)
fla/ops/utils/index.py (2)
  • prepare_chunk_offsets (109-113)
  • prepare_token_indices (94-96)
fla/ops/utils/pooling.py (1)
  • mean_pooling (190-207)
🪛 Ruff (0.12.2)
tests/ops/test_nsa.py

11-11: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


16-16: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


17-17: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


18-18: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


19-19: Unused noqa directive (non-enabled: E402)

Remove unused noqa directive

(RUF100)


410-410: Use of assert detected

(S101)


431-431: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


447-447: Use of assert detected

(S101)


449-449: Use of assert detected

(S101)


598-598: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


697-697: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


797-797: Use of assert detected

(S101)


824-824: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


828-828: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)


848-848: Use of assert detected

(S101)


850-850: Use of assert detected

(S101)


933-933: Consider [0, *q_lens] instead of concatenation

Replace with [0, *q_lens]

(RUF005)

fla/ops/nsa/parallel.py

560-560: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


566-566: Use of assert detected

(S101)


567-567: Use of assert detected

(S101)


954-954: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/ops/nsa/compression.py

546-546: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

fla/ops/nsa/naive.py

17-17: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


110-110: Use of assert detected

(S101)


142-142: Unused function argument: b

(ARG001)


142-142: Unused function argument: h

(ARG001)


146-146: Unused function argument: b

(ARG001)


146-146: Unused function argument: h

(ARG001)


152-152: Unused function argument: b

(ARG001)


152-152: Unused function argument: h

(ARG001)


195-195: Use of assert detected

(S101)


204-204: Use of assert detected

(S101)


238-238: Use of assert detected

(S101)


241-241: Use of assert detected

(S101)


244-244: Avoid specifying long messages outside the exception class

(TRY003)


324-324: Use of assert detected

(S101)


328-328: Use of assert detected

(S101)


329-329: Use of assert detected

(S101)


351-351: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (4)
fla/layers/nsa.py (1)

119-136: Varlen path: pass tuple cu_seqlens from unpad_input; good. Ensure rotary offsets match.

Looks correct: rotary is applied before unpadding; unpad_input returns (cu_seqlens_q, cu_seqlens_k) which flow into parallel_nsa. Verify RotaryEmbedding supports vector seqlen_offset (per-batch) when cu_seqlens is None; otherwise consider passing cu_seqlens there as well.

fla/ops/nsa/compression.py (2)

350-355: Consistency: varlen pooling should use cu_seqlens_k.

Good: mean_pooling(k, block_size, cu_seqlens_k) and mean_pooling(v, block_size, cu_seqlens_k) parallel the dual-length design.


492-500: Backward uses cu_seqlens_q only; confirm training-time invariants.

ctx.cu_seqlens and token_indices are taken from queries only, which is fine for inference (TQ != TK). Please confirm training path only uses TQ == TK; otherwise dkv gradients may need K-side cu_seqlens.

fla/ops/nsa/parallel.py (1)

79-90: Generated verification scripts to inspect the wrapper functions and kernel usage of Q_OFFSET—let me review the outputs and verify the need for the proposed host-side guards.

Comment on lines +151 to 167
if isinstance(q, torch.Tensor):
q = (q,)
cast_tuple = True
else:
cast_tuple = False

if q_len == seq_len:
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
q = tuple(index_first_axis(rearrange(q_, "b s ... -> (b s) ..."), indices_k) for q_ in q)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q[0].device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
q = tuple(q_.squeeze(1) for q_ in q)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Device mismatch risk when attention_mask is on CPU and q/states are on CUDA.

indices_k and cu_seqlens_k can end up on a different device than q/states, breaking downstream varlen paths. Move them to q’s device after normalizing q to a tuple.

Apply:

@@
-    if isinstance(q, torch.Tensor):
+    if isinstance(q, torch.Tensor):
         q = (q,)
         cast_tuple = True
     else:
         cast_tuple = False
+    # Ensure indexing tensors follow q's device
+    indices_k = indices_k.to(device=q[0].device)
+    cu_seqlens_k = cu_seqlens_k.to(device=q[0].device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(q, torch.Tensor):
q = (q,)
cast_tuple = True
else:
cast_tuple = False
if q_len == seq_len:
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
q = tuple(index_first_axis(rearrange(q_, "b s ... -> (b s) ..."), indices_k) for q_ in q)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q[0].device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
q = tuple(q_.squeeze(1) for q_ in q)
else:
if isinstance(q, torch.Tensor):
q = (q,)
cast_tuple = True
else:
cast_tuple = False
# Ensure indexing tensors follow q's device
indices_k = indices_k.to(device=q[0].device)
cu_seqlens_k = cu_seqlens_k.to(device=q[0].device)
if q_len == seq_len:
q = tuple(
index_first_axis(
rearrange(q_, "b s ... -> (b s) ..."),
indices_k
)
for q_ in q
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1,
dtype=torch.int32,
device=q[0].device
)
indices_q = cu_seqsels_q[:-1]
q = tuple(q_.squeeze(1) for q_ in q)
else:
...
🤖 Prompt for AI Agents
In fla/layers/utils.py around lines 151 to 167, after normalizing q into a tuple
(and determining cast_tuple), ensure indices_k and cu_seqlens_k are moved to q's
device to avoid CPU/CUDA mismatch when attention_mask lives on CPU;
specifically, right after q is converted to a tuple determine target_device =
q[0].device and call .to(target_device, non_blocking=True) on indices_k and
cu_seqlens_k (and any derived indices like indices_q) so downstream varlen code
uses the same device as q/states.

Comment on lines +106 to +114
Tq = cu_q[i+1] - cu_q[i]
Tk = cu_k[i+1] - cu_k[i]
q_b, k_b, v_b, i_b = (q[0][cu_q[i]:cu_q[i+1]], k[0][cu_k[i]:cu_k[i+1]],
v[0][cu_k[i]:cu_k[i+1]], block_indices[0][cu_q[i]:cu_q[i+1]])
assert Tq == Tk, "TQ != TK case is not supported in naive_nsa_sel"
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
i_b = i_b.view(Tq, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(Tq):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Tensors used where Python ints are required (varlen path).

Tq and Tk are torch tensors; they’re used in range() and view(), which require Python ints. This will error on CUDA tensors.

Apply:

-            Tq = cu_q[i+1] - cu_q[i]
-            Tk = cu_k[i+1] - cu_k[i]
+            Tq = int((cu_q[i + 1] - cu_q[i]).item())
+            Tk = int((cu_k[i + 1] - cu_k[i]).item())
@@
-        i_b = i_b.view(Tq, block_indices.shape[2], -1).transpose(1, 2)
+        i_b = i_b.view(Tq, block_indices.shape[2], -1).transpose(1, 2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Tq = cu_q[i+1] - cu_q[i]
Tk = cu_k[i+1] - cu_k[i]
q_b, k_b, v_b, i_b = (q[0][cu_q[i]:cu_q[i+1]], k[0][cu_k[i]:cu_k[i+1]],
v[0][cu_k[i]:cu_k[i+1]], block_indices[0][cu_q[i]:cu_q[i+1]])
assert Tq == Tk, "TQ != TK case is not supported in naive_nsa_sel"
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
i_b = i_b.view(Tq, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(Tq):
# compute lengths as Python ints, not tensor scalars
- Tq = cu_q[i+1] - cu_q[i]
Tq = int((cu_q[i + 1] - cu_q[i]).item())
Tk = int((cu_k[i + 1] - cu_k[i]).item())
q_b, k_b, v_b, i_b = (q[0][cu_q[i]:cu_q[i+1]], k[0][cu_k[i]:cu_k[i+1]],
v[0][cu_k[i]:cu_k[i+1]], block_indices[0][cu_q[i]:cu_q[i+1]])
assert Tq == Tk, "TQ != TK case is not supported in naive_nsa_sel"
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(Tq, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(Tq):
🧰 Tools
🪛 Ruff (0.12.2)

110-110: Use of assert detected

(S101)

🤖 Prompt for AI Agents
fla/ops/nsa/naive.py lines 106-114: Tq and Tk are torch tensors but are being
used where Python ints are required (range() and view sizes), which will fail on
CUDA tensors; convert Tq and Tk to Python integers (e.g., Tq = int(Tq.item()) or
Tq = int(Tq.cpu())) before using them in range() and view(), and ensure any size
arguments passed to view/reshape are Python ints; also replace
i_b.new_tensor(range(BS)) with a device/dtype-safe arange (e.g.,
torch.arange(BS, device=i_b.device)) so the arithmetic and broadcasting use
matching devices and dtypes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant