Skip to content

Commit 3622864

Browse files
committed
cleanup
1 parent 447768f commit 3622864

File tree

3 files changed

+38
-42
lines changed

3 files changed

+38
-42
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ $ python setup.py test
6464

6565
## Todo
6666

67+
- [x] move xi and xj separate project and sum logic into Conv class
68+
6769
- [ ] figure out DTP heuristic
6870
- [ ] move self interacting key / value production into Conv, fix no pooling in conv with self interaction
6971
- [ ] start moving some spherical harmonic stuff to cpp or nim
70-
- [ ] move xi and xj separate project and sum logic into Conv class
7172

7273
## Citations
7374

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,16 @@ def forward(self, x):
179179
return output
180180

181181
@beartype
182-
class Conv(nn.Module):
182+
class TP(nn.Module):
183+
""" 'Tensor Product' - in the equivariant sense """
184+
183185
def __init__(
184186
self,
185187
fiber_in: Tuple[int, ...],
186188
fiber_out: Tuple[int, ...],
187189
self_interaction = True,
190+
project_xi_xj = True, # whether to project xi and xj and then sum, as in paper
191+
project_out = True, # whether to do a project out after the "tensor product"
188192
pool = True,
189193
edge_dim = 0,
190194
splits = 4
@@ -194,27 +198,33 @@ def __init__(
194198
self.fiber_out = fiber_out
195199
self.edge_dim = edge_dim
196200
self.self_interaction = self_interaction
201+
self.pool = pool
202+
self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
197203

198-
self.kernel_unary = nn.ModuleDict()
204+
self.project_xi_xj = project_xi_xj
205+
if project_xi_xj:
206+
self.to_xi = Linear(fiber_in, fiber_in)
207+
self.to_xj = Linear(fiber_in, fiber_in)
199208

200-
self.splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
209+
self.kernel_unary = nn.ModuleDict()
201210

202211
for (di, mi), (do, mo) in fiber_product(self.fiber_in, self.fiber_out):
203-
self.kernel_unary[f'({di},{do})'] = PairwiseConv(di, mi, do, mo, edge_dim = edge_dim)
204-
205-
self.pool = pool
212+
self.kernel_unary[f'({di},{do})'] = PairwiseTP(di, mi, do, mo, edge_dim = edge_dim)
206213

207214
if self_interaction:
208215
assert self.pool, 'must pool edges if followed with self interaction'
209216
self.self_interact = Linear(fiber_in, fiber_out)
210217

218+
self.project_out = project_out
219+
if project_out:
220+
self.to_out = Linear(fiber_out, fiber_out)
221+
211222
def forward(
212223
self,
213224
inp,
214225
edge_info: EdgeInfo,
215226
rel_dist = None,
216-
basis = None,
217-
neighbors = None
227+
basis = None
218228
):
219229
splits = self.splits
220230
neighbor_indices, neighbor_masks, edges = edge_info
@@ -231,8 +241,10 @@ def forward(
231241

232242
# neighbors
233243

234-
neighbors_separate_embed = exists(neighbors)
235-
neighbors = default(neighbors, inp)
244+
if self.project_xi_xj:
245+
source, target = self.to_xi(inp), self.to_xj(inp)
246+
else:
247+
source, target = inp, inp
236248

237249
# go through every permutation of input degree type to output degree type
238250

@@ -242,11 +254,11 @@ def forward(
242254
for degree_in, m_in in enumerate(self.fiber_in):
243255
etype = f'({degree_in},{degree_out})'
244256

245-
xi, xj = inp[degree_in], neighbors[degree_in]
257+
xi, xj = source[degree_in], target[degree_in]
246258

247259
x = batched_index_select(xj, neighbor_indices, dim = 1)
248260

249-
if neighbors_separate_embed:
261+
if self.project_xi_xj:
250262
xi = rearrange(xi, 'b i d m -> b i 1 d m')
251263
x = x + xi
252264

@@ -280,6 +292,9 @@ def forward(
280292
self_interact_out = self.self_interact(inp)
281293
outputs = residual_fn(outputs, self_interact_out)
282294

295+
if self.project_out:
296+
outputs = self.to_out(outputs)
297+
283298
return outputs
284299

285300
class RadialFunc(nn.Module):
@@ -311,7 +326,7 @@ def forward(self, x):
311326
y = self.net(x)
312327
return rearrange(y, '... (o i f) -> ... o 1 i 1 f', i = self.in_dim, o = self.out_dim)
313328

314-
class PairwiseConv(nn.Module):
329+
class PairwiseTP(nn.Module):
315330
def __init__(
316331
self,
317332
degree_in,
@@ -430,11 +445,7 @@ def __init__(
430445
self.prenorm = Norm(fiber)
431446

432447
self.to_q = Linear(fiber, hidden_fiber)
433-
434-
self.to_xi = Linear(fiber, fiber)
435-
self.to_xj = Linear(fiber, fiber)
436-
437-
self.to_kv = Conv(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
448+
self.to_kv = TP(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
438449
self.to_self_kv = Linear(fiber, kv_hidden_fiber) if attend_self else None
439450

440451
self.to_out = Linear(hidden_fiber, fiber)
@@ -459,11 +470,8 @@ def forward(
459470

460471
queries = self.to_q(features)
461472

462-
xi, xj = self.to_xi(features), self.to_xj(features)
463-
464473
keyvalues = self.to_kv(
465-
xi,
466-
neighbors = xj,
474+
features,
467475
edge_info = edge_info,
468476
rel_dist = rel_dist,
469477
basis = basis
@@ -557,16 +565,9 @@ def __init__(
557565
intermediate_fiber = tuple_set_at_index(value_hidden_fiber, 0, sum(attn_hidden_dims) + type0_dim + htype_dims)
558566
self.intermediate_type0_split = [*attn_hidden_dims, type0_dim + htype_dims]
559567

560-
# linear project xi and xj separately
561-
562-
self.to_xi = Linear(fiber, fiber)
563-
self.to_xj = Linear(fiber, fiber)
564-
565568
# main branch tensor product
566569

567-
self.to_attn_and_v = Conv(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
568-
569-
self.post_to_attn_and_v_linear = Linear(intermediate_fiber, intermediate_fiber)
570+
self.to_attn_and_v = TP(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits)
570571

571572
# non-linear projection of attention branch into the attention logits
572573

@@ -601,19 +602,13 @@ def forward(
601602

602603
features = self.prenorm(features)
603604

604-
xi = self.to_xi(features)
605-
xj = self.to_xj(features)
606-
607605
intermediate = self.to_attn_and_v(
608-
xi,
609-
neighbors = xj,
606+
features,
610607
edge_info = edge_info,
611608
rel_dist = rel_dist,
612609
basis = basis
613610
)
614611

615-
intermediate = self.post_to_attn_and_v_linear(intermediate)
616-
617612
*attn_branch_type0, value_branch_type0 = intermediate[0].split(self.intermediate_type0_split, dim = -2)
618613

619614
intermediate[0] = value_branch_type0
@@ -739,11 +734,11 @@ def __init__(
739734

740735
# define fibers and dimensionality
741736

742-
conv_kwargs = dict(edge_dim = edge_dim, splits = splits)
737+
tp_kwargs = dict(edge_dim = edge_dim, splits = splits)
743738

744739
# main network
745740

746-
self.conv_in = Conv(self.dim_in, self.dim, **conv_kwargs)
741+
self.tp_in = TP(self.dim_in, self.dim, **tp_kwargs)
747742

748743
# trunk
749744

@@ -879,7 +874,7 @@ def forward(
879874

880875
# project in
881876

882-
x = self.conv_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)
877+
x = self.tp_in(x, edge_info, rel_dist = neighbor_rel_dist, basis = basis)
883878

884879
# transformer layers
885880

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.17'
1+
__version__ = '0.0.21'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)