-
Notifications
You must be signed in to change notification settings - Fork 256
[Linear Attention] Update fused_recurrent.py for inference with nomalization=true #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
287ffae
30d5606
8ec615a
3cb3c2a
d8965f2
1c6ea0c
b5d64ba
884a597
4c4c68c
20adb41
f050482
246f17c
c6ce801
ed6e92c
3ee20dd
8806758
565bbb8
1db56cb
104a2e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -235,6 +235,7 @@ def fused_recurrent_linear_attn( | |
v: torch.Tensor, | ||
scale: Optional[float] = None, | ||
initial_state: torch.Tensor = None, | ||
cum_k: torch.Tensor = None, | ||
output_final_state: bool = False, | ||
normalize: bool = False, | ||
|
||
head_first: bool = True | ||
|
@@ -245,7 +246,7 @@ def fused_recurrent_linear_attn( | |
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) | ||
o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) | ||
if normalize: | ||
o = normalize_output(q * scale, k, o) | ||
o = normalize_output(q * scale, k, o, cum_k) | ||
if not head_first: | ||
o = o.transpose(1, 2) | ||
return o, final_state |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | ||
|
||
import torch | ||
|
||
|
||
@torch.jit.script | ||
def normalize_output(q, k, o): | ||
@torch.compile | ||
def normalize_output(q, k, o, cum_k=None): | ||
|
||
k = k.cumsum(-2) | ||
yzhangcs marked this conversation as resolved.
Show resolved
Hide resolved
yzhangcs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if cum_k is not None: | ||
k = k + cum_k | ||
z = (q * k).sum(-1, keepdim=True) | ||
return o / (z + 1e-10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyousong I think we could make initial_state a Tuple if normalize is True. What do you think?