Skip to content

Commit 8bf2023

Browse files
committed
move self interaction logic into TP
1 parent 3622864 commit 8bf2023

File tree

3 files changed

+11
-21
lines changed

3 files changed

+11
-21
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ $ python setup.py test
6565
## Todo
6666

6767
- [x] move xi and xj separate project and sum logic into Conv class
68+
- [x] move self interacting key / value production into Conv, fix no pooling in conv with self interaction
6869

6970
- [ ] figure out DTP heuristic
70-
- [ ] move self interacting key / value production into Conv, fix no pooling in conv with self interaction
7171
- [ ] start moving some spherical harmonic stuff to cpp or nim
7272

7373
## Citations

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(
212212
self.kernel_unary[f'({di},{do})'] = PairwiseTP(di, mi, do, mo, edge_dim = edge_dim)
213213

214214
if self_interaction:
215-
assert self.pool, 'must pool edges if followed with self interaction'
216215
self.self_interact = Linear(fiber_in, fiber_out)
217216

218217
self.project_out = project_out
@@ -290,7 +289,12 @@ def forward(
290289

291290
if self.self_interaction:
292291
self_interact_out = self.self_interact(inp)
293-
outputs = residual_fn(outputs, self_interact_out)
292+
293+
if self.pool:
294+
outputs = residual_fn(outputs, self_interact_out)
295+
else:
296+
self_interact_out = {k: rearrange(v, '... d m -> ... 1 d m') for k, v in self_interact_out.items()}
297+
outputs = {degree: torch.cat(tensors, dim = -3) for degree, tensors in enumerate(zip(self_interact_out.values(), outputs.values()))}
294298

295299
if self.project_out:
296300
outputs = self.to_out(outputs)
@@ -445,8 +449,7 @@ def __init__(
445449
self.prenorm = Norm(fiber)
446450

447451
self.to_q = Linear(fiber, hidden_fiber)
448-
self.to_kv = TP(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
449-
self.to_self_kv = Linear(fiber, kv_hidden_fiber) if attend_self else None
452+
self.to_kv = TP(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
450453

451454
self.to_out = Linear(hidden_fiber, fiber)
452455

@@ -458,7 +461,7 @@ def forward(
458461
basis,
459462
mask = None
460463
):
461-
attend_self, one_head_kv = exists(self.to_self_kv), self.single_headed_kv
464+
one_head_kv = self.single_headed_kv
462465

463466
device, dtype = get_tensor_device_and_dtype(features)
464467
neighbor_indices, neighbor_mask, edges = edge_info
@@ -479,9 +482,6 @@ def forward(
479482

480483
kv_einsum_eq = 'b h i j d m' if not one_head_kv else 'b i j d m'
481484

482-
if attend_self:
483-
self_keyvalues = self.to_self_kv(features)
484-
485485
outputs = {}
486486

487487
for degree, h, scale in zip(features.keys(), self.heads, self.scale):
@@ -492,16 +492,6 @@ def forward(
492492
if not one_head_kv:
493493
kv = rearrange(kv, f'b i j (h d) m -> b h i j d m', h = h)
494494

495-
if attend_self:
496-
self_kv = self_keyvalues[degree]
497-
498-
if not one_head_kv:
499-
self_kv = rearrange(self_kv, 'b n (h d) m -> b h n 1 d m', h = h)
500-
else:
501-
self_kv = rearrange(self_kv, 'b n d m -> b n 1 d m')
502-
503-
kv = torch.cat((self_kv, kv), dim = -3)
504-
505495
k, v = kv.chunk(2, dim = -2)
506496

507497
sim = einsum(f'b h i d m, {kv_einsum_eq} -> b h i j', q, k) * scale
@@ -567,7 +557,7 @@ def __init__(
567557

568558
# main branch tensor product
569559

570-
self.to_attn_and_v = TP(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
560+
self.to_attn_and_v = TP(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
571561

572562
# non-linear projection of attention branch into the attention logits
573563

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.21'
1+
__version__ = '0.0.22'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)