Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions fla/ops/delta_rule/wy_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,46 @@ def prepare_wy_repr_bwd(
bwd_prepare_wy_repr = prepare_wy_repr_bwd

fwd_recompute_w_u = recompute_w_u_fwd


def fwd_prepare_T(
k: torch.Tensor,
beta: torch.Tensor,
chunk_size: int,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Prepare the transformation matrix T (A) for delta rule computation.

This function computes the matrix A = (I - tril(beta * K * K^T))^{-1}
which is used in the parallel delta rule algorithm.

Args:
k: Key tensor of shape [B, H, T, K] (head-first format)
beta: Beta weights of shape [B, H, T] (head-first format)
chunk_size: Size of chunks for processing
cu_seqlens: Optional cumulative sequence lengths for variable-length sequences

Returns:
A: Transformation matrix of shape [B, H, T, chunk_size]
"""
# Convert from head-first [B, H, T, K] to seq-first [B, T, H, K]
k_seq_first = k.transpose(1, 2).contiguous()
beta_seq_first = beta.transpose(1, 2).contiguous()

A = chunk_scaled_dot_kkt_fwd(
k=k_seq_first,
beta=beta_seq_first,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
output_dtype=torch.float32,
)
A = solve_tril(
A=A,
cu_seqlens=cu_seqlens,
output_dtype=k.dtype
)

# Convert back from [B, T, H, chunk_size] to [B, H, T, chunk_size]
A = A.transpose(1, 2).contiguous()
return A
103 changes: 103 additions & 0 deletions tests/ops/test_parallel_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-

import pytest
import torch
import torch.nn.functional as F

from fla.ops.delta_rule.parallel import naive_delta_rule_parallel, parallel_delta_rule
from fla.ops.delta_rule.wy_fast import fwd_prepare_T
from fla.utils import assert_close, device, device_platform

# IMPORTANT NOTE ON TENSOR FORMATS:
# While the documentation for some functions states inputs should be in [B, T, H, K] format,
# the actual implementation expects [B, H, T, K] format (head-first).
# All tests in this file use the head-first format to match the actual implementation.

Comment on lines +11 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Clarify tensor-format note; align tests with the API to avoid silent mismatches.

parallel_delta_rule’s public API expects [B, T, H, K] (seq-first) per its docstring, and it emits a warning when passed head-first tensors while head_first=False. Feeding head-first tensors here risks masking format bugs behind warnings.

Recommend converting inputs to seq-first before calling, then converting outputs back to head-first for assertions. See suggested changes on Lines 57-66, 67-69, and 79-80.

Run this to list all call sites using parallel_delta_rule and check for any reliance on head-first shapes:


🏁 Script executed:

#!/bin/bash
rg -n -A 2 $'parallel_delta_rule\\('

Length of output: 579


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "==== tests/ops/test_parallel_delta.py (full file) ===="
sed -n '1,240p' tests/ops/test_parallel_delta.py || true

echo
echo "==== fla/ops/delta_rule/parallel.py (around function) ===="
sed -n '1,420p' fla/ops/delta_rule/parallel.py || true

Length of output: 18238


Fix tensor-format mismatch: parallel_delta_rule docstring/behavior disagree with implementation (head-first).

Short: The public docstring claims seq-first ([B, T, H, ...]) but the kernels and Forward (ParallelDeltaRuleFunction) clearly operate on head-first tensors ([B, H, T, ...]). Tests pass head-first tensors and currently only get a warning — this can mask real format bugs. We should make the API, implementation, and tests consistent.

Files to change

  • fla/ops/delta_rule/parallel.py

    • Issue: docstring for parallel_delta_rule documents q/k/v/beta as [B, T, H, ...] but code (B, H, T, K = k.shape and all kernels) use head-first [B, H, T, ...].
    • Issue: wrapper currently raises a DeprecationWarning if head_first=True and only emits a warning (no conversion) when a head-first tensor is passed with head_first=False.
    • Recommended fix (pick one):
      • Preferred (safe, backwards-compatible): Auto-detect format and convert:
        • Auto-detect input format (e.g. head-first if q.shape[1] < q.shape[2], otherwise seq-first).
        • If seq-first, rearrange inputs to head-first before calling ParallelDeltaRuleFunction.apply and rearrange outputs back to seq-first before returning.
        • Keep head_first as an optional override (emit a DeprecationWarning but do not raise), but rely on auto-detection by default.
      • Alternative: If you want to keep the current internal layout, update the docstring to document head-first [B, H, T, ...] and remove/adjust the misleading warning.
    • Minimal example (replace current head_first/warning block with auto-detect + conversions):
      • Before:
        • if head_first:
          raise DeprecationWarning(...)
        • if not head_first and q.shape[1] < q.shape[2]:
          warnings.warn(...)
        • o, attn = ParallelDeltaRuleFunction.apply(...)
      • After (sketch):
        • is_head_first = head_first or (q.shape[1] < q.shape[2])
        • if head_first:
          warnings.warn("head_first is deprecated", DeprecationWarning)
        • if not is_head_first:
          q_h = rearrange(q, 'b t h k -> b h t k'); ... (k_h, v_h, beta_h)
        • else:
          q_h, k_h, v_h, beta_h = q, k, v, beta
        • o_h, attn = ParallelDeltaRuleFunction.apply(q_h, k_h, v_h, beta_h, scale, output_attentions)
        • o = rearrange(o_h, 'b h t v -> b t h v') if not is_head_first else o_h
        • return o, attn
    • Also update the docstring to clearly state the accepted public input/output formats and the internal conversion behavior (or document that the function requires head-first if you choose the docstring-update alternative).
  • tests/ops/test_parallel_delta.py

    • Issue: tests currently construct head-first tensors and call parallel_delta_rule without conversion; that causes the wrapper warning.
    • If you implement the wrapper auto-conversion above, tests can remain as-is (no warning). If you instead change the API to require seq-first, update the test call-site to convert head-first -> seq-first before calling and convert outputs back for assertions. Example (if converting at test-level):
      • q_seq = rearrange(q, 'b h t k -> b t h k'); k_seq = rearrange(k, 'b h t k -> b t h k'); v_seq = rearrange(v, 'b h t v -> b t h v'); beta_seq = rearrange(beta, 'b h t -> b t h')
      • o_seq, attn = parallel_delta_rule(q_seq, k_seq, v_seq, beta_seq, ...)
      • o = rearrange(o_seq, 'b t h v -> b h t v')

Why fix: leaving the mismatch and only warning risks silent bugs and confusion for users; tests should either exercise the public API shape or the public API should be documented/implemented to accept what callers expect.

Tag:

🤖 Prompt for AI Agents
tests/ops/test_parallel_delta.py lines 11-15: the comment/docstring and wrapper
for parallel_delta_rule disagree on tensor layout (tests use head-first [B, H,
T, ...] while the public docstring claims seq-first [B, T, H, ...]); update
fla/ops/delta_rule/parallel.py to auto-detect and convert layouts: detect
head-first by comparing q.shape[1] and q.shape[2], if head_first param is True
emit a DeprecationWarning (do not raise), otherwise derive is_head_first =
head_first or (q.shape[1] < q.shape[2]); if inputs are seq-first, rearrange
q/k/v/beta to head-first before calling ParallelDeltaRuleFunction.apply and
rearrange outputs back to seq-first before returning; update the function
docstring to state it accepts either layout and auto-converts (or documents
deprecation of head_first), and leave tests/ops/test_parallel_delta.py unchanged
since tests supply head-first tensors and will now call the wrapper without
warnings.

# NOTE ON TEST IMPLEMENTATION:
# We currently skip comparing parallel_delta_rule against naive_delta_rule_parallel
# because the naive implementation produces NaN values. This will be addressed in a
# future update. For now, we only verify that parallel_delta_rule runs without errors
# and produces outputs with the expected shapes.


@pytest.mark.parametrize(
('B', 'H', 'T', 'K', 'dtype'),
[
pytest.param(*test, id="B{}-H{}-T{}-K{}-{}".format(*test))
for test in [
(1, 2, 128, 64, torch.float16),
(2, 4, 128, 32, torch.float16),
]
]
)
@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Triton Failure'
)
def test_parallel_delta_rule(
B: int,
H: int,
T: int,
K: int,
dtype: torch.dtype,
):
"""Test parallel_delta_rule against naive implementation."""
torch.manual_seed(42)

# Generate test data
q = torch.randn(B, H, T, K, dtype=dtype, device=device)
k = F.normalize(torch.randn(B, H, T, K, dtype=dtype, device=device), p=2, dim=-1).to(dtype)
v = torch.randn(B, H, T, K, dtype=dtype, device=device)
beta = torch.randn(B, H, T, dtype=dtype, device=device).sigmoid()
scale = 1.0 / (K ** 0.5)

# Define whether to output attention matrices
output_attentions = True

# Test forward pass
o_parallel, attn_parallel = parallel_delta_rule(
q=q.clone(),
k=k.clone(),
v=v.clone(),
beta=beta.clone(),
scale=scale,
output_attentions=output_attentions
)

# Output should have the same shape as input v
assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}"

Comment on lines +67 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Update shape assertion to use converted head-first output.

After calling with seq-first inputs, o is [B, T, H, K]. Convert to head-first before comparing with v.

Apply:

-    # Output should have the same shape as input v
-    assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}"
+    # Output should have the same shape as input v (head-first)
+    assert o_parallel_hf.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel_hf.shape}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Output should have the same shape as input v
assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}"
# Output should have the same shape as input v (head-first)
assert o_parallel_hf.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel_hf.shape}"
🤖 Prompt for AI Agents
In tests/ops/test_parallel_delta.py around lines 67-69, the shape assertion
compares o_parallel in seq-first layout [B, T, H, K] to v which is head-first;
convert o_parallel to head-first before asserting. Rearrange o_parallel from [B,
T, H, K] to [B, H, T, K] (e.g., using .transpose(0,2,1,3) or
einops.rearrange(o_parallel, "b t h k -> b h t k")) and then assert the shape
equals v.shape.

# Check that attention matrix is produced if requested
if output_attentions:
assert attn_parallel is not None
assert attn_parallel.shape == (B, H, T, T), f"Expected shape {(B, H, T, T)}, got {attn_parallel.shape}"
else:
assert attn_parallel is None

o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone())

assert_close(' o', o_parallel, o_naive, 0.01)
assert_close('attn', attn_naive, attn_parallel, 0.01)
Comment on lines +77 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Compare the correct tensor (head-first) against the naive implementation.

Ensure the numerical comparison is performed on head-first tensors.

Apply:

-    assert_close('   o', o_parallel, o_naive, 0.01)
+    assert_close('   o', o_parallel_hf, o_naive, 0.01)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone())
assert_close(' o', o_parallel, o_naive, 0.01)
assert_close('attn', attn_naive, attn_parallel, 0.01)
o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone())
assert_close(' o', o_parallel_hf, o_naive, 0.01)
assert_close('attn', attn_naive, attn_parallel, 0.01)
🤖 Prompt for AI Agents
In tests/ops/test_parallel_delta.py around lines 77 to 80, the assertions
compare tensors in the wrong layout/order; convert both the naive and parallel
outputs to head-first layout (the same head-first dimension ordering used
elsewhere in tests) before asserting and ensure assert_close is called with the
expected (naive/head-first) tensor first and the actual (parallel/head-first)
tensor second; update both the 'o' and 'attn' comparisons to permute/transpose
to head-first and swap the argument order if necessary so the numerical
comparison is performed head-first and expected-before-actual.



@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Triton Failure'
)
def test_fwd_prepare_T():
"""Test that fwd_prepare_T can be imported and runs without error."""
torch.manual_seed(42)

# Using head-first format [B, H, T, K] to match other functions
B, H, T, K = 2, 4, 128, 64
k = torch.randn(B, H, T, K, device=device)
beta = torch.randn(B, H, T, device=device).sigmoid()
chunk_size = 32

# Test the function runs without error
A = fwd_prepare_T(k, beta, chunk_size)

# Check output shape
# After our fix, fwd_prepare_T returns [B, H, T, chunk_size] (head-first format)
expected_shape = (B, H, T, chunk_size)
assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}"
Loading