Skip to content

Commit 89284e2

Browse files
committed
upgrade linear attention
1 parent dbc16f4 commit 89284e2

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch import nn, is_tensor, Tensor
1111
import torch.nn.functional as F
1212

13+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
14+
1315
from opt_einsum import contract as opt_einsum
1416

1517
from equiformer_pytorch.basis import (
@@ -550,6 +552,7 @@ def __init__(
550552
single_headed_kv = False,
551553
radial_hidden_dim = 64,
552554
splits = 4,
555+
linear_attn_dim_head = 8,
553556
num_linear_attn_heads = 0,
554557
init_out_zero = True,
555558
gate_attn_head_outputs = True
@@ -585,8 +588,8 @@ def __init__(
585588

586589
if self.has_linear_attn:
587590
degree_zero_dim = fiber[0]
588-
self.linear_attn = LinearAttention(degree_zero_dim, dim_head = dim_head[0], heads = num_linear_attn_heads)
589-
hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + dim_head[0] * num_linear_attn_heads)
591+
self.linear_attn = TaylorSeriesLinearAttn(degree_zero_dim, dim_head = linear_attn_dim_head, heads = num_linear_attn_heads, combine_heads = False)
592+
hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + linear_attn_dim_head * num_linear_attn_heads)
590593

591594
# gating heads across all degree outputs
592595
# to allow for attending to nothing
@@ -691,7 +694,9 @@ def forward(
691694
outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
692695

693696
if self.has_linear_attn:
694-
lin_attn_out = self.linear_attn(features[0], mask = mask)
697+
linear_attn_input = rearrange(features[0], '... 1 -> ...')
698+
lin_attn_out = self.linear_attn(linear_attn_input, mask = mask)
699+
lin_attn_out = rearrange(lin_attn_out, '... -> ... 1')
695700
outputs[0] = torch.cat((outputs[0], lin_attn_out), dim = -2)
696701

697702
return self.to_out(outputs)
@@ -710,6 +715,7 @@ def __init__(
710715
attn_leakyrelu_slope = 0.1,
711716
attn_hidden_dim_mult = 4,
712717
radial_hidden_dim = 16,
718+
linear_attn_dim_head = 8,
713719
num_linear_attn_heads = 0,
714720
init_out_zero = True,
715721
gate_attn_head_outputs = True,
@@ -777,8 +783,8 @@ def __init__(
777783

778784
if self.has_linear_attn:
779785
degree_zero_dim = fiber[0]
780-
self.linear_attn = LinearAttention(degree_zero_dim, dim_head = dim_head[0], heads = num_linear_attn_heads)
781-
hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + dim_head[0] * num_linear_attn_heads)
786+
self.linear_attn = TaylorSeriesLinearAttn(degree_zero_dim, dim_head = linear_attn_dim_head, heads = num_linear_attn_heads, combine_heads = False)
787+
hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + linear_attn_dim_head * num_linear_attn_heads)
782788

783789
# gating heads across all degree outputs
784790
# to allow for attending to nothing
@@ -881,7 +887,10 @@ def forward(
881887
# linear attention
882888

883889
if self.has_linear_attn:
884-
lin_attn_out = self.linear_attn(features[0], mask = mask)
890+
linear_attn_input = rearrange(features[0], '... 1 -> ...')
891+
lin_attn_out = self.linear_attn(linear_attn_input, mask = mask)
892+
lin_attn_out = rearrange(lin_attn_out, '... -> ... 1')
893+
885894
outputs[0] = torch.cat((outputs[0], lin_attn_out), dim = -2)
886895

887896
# combine heads out

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.4.0'
1+
__version__ = '0.5.0'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
'einops>=0.6',
2727
'filelock',
2828
'opt-einsum',
29+
'taylor-series-linear-attention>=0.0.11',
2930
'torch>=1.6',
3031
],
3132
setup_requires=[

0 commit comments

Comments
 (0)