Skip to content

Commit 5cc6dfa

Browse files
committed
the way it was before is better
1 parent 8816684 commit 5cc6dfa

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,11 @@ def __init__(
350350
edge_dim = default(edge_dim, 0)
351351

352352
self.rp = nn.Sequential(
353-
nn.Linear(edge_dim + mid_dim, mid_dim),
353+
nn.Linear(edge_dim + 1, mid_dim),
354+
nn.SiLU(),
355+
LayerNorm(mid_dim),
356+
nn.Linear(mid_dim, mid_dim),
354357
nn.SiLU(),
355-
Residual(nn.Sequential(
356-
LayerNorm(mid_dim),
357-
nn.Linear(mid_dim, mid_dim),
358-
nn.SiLU()
359-
)),
360358
LayerNorm(mid_dim),
361359
nn.Linear(mid_dim, self.num_freq * nc_in * nc_out),
362360
Rearrange('... (o i f) -> ... o 1 i 1 f', i = nc_in, o = nc_out)
@@ -729,14 +727,6 @@ def __init__(
729727
self.edge_emb = nn.Embedding(num_edge_tokens, edge_dim) if exists(num_edge_tokens) else None
730728
self.has_edges = exists(edge_dim) and edge_dim > 0
731729

732-
# initial MLP of relative distances to intermediate representation, akin to time conditioning in ddpm unets
733-
734-
self.to_rel_dist_hidden = nn.Sequential(
735-
nn.Linear(1, radial_hidden_dim),
736-
nn.SiLU(),
737-
LayerNorm(radial_hidden_dim)
738-
)
739-
740730
# whether to differentiate through basis, needed gradients for iterative refinement
741731

742732
self.differentiable_coors = differentiable_coors
@@ -890,7 +880,7 @@ def forward(
890880

891881
# embed relative distances
892882

893-
neighbor_rel_dist = self.to_rel_dist_hidden(rearrange(neighbor_rel_dist, '... -> ... 1'))
883+
neighbor_rel_dist = rearrange(neighbor_rel_dist, '... -> ... 1')
894884

895885
# calculate basis
896886

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.27'
1+
__version__ = '0.0.29'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)