
This repository provides the official implementation of Flash Sparse Attention (FSA), which includes a novel kernel design that enables efficient Native Sparse Attention (NSA) across a wide range of popular LLMs on modern GPUs.
-
$\texttt{[2025-09, upcoming]}$ : π Online profiling module, which seamlessly transitions between NSA and FSA, will be released soon. -
$\texttt{[2025-08]}$ : π₯ Our Arxiv paper is released. -
$\texttt{[2025-08]}$ : π Beta version of one-step decoding is released, check the code residing infsa_preview
. -
$\texttt{[2025-08]}$ : π Open sourcedFlash-Sparse-Attention
, offering an optimized implementation for NSA, broadening the applicability of this novel natively trainable sparse attention technique.
For NSA selected attention module, the major system bottleneck, NSA loops over query tokens in the outer loop and loops over KV blocks in the inner loop. To optimize performance, NSA batches query heads that share the same key-value head for more efficient computation. However, when GQA group size is not sufficiently large, NSA selected attention kernel must pad data to satisfy hardware requirements on the matrix dimensions of matrix multiplication. Specifically, for NVIDIA GPUs, the warp-level matrix multiply-accumulate instructions require that each dimension of a matrix tile executed on a warp must be larger than the specified value (at least 8 on Hopper GPUs). In Triton, NSA selected attention kernel must also satisfy that each dimension of a matrix tile executed on a thread block must be at least 16.
In contrast, FSA exchanges the kernel loop order of original NSA kernel design, i.e., FSA loops over KV blocks in the outer loop and loops over query tokens in the inner loop. To optimize performance, FSA decouples the computation into three major kernels: (i) the main kernel batches query tokens that attend to the same KV block and stores the partial results to a buffer, (ii) the reduction kernel accumulates attention results for each query token, and (iii) the online softmax kernel that handles online softmax statistics computation. The key insight behind this arrangement is to effectively reduce unnecessary memory access and computations for the padded data, while avoiding atomic
additions for accumulating attention results for each query token across KV blocks.
The concrete computation process comparison between NSA (left) and FSA main kernel (right) can be visualized as follows:
π The speedup of FSA originates from significantly lowered kernel-level memory access volume and computations.
Under varied GQA group sizes, NSA hyperparameters block size
FSA provides an optimized kernel implementation for NSA selected attention module. Without modifying NSA algorithm, FSA provides an efficient Triton-based implementation for GQA group sizes smaller than 8, which is more popular on state-of-the-art large language models (LLMs), on modern high-performance NVIDIA GPUs. For GQA group sizes larger than or equal to 8, FSA usually chooses to fall back to the original NSA implementation for better performance.
FSA is currently well tested with:
- NVIDIA Ampere or Hopper GPUs (e.g., A100 SXM, H20, H100 PCIe, H100 NVL, H100 SXM, H200 SXM);
- Datatype of fp16 and bf16;
- The same head dimension (less than or equal to 256) across query, key, and value;
- Varied GQA group sizes, ranging from 1 to 16;
- Training and inference (prefill).
The following requirements should be satisfied:
- PyTorch >= 2.4
- Triton >=3.0
- transformers >=4.45.0
- datasets >=3.3.0
- accelerate >= 1.9.0
- flash-attn ==2.6.3
You can install dependencies for FSA with:
pip install -r requirements.txt
We provide FlashSparseAttention
class for you to use, it can be used as the following example:
import torch
from fsa.module.fsa import FlashSparseAttention, RopeConfig
FSA = (
FlashSparseAttention(
hidden_size=4096,
num_q_heads=4,
num_kv_heads=4,
head_dim=128,
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=16,
init_blocks=1,
local_blocks=2,
window_size=512,
rope_config=RopeConfig(
max_position_embeddings=131072,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
),
)
.cuda()
.to(torch.bfloat16)
)
# random input
seqlens = torch.LongTensor([65536, 32768]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
x = torch.randn(cu_seqlens[-1], 4096, device="cuda", dtype=torch.bfloat16)
y = FSA(x, cu_seqlens)
loss = (y * torch.randn_like(y)).sum(-1).mean()
loss.backward()
Under the hood, the FSATopkSparseAttention
class is called, provding the optimized kernels that accelerate the NSA selected attention module.
Training with FSA can be esaily achieved by replacing the attention module. The only thing you may need to handle is to instantiate the FSA module, and compute the cu_seqlens
for FSA. We provide an example on how to insert FSA into a LLM in SparseLlamaAttention
.
We provide detailed commands in scripts/run_unit_test.sh
for convenient benchmarking of FSA module. The benchmarking provides correctness comparison of forward and backward outputs, performance comparison, and memory usage comparison.
The optimized NSA selected attention module, which is the major system bottleneck, can be benchmarked through the commands in scripts/run_unit_test_sel_attn.sh
.
Tip
Try varied gqa
, seqlen
, block_size
, topk
argument in the provided scripts for more comprehensive benchmarking on your machine! Compared to benchmarking the FSA attention module, benchmarking the FSA selected attention module usually provides a higher speedup.
Performance comparison of Triton-based FSA, NSA, and Full Attention (enabled by Flash Attention) kernels under various configurations. The tuple (
$64$ ,$16$ ) / ($128$ ,$8$ ) represents the block size$BK$ and top-k value$Topk$ , respectively. For FSA and NSA, the execution latency is composed of compressed, selected, and sliding attention; for Full Attention, the execution latency is the Flash Attention kernel execution latency.

End-to-end training (right) and prefill (left) latency of state-of-the-art LLMs with FSA, NSA, or Full Attention.

@article{yan2025flashsparseattentionalternative,
title={Flash Sparse Attention: More Efficient Natively Trainable Sparse Attention},
author={Yan, Ran and Jiang, Youhe and Yuan, Binhang},
journal={arXiv preprint arXiv:2508.18224},
year={2025}
}
NSA paper: Native Sparse Attention
NSA reference implementation: Native Sparse Attention Triton