Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4b42665
Fixing block_indices generation: for the first token in a block, its …
mutiann Aug 19, 2025
9fd04a2
Adapt selective attention branch (parallel_nsa_fwd) to the cached inf…
mutiann Aug 19, 2025
17f7f7c
Adapt compressive attention branch (parallel_nsa_compression_fwd) to …
mutiann Aug 19, 2025
b81b0b8
Adapt top-k selection (parallel_nsa_topk) to the cached inference sce…
mutiann Aug 20, 2025
45a9444
Add tests on the complete NSA forward function
mutiann Aug 20, 2025
d32f335
Merge remote-tracking branch 'origin/main'
mutiann Aug 20, 2025
9c52df6
Merge remote-tracking branch 'origin/main'
mutiann Aug 22, 2025
7856f48
Fix NSA layer forward for inference
mutiann Aug 22, 2025
52a6632
Merge branch 'fla-org:main' into main
mutiann Aug 22, 2025
fd32c4c
Fix varlen mode for NSA selective attention; add tests
mutiann Aug 25, 2025
2dafac7
Cleanup test code
mutiann Aug 25, 2025
2aa3851
Fix backward pass of compressive NSA & add tests
mutiann Aug 28, 2025
ebec03a
Fix varlen mode for NSA compressive attention; add tests
mutiann Aug 28, 2025
ccbd727
Fix varlen mode for top-k kernel & naive impl.; add tests
mutiann Aug 29, 2025
72c7114
Clean up contiguity guards
mutiann Aug 29, 2025
9ad97b7
Fix block indices generation to match kernel outputs & fix boundary c…
mutiann Aug 31, 2025
79655c7
Impl. full naive path; add tests for full forward & backward func, va…
mutiann Aug 31, 2025
002987f
Merge branch 'fla-org:main' into main
mutiann Aug 31, 2025
368d47e
Merge remote-tracking branch 'origin/main'
mutiann Aug 31, 2025
31e1806
Skip redundant tests on each individual op given tests on the whole N…
mutiann Sep 1, 2025
ff69e59
Varlen mode for NSA layer
mutiann Sep 1, 2025
bb2e32e
Fix lint, hints, and comments
mutiann Sep 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions fla/layers/nsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from einops import rearrange
from transformers.utils import logging

from fla.layers.utils import pad_input, unpad_input
from fla.modules import RotaryEmbedding
from fla.ops.nsa.parallel import parallel_nsa
from fla.ops.utils.index import prepare_lens_from_mask
Expand Down Expand Up @@ -80,17 +81,16 @@ def forward(
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)

batch_size, seq_len, _ = hidden_states.size()
batch_size, q_len, _ = hidden_states.size()

q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)

cu_seqlens = kwargs.get('cu_seqlens', None)

seqlen_offset, max_seqlen = 0, seq_len
seqlen_offset, max_seqlen = 0, q_len
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset
Expand All @@ -109,27 +109,46 @@ def forward(
k_cached, v_cached = past_key_values.update(
attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
layer_idx=self.layer_idx,
offset=seq_len,
cache_kwargs=dict(window_size=self.window_size)
offset=q_len,
)['attn_state']
if cache_has_content:
k, v = k_cached, v_cached
k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)

o = parallel_nsa(
q=q,
k=k,
v=v,
g_cmp=g_cmp,
g_slc=g_slc,
g_swa=g_swa,
block_size=self.block_size,
block_counts=self.block_counts,
window_size=self.window_size,
cu_seqlens=cu_seqlens,
)
o = o.reshape(batch_size, seq_len, -1)
if attention_mask is not None:
(q, g), (k, v), indices_q, cu_seqlens, max_seq_lens = unpad_input(
(q, g), (k, v), attention_mask, q_len, keepdim=True)
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
o = parallel_nsa(
q=q,
k=k,
v=v,
g_cmp=g_cmp,
g_slc=g_slc,
g_swa=g_swa,
block_size=self.block_size,
block_counts=self.block_counts,
window_size=self.window_size,
cu_seqlens=cu_seqlens,
).squeeze(0)
o = pad_input(o, indices_q, batch_size, q_len)
else:
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
o = parallel_nsa(
q=q,
k=k,
v=v,
g_cmp=g_cmp,
g_slc=g_slc,
g_swa=g_swa,
block_size=self.block_size,
block_counts=self.block_counts,
window_size=self.window_size,
cu_seqlens=cu_seqlens,
)

o = o.reshape(batch_size, q_len, -1)
o = self.o_proj(o)

if not output_attentions:
Expand Down
29 changes: 19 additions & 10 deletions fla/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Code is adapted from flash-attn.bert_padding.py

