1010from torch import nn , is_tensor , Tensor
1111import torch .nn .functional as F
1212
13+ from taylor_series_linear_attention import TaylorSeriesLinearAttn
14+
1315from opt_einsum import contract as opt_einsum
1416
1517from equiformer_pytorch .basis import (
@@ -550,6 +552,7 @@ def __init__(
550552 single_headed_kv = False ,
551553 radial_hidden_dim = 64 ,
552554 splits = 4 ,
555+ linear_attn_dim_head = 8 ,
553556 num_linear_attn_heads = 0 ,
554557 init_out_zero = True ,
555558 gate_attn_head_outputs = True
@@ -585,8 +588,8 @@ def __init__(
585588
586589 if self .has_linear_attn :
587590 degree_zero_dim = fiber [0 ]
588- self .linear_attn = LinearAttention (degree_zero_dim , dim_head = dim_head [ 0 ] , heads = num_linear_attn_heads )
589- hidden_fiber = tuple_set_at_index (hidden_fiber , 0 , hidden_fiber [0 ] + dim_head [ 0 ] * num_linear_attn_heads )
591+ self .linear_attn = TaylorSeriesLinearAttn (degree_zero_dim , dim_head = linear_attn_dim_head , heads = num_linear_attn_heads , combine_heads = False )
592+ hidden_fiber = tuple_set_at_index (hidden_fiber , 0 , hidden_fiber [0 ] + linear_attn_dim_head * num_linear_attn_heads )
590593
591594 # gating heads across all degree outputs
592595 # to allow for attending to nothing
@@ -691,7 +694,9 @@ def forward(
691694 outputs [degree ] = rearrange (out , 'b h n d m -> b n (h d) m' )
692695
693696 if self .has_linear_attn :
694- lin_attn_out = self .linear_attn (features [0 ], mask = mask )
697+ linear_attn_input = rearrange (features [0 ], '... 1 -> ...' )
698+ lin_attn_out = self .linear_attn (linear_attn_input , mask = mask )
699+ lin_attn_out = rearrange (lin_attn_out , '... -> ... 1' )
695700 outputs [0 ] = torch .cat ((outputs [0 ], lin_attn_out ), dim = - 2 )
696701
697702 return self .to_out (outputs )
@@ -710,6 +715,7 @@ def __init__(
710715 attn_leakyrelu_slope = 0.1 ,
711716 attn_hidden_dim_mult = 4 ,
712717 radial_hidden_dim = 16 ,
718+ linear_attn_dim_head = 8 ,
713719 num_linear_attn_heads = 0 ,
714720 init_out_zero = True ,
715721 gate_attn_head_outputs = True ,
@@ -777,8 +783,8 @@ def __init__(
777783
778784 if self .has_linear_attn :
779785 degree_zero_dim = fiber [0 ]
780- self .linear_attn = LinearAttention (degree_zero_dim , dim_head = dim_head [ 0 ] , heads = num_linear_attn_heads )
781- hidden_fiber = tuple_set_at_index (hidden_fiber , 0 , hidden_fiber [0 ] + dim_head [ 0 ] * num_linear_attn_heads )
786+ self .linear_attn = TaylorSeriesLinearAttn (degree_zero_dim , dim_head = linear_attn_dim_head , heads = num_linear_attn_heads , combine_heads = False )
787+ hidden_fiber = tuple_set_at_index (hidden_fiber , 0 , hidden_fiber [0 ] + linear_attn_dim_head * num_linear_attn_heads )
782788
783789 # gating heads across all degree outputs
784790 # to allow for attending to nothing
@@ -881,7 +887,10 @@ def forward(
881887 # linear attention
882888
883889 if self .has_linear_attn :
884- lin_attn_out = self .linear_attn (features [0 ], mask = mask )
890+ linear_attn_input = rearrange (features [0 ], '... 1 -> ...' )
891+ lin_attn_out = self .linear_attn (linear_attn_input , mask = mask )
892+ lin_attn_out = rearrange (lin_attn_out , '... -> ... 1' )
893+
885894 outputs [0 ] = torch .cat ((outputs [0 ], lin_attn_out ), dim = - 2 )
886895
887896 # combine heads out
0 commit comments