-
Notifications
You must be signed in to change notification settings - Fork 256
Modify output shape in nsa for decoding #565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
# Simple GLA | ||
|
||
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). | ||
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). | ||
|
||
Compared to GLA, the gating is head-wise instead of elementwise. | ||
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. | ||
It is faster than GLA but has less expressive power. | ||
Compared to GLA, the gating is head-wise instead of elementwise. | ||
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. | ||
It is faster than GLA but has less expressive power. | ||
I will use it as a baseline for the GLA. | ||
|
||
$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,14 +7,14 @@ | |
> [!IMPORTANT] | ||
> The `flame` project has been migrated to a new project built on torchtitan. | ||
> Please visit the [new repository](https://github.com/fla-org/flame) for details and updates. | ||
> | ||
> | ||
> The code here is now **archived as legacy**, and no future updates will be synchronized here. | ||
|
||
A minimal framework for training FLA models, whether from scratch or through finetuning. | ||
|
||
Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code: | ||
we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training. | ||
|
||
In this README, we will guide you through the process of using `flame` to train GLA models. | ||
|
||
## Setup | ||
|
@@ -25,7 +25,7 @@ Clone the `fla` repository and install the necessary packages as follows: | |
|
||
```bash | ||
git clone https://github.com/sustcsonglin/flash-linear-attention.git | ||
pip install . | ||
pip install . | ||
pip install accelerate | ||
``` | ||
Comment on lines
26
to
30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add missing cd step before pip install . (and consider installing tokenizers explicitly). Without changing directory, ```bash
git clone https://github.com/sustcsonglin/flash-linear-attention.git
+cd flash-linear-attention
pip install .
pip install accelerate
+pip install 'tokenizers>=0.20.4'
+# If you plan to use DeepSpeed:
+# pip install 'accelerate[deepspeed]' deepspeed
In legacy/training/README.md around lines 26 to 30, the instructions run pip
|
||
|
||
|
@@ -35,8 +35,8 @@ pip install accelerate | |
|
||
## Preprocessing | ||
|
||
Before training, you need to download and pre-tokenize your dataset. | ||
We provide a straightforward script for this. | ||
Before training, you need to download and pre-tokenize your dataset. | ||
We provide a straightforward script for this. | ||
For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run: | ||
|
||
```bash | ||
|
@@ -103,15 +103,15 @@ Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported. | |
|
||
The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as | ||
`batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`. | ||
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). | ||
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens). | ||
|
||
The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps. | ||
Each step processes `global_batch_size` tokens. | ||
Each step processes `global_batch_size` tokens. | ||
Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively. | ||
|
||
:warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!! | ||
|
||
`flame` also supports resuming interrupted training by specifying the checkpoint path. | ||
`flame` also supports resuming interrupted training by specifying the checkpoint path. | ||
Simply use the following command: | ||
|
||
```bash | ||
|
@@ -141,7 +141,7 @@ You can also use `wandb` to monitor your training process effectively. | |
## Continual Pretraining | ||
|
||
`flame` supports continual training from a pretrained checkpoint. | ||
Below, we provide an example of how to finetune Mistral-7B to GLA. | ||
Below, we provide an example of how to finetune Mistral-7B to GLA. | ||
You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146): | ||
|
||
1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B: | ||
|
@@ -171,7 +171,7 @@ bash train.sh \ | |
cache=data/SlimPajama-627B/train | ||
``` | ||
|
||
Please be aware that finetuning on a single node may not be the most efficient approach. | ||
Please be aware that finetuning on a single node may not be the most efficient approach. | ||
If available, consider leveraging multi-node GPUs for optimal performance. | ||
You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh). | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,4 @@ | |
"use_gk": true, | ||
"use_gv": false, | ||
"vocab_size": 32000 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,4 @@ | |
"use_gk": true, | ||
"use_gv": false, | ||
"vocab_size": 32000 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,4 @@ | |
"use_gk": true, | ||
"use_gv": false, | ||
"vocab_size": 32000 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,4 @@ | |
"tie_word_embeddings": true, | ||
"use_cache": true, | ||
"vocab_size": 32000 | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Critical: grid/output use T_q but kernel indexing still uses T (KV length) → potential misaddressing/OOB for B>1
You correctly switched grid/o/lse to T_q, but
parallel_nsa_fwd_kernel
still receives onlyT
and derivesbos/eos
and all pointer offsets from it. In the non-varlen path, this makes q/lse/block_indices indexing depend on KV length, which becomes incorrect when T_q != T (e.g., decoding) and is unsafe for B>1 (wrong row strides, possible OOB).Actionable fix: pass T_q to the kernel and decouple q- vs k/v-based offsets inside the kernel.
Strongly recommend mirroring this decoupling in the backward kernels as well:
parallel_nsa_bwd_kernel_dq
: add T_q; use bos_q for q/do/lse/delta/dq and bos_k for k/v; change launch grid to (T_q, NV, B*H).parallel_nsa_bwd_kernel_dkv
: add T_q if needed for q-side offsets used inside the per-token loop; ensure block_mask and its T dimension are consistent with the iteration bound.Short-term mitigation (until kernels are updated): keep the guard suggested above to prevent B>1 decode runs when T_q != T.
Also consider updating the docstring of
parallel_nsa
to clarify that output seqlen matches q (T_q).I can draft the full patch set for fwd/bwd kernels if you’d like.
To validate, please run a decode pass with B=2, T_q << T (e.g., T_q=1, T=128) and confirm no memory errors and correct o.shape == (B, T_q, HQ, V).
🏁 Script executed:
Length of output: 550
🏁 Script executed:
Length of output: 27435
Critical: q-indexing uses T (KV length) while kernel/grid use T_q — will misaddress / OOB when T_q != T (e.g. decoding with B>1)
Short verification: I inspected fla/ops/nsa/parallel.py — forward allocates o/lse with T_q and launches the fwd kernel with grid=(T_q, NV, B*H) (lines ~533–562), but parallel_nsa_fwd_kernel (def at ~173) computes bos = i_b * T in the non-varlen path and uses (bos + i_t) to index q/o/lse/block_indices. The same pattern occurs in the backward kernels: parallel_nsa_bwd_kernel_dq (def ~299) and parallel_nsa_bwd_kernel_dkv (def ~407) derive offsets from T but callers / buffers use T_q. This will produce wrong row strides and can be OOB when T_q != T and B>1.
Affected locations (file: fla/ops/nsa/parallel.py)
Actionable fix (minimal, safe design)
Example (concise snippets to apply)
Required backward changes (summary)
Short-term mitigation
Docs/tests
I found definitive evidence in the file that the current implementation mixes T and T_q; this needs code changes as described above. I recommend applying the kernel-signature + offset-decoupling changes and updating the backward launches (grid dimensions and "all" computations) so q-indexing always uses T_q.