Skip to content

Commit b74579d

Browse files
committed
move self interaction logic into TP
1 parent 3622864 commit b74579d

File tree

3 files changed

+14
-23
lines changed

3 files changed

+14
-23
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ $ python setup.py test
6565
## Todo
6666

6767
- [x] move xi and xj separate project and sum logic into Conv class
68+
- [x] move self interacting key / value production into Conv, fix no pooling in conv with self interaction
6869

6970
- [ ] figure out DTP heuristic
70-
- [ ] move self interacting key / value production into Conv, fix no pooling in conv with self interaction
7171
- [ ] start moving some spherical harmonic stuff to cpp or nim
7272

7373
## Citations

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(
212212
self.kernel_unary[f'({di},{do})'] = PairwiseTP(di, mi, do, mo, edge_dim = edge_dim)
213213

214214
if self_interaction:
215-
assert self.pool, 'must pool edges if followed with self interaction'
216215
self.self_interact = Linear(fiber_in, fiber_out)
217216

218217
self.project_out = project_out
@@ -288,13 +287,19 @@ def forward(
288287

289288
outputs[degree_out] = output
290289

291-
if self.self_interaction:
292-
self_interact_out = self.self_interact(inp)
293-
outputs = residual_fn(outputs, self_interact_out)
290+
if not self.self_interaction and not self.project_out:
291+
return outputs
294292

295293
if self.project_out:
296294
outputs = self.to_out(outputs)
297295

296+
self_interact_out = self.self_interact(inp)
297+
298+
if self.pool:
299+
return residual_fn(outputs, self_interact_out)
300+
301+
self_interact_out = {k: rearrange(v, '... d m -> ... 1 d m') for k, v in self_interact_out.items()}
302+
outputs = {degree: torch.cat(tensors, dim = -3) for degree, tensors in enumerate(zip(self_interact_out.values(), outputs.values()))}
298303
return outputs
299304

300305
class RadialFunc(nn.Module):
@@ -445,8 +450,7 @@ def __init__(
445450
self.prenorm = Norm(fiber)
446451

447452
self.to_q = Linear(fiber, hidden_fiber)
448-
self.to_kv = TP(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
449-
self.to_self_kv = Linear(fiber, kv_hidden_fiber) if attend_self else None
453+
self.to_kv = TP(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
450454

451455
self.to_out = Linear(hidden_fiber, fiber)
452456

@@ -458,7 +462,7 @@ def forward(
458462
basis,
459463
mask = None
460464
):
461-
attend_self, one_head_kv = exists(self.to_self_kv), self.single_headed_kv
465+
one_head_kv = self.single_headed_kv
462466

463467
device, dtype = get_tensor_device_and_dtype(features)
464468
neighbor_indices, neighbor_mask, edges = edge_info
@@ -479,9 +483,6 @@ def forward(
479483

480484
kv_einsum_eq = 'b h i j d m' if not one_head_kv else 'b i j d m'
481485

482-
if attend_self:
483-
self_keyvalues = self.to_self_kv(features)
484-
485486
outputs = {}
486487

487488
for degree, h, scale in zip(features.keys(), self.heads, self.scale):
@@ -492,16 +493,6 @@ def forward(
492493
if not one_head_kv:
493494
kv = rearrange(kv, f'b i j (h d) m -> b h i j d m', h = h)
494495

495-
if attend_self:
496-
self_kv = self_keyvalues[degree]
497-
498-
if not one_head_kv:
499-
self_kv = rearrange(self_kv, 'b n (h d) m -> b h n 1 d m', h = h)
500-
else:
501-
self_kv = rearrange(self_kv, 'b n d m -> b n 1 d m')
502-
503-
kv = torch.cat((self_kv, kv), dim = -3)
504-
505496
k, v = kv.chunk(2, dim = -2)
506497

507498
sim = einsum(f'b h i d m, {kv_einsum_eq} -> b h i j', q, k) * scale
@@ -567,7 +558,7 @@ def __init__(
567558

568559
# main branch tensor product
569560

570-
self.to_attn_and_v = TP(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
561+
self.to_attn_and_v = TP(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = attend_self, splits = splits)
571562

572563
# non-linear projection of attention branch into the attention logits
573564

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.21'
1+
__version__ = '0.0.22'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)