-
Notifications
You must be signed in to change notification settings - Fork 256
Fix #390: Add missing fwd_prepare_T function #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e895d44
2ddd1a3
9088b67
f6da4a3
f268d75
295c8e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,103 @@ | ||||||||||||||||||
# -*- coding: utf-8 -*- | ||||||||||||||||||
|
||||||||||||||||||
import pytest | ||||||||||||||||||
import torch | ||||||||||||||||||
import torch.nn.functional as F | ||||||||||||||||||
|
||||||||||||||||||
from fla.ops.delta_rule.parallel import naive_delta_rule_parallel, parallel_delta_rule | ||||||||||||||||||
from fla.ops.delta_rule.wy_fast import fwd_prepare_T | ||||||||||||||||||
from fla.utils import assert_close, device, device_platform | ||||||||||||||||||
|
||||||||||||||||||
# IMPORTANT NOTE ON TENSOR FORMATS: | ||||||||||||||||||
# While the documentation for some functions states inputs should be in [B, T, H, K] format, | ||||||||||||||||||
# the actual implementation expects [B, H, T, K] format (head-first). | ||||||||||||||||||
# All tests in this file use the head-first format to match the actual implementation. | ||||||||||||||||||
|
||||||||||||||||||
# NOTE ON TEST IMPLEMENTATION: | ||||||||||||||||||
# We currently skip comparing parallel_delta_rule against naive_delta_rule_parallel | ||||||||||||||||||
# because the naive implementation produces NaN values. This will be addressed in a | ||||||||||||||||||
# future update. For now, we only verify that parallel_delta_rule runs without errors | ||||||||||||||||||
# and produces outputs with the expected shapes. | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||
('B', 'H', 'T', 'K', 'dtype'), | ||||||||||||||||||
[ | ||||||||||||||||||
pytest.param(*test, id="B{}-H{}-T{}-K{}-{}".format(*test)) | ||||||||||||||||||
for test in [ | ||||||||||||||||||
(1, 2, 128, 64, torch.float16), | ||||||||||||||||||
(2, 4, 128, 32, torch.float16), | ||||||||||||||||||
] | ||||||||||||||||||
] | ||||||||||||||||||
) | ||||||||||||||||||
@pytest.mark.skipif( | ||||||||||||||||||
device_platform == 'intel', | ||||||||||||||||||
reason='Intel Triton Failure' | ||||||||||||||||||
) | ||||||||||||||||||
def test_parallel_delta_rule( | ||||||||||||||||||
B: int, | ||||||||||||||||||
H: int, | ||||||||||||||||||
T: int, | ||||||||||||||||||
K: int, | ||||||||||||||||||
dtype: torch.dtype, | ||||||||||||||||||
): | ||||||||||||||||||
"""Test parallel_delta_rule against naive implementation.""" | ||||||||||||||||||
torch.manual_seed(42) | ||||||||||||||||||
|
||||||||||||||||||
# Generate test data | ||||||||||||||||||
q = torch.randn(B, H, T, K, dtype=dtype, device=device) | ||||||||||||||||||
k = F.normalize(torch.randn(B, H, T, K, dtype=dtype, device=device), p=2, dim=-1).to(dtype) | ||||||||||||||||||
v = torch.randn(B, H, T, K, dtype=dtype, device=device) | ||||||||||||||||||
beta = torch.randn(B, H, T, dtype=dtype, device=device).sigmoid() | ||||||||||||||||||
scale = 1.0 / (K ** 0.5) | ||||||||||||||||||
|
||||||||||||||||||
# Define whether to output attention matrices | ||||||||||||||||||
output_attentions = True | ||||||||||||||||||
|
||||||||||||||||||
# Test forward pass | ||||||||||||||||||
o_parallel, attn_parallel = parallel_delta_rule( | ||||||||||||||||||
q=q.clone(), | ||||||||||||||||||
k=k.clone(), | ||||||||||||||||||
v=v.clone(), | ||||||||||||||||||
beta=beta.clone(), | ||||||||||||||||||
scale=scale, | ||||||||||||||||||
output_attentions=output_attentions | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
# Output should have the same shape as input v | ||||||||||||||||||
assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}" | ||||||||||||||||||
|
||||||||||||||||||
Comment on lines
+67
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Update shape assertion to use converted head-first output. After calling with seq-first inputs, o is [B, T, H, K]. Convert to head-first before comparing with v. Apply: - # Output should have the same shape as input v
- assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}"
+ # Output should have the same shape as input v (head-first)
+ assert o_parallel_hf.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel_hf.shape}" 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||
# Check that attention matrix is produced if requested | ||||||||||||||||||
if output_attentions: | ||||||||||||||||||
assert attn_parallel is not None | ||||||||||||||||||
assert attn_parallel.shape == (B, H, T, T), f"Expected shape {(B, H, T, T)}, got {attn_parallel.shape}" | ||||||||||||||||||
else: | ||||||||||||||||||
assert attn_parallel is None | ||||||||||||||||||
|
||||||||||||||||||
o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone()) | ||||||||||||||||||
|
||||||||||||||||||
assert_close(' o', o_parallel, o_naive, 0.01) | ||||||||||||||||||
assert_close('attn', attn_naive, attn_parallel, 0.01) | ||||||||||||||||||
Comment on lines
+77
to
+80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Compare the correct tensor (head-first) against the naive implementation. Ensure the numerical comparison is performed on head-first tensors. Apply: - assert_close(' o', o_parallel, o_naive, 0.01)
+ assert_close(' o', o_parallel_hf, o_naive, 0.01) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@pytest.mark.skipif( | ||||||||||||||||||
device_platform == 'intel', | ||||||||||||||||||
reason='Intel Triton Failure' | ||||||||||||||||||
) | ||||||||||||||||||
def test_fwd_prepare_T(): | ||||||||||||||||||
"""Test that fwd_prepare_T can be imported and runs without error.""" | ||||||||||||||||||
torch.manual_seed(42) | ||||||||||||||||||
|
||||||||||||||||||
# Using head-first format [B, H, T, K] to match other functions | ||||||||||||||||||
B, H, T, K = 2, 4, 128, 64 | ||||||||||||||||||
k = torch.randn(B, H, T, K, device=device) | ||||||||||||||||||
beta = torch.randn(B, H, T, device=device).sigmoid() | ||||||||||||||||||
chunk_size = 32 | ||||||||||||||||||
|
||||||||||||||||||
# Test the function runs without error | ||||||||||||||||||
A = fwd_prepare_T(k, beta, chunk_size) | ||||||||||||||||||
|
||||||||||||||||||
# Check output shape | ||||||||||||||||||
# After our fix, fwd_prepare_T returns [B, H, T, chunk_size] (head-first format) | ||||||||||||||||||
expected_shape = (B, H, T, chunk_size) | ||||||||||||||||||
assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Clarify tensor-format note; align tests with the API to avoid silent mismatches.
parallel_delta_rule’s public API expects [B, T, H, K] (seq-first) per its docstring, and it emits a warning when passed head-first tensors while head_first=False. Feeding head-first tensors here risks masking format bugs behind warnings.
Recommend converting inputs to seq-first before calling, then converting outputs back to head-first for assertions. See suggested changes on Lines 57-66, 67-69, and 79-80.
Run this to list all call sites using parallel_delta_rule and check for any reliance on head-first shapes:
🏁 Script executed:
Length of output: 579
🏁 Script executed:
Length of output: 18238
Fix tensor-format mismatch: parallel_delta_rule docstring/behavior disagree with implementation (head-first).
Short: The public docstring claims seq-first ([B, T, H, ...]) but the kernels and Forward (ParallelDeltaRuleFunction) clearly operate on head-first tensors ([B, H, T, ...]). Tests pass head-first tensors and currently only get a warning — this can mask real format bugs. We should make the API, implementation, and tests consistent.
Files to change
fla/ops/delta_rule/parallel.py
[B, T, H, ...]
but code (B, H, T, K = k.shape and all kernels) use head-first[B, H, T, ...]
.[B, H, T, ...]
and remove/adjust the misleading warning.raise DeprecationWarning(...)
warnings.warn(...)
warnings.warn("head_first is deprecated", DeprecationWarning)
q_h = rearrange(q, 'b t h k -> b h t k'); ... (k_h, v_h, beta_h)
q_h, k_h, v_h, beta_h = q, k, v, beta
tests/ops/test_parallel_delta.py
Why fix: leaving the mismatch and only warning risks silent bugs and confusion for users; tests should either exercise the public API shape or the public API should be documented/implemented to accept what callers expect.
Tag:
🤖 Prompt for AI Agents