Skip to content

Commit 1ab6ce9

Browse files
committed
cleanup
1 parent 8e47674 commit 1ab6ce9

File tree

2 files changed

+17
-31
lines changed

2 files changed

+17
-31
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -302,33 +302,6 @@ def forward(
302302
outputs = {degree: torch.cat(tensors, dim = -3) for degree, tensors in enumerate(zip(self_interact_out.values(), outputs.values()))}
303303
return outputs
304304

305-
class RadialFunc(nn.Module):
306-
def __init__(
307-
self,
308-
num_freq,
309-
in_dim,
310-
out_dim,
311-
edge_dim = None,
312-
mid_dim = 64,
313-
):
314-
super().__init__()
315-
self.in_dim = in_dim
316-
self.mid_dim = mid_dim
317-
self.out_dim = out_dim
318-
319-
edge_dim = default(edge_dim, 0)
320-
321-
self.net = nn.Sequential(
322-
nn.Linear(edge_dim + mid_dim, mid_dim),
323-
nn.SiLU(),
324-
LayerNorm(mid_dim),
325-
nn.Linear(mid_dim, num_freq * in_dim * out_dim)
326-
)
327-
328-
def forward(self, x):
329-
y = self.net(x)
330-
return rearrange(y, '... (o i f) -> ... o 1 i 1 f', i = self.in_dim, o = self.out_dim)
331-
332305
class PairwiseTP(nn.Module):
333306
def __init__(
334307
self,
@@ -337,7 +310,7 @@ def __init__(
337310
degree_out,
338311
nc_out,
339312
edge_dim = 0,
340-
radial_hidden_dim = 16
313+
radial_hidden_dim = 64
341314
):
342315
super().__init__()
343316
self.degree_in = degree_in
@@ -349,10 +322,23 @@ def __init__(
349322
self.d_out = to_order(degree_out)
350323
self.edge_dim = edge_dim
351324

352-
self.rp = RadialFunc(self.num_freq, nc_in, nc_out, mid_dim = radial_hidden_dim, edge_dim = edge_dim)
325+
mid_dim = radial_hidden_dim
326+
edge_dim = default(edge_dim, 0)
327+
328+
self.rp = nn.Sequential(
329+
nn.Linear(edge_dim + mid_dim, mid_dim),
330+
nn.SiLU(),
331+
LayerNorm(mid_dim),
332+
nn.Linear(mid_dim, mid_dim),
333+
nn.SiLU(),
334+
LayerNorm(mid_dim),
335+
nn.Linear(mid_dim, self.num_freq * nc_in * nc_out)
336+
)
353337

354338
def forward(self, feat, basis):
355339
R = self.rp(feat)
340+
R = rearrange(R, '... (o i f) -> ... o 1 i 1 f', i = self.nc_in, o = self.nc_out)
341+
356342
B = basis[f'{self.degree_in},{self.degree_out}']
357343

358344
out_shape = (*R.shape[:3], self.d_out * self.nc_out, -1)
@@ -425,7 +411,7 @@ def __init__(
425411
attend_self = False,
426412
edge_dim = None,
427413
single_headed_kv = False,
428-
radial_hidden_dim = 16,
414+
radial_hidden_dim = 64,
429415
splits = 4
430416
):
431417
super().__init__()

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.23'
1+
__version__ = '0.0.24'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)