Skip to content

Commit 7e9bf59

Browse files
committed
residualize a section of the mlp within the radial function
1 parent cffb431 commit 7e9bf59

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@
2020

2121
EdgeInfo = namedtuple('EdgeInfo', ['neighbor_indices', 'neighbor_mask', 'edges'])
2222

23-
# biasless layernorm
24-
25-
class LayerNorm(nn.Module):
26-
def __init__(self, dim):
27-
super().__init__()
28-
self.gamma = nn.Parameter(torch.ones(dim))
29-
self.register_buffer("beta", torch.zeros(dim))
30-
31-
def forward(self, x):
32-
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
33-
3423
# fiber functions
3524

3625
@beartype
@@ -89,6 +78,23 @@ def feature_fiber(feature):
8978

9079
# classes
9180

81+
class Residual(nn.Module):
82+
def __init__(self, fn):
83+
super().__init__()
84+
self.fn = fn
85+
86+
def forward(self, x, **kwargs):
87+
return self.fn(x, **kwargs) + x
88+
89+
class LayerNorm(nn.Module):
90+
def __init__(self, dim):
91+
super().__init__()
92+
self.gamma = nn.Parameter(torch.ones(dim))
93+
self.register_buffer("beta", torch.zeros(dim))
94+
95+
def forward(self, x):
96+
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
97+
9298
@beartype
9399
class Linear(nn.Module):
94100
def __init__(
@@ -328,9 +334,11 @@ def __init__(
328334
self.rp = nn.Sequential(
329335
nn.Linear(edge_dim + mid_dim, mid_dim),
330336
nn.SiLU(),
331-
LayerNorm(mid_dim),
332-
nn.Linear(mid_dim, mid_dim),
333-
nn.SiLU(),
337+
Residual(nn.Sequential(
338+
LayerNorm(mid_dim),
339+
nn.Linear(mid_dim, mid_dim),
340+
nn.SiLU()
341+
)),
334342
LayerNorm(mid_dim),
335343
nn.Linear(mid_dim, self.num_freq * nc_in * nc_out)
336344
)

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = '0.0.25'
1+
__version__ = '0.0.26'
22

33
__cuda_pkg_name__ = f'equiformer_pytorch_cuda_{__version__.replace(".", "_")}'

0 commit comments

Comments
 (0)