@@ -191,6 +191,7 @@ def __init__(
191191 project_out = True , # whether to do a project out after the "tensor product"
192192 pool = True ,
193193 edge_dim = 0 ,
194+ radial_hidden_dim = 16 ,
194195 splits = 4
195196 ):
196197 super ().__init__ ()
@@ -209,7 +210,7 @@ def __init__(
209210 self .kernel_unary = nn .ModuleDict ()
210211
211212 for (di , mi ), (do , mo ) in fiber_product (self .fiber_in , self .fiber_out ):
212- self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseTP (di , mi , do , mo , edge_dim = edge_dim )
213+ self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseTP (di , mi , do , mo , radial_hidden_dim = radial_hidden_dim , edge_dim = edge_dim )
213214
214215 if self_interaction :
215216 self .self_interact = Linear (fiber_in , fiber_out )
@@ -227,7 +228,6 @@ def forward(
227228 ):
228229 splits = self .splits
229230 neighbor_indices , neighbor_masks , edges = edge_info
230- rel_dist = rearrange (rel_dist , 'b m n -> b m n 1' )
231231
232232 kernels = {}
233233 outputs = {}
@@ -309,19 +309,17 @@ def __init__(
309309 in_dim ,
310310 out_dim ,
311311 edge_dim = None ,
312- mid_dim = 128
312+ mid_dim = 64 ,
313313 ):
314314 super ().__init__ ()
315- self .num_freq = num_freq
316315 self .in_dim = in_dim
317316 self .mid_dim = mid_dim
318317 self .out_dim = out_dim
319318
319+ edge_dim = default (edge_dim , 0 )
320+
320321 self .net = nn .Sequential (
321- nn .Linear (default (edge_dim , 0 ) + 1 , mid_dim ),
322- nn .SiLU (),
323- LayerNorm (mid_dim ),
324- nn .Linear (mid_dim , mid_dim ),
322+ nn .Linear (edge_dim + mid_dim , mid_dim ),
325323 nn .SiLU (),
326324 LayerNorm (mid_dim ),
327325 nn .Linear (mid_dim , num_freq * in_dim * out_dim )
@@ -338,7 +336,8 @@ def __init__(
338336 nc_in ,
339337 degree_out ,
340338 nc_out ,
341- edge_dim = 0
339+ edge_dim = 0 ,
340+ radial_hidden_dim = 16
342341 ):
343342 super ().__init__ ()
344343 self .degree_in = degree_in
@@ -350,7 +349,7 @@ def __init__(
350349 self .d_out = to_order (degree_out )
351350 self .edge_dim = edge_dim
352351
353- self .rp = RadialFunc (self .num_freq , nc_in , nc_out , edge_dim )
352+ self .rp = RadialFunc (self .num_freq , nc_in , nc_out , mid_dim = radial_hidden_dim , edge_dim = edge_dim )
354353
355354 def forward (self , feat , basis ):
356355 R = self .rp (feat )
@@ -426,6 +425,7 @@ def __init__(
426425 attend_self = False ,
427426 edge_dim = None ,
428427 single_headed_kv = False ,
428+ radial_hidden_dim = 16 ,
429429 splits = 4
430430 ):
431431 super ().__init__ ()
@@ -450,7 +450,7 @@ def __init__(
450450 self .prenorm = Norm (fiber )
451451
452452 self .to_q = Linear (fiber , hidden_fiber )
453- self .to_kv = TP (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = attend_self , splits = splits )
453+ self .to_kv = TP (fiber , kv_hidden_fiber , radial_hidden_dim = radial_hidden_dim , edge_dim = edge_dim , pool = False , self_interaction = attend_self , splits = splits )
454454
455455 self .to_out = Linear (hidden_fiber , fiber )
456456
@@ -521,6 +521,7 @@ def __init__(
521521 single_headed_kv = False ,
522522 attn_leakyrelu_slope = 0.1 ,
523523 attn_hidden_dim_mult = 4 ,
524+ radial_hidden_dim = 16 ,
524525 ** kwargs
525526 ):
526527 super ().__init__ ()
@@ -558,7 +559,7 @@ def __init__(
558559
559560 # main branch tensor product
560561
561- self .to_attn_and_v = TP (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = attend_self , splits = splits )
562+ self .to_attn_and_v = TP (fiber , intermediate_fiber , radial_hidden_dim = radial_hidden_dim , edge_dim = edge_dim , pool = False , self_interaction = attend_self , splits = splits )
562563
563564 # non-linear projection of attention branch into the attention logits
564565
@@ -654,6 +655,7 @@ def __init__(
654655 valid_radius = 1e5 ,
655656 num_neighbors = float ('inf' ),
656657 reduce_dim_out = False ,
658+ radial_hidden_dim = 64 ,
657659 num_tokens = None ,
658660 num_positions = None ,
659661 num_edge_tokens = None ,
@@ -709,12 +711,21 @@ def __init__(
709711
710712 # edges
711713
712- assert not (exists (num_edge_tokens ) and not exists (edge_dim )), 'edge dimension (edge_dim) must be supplied if SE3 transformer is to have edge tokens'
714+ assert not (exists (num_edge_tokens ) and not exists (edge_dim )), 'edge dimension (edge_dim) must be supplied if equiformer is to have edge tokens'
713715
714716 self .edge_emb = nn .Embedding (num_edge_tokens , edge_dim ) if exists (num_edge_tokens ) else None
715717 self .has_edges = exists (edge_dim ) and edge_dim > 0
716718
717- # whether to differentiate through basis, needed for alphafold2
719+ # initial MLP of relative distances to intermediate representation, akin to time conditioning in ddpm unets
720+
721+ self .to_rel_dist_hidden = nn .Sequential (
722+ nn .Linear (1 , radial_hidden_dim ),
723+ nn .SiLU (),
724+ LayerNorm (radial_hidden_dim ),
725+ nn .Linear (radial_hidden_dim , radial_hidden_dim ),
726+ )
727+
728+ # whether to differentiate through basis, needed gradients for iterative refinement
718729
719730 self .differentiable_coors = differentiable_coors
720731
@@ -723,13 +734,15 @@ def __init__(
723734 self .valid_radius = valid_radius
724735 self .num_neighbors = num_neighbors
725736
726- # define fibers and dimensionality
727-
728- tp_kwargs = dict (edge_dim = edge_dim , splits = splits )
729-
730737 # main network
731738
732- self .tp_in = TP (self .dim_in , self .dim , ** tp_kwargs )
739+ self .tp_in = TP (
740+ self .dim_in ,
741+ self .dim ,
742+ edge_dim = edge_dim ,
743+ radial_hidden_dim = radial_hidden_dim ,
744+ splits = splits
745+ )
733746
734747 # trunk
735748
@@ -739,7 +752,17 @@ def __init__(
739752
740753 for ind in range (depth ):
741754 self .layers .append (nn .ModuleList ([
742- attention_klass (self .dim , heads = heads , dim_head = dim_head , attend_self = attend_self , edge_dim = edge_dim , splits = splits , single_headed_kv = single_headed_kv , ** kwargs ),
755+ attention_klass (
756+ self .dim ,
757+ heads = heads ,
758+ dim_head = dim_head ,
759+ attend_self = attend_self ,
760+ edge_dim = edge_dim ,
761+ splits = splits ,
762+ single_headed_kv = single_headed_kv ,
763+ radial_hidden_dim = radial_hidden_dim ,
764+ ** kwargs
765+ ),
743766 FeedForward (self .dim , include_htype_norms = ff_include_htype_norms )
744767 ]))
745768
@@ -784,7 +807,7 @@ def forward(
784807
785808 num_degrees , neighbors , valid_radius = self .num_degrees , self .num_neighbors , self .valid_radius
786809
787- # se3 transformer by default cannot have a node attend to itself
810+ # cannot have a node attend to itself
788811
789812 exclude_self_mask = rearrange (~ torch .eye (n , dtype = torch .bool , device = device ), 'i j -> 1 i j' )
790813 remove_self = lambda t : t .masked_select (exclude_self_mask ).reshape (b , n , n - 1 )
@@ -853,6 +876,10 @@ def forward(
853876 if exists (edges ):
854877 edges = batched_index_select (edges , nearest_indices , dim = 2 )
855878
879+ # embed relative distances
880+
881+ neighbor_rel_dist = self .to_rel_dist_hidden (rearrange (neighbor_rel_dist , '... -> ... 1' ))
882+
856883 # calculate basis
857884
858885 basis = get_basis (neighbor_rel_pos , num_degrees - 1 , differentiable = self .differentiable_coors )
0 commit comments