Skip to content

Commit 967d71d

Browse files
committed
first pass at MLP attention, seems to be training, but not better than dot product attention (slower convergence in the first few steps, experiments pending)
1 parent 1fc156d commit 967d71d

File tree

3 files changed

+149
-32
lines changed

3 files changed

+149
-32
lines changed

README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ from equiformer_pytorch import Equiformer
2424

2525
model = Equiformer(
2626
num_tokens = 24,
27-
dim = (4, 4, 2),
28-
dim_head = (4, 4, 4),
29-
heads = (2, 2, 2),
30-
num_degrees = 3,
31-
depth = 4,
32-
attend_self = True,
33-
input_degrees = 1,
34-
reduce_dim_out = True
27+
dim = (4, 4, 2), # dimensions per type, ascending, length must match number of degrees (num_degrees)
28+
dim_head = (4, 4, 4), # dimension per attention head
29+
heads = (2, 2, 2), # number of attention heads
30+
num_degrees = 3, # number of degrees
31+
depth = 4, # depth of equivariant transformer
32+
attend_self = True, # attending to self or not
33+
reduce_dim_out = True, # whether to reduce out to dimension of 1, say for predicting new coordinates for type 1 features
34+
dot_product_attention = False # set to False to try out MLP attention
3535
).cuda()
3636

3737
feats = torch.randint(0, 24, (1, 128)).cuda()
@@ -56,6 +56,12 @@ Tests for spherical harmonics, network equivariance etc
5656
$ python setup.py test
5757
```
5858

59+
## Todo
60+
61+
- [ ] figure out DTP heuristic
62+
- [ ] move self interacting key / value production into Conv, fix no pooling in conv with self interaction
63+
- [ ] start moving some spherical harmonic stuff to cpp or nim
64+
5965
## Citations
6066

6167
```bibtex

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
Return = namedtuple('Return', ['type0', 'type1'])
2020

21+
EdgeInfo = namedtuple('EdgeInfo', ['neighbor_indices', 'neighbor_mask', 'edges'])
22+
2123
# biasless layernorm
2224

2325
class LayerNorm(nn.Module):
@@ -74,6 +76,17 @@ def residual_fn(x, residual):
7476
out[degree] = out[degree] + residual[degree]
7577
return out
7678

79+
def tuple_set_at_index(tup, index, value):
80+
l = list(tup)
81+
l[index] = value
82+
return tuple(l)
83+
84+
def feature_shapes(feature):
85+
return tuple(v.shape for v in feature.values())
86+
87+
def feature_fiber(feature):
88+
return tuple(v.shape[-2] for v in feature.values())
89+
7790
# classes
7891

