2
2
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
3
4
4
import warnings
5
- from typing import Optional , Union , Tuple
5
+ from typing import Optional , Tuple , Union
6
6
7
7
import torch
8
8
import triton
@@ -130,8 +130,8 @@ def parallel_nsa_kernel_topk(
130
130
o_i = tl .zeros ([BC ], dtype = tl .int32 )
131
131
m_i = tl .arange (0 , BC ) < BC // 2
132
132
133
- IC = (i_t + Q_OFFSET ) // BS # Idx of the current query block
134
- for i_c in range (0 , IC + 1 , BC ): # +1, because the current block might be also included
133
+ IC = (i_t + Q_OFFSET ) // BS # Idx of the current query block
134
+ for i_c in range (0 , IC + 1 , BC ): # +1, because the current block might be also included
135
135
o_c = i_c + tl .arange (0 , BC )
136
136
# Recall k: [B, TC, H, K], boc = i_b * TC
137
137
# we first shift to k[i_b, 0, i_h], and read a block of transposed keys from k[i_b, i_c, i_h]
@@ -207,7 +207,7 @@ def parallel_nsa_fwd_kernel(
207
207
IS_VARLEN : tl .constexpr ,
208
208
USE_BLOCK_COUNTS : tl .constexpr
209
209
):
210
- i_t , i_v , i_bh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 ) # i_t: token, i_v: value dim, i_bh: batch * kv head
210
+ i_t , i_v , i_bh = tl .program_id (0 ), tl .program_id (1 ), tl .program_id (2 ) # i_t: token, i_v: value dim, i_bh: batch * kv head
211
211
i_b , i_h = i_bh // H , i_bh % H
212
212
# k: [B, TK, H, K], v: [B, TK, H, V], q: [B, TQ, HQ, K]
213
213
# block_indices: [B, TQ, H, S]
@@ -259,7 +259,7 @@ def parallel_nsa_fwd_kernel(
259
259
# p_q then reads the BK dimensions at the last dimension
260
260
# the Q block is kept in the shared memory throughout the whole kernel
261
261
# [G, BK]
262
- b_q = tl .load (p_q , boundary_check = (0 , 1 )) # note that BK >= K, but there is boundary check
262
+ b_q = tl .load (p_q , boundary_check = (0 , 1 )) # note that BK >= K, but there is boundary check
263
263
b_q = (b_q * scale ).to (b_q .dtype )
264
264
265
265
p_o = tl .make_block_ptr (
@@ -275,10 +275,10 @@ def parallel_nsa_fwd_kernel(
275
275
# [G, BV]
276
276
b_o = tl .zeros ([G , BV ], dtype = tl .float32 )
277
277
278
- b_m = tl .full ([G ], float ('-inf' ), dtype = tl .float32 ) # running maximum
279
- b_acc = tl .zeros ([G ], dtype = tl .float32 ) # sumexp
280
- for i in range (NS ): # number of blocks
281
- i_s = tl .load (block_indices + i ).to (tl .int32 ) * BS # i_s is the start token index of the current KV block
278
+ b_m = tl .full ([G ], float ('-inf' ), dtype = tl .float32 ) # running maximum
279
+ b_acc = tl .zeros ([G ], dtype = tl .float32 ) # sumexp
280
+ for i in range (NS ): # number of blocks
281
+ i_s = tl .load (block_indices + i ).to (tl .int32 ) * BS # i_s is the start token index of the current KV block
282
282
# Here we assume that q tokens are last TQ tokens
283
283
if i_s <= Q_OFFSET + i_t and i_s >= 0 :
284
284
# Recall: k ([B, T, H, K]) already shifted to the start of the current sequence at head i_h, i.e. k[i_b, 0, i_h]
@@ -306,11 +306,10 @@ def parallel_nsa_fwd_kernel(
306
306
# [G, BS]
307
307
b_p = exp (b_s - b_m [:, None ])
308
308
# [G]
309
- b_acc = b_acc * b_r + tl .sum (b_p , 1 ) # summed over T dimension
309
+ b_acc = b_acc * b_r + tl .sum (b_p , 1 ) # summed over T dimension
310
310
# [G, BV]; note that b_p is fp32, while b_q may not
311
311
b_o = b_o * b_r [:, None ] + tl .dot (b_p .to (b_q .dtype ), b_v )
312
312
313
-
314
313
# o = o_n / a_n
315
314
# lse = log( exp(m_n) * a_n )
316
315
@@ -319,6 +318,7 @@ def parallel_nsa_fwd_kernel(
319
318
tl .store (p_o , b_o .to (p_o .dtype .element_ty ), boundary_check = (0 , 1 ))
320
319
tl .store (p_lse , b_m .to (p_lse .dtype .element_ty ))
321
320
321
+
322
322
@triton .heuristics ({
323
323
'USE_BLOCK_COUNTS' : lambda args : isinstance (args ['block_counts' ], torch .Tensor )
324
324
})
@@ -548,6 +548,7 @@ def parallel_nsa_bwd_kernel_dkv(
548
548
tl .store (p_dk , b_dk .to (p_dk .dtype .element_ty ), boundary_check = (0 , 1 ))
549
549
tl .store (p_dv , b_dv .to (p_dv .dtype .element_ty ), boundary_check = (0 , 1 ))
550
550
551
+
551
552
@contiguous
552
553
def parallel_nsa_topk (
553
554
q : torch .Tensor ,
@@ -557,7 +558,7 @@ def parallel_nsa_topk(
557
558
block_counts : Union [torch .LongTensor , int ],
558
559
block_size : int = 64 ,
559
560
scale : float = None ,
560
- cu_seqlens : Optional [ torch .LongTensor ] = None ,
561
+ cu_seqlens : Union [ None , torch .LongTensor , Tuple [ torch . LongTensor , torch . LongTensor ] ] = None ,
561
562
) -> torch .LongTensor :
562
563
B , TQ , HQ , K = q .shape
563
564
_ , TC , H , _ = k .shape
@@ -610,6 +611,7 @@ def parallel_nsa_topk(
610
611
)
611
612
return block_indices
612
613
614
+
613
615
@contiguous
614
616
def parallel_nsa_fwd (
615
617
q : torch .Tensor ,
@@ -655,7 +657,7 @@ def parallel_nsa_fwd(
655
657
token_indices_q = token_indices_q ,
656
658
TQ = T_q ,
657
659
TK = T_kv ,
658
- H = H ,
660
+ H = H ,
659
661
HQ = HQ ,
660
662
G = G ,
661
663
K = K ,
@@ -855,6 +857,7 @@ def backward(ctx, do):
855
857
)
856
858
return dq .to (q ), dk .to (k ), dv .to (v ), None , None , None , None , None , None , None , None
857
859
860
+
858
861
@contiguous
859
862
def parallel_nsa (
860
863
q : torch .Tensor ,
@@ -868,7 +871,7 @@ def parallel_nsa(
868
871
block_size : int = 64 ,
869
872
window_size : int = 0 ,
870
873
scale : Optional [float ] = None ,
871
- cu_seqlens : Union [None , torch .LongTensor , Tuple [torch .LongTensor ]] = None ,
874
+ cu_seqlens : Union [None , torch .LongTensor , Tuple [torch .LongTensor , torch . LongTensor ]] = None ,
872
875
) -> torch .Tensor :
873
876
r"""
874
877
Args:
@@ -888,7 +891,7 @@ def parallel_nsa(
888
891
block_indices (torch.LongTensor):
889
892
Block indices of shape `[B, TQ, H, S]`.
890
893
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
891
- If `g_cmp` is provided, the passed `block_indices` will be ignored .
894
+ Will override the computed block indices from compression if provided .
892
895
block_counts (Optional[Union[torch.LongTensor, int]]):
893
896
Number of selected blocks for each query.
894
897
If a tensor is provided, with shape `[B, TQ, H]`,
@@ -901,9 +904,10 @@ def parallel_nsa(
901
904
scale (Optional[float]):
902
905
Scale factor for attention scores.
903
906
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
904
- cu_seqlens (torch.LongTensor):
907
+ cu_seqlens (torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor] or None ):
905
908
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
906
909
consistent with the FlashAttention API.
910
+ When a tuple is provided, it should contain two tensors: `(cu_seqlens_q, cu_seqlens_k)`.
907
911
908
912
Returns:
909
913
o (torch.Tensor):
0 commit comments