Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion evals/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import fla # noqa
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

import fla # noqa


@register_model('fla')
class FlashLinearAttentionLMWrapper(HFLM):
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/nsa/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def parallel_nsa_fwd(
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
_, T_q, HQ, _ = q.shape
G = HQ // H
BS = block_size
if check_shared_mem('hopper', q.device.index):
Expand All @@ -555,9 +555,9 @@ def parallel_nsa_fwd(
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"

grid = (T, NV, B * H)
o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
grid = (T_q, NV, B * H)
o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device)

Comment on lines +558 to 561
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Critical: grid/output use T_q but kernel indexing still uses T (KV length) → potential misaddressing/OOB for B>1

You correctly switched grid/o/lse to T_q, but parallel_nsa_fwd_kernel still receives only T and derives bos/eos and all pointer offsets from it. In the non-varlen path, this makes q/lse/block_indices indexing depend on KV length, which becomes incorrect when T_q != T (e.g., decoding) and is unsafe for B>1 (wrong row strides, possible OOB).

Actionable fix: pass T_q to the kernel and decouple q- vs k/v-based offsets inside the kernel.

  • Forward call site change (add T_q):
# around Lines 562-583: pass T_k and T_q distinctly
parallel_nsa_fwd_kernel[grid](
    q=q, k=k, v=v, o=o, lse=lse, scale=scale,
    block_indices=block_indices, block_counts=block_counts,
    cu_seqlens=cu_seqlens, token_indices=token_indices,
    T=T,              # keys length (unchanged)
    T_q=T_q,          # NEW: queries length
    H=H, HQ=HQ, G=G, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV,
)
  • Kernel signature and internal offsets (decouple bos for q vs k/v):
@triton.jit
def parallel_nsa_fwd_kernel(
    q, k, v, o, lse, scale, block_indices, block_counts,
    cu_seqlens, token_indices,
    T,               # keys length
    T_q,             # NEW: queries length
    H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
    S: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
    IS_VARLEN: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr
):
    ...
    if IS_VARLEN:
        i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
        bos_k = tl.load(cu_seqlens + i_n).to(tl.int32)
        bos_q = bos_k  # varlen paths have matching lengths
    else:
        bos_k = i_b * T
        bos_q = i_b * T_q

    # Offsets using bos_k for k/v, bos_q for q/o/lse/block_indices
    k += (bos_k * H + i_h) * K
    v += (bos_k * H + i_h) * V
    block_indices += (bos_q + i_t) * H * S + i_h * S
    p_q  = tl.make_block_ptr(q  + (bos_q + i_t) * HQ*K, ...)
    p_o  = tl.make_block_ptr(o  + (bos_q + i_t) * HQ*V, ...)
    p_lse = lse + (bos_q + i_t) * HQ + i_h * G + tl.arange(0, G)
    if USE_BLOCK_COUNTS:
        NS = tl.load(block_counts + (bos_q + i_t) * H + i_h)
    else:
        NS = S
    ...
    # Keep p_k/p_v constructed with (K, T) / (T, V) layouts (T is key length)

Strongly recommend mirroring this decoupling in the backward kernels as well:

  • parallel_nsa_bwd_kernel_dq: add T_q; use bos_q for q/do/lse/delta/dq and bos_k for k/v; change launch grid to (T_q, NV, B*H).
  • parallel_nsa_bwd_kernel_dkv: add T_q if needed for q-side offsets used inside the per-token loop; ensure block_mask and its T dimension are consistent with the iteration bound.

Short-term mitigation (until kernels are updated): keep the guard suggested above to prevent B>1 decode runs when T_q != T.

Also consider updating the docstring of parallel_nsa to clarify that output seqlen matches q (T_q).

I can draft the full patch set for fwd/bwd kernels if you’d like.

To validate, please run a decode pass with B=2, T_q << T (e.g., T_q=1, T=128) and confirm no memory errors and correct o.shape == (B, T_q, HQ, V).


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "---- repo root ----"
pwd
echo

echo "---- show file existence and size ----"
ls -l fla/ops/nsa/parallel.py || true
echo

echo "---- print file with line numbers (1..800) ----"
nl -ba fla/ops/nsa/parallel.py | sed -n '1,800p'
echo

echo "---- search for kernel call sites and definitions ----"
rg "parallel_nsa_fwd_kernel" -n || true
rg "parallel_nsa_bwd_kernel" -n || true
rg "T_q" -n || true
rg "bos_q" -n || true
rg "bos_k" -n || true

Length of output: 550


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "---- repo root ----"
pwd
echo

echo "---- file info ----"
ls -l fla/ops/nsa/parallel.py || true
echo

echo "---- print lines 1..220 ----"
sed -n '1,220p' fla/ops/nsa/parallel.py || true
echo

echo "---- print lines 221..480 ----"
sed -n '221,480p' fla/ops/nsa/parallel.py || true
echo

echo "---- print lines 481..760 ----"
sed -n '481,760p' fla/ops/nsa/parallel.py || true
echo

echo "---- search for kernel defs/calls and T_q usage ----"
rg -n "parallel_nsa|parallel_nsa_fwd|parallel_nsa_fwd_kernel|parallel_nsa_bwd|parallel_nsa_bwd_kernel|T_q|make_block_ptr|triton.jit" fla/ops/nsa/parallel.py || true

Length of output: 27435


Critical: q-indexing uses T (KV length) while kernel/grid use T_q — will misaddress / OOB when T_q != T (e.g. decoding with B>1)

Short verification: I inspected fla/ops/nsa/parallel.py — forward allocates o/lse with T_q and launches the fwd kernel with grid=(T_q, NV, B*H) (lines ~533–562), but parallel_nsa_fwd_kernel (def at ~173) computes bos = i_b * T in the non-varlen path and uses (bos + i_t) to index q/o/lse/block_indices. The same pattern occurs in the backward kernels: parallel_nsa_bwd_kernel_dq (def ~299) and parallel_nsa_bwd_kernel_dkv (def ~407) derive offsets from T but callers / buffers use T_q. This will produce wrong row strides and can be OOB when T_q != T and B>1.

Affected locations (file: fla/ops/nsa/parallel.py)

  • parallel_nsa_fwd (caller): grid, o, lse allocation ~lines 533–562 (uses T_q)
  • parallel_nsa_fwd_kernel (kernel): def ~173 — computes bos using T and indexes q/o/lse with (bos + i_t)
  • parallel_nsa_bwd (caller): dq allocation and kernel launch ~lines 614–646 (dq uses q.shape / T_q, launch uses T)
  • parallel_nsa_bwd_kernel_dq (kernel): def ~299 — uses all = B * T and bos = i_b * T for q/dq offsets
  • parallel_nsa_bwd_kernel_dkv (kernel): def ~407 — loop bounds and q pointers use T/bos

Actionable fix (minimal, safe design)

  • Pass queries length (T_q) explicitly to all kernels that index/iterate over queries.
  • Inside kernels decouple query vs key/value base offsets:
    • compute bos_k = i_b * T_k and bos_q = i_b * T_q (for varlen keep bos_k = bos_q = bos).
    • Use bos_k for k/v offsets; use bos_q for q/o/lse/block_indices/block_counts offsets and for any loop bounds over query tokens.
  • Update dq allocation / launch:
    • allocate dq as now (NV, q.shape) but launch dq kernel with grid = (T_q, NV, BH) and compute dq pointer arithmetic using all_q = B * T_q (not B * T_k).
  • Mirror the same decoupling in parallel_nsa_bwd_kernel_dq and parallel_nsa_bwd_kernel_dkv (loop bounds, block_mask indexing, and any "all" computations must use T_q for q-side, T_k for k/v-side).

Example (concise snippets to apply)

  • Forward call site: pass T_q
parallel_nsa_fwd_kernel[grid](
    q=q, k=k, v=v, o=o, lse=lse, scale=scale,
    block_indices=block_indices, block_counts=block_counts,
    cu_seqlens=cu_seqlens, token_indices=token_indices,
    T=T,              # keys length (unchanged)
    T_q=T_q,          # NEW: queries length
    H=H, HQ=HQ, G=G, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV,
)
  • Kernel signature + internal decoupling (illustrative)
def parallel_nsa_fwd_kernel(..., T, T_q, H:tl.constexpr, HQ:tl.constexpr, ...):
    ...
    if IS_VARLEN:
        i_n, i_t = ...
        bos, eos = ...
        bos_k = bos_q = bos
        # T set from eos-bos for varlen
    else:
        bos_k = i_b * T         # keys base
        bos_q = i_b * T_q       # queries base

    k += (bos_k * H + i_h) * K
    v += (bos_k * H + i_h) * V
    block_indices += (bos_q + i_t) * H*S + i_h * S

    p_q = tl.make_block_ptr(q + (bos_q + i_t) * HQ*K, ...)
    p_o = tl.make_block_ptr(o + (bos_q + i_t) * HQ*V, ...)
    p_lse = lse + (bos_q + i_t) * HQ + i_h * G + tl.arange(0, G)
    if USE_BLOCK_COUNTS:
        NS = tl.load(block_counts + (bos_q + i_t) * H + i_h)
    else:
        NS = S
    ...

Required backward changes (summary)

  • parallel_nsa_bwd: allocate dq as before, but launch parallel_nsa_bwd_kernel_dq with grid = (T_q, NV, B*H) and pass T_q.
  • parallel_nsa_bwd_kernel_dq: add T_q param, use all_q = B * T_q for dq pointer arithmetic and bos_q for q/do/lse/dq offsets; use bos_k for k/v offsets.
  • parallel_nsa_bwd_kernel_dkv: add T_q (if q-side offsets/loop bounds used), use bos_q for q pointers and loop bounds over queries; use bos_k / all_k for dk/dv offsets.
  • Ensure parallel_nsa_block_mask / NS computation and any other T-derived sizes use the same T_q where appropriate (block_indices shape must match queries length).

Short-term mitigation

  • Until kernels are updated, add a guard preventing runs with B>1 when T_q != T (decoding case) to avoid OOB.

Docs/tests

  • Update parallel_nsa docstring to state output seqlen == queries length T_q.
  • Add a small unit test: B=2, T_q=1, T=128 decode case to validate shapes and no OOB.

I found definitive evidence in the file that the current implementation mixes T and T_q; this needs code changes as described above. I recommend applying the kernel-signature + offset-decoupling changes and updating the backward launches (grid dimensions and "all" computations) so q-indexing always uses T_q.

parallel_nsa_fwd_kernel[grid](
q=q,
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/simple_gla/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Simple GLA

Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).

Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
I will use it as a baseline for the GLA.

$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
20 changes: 10 additions & 10 deletions legacy/training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
> [!IMPORTANT]
> The `flame` project has been migrated to a new project built on torchtitan.
> Please visit the [new repository](https://github.com/fla-org/flame) for details and updates.
>
>
> The code here is now **archived as legacy**, and no future updates will be synchronized here.

A minimal framework for training FLA models, whether from scratch or through finetuning.

Built on the robust infrastructure of 🤗, `flame` enables you to train large language models with just a few lines of code:
we use `datasets` for data processing, `transformers` for model definitions, and `accelerate`[^1] for seamless distributed training.

In this README, we will guide you through the process of using `flame` to train GLA models.

## Setup
Expand All @@ -25,7 +25,7 @@ Clone the `fla` repository and install the necessary packages as follows:

```bash
git clone https://github.com/sustcsonglin/flash-linear-attention.git
pip install .
pip install .
pip install accelerate
```
Comment on lines 26 to 30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add missing cd step before pip install . (and consider installing tokenizers explicitly).

Without changing directory, pip install . will not install the cloned repo. Also, since you call out a minimum tokenizer version below, it’s helpful to show the explicit install here.

 ```bash
 git clone https://github.com/sustcsonglin/flash-linear-attention.git
+cd flash-linear-attention
 pip install .
 pip install accelerate
+pip install 'tokenizers>=0.20.4'
+# If you plan to use DeepSpeed:
+# pip install 'accelerate[deepspeed]' deepspeed

<details>
<summary>🤖 Prompt for AI Agents</summary>

In legacy/training/README.md around lines 26 to 30, the instructions run pip
install . immediately after git clone which fails because the working directory
wasn't changed; update the steps to cd into the cloned flash-linear-attention
directory before running pip install ., and add an explicit installation of a
compatible tokenizers package (e.g. tokenizers>=0.20.4); optionally note
DeepSpeed install extras (accelerate[deepspeed] and deepspeed) as an additional
step for users planning to use DeepSpeed.


</details>

<!-- fingerprinting:phantom:triton:chinchilla -->

<!-- This is an auto-generated comment by CodeRabbit -->


Expand All @@ -35,8 +35,8 @@ pip install accelerate

## Preprocessing

Before training, you need to download and pre-tokenize your dataset.
We provide a straightforward script for this.
Before training, you need to download and pre-tokenize your dataset.
We provide a straightforward script for this.
For instance, to tokenize a 10B sample of the `fineweb-edu` dataset, run:

```bash
Expand Down Expand Up @@ -103,15 +103,15 @@ Other scheduler types like WSD (`warmup_stable_decay`)[^2] are also supported.

The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as
`batch_size × gradient_accumulation_steps × context_length × num_gpus_per_node × num_nodes`.
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens).
For instance, in the 340M model example, the `global_batch_size` calculates to $32 \times 1 \times 2048 \times 8 \times 1 = 524,288$ (0.5M tokens).

The `warmup_steps` parameter indicates the number of steps for the learning rate warmup phase, while `max_steps` represents the maximum number of training steps.
Each step processes `global_batch_size` tokens.
Each step processes `global_batch_size` tokens.
Consequently, `512` and `20480` correspond to processing 0.5B and 10B tokens, respectively.

:warning: Monitor the value of `global_batch_size`, `warmup_steps`, and `max_steps` carefully when modifying any of the hyperparameters!!

`flame` also supports resuming interrupted training by specifying the checkpoint path.
`flame` also supports resuming interrupted training by specifying the checkpoint path.
Simply use the following command:

```bash
Expand Down Expand Up @@ -141,7 +141,7 @@ You can also use `wandb` to monitor your training process effectively.
## Continual Pretraining

`flame` supports continual training from a pretrained checkpoint.
Below, we provide an example of how to finetune Mistral-7B to GLA.
Below, we provide an example of how to finetune Mistral-7B to GLA.
You can follow similar steps to reproduce the results in the [GSA paper](https://arxiv.org/abs/2409.07146):

1. Initialize a brand-new GLA-7B model from the config and copy the mathced pretrained weights from Mistral-7B:
Expand Down Expand Up @@ -171,7 +171,7 @@ bash train.sh \
cache=data/SlimPajama-627B/train
```

Please be aware that finetuning on a single node may not be the most efficient approach.
Please be aware that finetuning on a single node may not be the most efficient approach.
If available, consider leveraging multi-node GPUs for optimal performance.
You can find guidance on how to launch a multi-node job in the [accelerate tutorial](https://github.com/huggingface/accelerate/blob/main/examples/slurm/submit_multinode.sh).

Expand Down
2 changes: 1 addition & 1 deletion legacy/training/configs/gla_1B.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion legacy/training/configs/gla_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion legacy/training/configs/gla_7B.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
"use_gk": true,
"use_gv": false,
"vocab_size": 32000
}
}
2 changes: 1 addition & 1 deletion legacy/training/configs/transformer_340M.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"tie_word_embeddings": true,
"use_cache": true,
"vocab_size": 32000
}
}
3 changes: 1 addition & 2 deletions legacy/training/flame/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import sys
import time

from transformers.trainer_callback import (ExportableState, TrainerCallback,
TrainerControl, TrainerState)
from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments


Expand Down
3 changes: 1 addition & 2 deletions legacy/training/flame/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from typing import Optional

import transformers
from transformers import HfArgumentParser, TrainingArguments

from flame.logging import get_logger
from transformers import HfArgumentParser, TrainingArguments

logger = get_logger(__name__)

Expand Down
7 changes: 3 additions & 4 deletions legacy/training/run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-

from datasets import load_from_disk
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
Trainer)

import fla # noqa
from flame.data import DataCollatorForLanguageModeling
from flame.logging import LogCallback, get_logger
from flame.parser import get_train_args
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer

import fla # noqa

logger = get_logger(__name__)

Expand Down