7992
@beartype
@@ -198,7 +211,7 @@ def __init__(
198211
def forward(
199212
self,
200213
inp,
201-
edge_info,
214+
edge_info: EdgeInfo,
202215
rel_dist = None,
203216
basis = None
204217
):
@@ -328,11 +341,6 @@ def forward(self, feat, basis):
328341

329342
# feed forwards
330343

331-
def tuple_set_at_index(tup, index, value):
332-
l = list(tup)
333-
l[index] = value
334-
return tuple(l)
335-
336344
@beartype
337345
class FeedForward(nn.Module):
338346
def __init__(
@@ -402,27 +410,30 @@ def __init__(
402410

403411
self.single_headed_kv = single_headed_kv
404412

405-
if not single_headed_kv:
406-
kv_hidden_fiber = tuple(dim * 2 for dim in hidden_fiber)
407-
else:
408-
kv_hidden_fiber = tuple(dim * 2 for dim in dim_head)
413+
kv_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head
414+
kv_hidden_fiber = tuple(dim * 2 for dim in kv_hidden_fiber)
409415

410416
self.scale = tuple(dim ** -0.5 for dim in dim_head)
411417
self.heads = heads
412418

413419
self.prenorm = Norm(fiber)
414420

415421
self.to_q = Linear(fiber, hidden_fiber)
416-
self.to_kv = Conv(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
417-
self.to_out = Linear(hidden_fiber, fiber)
418422

419-
self.attend_self = attend_self
420-
if attend_self:
421-
self.to_self_kv = Linear(fiber, kv_hidden_fiber)
423+
self.to_kv = Conv(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
424+
self.to_self_kv = Linear(fiber, kv_hidden_fiber) if attend_self else None
422425

426+
self.to_out = Linear(hidden_fiber, fiber)
423427

424-
def forward(self, features, edge_info, rel_dist, basis, mask = None):
425-
attend_self, one_head_kv = self.attend_self, self.single_headed_kv
428+
def forward(
429+
self,
430+
features,
431+
edge_info: EdgeInfo,
432+
rel_dist,
433+
basis,
434+
mask = None
435+
):
436+
attend_self, one_head_kv = exists(self.to_self_kv), self.single_headed_kv
426437

427438
device, dtype = get_tensor_device_and_dtype(features)
428439
neighbor_indices, neighbor_mask, edges = edge_info
@@ -484,12 +495,108 @@ def __init__(
484495
heads: Union[int, Tuple[int, ...]] = 8,
485496
attend_self = False,
486497
edge_dim = None,
487-
splits = 4
498+
splits = 4,
499+
attn_leakyrelu_slope = 0.1,
500+
attn_hidden_dim_mult = 4,
501+
**kwargs
488502
):
489503
super().__init__()
504+
num_degrees = len(fiber)
505+
506+
dim_head = cast_tuple(dim_head, num_degrees)
507+
assert len(dim_head) == num_degrees
508+
509+
heads = cast_tuple(heads, num_degrees)
510+
assert len(heads) == num_degrees
511+
512+
hidden_fiber = tuple(dim * head for dim, head in zip(dim_head, heads))
513+
514+
self.scale = tuple(dim ** -0.5 for dim in dim_head)
515+
self.heads = heads
516+
517+
self.prenorm = Norm(fiber)
490518

491-
def forward(self, features, edge_info, rel_dist, basis, mask = None):
492-
raise NotImplementedError
519+
# type 0 needs greater dimension, for
520+
# (1) gating the htypes on the values branch
521+
# (2) attention logits, with dimension equal to heads amount for starters
522+
523+
type0_dim = hidden_fiber[0]
524+
htype_dims = sum(hidden_fiber[1:])
525+
526+
value_gate_fiber = tuple_set_at_index(hidden_fiber, 0, type0_dim + htype_dims)
527+
528+
attn_hidden_dims = tuple(head * attn_hidden_dim_mult for head in heads)
529+
530+
intermediate_fiber = tuple_set_at_index(hidden_fiber, 0, sum(attn_hidden_dims) + type0_dim + htype_dims)
531+
self.intermediate_type0_split = [*attn_hidden_dims, type0_dim + htype_dims]
532+
533+
# main branch tensor product
534+
535+
self.to_attn_and_v = Conv(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
536+
537+
# non-linear projection of attention branch into the attention logits
538+
539+
self.to_attn_logits = nn.ModuleList([
540+
nn.Sequential(
541+
nn.LeakyReLU(attn_leakyrelu_slope),
542+
nn.Linear(attn_hidden_dim, h, bias = False)
543+
) for attn_hidden_dim, h in zip(attn_hidden_dims, self.heads)
544+
])
545+
546+
# non-linear transform of the value branch
547+
# todo - needs a DTP here?
548+
549+
self.to_values = nn.Sequential(
550+
Gate(value_gate_fiber)
551+
)
552+
553+
# combining heads and projection out
554+
555+
self.to_out = Linear(hidden_fiber, fiber)
556+
557+
def forward(
558+
self,
559+
features,
560+
edge_info: EdgeInfo,
561+
rel_dist,
562+
basis,
563+
mask = None
564+
):
565+
features = self.prenorm(features)
566+
567+
intermediate = self.to_attn_and_v(features, edge_info, rel_dist, basis)
568+
569+
*attn_branch_type0, value_branch_type0 = intermediate[0].split(self.intermediate_type0_split, dim = -2)
570+
571+
intermediate[0] = value_branch_type0
572+
573+
# process the attention branch
574+
575+
attentions = []
576+
577+
for fn, attn_intermediate, scale in zip(self.to_attn_logits, attn_branch_type0, self.scale):
578+
attn_intermediate = rearrange(attn_intermediate, '... 1 -> ...')
579+
attn_logits = fn(attn_intermediate)
580+
attn_logits = attn_logits * scale
581+
attn = attn_logits.softmax(dim = -1) # (batch, source, target, heads)
582+
attentions.append(attn)
583+
584+
# process values branch
585+
586+
values = self.to_values(intermediate)
587+
588+
# aggregate values with attention matrix
589+
590+
outputs = {}
591+
592+
for degree, (attn, value, h) in enumerate(zip(attentions, values.values(), self.heads)):
593+
value = rearrange(value, 'b i j (h d) m -> b i j h d m', h = h)
594+
out = einsum('b i j h, b i j h d m -> b i h d m', attn, value)
595+
out = rearrange(out, 'b i h d m -> b i (h d) m')
596+
597+
outputs[degree] = out
598+
599+
return self.to_out(outputs)
493600

494601
# main class
495602

@@ -518,7 +625,8 @@ def __init__(
518625
linear_out = True,
519626
embedding_grad_frac = 0.5,
520627
single_headed_kv = False, # whether to do single headed key/values for dot product attention, to save on memory and compute
521-
ff_include_htype_norms = False # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms
628+
ff_include_htype_norms = False, # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms
629+
dot_product_attention = True
522630
):
523631
super().__init__()
524632

@@ -587,9 +695,11 @@ def __init__(
587695

588696
self.layers = nn.ModuleList([])
589697

698+
attention_klass = DotProductAttention if dot_product_attention else MLPAttention
699+
590700
for ind in range(depth):
591701
self.layers.append(nn.ModuleList([
592-
DotProductAttention(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv),
702+
attention_klass(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv),
593703
FeedForward(self.dim, include_htype_norms = ff_include_htype_norms)
594704
]))
595705

@@ -709,7 +819,8 @@ def forward(
709819

710820
# main logic
711821

712-
edge_info = (neighbor_indices, neighbor_mask, edges)
822+
edge_info = EdgeInfo(neighbor_indices, neighbor_mask, edges)
823+
713824
x = feats
714825

715826
# project in

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.9'
1+
__version__ = '0.0.11'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)