@@ -302,33 +302,6 @@ def forward(
302302 outputs = {degree : torch .cat (tensors , dim = - 3 ) for degree , tensors in enumerate (zip (self_interact_out .values (), outputs .values ()))}
303303 return outputs
304304
305- class RadialFunc (nn .Module ):
306- def __init__ (
307- self ,
308- num_freq ,
309- in_dim ,
310- out_dim ,
311- edge_dim = None ,
312- mid_dim = 64 ,
313- ):
314- super ().__init__ ()
315- self .in_dim = in_dim
316- self .mid_dim = mid_dim
317- self .out_dim = out_dim
318-
319- edge_dim = default (edge_dim , 0 )
320-
321- self .net = nn .Sequential (
322- nn .Linear (edge_dim + mid_dim , mid_dim ),
323- nn .SiLU (),
324- LayerNorm (mid_dim ),
325- nn .Linear (mid_dim , num_freq * in_dim * out_dim )
326- )
327-
328- def forward (self , x ):
329- y = self .net (x )
330- return rearrange (y , '... (o i f) -> ... o 1 i 1 f' , i = self .in_dim , o = self .out_dim )
331-
332305class PairwiseTP (nn .Module ):
333306 def __init__ (
334307 self ,
@@ -337,7 +310,7 @@ def __init__(
337310 degree_out ,
338311 nc_out ,
339312 edge_dim = 0 ,
340- radial_hidden_dim = 16
313+ radial_hidden_dim = 64
341314 ):
342315 super ().__init__ ()
343316 self .degree_in = degree_in
@@ -349,10 +322,23 @@ def __init__(
349322 self .d_out = to_order (degree_out )
350323 self .edge_dim = edge_dim
351324
352- self .rp = RadialFunc (self .num_freq , nc_in , nc_out , mid_dim = radial_hidden_dim , edge_dim = edge_dim )
325+ mid_dim = radial_hidden_dim
326+ edge_dim = default (edge_dim , 0 )
327+
328+ self .rp = nn .Sequential (
329+ nn .Linear (edge_dim + mid_dim , mid_dim ),
330+ nn .SiLU (),
331+ LayerNorm (mid_dim ),
332+ nn .Linear (mid_dim , mid_dim ),
333+ nn .SiLU (),
334+ LayerNorm (mid_dim ),
335+ nn .Linear (mid_dim , self .num_freq * nc_in * nc_out )
336+ )
353337
354338 def forward (self , feat , basis ):
355339 R = self .rp (feat )
340+ R = rearrange (R , '... (o i f) -> ... o 1 i 1 f' , i = self .nc_in , o = self .nc_out )
341+
356342 B = basis [f'{ self .degree_in } ,{ self .degree_out } ' ]
357343
358344 out_shape = (* R .shape [:3 ], self .d_out * self .nc_out , - 1 )
@@ -425,7 +411,7 @@ def __init__(
425411 attend_self = False ,
426412 edge_dim = None ,
427413 single_headed_kv = False ,
428- radial_hidden_dim = 16 ,
414+ radial_hidden_dim = 64 ,
429415 splits = 4
430416 ):
431417 super ().__init__ ()
0 commit comments