Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fla/ops/linear_attn/fused_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def fused_chunk_linear_attn(
v: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
cum_k: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True,
head_first: bool = True
Expand Down Expand Up @@ -312,7 +313,7 @@ def fused_chunk_linear_attn(
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
o = normalize_output(q * scale, k, o, cum_k)
if not head_first:
o = o.transpose(1, 2)
return o, final_state
3 changes: 2 additions & 1 deletion fla/ops/linear_attn/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def fused_recurrent_linear_attn(
v: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
cum_k: torch.Tensor = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyousong I think we could make initial_state a Tuple if normalize is True. What do you think?

output_final_state: bool = False,
normalize: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyousong Could you please add some docstrings BTW.

head_first: bool = True
Expand All @@ -245,7 +246,7 @@ def fused_recurrent_linear_attn(
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
o = normalize_output(q * scale, k, o, cum_k)
if not head_first:
o = o.transpose(1, 2)
return o, final_state
7 changes: 5 additions & 2 deletions fla/ops/linear_attn/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import torch


@torch.jit.script
def normalize_output(q, k, o):
@torch.compile
def normalize_output(q, k, o, cum_k=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyousong Maybe we could pass initial_state as an arg with cum_k included for API consistency.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only use ops, so I like passing in separately. However, your code you choice. I believe it doesn't matter as long as you don't merge them into one tensor

k = k.cumsum(-2)
if cum_k is not None:
k = k + cum_k
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-10)