@@ -122,10 +122,10 @@ def forward(self, batch):
122
122
return results
123
123
124
124
125
- class NodeAttention (nn .Layer ):
125
+ class FirstBodyAxialAttention (nn .Layer ):
126
126
"""Compute self-attention over columns of a 2D input."""
127
127
def __init__ (self , model_config , global_config ):
128
- super (NodeAttention , self ).__init__ ()
128
+ super (FirstBodyAxialAttention , self ).__init__ ()
129
129
self .model_config = model_config
130
130
131
131
node_channel = global_config .node_channel
@@ -253,9 +253,9 @@ def forward(self, x):
253
253
return x
254
254
255
255
256
- class OuterProductMean (nn .Layer ):
256
+ class Low2HighModule (nn .Layer ):
257
257
def __init__ (self , model_config , global_config ):
258
- super (OuterProductMean , self ).__init__ ()
258
+ super (Low2HighModule , self ).__init__ ()
259
259
node_channel = global_config .node_channel
260
260
pair_channel = global_config .pair_channel
261
261
inner_channel = model_config .inner_channel
@@ -285,9 +285,9 @@ def forward(self, node_acts, node_mask):
285
285
return act
286
286
287
287
288
- class TriangleAttentionWithAngle (nn .Layer ):
288
+ class SecondBodyAxialAttentionWithAngle (nn .Layer ):
289
289
def __init__ (self , model_config , global_config ):
290
- super (TriangleAttentionWithAngle , self ).__init__ ()
290
+ super (SecondBodyAxialAttentionWithAngle , self ).__init__ ()
291
291
pair_channel = global_config .pair_channel
292
292
triple_channel = global_config .triple_channel
293
293
self .num_head = model_config .num_head
@@ -342,9 +342,9 @@ def forward(self, pair_acts, triple_acts, bias):
342
342
return out
343
343
344
344
345
- class TriangleAttentionWithAngleBias (nn .Layer ):
345
+ class SecondBodyAxialAttentionWithAngleBias (nn .Layer ):
346
346
def __init__ (self , model_config , global_config ):
347
- super (TriangleAttentionWithAngleBias , self ).__init__ ()
347
+ super (SecondBodyAxialAttentionWithAngleBias , self ).__init__ ()
348
348
pair_channel = global_config .pair_channel
349
349
triple_channel = global_config .triple_channel
350
350
self .num_head = model_config .num_head
@@ -395,15 +395,15 @@ def forward(self, pair_acts, triple_acts, bias):
395
395
return out
396
396
397
397
398
- class TriangleAttention (nn .Layer ):
398
+ class SecondBodyAxialAttention (nn .Layer ):
399
399
def __init__ (self , model_config , global_config ):
400
- super (TriangleAttention , self ).__init__ ()
400
+ super (SecondBodyAxialAttention , self ).__init__ ()
401
401
self .is_start = model_config .is_start
402
402
403
403
if model_config .get ('angle_as_bias' , False ):
404
- self .attn_mod = TriangleAttentionWithAngleBias (model_config , global_config )
404
+ self .attn_mod = SecondBodyAxialAttentionWithAngleBias (model_config , global_config )
405
405
else :
406
- self .attn_mod = TriangleAttentionWithAngle (model_config , global_config )
406
+ self .attn_mod = SecondBodyAxialAttentionWithAngle (model_config , global_config )
407
407
408
408
def forward (self , pair_acts , triple_acts , triple_mask ):
409
409
"""
@@ -431,29 +431,29 @@ def __init__(self, model_config, global_config):
431
431
pair_channel = global_config .pair_channel
432
432
433
433
### node track
434
- self .node_attn = NodeAttention (
435
- model_config .node_attention , global_config )
436
- self .node_attn_dropout = nn .Dropout (model_config .node_dropout_rate )
434
+ self .first_body_axial_attention = FirstBodyAxialAttention (
435
+ model_config .first_body_axial_attention , global_config )
436
+ self .first_body_axial_attention_dropout = nn .Dropout (model_config .first_body_axial_attention_dropout )
437
437
438
438
self .node_ffn = FeedForwardNetwork (
439
439
model_config .node_ffn , node_channel )
440
- self .node_ffn_dropout = nn .Dropout (model_config .node_dropout_rate )
440
+ self .node_ffn_dropout = nn .Dropout (model_config .first_body_axial_attention_dropout )
441
441
442
- ### outer
443
- self .outer_product = OuterProductMean (
444
- model_config .outer_product , global_config )
445
- self .outer_product_dropout = nn .Dropout (model_config .pair_dropout_rate )
442
+ ### low2high
443
+ self .low2high = Low2HighModule (
444
+ model_config .low2high , global_config )
445
+ self .low2high_dropout = nn .Dropout (model_config .pair_dropout_rate )
446
446
447
447
### pair track
448
448
self .pair_before_ln = nn .LayerNorm (pair_channel )
449
449
450
- self .triangle_attn_start = TriangleAttention (
451
- model_config .triangle_attention_start_node , global_config )
452
- self .triangle_attn_start_dropout = nn .Dropout (model_config .pair_dropout_rate )
450
+ self .second_body_first_axis = SecondBodyAxialAttention (
451
+ model_config .second_body_first_axis , global_config )
452
+ self .second_body_first_axis_dropout = nn .Dropout (model_config .pair_dropout_rate )
453
453
454
- self .triangle_attn_end = TriangleAttention (
455
- model_config .triangle_attention_end_node , global_config )
456
- self .triangle_attn_end_dropout = nn .Dropout (model_config .pair_dropout_rate )
454
+ self .second_body_second_axis = SecondBodyAxialAttention (
455
+ model_config .second_body_second_axis , global_config )
456
+ self .second_body_second_axis_dropout = nn .Dropout (model_config .pair_dropout_rate )
457
457
458
458
self .pair_ffn = FeedForwardNetwork (
459
459
model_config .pair_ffn , pair_channel )
@@ -476,24 +476,24 @@ def forward(self, node_acts, pair_acts, triple_acts, mask_dict):
476
476
triple_mask = mask_dict ['triple' ]
477
477
478
478
# node track
479
- residual = self .node_attn (node_acts , pair_acts , node_mask , pair_mask )
480
- node_acts += self .node_attn_dropout (residual )
479
+ residual = self .first_body_axial_attention (node_acts , pair_acts , node_mask , pair_mask )
480
+ node_acts += self .first_body_axial_attention_dropout (residual )
481
481
482
482
residual = self .node_ffn (node_acts )
483
483
node_acts += self .node_ffn_dropout (residual )
484
484
485
485
# outer
486
- outer = self .outer_product (node_acts , node_mask )
487
- pair_acts += self .outer_product_dropout (outer )
486
+ outer = self .low2high (node_acts , node_mask )
487
+ pair_acts += self .low2high_dropout (outer )
488
488
489
489
# pair track
490
490
pair_acts = self .pair_before_ln (pair_acts )
491
491
492
- residual = self .triangle_attn_start (pair_acts , triple_acts , triple_mask )
493
- pair_acts += self .triangle_attn_start_dropout (residual )
492
+ residual = self .second_body_first_axis (pair_acts , triple_acts , triple_mask )
493
+ pair_acts += self .second_body_first_axis_dropout (residual )
494
494
495
- residual = self .triangle_attn_end (pair_acts , triple_acts , triple_mask )
496
- pair_acts += self .triangle_attn_end_dropout (residual )
495
+ residual = self .second_body_second_axis (pair_acts , triple_acts , triple_mask )
496
+ pair_acts += self .second_body_second_axis_dropout (residual )
497
497
498
498
residual = self .pair_ffn (pair_acts )
499
499
pair_acts += self .pair_ffn_dropout (residual )
0 commit comments