Skip to content

Commit 3011ad3

Browse files
committed
using negative euclidean distances is far better than previous dot product or the mlp attention
1 parent 6343697 commit 3011ad3

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Update: The choice of the norm or gating (still need to ablate to figure out whi
1414

1515
Update: Nevermind, MLP attention seems to be working, but about the same as dot product attention.
1616

17+
Update: By using the negative of the euclidean distance for dot product of higher types in dot product attention, I now see results that are far better than before as well as MLP attention. My conclusion is that the choice of norm and gating is contributing way more to the results in the paper than MLP attention
18+
1719
<a href="https://wandb.ai/lucidrains/equiformer/reports/equiformer-and-mlp-attention---VmlldzozMDQwMTY3?accessToken=xmj0a1c80m8hehylrmbr0hndka8kk1vxmdrmvtmy7r1qgphtnuhq1643cb76zgfo">Running experiment, denoising residue positions in protein sequence</a>
1820

1921
## Install

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def feature_shapes(feature):
8787
def feature_fiber(feature):
8888
return tuple(v.shape[-2] for v in feature.values())
8989

90+
def cdist(a, b, dim = -1, eps = 1e-5):
91+
return ((a - b) ** 2).sum(dim = dim).clamp(min = eps).sqrt()
92+
9093
# classes
9194

9295
class Residual(nn.Module):
@@ -437,7 +440,6 @@ def __init__(
437440
edge_dim = None,
438441
single_headed_kv = False,
439442
radial_hidden_dim = 64,
440-
use_cdist_sim = True,
441443
splits = 4
442444
):
443445
super().__init__()
@@ -453,7 +455,6 @@ def __init__(
453455

454456
self.single_headed_kv = single_headed_kv
455457
self.attend_self = attend_self
456-
self.use_cdist_sim = use_cdist_sim
457458

458459
kv_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head
459460
kv_hidden_fiber = tuple(dim * 2 for dim in kv_hidden_fiber)
@@ -509,14 +510,14 @@ def forward(
509510

510511
k, v = kv.chunk(2, dim = -2)
511512

512-
if degree == 0 or not self.use_cdist_sim:
513+
if degree == 0:
513514
sim = einsum(f'b h i d m, {kv_einsum_eq} -> b h i j', q, k) * scale
514515
else:
515516
if one_head_kv:
516517
k = repeat(k, 'b i j d m -> b h i j d m', h = h)
517518

518519
q = rearrange(q, 'b h i d m -> b h i 1 d m')
519-
sim = -((q - k) ** 2).sum(dim = -1).clamp(min = 1e-5).sqrt().sum(dim = -1) * scale
520+
sim = -cdist(q, k).sum(dim = -1) * scale
520521

521522
if exists(neighbor_mask):
522523
left_pad_needed = int(self.attend_self)
@@ -688,8 +689,7 @@ def __init__(
688689
embedding_grad_frac = 0.5,
689690
single_headed_kv = False, # whether to do single headed key/values for dot product attention, to save on memory and compute
690691
ff_include_htype_norms = False, # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms
691-
dot_product_attention = True,
692-
dot_product_attention_use_cdist_sim = True,
692+
dot_product_attention = True, # turn to False to use MLP attention as proposed in paper, but dot product attention with -cdist similarity is still far better, and i haven't even rotated distances (rotary embeddings) into the type 0 features yet
693693
**kwargs
694694
):
695695
super().__init__()
@@ -761,7 +761,7 @@ def __init__(
761761

762762
self.layers = nn.ModuleList([])
763763

764-
attention_klass = partial(DotProductAttention, use_cdist_sim = dot_product_attention_use_cdist_sim) if dot_product_attention else MLPAttention
764+
attention_klass = DotProductAttention if dot_product_attention else MLPAttention
765765

766766
for ind in range(depth):
767767
self.layers.append(nn.ModuleList([

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.30'
1+
__version__ = '0.0.31'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)