from typing import Tuple
from typing import Tuple, Union

import torch
from einops import rearrange, repeat
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_unpad_data(


def unpad_input(
q: torch.Tensor,
q: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
states: Tuple[torch.Tensor],
attention_mask: torch.Tensor,
q_len: int,
Expand All @@ -111,8 +111,9 @@ def unpad_input(


Arguments:
q (`torch.Tensor`):
q (`torch.Tensor` or `Tuple[torch.Tensor]`):
Query state with padding. Shape: [batch_size, q_len, ...].
When it is a tuple, do unpadding for each tensor in the tuple.
states (`Tuple[torch.Tensor]`):
Attention state with padding. Shape: [batch_size, seq_len, ...].
attention_mask (`torch.Tensor`):
Expand All @@ -123,19 +124,20 @@ def unpad_input(
Whether to keep the batch dimension. Default: `False`.

Return:
q (`torch.Tensor`):
q (`torch.Tensor` or `Tuple[torch.Tensor]`):
Query state without padding.
Shape: [1, total_target_length, ...] if `keepdim=True` else [total_target_length, ...].
When the `q` passed in is a tuple, return a tuple of such unpadded tensors.
states (`Tuple[torch.Tensor]`):
Attention state without padding.
Shape: [1, total_source_length, ...] if `keepdim=True` else [total_source_length, ...].
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
(cu_seqlens_q, cu_seqlens_k) (`Tuple[torch.LongTensor, torch.LongTensor]`):
The cumulative sequence lengths for the target (query) and source (key, value),
used to index into ragged (unpadded) tensors.
`cu_seqlens` shape is [batch_size + 1].
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int, int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence
i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
Expand All @@ -146,23 +148,30 @@ def unpad_input(
index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
for s in states
)
if isinstance(q, torch.Tensor):
q = (q,)
cast_tuple = True
else:
cast_tuple = False

if q_len == seq_len:
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
q = tuple(index_first_axis(rearrange(q_, "b s ... -> (b s) ..."), indices_k) for q_ in q)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q[0].device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
q = tuple(q_.squeeze(1) for q_ in q)
else:
Comment on lines +151 to 167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Device mismatch risk when attention_mask is on CPU and q/states are on CUDA.

indices_k and cu_seqlens_k can end up on a different device than q/states, breaking downstream varlen paths. Move them to q’s device after normalizing q to a tuple.

Apply:

@@
-    if isinstance(q, torch.Tensor):
+    if isinstance(q, torch.Tensor):
         q = (q,)
         cast_tuple = True
     else:
         cast_tuple = False
+    # Ensure indexing tensors follow q's device
+    indices_k = indices_k.to(device=q[0].device)
+    cu_seqlens_k = cu_seqlens_k.to(device=q[0].device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(q, torch.Tensor):
q = (q,)
cast_tuple = True
else:
cast_tuple = False
if q_len == seq_len:
q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
q = tuple(index_first_axis(rearrange(q_, "b s ... -> (b s) ..."), indices_k) for q_ in q)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q[0].device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
q = tuple(q_.squeeze(1) for q_ in q)
else:
if isinstance(q, torch.Tensor):
q = (q,)
cast_tuple = True
else:
cast_tuple = False
# Ensure indexing tensors follow q's device
indices_k = indices_k.to(device=q[0].device)
cu_seqlens_k = cu_seqlens_k.to(device=q[0].device)
if q_len == seq_len:
q = tuple(
index_first_axis(
rearrange(q_, "b s ... -> (b s) ..."),
indices_k
)
for q_ in q
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1,
dtype=torch.int32,
device=q[0].device
)
indices_q = cu_seqsels_q[:-1]
q = tuple(q_.squeeze(1) for q_ in q)
else:
...
🤖 Prompt for AI Agents
In fla/layers/utils.py around lines 151 to 167, after normalizing q into a tuple
(and determining cast_tuple), ensure indices_k and cu_seqlens_k are moved to q's
device to avoid CPU/CUDA mismatch when attention_mask lives on CPU;
specifically, right after q is converted to a tuple determine target_device =
q[0].device and call .to(target_device, non_blocking=True) on indices_k and
cu_seqlens_k (and any derived indices like indices_q) so downstream varlen code
uses the same device as q/states.

raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")

if keepdim:
q = q.unsqueeze(0)
q = tuple(q_.unsqueeze(0) for q_ in q)
state = tuple(s.unsqueeze(0) for s in state)
if cast_tuple:
q = q[0]

return (
q,
Expand Down
Loading
Loading