Skip to content

Commit dd057f3

Browse files
committed
make sure to mask out padding tokens in mlp attention, at the attention processing stage
1 parent 5dfc465 commit dd057f3

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,9 @@ def forward(
610610
if exists(neighbor_mask):
611611
neighbor_mask = rearrange(neighbor_mask, 'b i j -> b 1 i j')
612612

613+
if self.attend_self:
614+
neighbor_mask = F.pad(neighbor_mask, (1, 0), value = True)
615+
613616
features = self.prenorm(features)
614617

615618
queries = self.to_q(features)
@@ -651,9 +654,6 @@ def forward(
651654
if not is_degree_zero:
652655
sim = sim.sum(dim = -1)
653656

654-
if exists(neighbor_mask):
655-
left_pad_needed = int(self.attend_self)
656-
padded_neighbor_mask = F.pad(neighbor_mask, (left_pad_needed, 0), value = True)
657657
sim = sim.masked_fill(~padded_neighbor_mask, -torch.finfo(sim.dtype).max)
658658

659659
attn = sim.softmax(dim = -1)
@@ -698,6 +698,8 @@ def __init__(
698698
self.single_headed_kv = single_headed_kv
699699
value_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head
700700

701+
self.attend_self = attend_self
702+
701703
self.scale = tuple(dim ** -0.5 for dim in dim_head)
702704
self.heads = heads
703705

@@ -766,6 +768,14 @@ def forward(
766768
):
767769
one_headed_kv = self.single_headed_kv
768770

771+
_, neighbor_mask, _ = edge_info
772+
773+
if exists(neighbor_mask):
774+
if self.attend_self:
775+
neighbor_mask = F.pad(neighbor_mask, (1, 0), value = True)
776+
777+
neighbor_mask = rearrange(neighbor_mask, '... -> ... 1')
778+
769779
features = self.prenorm(features)
770780

771781
intermediate = self.to_attn_and_v(
@@ -788,6 +798,10 @@ def forward(
788798
attn_intermediate = rearrange(attn_intermediate, '... 1 -> ...')
789799
attn_logits = fn(attn_intermediate)
790800
attn_logits = attn_logits * scale
801+
802+
if exists(neighbor_mask):
803+
attn_logits = attn_logits.masked_fill(~neighbor_mask, -torch.finfo(attn_logits.dtype).max)
804+
791805
attn = attn_logits.softmax(dim = -2) # (batch, source, target, heads)
792806
attentions.append(attn)
793807

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.7'
1+
__version__ = '0.3.8'

0 commit comments

Comments
 (0)