@@ -179,12 +179,16 @@ def forward(self, x):
179179 return output
180180
181181@beartype
182- class Conv (nn .Module ):
182+ class TP (nn .Module ):
183+ """ 'Tensor Product' - in the equivariant sense """
184+
183185 def __init__ (
184186 self ,
185187 fiber_in : Tuple [int , ...],
186188 fiber_out : Tuple [int , ...],
187189 self_interaction = True ,
190+ project_xi_xj = True , # whether to project xi and xj and then sum, as in paper
191+ project_out = True , # whether to do a project out after the "tensor product"
188192 pool = True ,
189193 edge_dim = 0 ,
190194 splits = 4
@@ -194,27 +198,33 @@ def __init__(
194198 self .fiber_out = fiber_out
195199 self .edge_dim = edge_dim
196200 self .self_interaction = self_interaction
201+ self .pool = pool
202+ self .splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
197203
198- self .kernel_unary = nn .ModuleDict ()
204+ self .project_xi_xj = project_xi_xj
205+ if project_xi_xj :
206+ self .to_xi = Linear (fiber_in , fiber_in )
207+ self .to_xj = Linear (fiber_in , fiber_in )
199208
200- self .splits = splits # for splitting the computation of kernel and basis, to reduce peak memory usage
209+ self .kernel_unary = nn . ModuleDict ()
201210
202211 for (di , mi ), (do , mo ) in fiber_product (self .fiber_in , self .fiber_out ):
203- self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseConv (di , mi , do , mo , edge_dim = edge_dim )
204-
205- self .pool = pool
212+ self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseTP (di , mi , do , mo , edge_dim = edge_dim )
206213
207214 if self_interaction :
208215 assert self .pool , 'must pool edges if followed with self interaction'
209216 self .self_interact = Linear (fiber_in , fiber_out )
210217
218+ self .project_out = project_out
219+ if project_out :
220+ self .to_out = Linear (fiber_out , fiber_out )
221+
211222 def forward (
212223 self ,
213224 inp ,
214225 edge_info : EdgeInfo ,
215226 rel_dist = None ,
216- basis = None ,
217- neighbors = None
227+ basis = None
218228 ):
219229 splits = self .splits
220230 neighbor_indices , neighbor_masks , edges = edge_info
@@ -231,8 +241,10 @@ def forward(
231241
232242 # neighbors
233243
234- neighbors_separate_embed = exists (neighbors )
235- neighbors = default (neighbors , inp )
244+ if self .project_xi_xj :
245+ source , target = self .to_xi (inp ), self .to_xj (inp )
246+ else :
247+ source , target = inp , inp
236248
237249 # go through every permutation of input degree type to output degree type
238250
@@ -242,11 +254,11 @@ def forward(
242254 for degree_in , m_in in enumerate (self .fiber_in ):
243255 etype = f'({ degree_in } ,{ degree_out } )'
244256
245- xi , xj = inp [degree_in ], neighbors [degree_in ]
257+ xi , xj = source [degree_in ], target [degree_in ]
246258
247259 x = batched_index_select (xj , neighbor_indices , dim = 1 )
248260
249- if neighbors_separate_embed :
261+ if self . project_xi_xj :
250262 xi = rearrange (xi , 'b i d m -> b i 1 d m' )
251263 x = x + xi
252264
@@ -280,6 +292,9 @@ def forward(
280292 self_interact_out = self .self_interact (inp )
281293 outputs = residual_fn (outputs , self_interact_out )
282294
295+ if self .project_out :
296+ outputs = self .to_out (outputs )
297+
283298 return outputs
284299
285300class RadialFunc (nn .Module ):
@@ -311,7 +326,7 @@ def forward(self, x):
311326 y = self .net (x )
312327 return rearrange (y , '... (o i f) -> ... o 1 i 1 f' , i = self .in_dim , o = self .out_dim )
313328
314- class PairwiseConv (nn .Module ):
329+ class PairwiseTP (nn .Module ):
315330 def __init__ (
316331 self ,
317332 degree_in ,
@@ -430,11 +445,7 @@ def __init__(
430445 self .prenorm = Norm (fiber )
431446
432447 self .to_q = Linear (fiber , hidden_fiber )
433-
434- self .to_xi = Linear (fiber , fiber )
435- self .to_xj = Linear (fiber , fiber )
436-
437- self .to_kv = Conv (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
448+ self .to_kv = TP (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
438449 self .to_self_kv = Linear (fiber , kv_hidden_fiber ) if attend_self else None
439450
440451 self .to_out = Linear (hidden_fiber , fiber )
@@ -459,11 +470,8 @@ def forward(
459470
460471 queries = self .to_q (features )
461472
462- xi , xj = self .to_xi (features ), self .to_xj (features )
463-
464473 keyvalues = self .to_kv (
465- xi ,
466- neighbors = xj ,
474+ features ,
467475 edge_info = edge_info ,
468476 rel_dist = rel_dist ,
469477 basis = basis
@@ -557,16 +565,9 @@ def __init__(
557565 intermediate_fiber = tuple_set_at_index (value_hidden_fiber , 0 , sum (attn_hidden_dims ) + type0_dim + htype_dims )
558566 self .intermediate_type0_split = [* attn_hidden_dims , type0_dim + htype_dims ]
559567
560- # linear project xi and xj separately
561-
562- self .to_xi = Linear (fiber , fiber )
563- self .to_xj = Linear (fiber , fiber )
564-
565568 # main branch tensor product
566569
567- self .to_attn_and_v = Conv (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
568-
569- self .post_to_attn_and_v_linear = Linear (intermediate_fiber , intermediate_fiber )
570+ self .to_attn_and_v = TP (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
570571
571572 # non-linear projection of attention branch into the attention logits
572573
@@ -601,19 +602,13 @@ def forward(
601602
602603 features = self .prenorm (features )
603604
604- xi = self .to_xi (features )
605- xj = self .to_xj (features )
606-
607605 intermediate = self .to_attn_and_v (
608- xi ,
609- neighbors = xj ,
606+ features ,
610607 edge_info = edge_info ,
611608 rel_dist = rel_dist ,
612609 basis = basis
613610 )
614611
615- intermediate = self .post_to_attn_and_v_linear (intermediate )
616-
617612 * attn_branch_type0 , value_branch_type0 = intermediate [0 ].split (self .intermediate_type0_split , dim = - 2 )
618613
619614 intermediate [0 ] = value_branch_type0
@@ -739,11 +734,11 @@ def __init__(
739734
740735 # define fibers and dimensionality
741736
742- conv_kwargs = dict (edge_dim = edge_dim , splits = splits )
737+ tp_kwargs = dict (edge_dim = edge_dim , splits = splits )
743738
744739 # main network
745740
746- self .conv_in = Conv (self .dim_in , self .dim , ** conv_kwargs )
741+ self .tp_in = TP (self .dim_in , self .dim , ** tp_kwargs )
747742
748743 # trunk
749744
@@ -879,7 +874,7 @@ def forward(
879874
880875 # project in
881876
882- x = self .conv_in (x , edge_info , rel_dist = neighbor_rel_dist , basis = basis )
877+ x = self .tp_in (x , edge_info , rel_dist = neighbor_rel_dist , basis = basis )
883878
884879 # transformer layers
885880
0 commit comments