|
20 | 20 |
|
21 | 21 | EdgeInfo = namedtuple('EdgeInfo', ['neighbor_indices', 'neighbor_mask', 'edges']) |
22 | 22 |
|
23 | | -# biasless layernorm |
24 | | - |
25 | | -class LayerNorm(nn.Module): |
26 | | - def __init__(self, dim): |
27 | | - super().__init__() |
28 | | - self.gamma = nn.Parameter(torch.ones(dim)) |
29 | | - self.register_buffer("beta", torch.zeros(dim)) |
30 | | - |
31 | | - def forward(self, x): |
32 | | - return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) |
33 | | - |
34 | 23 | # fiber functions |
35 | 24 |
|
36 | 25 | @beartype |
@@ -89,6 +78,23 @@ def feature_fiber(feature): |
89 | 78 |
|
90 | 79 | # classes |
91 | 80 |
|
| 81 | +class Residual(nn.Module): |
| 82 | + def __init__(self, fn): |
| 83 | + super().__init__() |
| 84 | + self.fn = fn |
| 85 | + |
| 86 | + def forward(self, x, **kwargs): |
| 87 | + return self.fn(x, **kwargs) + x |
| 88 | + |
| 89 | +class LayerNorm(nn.Module): |
| 90 | + def __init__(self, dim): |
| 91 | + super().__init__() |
| 92 | + self.gamma = nn.Parameter(torch.ones(dim)) |
| 93 | + self.register_buffer("beta", torch.zeros(dim)) |
| 94 | + |
| 95 | + def forward(self, x): |
| 96 | + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) |
| 97 | + |
92 | 98 | @beartype |
93 | 99 | class Linear(nn.Module): |
94 | 100 | def __init__( |
@@ -328,9 +334,11 @@ def __init__( |
328 | 334 | self.rp = nn.Sequential( |
329 | 335 | nn.Linear(edge_dim + mid_dim, mid_dim), |
330 | 336 | nn.SiLU(), |
331 | | - LayerNorm(mid_dim), |
332 | | - nn.Linear(mid_dim, mid_dim), |
333 | | - nn.SiLU(), |
| 337 | + Residual(nn.Sequential( |
| 338 | + LayerNorm(mid_dim), |
| 339 | + nn.Linear(mid_dim, mid_dim), |
| 340 | + nn.SiLU() |
| 341 | + )), |
334 | 342 | LayerNorm(mid_dim), |
335 | 343 | nn.Linear(mid_dim, self.num_freq * nc_in * nc_out) |
336 | 344 | ) |
|
0 commit comments