@@ -179,12 +179,16 @@ def forward(self, x):
179179 return output
180180
181181@beartype
182- class Conv (nn .Module ):
182+ class TP (nn .Module ):
183+ """ 'Tensor Product' - in the equivariant sense """
184+
183185 def __init__ (
184186 self ,
185187 fiber_in : Tuple [int , ...],
186188 fiber_out : Tuple [int , ...],
187189 self_interaction = True ,
190+ project_xi_xj = True , # whether to project xi and xj and then sum, as in paper
191+ project_out = True , # whether to do a project out after the "tensor product"
188192 pool = True ,
189193 edge_dim = 0 ,
190194 splits = 4
@@ -194,27 +198,33 @@ def __init__(
194198 self .fiber_out = fiber_out
195199 self .edge_dim = edge_dim
196200 self .self_interaction = self_interaction
201+ self .pool = pool
202+ self .splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
197203
198- self .kernel_unary = nn .ModuleDict ()
204+ self .project_xi_xj = project_xi_xj
205+ if project_xi_xj :
206+ self .to_xi = Linear (fiber_in , fiber_in )
207+ self .to_xj = Linear (fiber_in , fiber_in )
199208
200- self .splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
209+ self .kernel_unary = nn . ModuleDict ()
201210
202211 for (di , mi ), (do , mo ) in fiber_product (self .fiber_in , self .fiber_out ):
203- self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseConv (di , mi , do , mo , edge_dim = edge_dim )
204-
205- self .pool = pool
212+ self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseTP (di , mi , do , mo , edge_dim = edge_dim )
206213
207214 if self_interaction :
208215 assert self .pool , 'must pool edges if followed with self interaction'
209216 self .self_interact = Linear (fiber_in , fiber_out )
210217
218+ self .project_out = project_out
219+ if project_out :
220+ self .to_out = Linear (fiber_out , fiber_out )
221+
211222 def forward (
212223 self ,
213224 inp ,
214225 edge_info : EdgeInfo ,
215226 rel_dist = None ,
216- basis = None ,
217- neighbors = None
227+ basis = None
218228 ):
219229 splits = self .splits
220230 neighbor_indices , neighbor_masks , edges = edge_info
@@ -231,8 +241,10 @@ def forward(
231241
232242 # neighbors
233243
234- neighbors_separate_embed = exists (neighbors )
235- neighbors = default (neighbors , inp )
244+ if self .project_xi_xj :
245+ source , target = self .to_xi (inp ), self .to_xj (inp )
246+ else :
247+ source , target = inp , inp
236248
237249 # go through every permutation of input degree type to output degree type
238250
@@ -242,11 +254,11 @@ def forward(
242254 for degree_in , m_in in enumerate (self .fiber_in ):
243255 etype = f'({ degree_in } ,{ degree_out } )'
244256
245- xi , xj = inp [degree_in ], neighbors [degree_in ]
257+ xi , xj = source [degree_in ], target [degree_in ]
246258
247259 x = batched_index_select (xj , neighbor_indices , dim = 1 )
248260
249- if neighbors_separate_embed :
261+ if self . project_xi_xj :
250262 xi = rearrange (xi , 'b i d m -> b i 1 d m' )
251263 x = x + xi
252264
@@ -280,6 +292,9 @@ def forward(
280292 self_interact_out = self .self_interact (inp )
281293 outputs = residual_fn (outputs , self_interact_out )
282294
295+ if self .project_out :
296+ outputs = self .to_out (outputs )
297+
283298 return outputs
284299
285300class RadialFunc (nn .Module ):
@@ -311,7 +326,7 @@ def forward(self, x):
311326 y = self .net (x )
312327 return rearrange (y , '... (o i f) -> ... o 1 i 1 f' , i = self .in_dim , o = self .out_dim )
313328
314- class PairwiseConv (nn .Module ):
329+ class PairwiseTP (nn .Module ):
315330 def __init__ (
316331 self ,
317332 degree_in ,
@@ -434,7 +449,7 @@ def __init__(
434449 self .to_xi = Linear (fiber , fiber )
435450 self .to_xj = Linear (fiber , fiber )
436451
437- self .to_kv = Conv (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
452+ self .to_kv = TP (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
438453 self .to_self_kv = Linear (fiber , kv_hidden_fiber ) if attend_self else None
439454
440455 self .to_out = Linear (hidden_fiber , fiber )
@@ -463,7 +478,6 @@ def forward(
463478
464479 keyvalues = self .to_kv (
465480 xi ,
466- neighbors = xj ,
467481 edge_info = edge_info ,
468482 rel_dist = rel_dist ,
469483 basis = basis
@@ -557,16 +571,9 @@ def __init__(
557571 intermediate_fiber = tuple_set_at_index (value_hidden_fiber , 0 , sum (attn_hidden_dims ) + type0_dim + htype_dims )
558572 self .intermediate_type0_split = [* attn_hidden_dims , type0_dim + htype_dims ]
559573
560- # linear project xi and xj separately
561-
562- self .to_xi = Linear (fiber , fiber )
563- self .to_xj = Linear (fiber , fiber )
564-
565574 # main branch tensor product
566575
567- self .to_attn_and_v = Conv (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
568-
569- self .post_to_attn_and_v_linear = Linear (intermediate_fiber , intermediate_fiber )
576+ self .to_attn_and_v = TP (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
570577
571578 # non-linear projection of attention branch into the attention logits
572579
@@ -601,19 +608,13 @@ def forward(
601608
602609 features = self .prenorm (features )
603610
604- xi = self .to_xi (features )
605- xj = self .to_xj (features )
606-
607611 intermediate = self .to_attn_and_v (
608- xi ,
609- neighbors = xj ,
612+ features ,
610613 edge_info = edge_info ,
611614 rel_dist = rel_dist ,
612615 basis = basis
613616 )
614617
615- intermediate = self .post_to_attn_and_v_linear (intermediate )
616-
617618 * attn_branch_type0 , value_branch_type0 = intermediate [0 ].split (self .intermediate_type0_split , dim = - 2 )
618619
619620 intermediate [0 ] = value_branch_type0
@@ -739,11 +740,11 @@ def __init__(
739740
740741 # define fibers and dimensionality
741742
742- conv_kwargs = dict (edge_dim = edge_dim , splits = splits )
743+ tp_kwargs = dict (edge_dim = edge_dim , splits = splits )
743744
744745 # main network
745746
746- self .conv_in = Conv (self .dim_in , self .dim , ** conv_kwargs )
747+ self .tp_in = TP (self .dim_in , self .dim , ** tp_kwargs )
747748
748749 # trunk
749750
@@ -879,7 +880,7 @@ def forward(
879880
880881 # project in
881882
882- x = self .conv_in (x , edge_info , rel_dist = neighbor_rel_dist , basis = basis )
883+ x = self .tp_in (x , edge_info , rel_dist = neighbor_rel_dist , basis = basis )
883884
884885 # transformer layers
885886
0 commit comments