@@ -212,7 +212,6 @@ def __init__(
212212 self .kernel_unary [f'({ di } ,{ do } )' ] = PairwiseTP (di , mi , do , mo , edge_dim = edge_dim )
213213
214214 if self_interaction :
215- assert self .pool , 'must pool edges if followed with self interaction'
216215 self .self_interact = Linear (fiber_in , fiber_out )
217216
218217 self .project_out = project_out
@@ -290,7 +289,12 @@ def forward(
290289
291290 if self .self_interaction :
292291 self_interact_out = self .self_interact (inp )
293- outputs = residual_fn (outputs , self_interact_out )
292+
293+ if self .pool :
294+ outputs = residual_fn (outputs , self_interact_out )
295+ else :
296+ self_interact_out = {k : rearrange (v , '... d m -> ... 1 d m' ) for k , v in self_interact_out .items ()}
297+ outputs = {degree : torch .cat (tensors , dim = - 3 ) for degree , tensors in enumerate (zip (self_interact_out .values (), outputs .values ()))}
294298
295299 if self .project_out :
296300 outputs = self .to_out (outputs )
@@ -445,8 +449,7 @@ def __init__(
445449 self .prenorm = Norm (fiber )
446450
447451 self .to_q = Linear (fiber , hidden_fiber )
448- self .to_kv = TP (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
449- self .to_self_kv = Linear (fiber , kv_hidden_fiber ) if attend_self else None
452+ self .to_kv = TP (fiber , kv_hidden_fiber , edge_dim = edge_dim , pool = False , self_interaction = attend_self , splits = splits )
450453
451454 self .to_out = Linear (hidden_fiber , fiber )
452455
@@ -458,7 +461,7 @@ def forward(
458461 basis ,
459462 mask = None
460463 ):
461- attend_self , one_head_kv = exists ( self . to_self_kv ), self .single_headed_kv
464+ one_head_kv = self .single_headed_kv
462465
463466 device , dtype = get_tensor_device_and_dtype (features )
464467 neighbor_indices , neighbor_mask , edges = edge_info
@@ -479,9 +482,6 @@ def forward(
479482
480483 kv_einsum_eq = 'b h i j d m' if not one_head_kv else 'b i j d m'
481484
482- if attend_self :
483- self_keyvalues = self .to_self_kv (features )
484-
485485 outputs = {}
486486
487487 for degree , h , scale in zip (features .keys (), self .heads , self .scale ):
@@ -492,16 +492,6 @@ def forward(
492492 if not one_head_kv :
493493 kv = rearrange (kv , f'b i j (h d) m -> b h i j d m' , h = h )
494494
495- if attend_self :
496- self_kv = self_keyvalues [degree ]
497-
498- if not one_head_kv :
499- self_kv = rearrange (self_kv , 'b n (h d) m -> b h n 1 d m' , h = h )
500- else :
501- self_kv = rearrange (self_kv , 'b n d m -> b n 1 d m' )
502-
503- kv = torch .cat ((self_kv , kv ), dim = - 3 )
504-
505495 k , v = kv .chunk (2 , dim = - 2 )
506496
507497 sim = einsum (f'b h i d m, { kv_einsum_eq } -> b h i j' , q , k ) * scale
@@ -567,7 +557,7 @@ def __init__(
567557
568558 # main branch tensor product
569559
570- self .to_attn_and_v = TP (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = False , splits = splits )
560+ self .to_attn_and_v = TP (fiber , intermediate_fiber , edge_dim = edge_dim , pool = False , self_interaction = attend_self , splits = splits )
571561
572562 # non-linear projection of attention branch into the attention logits
573563
0 commit comments