diff --git a/fla/ops/delta_rule/wy_fast.py b/fla/ops/delta_rule/wy_fast.py index 401ca842c..aae5ff3f4 100644 --- a/fla/ops/delta_rule/wy_fast.py +++ b/fla/ops/delta_rule/wy_fast.py @@ -292,3 +292,46 @@ def prepare_wy_repr_bwd( bwd_prepare_wy_repr = prepare_wy_repr_bwd fwd_recompute_w_u = recompute_w_u_fwd + + +def fwd_prepare_T( + k: torch.Tensor, + beta: torch.Tensor, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> torch.Tensor: + """ + Prepare the transformation matrix T (A) for delta rule computation. + + This function computes the matrix A = (I - tril(beta * K * K^T))^{-1} + which is used in the parallel delta rule algorithm. + + Args: + k: Key tensor of shape [B, H, T, K] (head-first format) + beta: Beta weights of shape [B, H, T] (head-first format) + chunk_size: Size of chunks for processing + cu_seqlens: Optional cumulative sequence lengths for variable-length sequences + + Returns: + A: Transformation matrix of shape [B, H, T, chunk_size] + """ + # Convert from head-first [B, H, T, K] to seq-first [B, T, H, K] + k_seq_first = k.transpose(1, 2).contiguous() + beta_seq_first = beta.transpose(1, 2).contiguous() + + A = chunk_scaled_dot_kkt_fwd( + k=k_seq_first, + beta=beta_seq_first, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + output_dtype=torch.float32, + ) + A = solve_tril( + A=A, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + + # Convert back from [B, T, H, chunk_size] to [B, H, T, chunk_size] + A = A.transpose(1, 2).contiguous() + return A diff --git a/tests/ops/test_parallel_delta.py b/tests/ops/test_parallel_delta.py new file mode 100644 index 000000000..8666b9f31 --- /dev/null +++ b/tests/ops/test_parallel_delta.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +import pytest +import torch +import torch.nn.functional as F + +from fla.ops.delta_rule.parallel import naive_delta_rule_parallel, parallel_delta_rule +from fla.ops.delta_rule.wy_fast import fwd_prepare_T +from fla.utils import assert_close, device, device_platform + +# IMPORTANT NOTE ON TENSOR FORMATS: +# While the documentation for some functions states inputs should be in [B, T, H, K] format, +# the actual implementation expects [B, H, T, K] format (head-first). +# All tests in this file use the head-first format to match the actual implementation. + +# NOTE ON TEST IMPLEMENTATION: +# We currently skip comparing parallel_delta_rule against naive_delta_rule_parallel +# because the naive implementation produces NaN values. This will be addressed in a +# future update. For now, we only verify that parallel_delta_rule runs without errors +# and produces outputs with the expected shapes. + + +@pytest.mark.parametrize( + ('B', 'H', 'T', 'K', 'dtype'), + [ + pytest.param(*test, id="B{}-H{}-T{}-K{}-{}".format(*test)) + for test in [ + (1, 2, 128, 64, torch.float16), + (2, 4, 128, 32, torch.float16), + ] + ] +) +@pytest.mark.skipif( + device_platform == 'intel', + reason='Intel Triton Failure' +) +def test_parallel_delta_rule( + B: int, + H: int, + T: int, + K: int, + dtype: torch.dtype, +): + """Test parallel_delta_rule against naive implementation.""" + torch.manual_seed(42) + + # Generate test data + q = torch.randn(B, H, T, K, dtype=dtype, device=device) + k = F.normalize(torch.randn(B, H, T, K, dtype=dtype, device=device), p=2, dim=-1).to(dtype) + v = torch.randn(B, H, T, K, dtype=dtype, device=device) + beta = torch.randn(B, H, T, dtype=dtype, device=device).sigmoid() + scale = 1.0 / (K ** 0.5) + + # Define whether to output attention matrices + output_attentions = True + + # Test forward pass + o_parallel, attn_parallel = parallel_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + scale=scale, + output_attentions=output_attentions + ) + + # Output should have the same shape as input v + assert o_parallel.shape == v.shape, f"Expected shape {v.shape}, got {o_parallel.shape}" + + # Check that attention matrix is produced if requested + if output_attentions: + assert attn_parallel is not None + assert attn_parallel.shape == (B, H, T, T), f"Expected shape {(B, H, T, T)}, got {attn_parallel.shape}" + else: + assert attn_parallel is None + + o_naive, attn_naive = naive_delta_rule_parallel(q.clone(), k.clone(), v.clone(), beta.clone()) + + assert_close(' o', o_parallel, o_naive, 0.01) + assert_close('attn', attn_naive, attn_parallel, 0.01) + + +@pytest.mark.skipif( + device_platform == 'intel', + reason='Intel Triton Failure' +) +def test_fwd_prepare_T(): + """Test that fwd_prepare_T can be imported and runs without error.""" + torch.manual_seed(42) + + # Using head-first format [B, H, T, K] to match other functions + B, H, T, K = 2, 4, 128, 64 + k = torch.randn(B, H, T, K, device=device) + beta = torch.randn(B, H, T, device=device).sigmoid() + chunk_size = 32 + + # Test the function runs without error + A = fwd_prepare_T(k, beta, chunk_size) + + # Check output shape + # After our fix, fwd_prepare_T returns [B, H, T, chunk_size] (head-first format) + expected_shape = (B, H, T, chunk_size) + assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}"