Skip to content

Commit bb2e32e

Browse files
committed
Fix lint, hints, and comments
1 parent ff69e59 commit bb2e32e

File tree

6 files changed

+199
-169
lines changed

6 files changed

+199
-169
lines changed

fla/layers/nsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from einops import rearrange
1111
from transformers.utils import logging
1212

13+
from fla.layers.utils import pad_input, unpad_input
1314
from fla.modules import RotaryEmbedding
1415
from fla.ops.nsa.parallel import parallel_nsa
1516
from fla.ops.utils.index import prepare_lens_from_mask
16-
from fla.layers.utils import pad_input, unpad_input
1717

1818
if TYPE_CHECKING:
1919
from fla.models.utils import Cache

fla/layers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_unpad_data(
9999

100100

101101
def unpad_input(
102-
q: Union[torch.Tensor, Tuple[torch.Tensor]],
102+
q: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
103103
states: Tuple[torch.Tensor],
104104
attention_mask: torch.Tensor,
105105
q_len: int,
@@ -133,11 +133,11 @@ def unpad_input(
133133
Shape: [1, total_source_length, ...] if `keepdim=True` else [total_source_length, ...].
134134
indices_q (`torch.Tensor`):
135135
The indices of non-masked tokens from the flattened input target sequence.
136-
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
136+
(cu_seqlens_q, cu_seqlens_k) (`Tuple[torch.LongTensor, torch.LongTensor]`):
137137
The cumulative sequence lengths for the target (query) and source (key, value),
138138
used to index into ragged (unpadded) tensors.
139139
`cu_seqlens` shape is [batch_size + 1].
140-
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
140+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int, int]`):
141141
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence
142142
i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
143143
"""

fla/ops/nsa/compression.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
from typing import Optional
4+
from typing import Optional, Tuple, Union
55

66
import torch
77
import triton
@@ -84,9 +84,8 @@ def parallel_nsa_compression_fwd_kernel(
8484
# lse = log(acc) + m
8585
b_acc = tl.zeros([G], dtype=tl.float32)
8686

87-
8887
for i_c in range(0, NC, BC):
89-
o_c = i_c + tl.arange(0, BC) # block idx
88+
o_c = i_c + tl.arange(0, BC) # block idx
9089

9190
p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
9291
p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0))
@@ -380,7 +379,6 @@ def parallel_nsa_compression_fwd(
380379
return o, lse
381380

382381

383-
384382
def parallel_nsa_compression_bwd(
385383
q: torch.Tensor,
386384
k: torch.Tensor,
@@ -512,6 +510,7 @@ def forward(
512510
token_indices_q=token_indices_q
513511
)
514512
ctx.save_for_backward(q, k, v, o, lse)
513+
# Use cu_seqlens of q in backward, as cu_seqlens for q & k are different only for inference
515514
ctx.cu_seqlens = cu_seqlens_q
516515
ctx.token_indices = token_indices_q
517516
ctx.block_size = block_size
@@ -545,7 +544,7 @@ def parallel_nsa_compression(
545544
TK: int,
546545
block_size: int = 64,
547546
scale: float = None,
548-
cu_seqlens: Optional[torch.LongTensor] = None
547+
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None
549548
):
550549
if scale is None:
551550
scale = k.shape[-1] ** -0.5

fla/ops/nsa/naive.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
import warnings
5-
from typing import Optional, Union, Tuple
5+
from typing import Optional, Tuple, Union
66

77
import torch
88
from einops import repeat
9-
from torch.nn.attention.flex_attention import create_block_mask, and_masks
10-
from torch.nn.attention.flex_attention import flex_attention
9+
from torch.nn.attention.flex_attention import and_masks, create_block_mask, flex_attention
10+
11+
from fla.ops.utils import prepare_chunk_offsets, prepare_token_indices
1112
from fla.ops.utils.pooling import mean_pooling
12-
from fla.ops.utils import prepare_token_indices, prepare_chunk_offsets
1313

1414
try:
1515
from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -20,14 +20,15 @@
2020
)
2121
flash_attn_func = None
2222

23+
2324
def naive_nsa_sel(
2425
q: torch.Tensor,
2526
k: torch.Tensor,
2627
v: torch.Tensor,
2728
block_indices: torch.LongTensor,
2829
block_size: int = 64,
2930
scale: Optional[float] = None,
30-
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor]] = None,
31+
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None,
3132
head_first: bool = False
3233
) -> torch.Tensor:
3334
r"""
@@ -47,7 +48,7 @@ def naive_nsa_sel(
4748
scale (Optional[float]):
4849
Scale factor for attention scores.
4950
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
50-
cu_seqlens (torch.LongTensor or Tuple[torch.LongTensor]):
51+
cu_seqlens (torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor] or None):
5152
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
5253
consistent with the FlashAttention API.
5354
When a tuple is provided, it should contain two tensors: `(cu_seqlens_q, cu_seqlens_k)`.
@@ -88,10 +89,10 @@ def naive_nsa_sel(
8889
Tq = Tk = q.shape[1]
8990
cu_q = torch.cat([
9091
block_indices.new_tensor(range(0, B * Tq, Tq)), block_indices.new_tensor([B * Tq])
91-
])
92+
]).to(device=q.device)
9293
cu_k = torch.cat([
9394
block_indices.new_tensor(range(0, B * Tk, Tk)), block_indices.new_tensor([B * Tk])
94-
])
95+
]).to(device=q.device)
9596
else:
9697
if isinstance(cu_seqlens, tuple):
9798
cu_q, cu_k = cu_seqlens
@@ -104,8 +105,9 @@ def naive_nsa_sel(
104105
else:
105106
Tq = cu_q[i+1] - cu_q[i]
106107
Tk = cu_k[i+1] - cu_k[i]
107-
q_b, k_b, v_b, i_b = q[0][cu_q[i]:cu_q[i+1]], k[0][cu_k[i]:cu_k[i+1]], v[0][cu_k[i]:cu_k[i+1]], block_indices[0][cu_q[i]:cu_q[i+1]]
108-
108+
q_b, k_b, v_b, i_b = (q[0][cu_q[i]:cu_q[i+1]], k[0][cu_k[i]:cu_k[i+1]],
109+
v[0][cu_k[i]:cu_k[i+1]], block_indices[0][cu_q[i]:cu_q[i+1]])
110+
assert Tq == Tk, "TQ != TK case is not supported in naive_nsa_sel"
109111
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
110112
# [T, S*BS, HQ]
111113
i_b = i_b.view(Tq, block_indices.shape[2], -1).transpose(1, 2)
@@ -115,16 +117,19 @@ def naive_nsa_sel(
115117
# [S*BS, HQ]
116118
i_i = i_b[i_q]
117119
# [S*BS, HQ, -1]
118-
k_i, v_i = map(lambda x: x.gather(0, i_i.clamp(0, Tk-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
120+
k_i, v_i = map(lambda x: x.gather(0, i_i.clamp(0, Tk-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])),
121+
(k_b, v_b))
119122
# [S*BS, HQ]
120-
attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(torch.logical_or(i_i > i_q, i_i < 0), float('-inf')).softmax(0)
123+
attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(
124+
torch.logical_or(i_i > i_q, i_i < 0), float('-inf')).softmax(0)
121125
if not varlen:
122126
o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
123127
else:
124128
o[0][cu_q[i] + i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
125129

126130
return o.to(dtype)
127131

132+
128133
def naive_nsa_cmp(q, k_cmp, v_cmp, block_size, scale, cu_seqlens=None):
129134
if cu_seqlens is not None:
130135
seq_indices = prepare_token_indices(cu_seqlens)
@@ -167,7 +172,7 @@ def naive_nsa_topk(
167172
block_counts: Union[int, torch.Tensor], # int or [B, T_q, Hkv]
168173
block_size: int,
169174
scale: float,
170-
cu_seqlens: Union[None, torch.Tensor, Tuple[torch.Tensor]] = None,
175+
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None,
171176
) -> torch.Tensor:
172177
B, Tq, Hq, _ = q.shape
173178
Hkv = k_cmp.shape[2]
@@ -214,12 +219,12 @@ def naive_nsa_topk(
214219
t = torch.arange(Tq, device=device).unsqueeze(1)
215220
s = torch.arange(Tc, device=device).unsqueeze(0)
216221
block_last_pos = (s + 1) * block_size - 1
217-
base_allow = (block_last_pos <= t) # [Tq,Tc]
222+
base_allow = (block_last_pos <= t) # [Tq,Tc]
218223

219224
i_qb = (t // block_size) # [Tq,1]
220225
is_current_block = (s == i_qb) | (s == 0) | (s == i_qb - 1) # [Tq,Tc]
221226
logits = logits.masked_fill(~base_allow[:, None, None, :], float("-inf"))
222-
allow = base_allow | is_current_block # [Tq,Tc]
227+
allow = base_allow | is_current_block # [Tq,Tc]
223228

224229
probs_q = torch.softmax(logits, dim=-1) # [Tq, Hkv, G, Tc]
225230
probs_q = torch.nan_to_num(probs_q, nan=0.0) # rows with no valid blocks -> 0
@@ -273,7 +278,7 @@ def naive_nsa(
273278
block_size: int = 64,
274279
window_size: int = 0,
275280
scale: Optional[float] = None,
276-
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor]] = None,
281+
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None,
277282
return_block_indices: bool = False,
278283
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.LongTensor]]:
279284
r"""
@@ -307,7 +312,7 @@ def naive_nsa(
307312
scale (Optional[float]):
308313
Scale factor for attention scores.
309314
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
310-
cu_seqlens (torch.LongTensor or Tuple[torch.LongTensor]):
315+
cu_seqlens (torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor] or None):
311316
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
312317
consistent with the FlashAttention API.
313318
When a tuple is provided, it should contain two tensors: `(cu_seqlens_q, cu_seqlens_k)`.

fla/ops/nsa/parallel.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

44
import warnings
5-
from typing import Optional, Union, Tuple
5+
from typing import Optional, Tuple, Union
66

77
import torch
88
import triton
@@ -130,8 +130,8 @@ def parallel_nsa_kernel_topk(
130130
o_i = tl.zeros([BC], dtype=tl.int32)
131131
m_i = tl.arange(0, BC) < BC//2
132132

133-
IC = (i_t + Q_OFFSET) // BS # Idx of the current query block
134-
for i_c in range(0, IC + 1, BC): # +1, because the current block might be also included
133+
IC = (i_t + Q_OFFSET) // BS # Idx of the current query block
134+
for i_c in range(0, IC + 1, BC): # +1, because the current block might be also included
135135
o_c = i_c + tl.arange(0, BC)
136136
# Recall k: [B, TC, H, K], boc = i_b * TC
137137
# we first shift to k[i_b, 0, i_h], and read a block of transposed keys from k[i_b, i_c, i_h]
@@ -207,7 +207,7 @@ def parallel_nsa_fwd_kernel(
207207
IS_VARLEN: tl.constexpr,
208208
USE_BLOCK_COUNTS: tl.constexpr
209209
):
210-
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) # i_t: token, i_v: value dim, i_bh: batch * kv head
210+
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) # i_t: token, i_v: value dim, i_bh: batch * kv head
211211
i_b, i_h = i_bh // H, i_bh % H
212212
# k: [B, TK, H, K], v: [B, TK, H, V], q: [B, TQ, HQ, K]
213213
# block_indices: [B, TQ, H, S]
@@ -259,7 +259,7 @@ def parallel_nsa_fwd_kernel(
259259
# p_q then reads the BK dimensions at the last dimension
260260
# the Q block is kept in the shared memory throughout the whole kernel
261261
# [G, BK]
262-
b_q = tl.load(p_q, boundary_check=(0, 1)) # note that BK >= K, but there is boundary check
262+
b_q = tl.load(p_q, boundary_check=(0, 1)) # note that BK >= K, but there is boundary check
263263
b_q = (b_q * scale).to(b_q.dtype)
264264

265265
p_o = tl.make_block_ptr(
@@ -275,10 +275,10 @@ def parallel_nsa_fwd_kernel(
275275
# [G, BV]
276276
b_o = tl.zeros([G, BV], dtype=tl.float32)
277277

278-
b_m = tl.full([G], float('-inf'), dtype=tl.float32) # running maximum
279-
b_acc = tl.zeros([G], dtype=tl.float32) # sumexp
280-
for i in range(NS): # number of blocks
281-
i_s = tl.load(block_indices + i).to(tl.int32) * BS # i_s is the start token index of the current KV block
278+
b_m = tl.full([G], float('-inf'), dtype=tl.float32) # running maximum
279+
b_acc = tl.zeros([G], dtype=tl.float32) # sumexp
280+
for i in range(NS): # number of blocks
281+
i_s = tl.load(block_indices + i).to(tl.int32) * BS # i_s is the start token index of the current KV block
282282
# Here we assume that q tokens are last TQ tokens
283283
if i_s <= Q_OFFSET + i_t and i_s >= 0:
284284
# Recall: k ([B, T, H, K]) already shifted to the start of the current sequence at head i_h, i.e. k[i_b, 0, i_h]
@@ -306,11 +306,10 @@ def parallel_nsa_fwd_kernel(
306306
# [G, BS]
307307
b_p = exp(b_s - b_m[:, None])
308308
# [G]
309-
b_acc = b_acc * b_r + tl.sum(b_p, 1) # summed over T dimension
309+
b_acc = b_acc * b_r + tl.sum(b_p, 1) # summed over T dimension
310310
# [G, BV]; note that b_p is fp32, while b_q may not
311311
b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
312312

313-
314313
# o = o_n / a_n
315314
# lse = log( exp(m_n) * a_n )
316315

@@ -319,6 +318,7 @@ def parallel_nsa_fwd_kernel(
319318
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
320319
tl.store(p_lse, b_m.to(p_lse.dtype.element_ty))
321320

321+
322322
@triton.heuristics({
323323
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
324324
})
@@ -548,6 +548,7 @@ def parallel_nsa_bwd_kernel_dkv(
548548
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
549549
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
550550

551+
551552
@contiguous
552553
def parallel_nsa_topk(
553554
q: torch.Tensor,
@@ -557,7 +558,7 @@ def parallel_nsa_topk(
557558
block_counts: Union[torch.LongTensor, int],
558559
block_size: int = 64,
559560
scale: float = None,
560-
cu_seqlens: Optional[torch.LongTensor] = None,
561+
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None,
561562
) -> torch.LongTensor:
562563
B, TQ, HQ, K = q.shape
563564
_, TC, H, _ = k.shape
@@ -610,6 +611,7 @@ def parallel_nsa_topk(
610611
)
611612
return block_indices
612613

614+
613615
@contiguous
614616
def parallel_nsa_fwd(
615617
q: torch.Tensor,
@@ -655,7 +657,7 @@ def parallel_nsa_fwd(
655657
token_indices_q=token_indices_q,
656658
TQ=T_q,
657659
TK=T_kv,
658-
H=H,
660+
H=H,
659661
HQ=HQ,
660662
G=G,
661663
K=K,
@@ -855,6 +857,7 @@ def backward(ctx, do):
855857
)
856858
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
857859

860+
858861
@contiguous
859862
def parallel_nsa(
860863
q: torch.Tensor,
@@ -868,7 +871,7 @@ def parallel_nsa(
868871
block_size: int = 64,
869872
window_size: int = 0,
870873
scale: Optional[float] = None,
871-
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor]] = None,
874+
cu_seqlens: Union[None, torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]] = None,
872875
) -> torch.Tensor:
873876
r"""
874877
Args:
@@ -888,7 +891,7 @@ def parallel_nsa(
888891
block_indices (torch.LongTensor):
889892
Block indices of shape `[B, TQ, H, S]`.
890893
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
891-
If `g_cmp` is provided, the passed `block_indices` will be ignored.
894+
Will override the computed block indices from compression if provided.
892895
block_counts (Optional[Union[torch.LongTensor, int]]):
893896
Number of selected blocks for each query.
894897
If a tensor is provided, with shape `[B, TQ, H]`,
@@ -901,9 +904,10 @@ def parallel_nsa(
901904
scale (Optional[float]):
902905
Scale factor for attention scores.
903906
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
904-
cu_seqlens (torch.LongTensor):
907+
cu_seqlens (torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor] or None):
905908
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
906909
consistent with the FlashAttention API.
910+
When a tuple is provided, it should contain two tensors: `(cu_seqlens_q, cu_seqlens_k)`.
907911
908912
Returns:
909913
o (torch.Tensor):

0 commit comments

Comments
 (0)