@@ -350,13 +350,11 @@ def __init__(
350350 edge_dim = default (edge_dim , 0 )
351351
352352 self .rp = nn .Sequential (
353- nn .Linear (edge_dim + mid_dim , mid_dim ),
353+ nn .Linear (edge_dim + 1 , mid_dim ),
354+ nn .SiLU (),
355+ LayerNorm (mid_dim ),
356+ nn .Linear (mid_dim , mid_dim ),
354357 nn .SiLU (),
355- Residual (nn .Sequential (
356- LayerNorm (mid_dim ),
357- nn .Linear (mid_dim , mid_dim ),
358- nn .SiLU ()
359- )),
360358 LayerNorm (mid_dim ),
361359 nn .Linear (mid_dim , self .num_freq * nc_in * nc_out ),
362360 Rearrange ('... (o i f) -> ... o 1 i 1 f' , i = nc_in , o = nc_out )
@@ -729,14 +727,6 @@ def __init__(
729727 self .edge_emb = nn .Embedding (num_edge_tokens , edge_dim ) if exists (num_edge_tokens ) else None
730728 self .has_edges = exists (edge_dim ) and edge_dim > 0
731729
732- # initial MLP of relative distances to intermediate representation, akin to time conditioning in ddpm unets
733-
734- self .to_rel_dist_hidden = nn .Sequential (
735- nn .Linear (1 , radial_hidden_dim ),
736- nn .SiLU (),
737- LayerNorm (radial_hidden_dim )
738- )
739-
740730 # whether to differentiate through basis, needed gradients for iterative refinement
741731
742732 self .differentiable_coors = differentiable_coors
@@ -890,7 +880,7 @@ def forward(
890880
891881 # embed relative distances
892882
893- neighbor_rel_dist = self . to_rel_dist_hidden ( rearrange (neighbor_rel_dist , '... -> ... 1' ) )
883+ neighbor_rel_dist = rearrange (neighbor_rel_dist , '... -> ... 1' )
894884
895885 # calculate basis
896886
0 commit comments