@@ -610,6 +610,9 @@ def forward(
610610 if exists (neighbor_mask ):
611611 neighbor_mask = rearrange (neighbor_mask , 'b i j -> b 1 i j' )
612612
613+ if self .attend_self :
614+ neighbor_mask = F .pad (neighbor_mask , (1 , 0 ), value = True )
615+
613616 features = self .prenorm (features )
614617
615618 queries = self .to_q (features )
@@ -651,9 +654,6 @@ def forward(
651654 if not is_degree_zero :
652655 sim = sim .sum (dim = - 1 )
653656
654- if exists (neighbor_mask ):
655- left_pad_needed = int (self .attend_self )
656- padded_neighbor_mask = F .pad (neighbor_mask , (left_pad_needed , 0 ), value = True )
657657 sim = sim .masked_fill (~ padded_neighbor_mask , - torch .finfo (sim .dtype ).max )
658658
659659 attn = sim .softmax (dim = - 1 )
@@ -698,6 +698,8 @@ def __init__(
698698 self .single_headed_kv = single_headed_kv
699699 value_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head
700700
701+ self .attend_self = attend_self
702+
701703 self .scale = tuple (dim ** - 0.5 for dim in dim_head )
702704 self .heads = heads
703705
@@ -766,6 +768,14 @@ def forward(
766768 ):
767769 one_headed_kv = self .single_headed_kv
768770
771+ _ , neighbor_mask , _ = edge_info
772+
773+ if exists (neighbor_mask ):
774+ if self .attend_self :
775+ neighbor_mask = F .pad (neighbor_mask , (1 , 0 ), value = True )
776+
777+ neighbor_mask = rearrange (neighbor_mask , '... -> ... 1' )
778+
769779 features = self .prenorm (features )
770780
771781 intermediate = self .to_attn_and_v (
@@ -788,6 +798,10 @@ def forward(
788798 attn_intermediate = rearrange (attn_intermediate , '... 1 -> ...' )
789799 attn_logits = fn (attn_intermediate )
790800 attn_logits = attn_logits * scale
801+
802+ if exists (neighbor_mask ):
803+ attn_logits = attn_logits .masked_fill (~ neighbor_mask , - torch .finfo (attn_logits .dtype ).max )
804+
791805 attn = attn_logits .softmax (dim = - 2 ) # (batch, source, target, heads)
792806 attentions .append (attn )
793807
0 commit comments