@@ -551,7 +551,8 @@ def __init__(
551551 radial_hidden_dim = 64 ,
552552 splits = 4 ,
553553 num_linear_attn_heads = 0 ,
554- init_out_zero = True
554+ init_out_zero = True ,
555+ gate_attn_head_outputs = True
555556 ):
556557 super ().__init__ ()
557558 num_degrees = len (fiber )
@@ -587,6 +588,21 @@ def __init__(
587588 self .linear_attn = LinearAttention (degree_zero_dim , dim_head = dim_head [0 ], heads = num_linear_attn_heads )
588589 hidden_fiber = tuple_set_at_index (hidden_fiber , 0 , hidden_fiber [0 ] + dim_head [0 ] * num_linear_attn_heads )
589590
591+ # gating heads across all degree outputs
592+ # to allow for attending to nothing
593+
594+ self .attn_head_gates = None
595+
596+ if gate_attn_head_outputs :
597+ self .attn_head_gates = nn .Sequential (
598+ Rearrange ('... d 1 -> ... d' ),
599+ nn .Linear (fiber [0 ], sum (heads )),
600+ nn .Sigmoid (),
601+ Rearrange ('... n h -> ... h n 1 1' )
602+ )
603+
604+ # combine heads
605+
590606 self .to_out = Linear (hidden_fiber , fiber )
591607
592608 if init_out_zero :
@@ -615,6 +631,8 @@ def forward(
615631
616632 features = self .prenorm (features )
617633
634+ # generate queries, keys, values
635+
618636 queries = self .to_q (features )
619637
620638 keyvalues = self .to_kv (
@@ -625,11 +643,20 @@ def forward(
625643 D = D
626644 )
627645
646+ # create gates
647+
648+ gates = (None ,) * len (self .heads )
649+
650+ if exists (self .attn_head_gates ):
651+ gates = self .attn_head_gates (features [0 ]).split (self .heads , dim = - 4 )
652+
653+ # single headed vs not
654+
628655 kv_einsum_eq = 'b h i j d m' if not one_head_kv else 'b i j d m'
629656
630657 outputs = {}
631658
632- for degree , h , scale in zip (features .keys (), self .heads , self .scale ):
659+ for degree , gate , h , scale in zip (features .keys (), gates , self .heads , self .scale ):
633660 is_degree_zero = degree == 0
634661
635662 q , kv = map (lambda t : t [degree ], (queries , keyvalues ))
@@ -657,6 +684,10 @@ def forward(
657684
658685 attn = sim .softmax (dim = - 1 )
659686 out = einsum (attn , v , f'b h i j, { kv_einsum_eq } -> b h i d m' )
687+
688+ if exists (gate ):
689+ out = out * gate
690+
660691 outputs [degree ] = rearrange (out , 'b h n d m -> b n (h d) m' )
661692
662693 if self .has_linear_attn :
@@ -681,6 +712,7 @@ def __init__(
681712 radial_hidden_dim = 16 ,
682713 num_linear_attn_heads = 0 ,
683714 init_out_zero = True ,
715+ gate_attn_head_outputs = True ,
684716 ** kwargs
685717 ):
686718 super ().__init__ ()
@@ -748,6 +780,19 @@ def __init__(
748780 self .linear_attn = LinearAttention (degree_zero_dim , dim_head = dim_head [0 ], heads = num_linear_attn_heads )
749781 hidden_fiber = tuple_set_at_index (hidden_fiber , 0 , hidden_fiber [0 ] + dim_head [0 ] * num_linear_attn_heads )
750782
783+ # gating heads across all degree outputs
784+ # to allow for attending to nothing
785+
786+ self .attn_head_gates = None
787+
788+ if gate_attn_head_outputs :
789+ self .attn_head_gates = nn .Sequential (
790+ Rearrange ('... d 1 -> ... d' ),
791+ nn .Linear (fiber [0 ], sum (heads )),
792+ nn .Sigmoid (),
793+ Rearrange ('... h -> ... h 1 1' )
794+ )
795+
751796 # combining heads and projection out
752797
753798 self .to_out = Linear (hidden_fiber , fiber )
@@ -789,6 +834,13 @@ def forward(
789834
790835 intermediate [0 ] = value_branch_type0
791836
837+ # create gates
838+
839+ gates = (None ,) * len (self .heads )
840+
841+ if exists (self .attn_head_gates ):
842+ gates = self .attn_head_gates (features [0 ]).split (self .heads , dim = - 3 )
843+
792844 # process the attention branch
793845
794846 attentions = []
@@ -814,11 +866,15 @@ def forward(
814866
815867 value_einsum_eq = 'b i j h d m' if not one_headed_kv else 'b i j d m'
816868
817- for degree , (attn , value , h ) in enumerate (zip (attentions , values .values (), self .heads )):
869+ for degree , (attn , value , gate , h ) in enumerate (zip (attentions , values .values (), gates , self .heads )):
818870 if not one_headed_kv :
819871 value = rearrange (value , 'b i j (h d) m -> b i j h d m' , h = h )
820872
821873 out = einsum (attn , value , f'b i j h, { value_einsum_eq } -> b i h d m' )
874+
875+ if exists (gate ):
876+ out = out * gate
877+
822878 out = rearrange (out , 'b i h d m -> b i (h d) m' )
823879 outputs [degree ] = out
824880
@@ -863,6 +919,7 @@ def __init__(
863919 l2_dist_attention = True , # turn to False to use MLP attention as proposed in paper, but dot product attention with -cdist similarity is still far better, and i haven't even rotated distances (rotary embeddings) into the type 0 features yet
864920 reversible = False , # turns on reversible networks, to scale depth without incurring depth times memory cost
865921 attend_sparse_neighbors = False , # ability to accept an adjacency matrix
922+ gate_attn_head_outputs = True , # gate each attention head output, to allow for attending to nothing
866923 num_adj_degrees_embed = None ,
867924 adj_dim = 0 ,
868925 max_sparse_neighbors = float ('inf' ),
@@ -958,6 +1015,7 @@ def __init__(
9581015 edge_dim = edge_dim ,
9591016 single_headed_kv = single_headed_kv ,
9601017 radial_hidden_dim = radial_hidden_dim ,
1018+ gate_attn_head_outputs = gate_attn_head_outputs ,
9611019 ** kwargs
9621020 ),
9631021 FeedForward (self .dim , include_htype_norms = ff_include_htype_norms )
0 commit comments