Skip to content

Commit d9e58e4

Browse files
committed
fix a bug with mlp attention (attending on wrong dimension), also make sure single headed values work with mlp attention. experiments still dont show a big gain over dot product attention
1 parent 967d71d commit d9e58e4

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def __init__(
496496
attend_self = False,
497497
edge_dim = None,
498498
splits = 4,
499+
single_headed_kv = False,
499500
attn_leakyrelu_slope = 0.1,
500501
attn_hidden_dim_mult = 4,
501502
**kwargs
@@ -511,6 +512,9 @@ def __init__(
511512

512513
hidden_fiber = tuple(dim * head for dim, head in zip(dim_head, heads))
513514

515+
self.single_headed_kv = single_headed_kv
516+
value_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head
517+
514518
self.scale = tuple(dim ** -0.5 for dim in dim_head)
515519
self.heads = heads
516520

@@ -520,14 +524,14 @@ def __init__(
520524
# (1) gating the htypes on the values branch
521525
# (2) attention logits, with dimension equal to heads amount for starters
522526

523-
type0_dim = hidden_fiber[0]
524-
htype_dims = sum(hidden_fiber[1:])
527+
type0_dim = value_hidden_fiber[0]
528+
htype_dims = sum(value_hidden_fiber[1:])
525529

526-
value_gate_fiber = tuple_set_at_index(hidden_fiber, 0, type0_dim + htype_dims)
530+
value_gate_fiber = tuple_set_at_index(value_hidden_fiber, 0, type0_dim + htype_dims)
527531

528532
attn_hidden_dims = tuple(head * attn_hidden_dim_mult for head in heads)
529533

530-
intermediate_fiber = tuple_set_at_index(hidden_fiber, 0, sum(attn_hidden_dims) + type0_dim + htype_dims)
534+
intermediate_fiber = tuple_set_at_index(value_hidden_fiber, 0, sum(attn_hidden_dims) + type0_dim + htype_dims)
531535
self.intermediate_type0_split = [*attn_hidden_dims, type0_dim + htype_dims]
532536

533537
# main branch tensor product
@@ -562,6 +566,8 @@ def forward(
562566
basis,
563567
mask = None
564568
):
569+
one_headed_kv = self.single_headed_kv
570+
565571
features = self.prenorm(features)
566572

567573
intermediate = self.to_attn_and_v(features, edge_info, rel_dist, basis)
@@ -578,7 +584,7 @@ def forward(
578584
attn_intermediate = rearrange(attn_intermediate, '... 1 -> ...')
579585
attn_logits = fn(attn_intermediate)
580586
attn_logits = attn_logits * scale
581-
attn = attn_logits.softmax(dim = -1) # (batch, source, target, heads)
587+
attn = attn_logits.softmax(dim = -2) # (batch, source, target, heads)
582588
attentions.append(attn)
583589

584590
# process values branch
@@ -589,13 +595,18 @@ def forward(
589595

590596
outputs = {}
591597

598+
value_einsum_eq = 'b i j h d m' if not one_headed_kv else 'b i j d m'
599+
592600
for degree, (attn, value, h) in enumerate(zip(attentions, values.values(), self.heads)):
593-
value = rearrange(value, 'b i j (h d) m -> b i j h d m', h = h)
594-
out = einsum('b i j h, b i j h d m -> b i h d m', attn, value)
595-
out = rearrange(out, 'b i h d m -> b i (h d) m')
601+
if not one_headed_kv:
602+
value = rearrange(value, 'b i j (h d) m -> b i j h d m', h = h)
596603

604+
out = einsum(f'b i j h, {value_einsum_eq} -> b i h d m', attn, value)
605+
out = rearrange(out, 'b i h d m -> b i (h d) m')
597606
outputs[degree] = out
598607

608+
# combine heads out
609+
599610
return self.to_out(outputs)
600611

601612
# main class
@@ -626,7 +637,8 @@ def __init__(
626637
embedding_grad_frac = 0.5,
627638
single_headed_kv = False, # whether to do single headed key/values for dot product attention, to save on memory and compute
628639
ff_include_htype_norms = False, # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms
629-
dot_product_attention = True
640+
dot_product_attention = True,
641+
**kwargs
630642
):
631643
super().__init__()
632644

@@ -699,7 +711,7 @@ def __init__(
699711

700712
for ind in range(depth):
701713
self.layers.append(nn.ModuleList([
702-
attention_klass(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv),
714+
attention_klass(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv, **kwargs),
703715
FeedForward(self.dim, include_htype_norms = ff_include_htype_norms)
704716
]))
705717

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.11'
1+
__version__ = '0.0.12'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)