|
18 | 18 |
|
19 | 19 | Return = namedtuple('Return', ['type0', 'type1']) |
20 | 20 |
|
| 21 | +EdgeInfo = namedtuple('EdgeInfo', ['neighbor_indices', 'neighbor_mask', 'edges']) |
| 22 | + |
21 | 23 | # biasless layernorm |
22 | 24 |
|
23 | 25 | class LayerNorm(nn.Module): |
@@ -74,6 +76,17 @@ def residual_fn(x, residual): |
74 | 76 | out[degree] = out[degree] + residual[degree] |
75 | 77 | return out |
76 | 78 |
|
| 79 | +def tuple_set_at_index(tup, index, value): |
| 80 | + l = list(tup) |
| 81 | + l[index] = value |
| 82 | + return tuple(l) |
| 83 | + |
| 84 | +def feature_shapes(feature): |
| 85 | + return tuple(v.shape for v in feature.values()) |
| 86 | + |
| 87 | +def feature_fiber(feature): |
| 88 | + return tuple(v.shape[-2] for v in feature.values()) |
| 89 | + |
77 | 90 | # classes |
78 | 91 |
|
79 | 92 | @beartype |
@@ -198,7 +211,7 @@ def __init__( |
198 | 211 | def forward( |
199 | 212 | self, |
200 | 213 | inp, |
201 | | - edge_info, |
| 214 | + edge_info: EdgeInfo, |
202 | 215 | rel_dist = None, |
203 | 216 | basis = None |
204 | 217 | ): |
@@ -328,11 +341,6 @@ def forward(self, feat, basis): |
328 | 341 |
|
329 | 342 | # feed forwards |
330 | 343 |
|
331 | | -def tuple_set_at_index(tup, index, value): |
332 | | - l = list(tup) |
333 | | - l[index] = value |
334 | | - return tuple(l) |
335 | | - |
336 | 344 | @beartype |
337 | 345 | class FeedForward(nn.Module): |
338 | 346 | def __init__( |
@@ -402,27 +410,30 @@ def __init__( |
402 | 410 |
|
403 | 411 | self.single_headed_kv = single_headed_kv |
404 | 412 |
|
405 | | - if not single_headed_kv: |
406 | | - kv_hidden_fiber = tuple(dim * 2 for dim in hidden_fiber) |
407 | | - else: |
408 | | - kv_hidden_fiber = tuple(dim * 2 for dim in dim_head) |
| 413 | + kv_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head |
| 414 | + kv_hidden_fiber = tuple(dim * 2 for dim in kv_hidden_fiber) |
409 | 415 |
|
410 | 416 | self.scale = tuple(dim ** -0.5 for dim in dim_head) |
411 | 417 | self.heads = heads |
412 | 418 |
|
413 | 419 | self.prenorm = Norm(fiber) |
414 | 420 |
|
415 | 421 | self.to_q = Linear(fiber, hidden_fiber) |
416 | | - self.to_kv = Conv(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits) |
417 | | - self.to_out = Linear(hidden_fiber, fiber) |
418 | 422 |
|
419 | | - self.attend_self = attend_self |
420 | | - if attend_self: |
421 | | - self.to_self_kv = Linear(fiber, kv_hidden_fiber) |
| 423 | + self.to_kv = Conv(fiber, kv_hidden_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits) |
| 424 | + self.to_self_kv = Linear(fiber, kv_hidden_fiber) if attend_self else None |
422 | 425 |
|
| 426 | + self.to_out = Linear(hidden_fiber, fiber) |
423 | 427 |
|
424 | | - def forward(self, features, edge_info, rel_dist, basis, mask = None): |
425 | | - attend_self, one_head_kv = self.attend_self, self.single_headed_kv |
| 428 | + def forward( |
| 429 | + self, |
| 430 | + features, |
| 431 | + edge_info: EdgeInfo, |
| 432 | + rel_dist, |
| 433 | + basis, |
| 434 | + mask = None |
| 435 | + ): |
| 436 | + attend_self, one_head_kv = exists(self.to_self_kv), self.single_headed_kv |
426 | 437 |
|
427 | 438 | device, dtype = get_tensor_device_and_dtype(features) |
428 | 439 | neighbor_indices, neighbor_mask, edges = edge_info |
@@ -484,12 +495,108 @@ def __init__( |
484 | 495 | heads: Union[int, Tuple[int, ...]] = 8, |
485 | 496 | attend_self = False, |
486 | 497 | edge_dim = None, |
487 | | - splits = 4 |
| 498 | + splits = 4, |
| 499 | + attn_leakyrelu_slope = 0.1, |
| 500 | + attn_hidden_dim_mult = 4, |
| 501 | + **kwargs |
488 | 502 | ): |
489 | 503 | super().__init__() |
| 504 | + num_degrees = len(fiber) |
| 505 | + |
| 506 | + dim_head = cast_tuple(dim_head, num_degrees) |
| 507 | + assert len(dim_head) == num_degrees |
| 508 | + |
| 509 | + heads = cast_tuple(heads, num_degrees) |
| 510 | + assert len(heads) == num_degrees |
| 511 | + |
| 512 | + hidden_fiber = tuple(dim * head for dim, head in zip(dim_head, heads)) |
| 513 | + |
| 514 | + self.scale = tuple(dim ** -0.5 for dim in dim_head) |
| 515 | + self.heads = heads |
| 516 | + |
| 517 | + self.prenorm = Norm(fiber) |
490 | 518 |
|
491 | | - def forward(self, features, edge_info, rel_dist, basis, mask = None): |
492 | | - raise NotImplementedError |
| 519 | + # type 0 needs greater dimension, for |
| 520 | + # (1) gating the htypes on the values branch |
| 521 | + # (2) attention logits, with dimension equal to heads amount for starters |
| 522 | + |
| 523 | + type0_dim = hidden_fiber[0] |
| 524 | + htype_dims = sum(hidden_fiber[1:]) |
| 525 | + |
| 526 | + value_gate_fiber = tuple_set_at_index(hidden_fiber, 0, type0_dim + htype_dims) |
| 527 | + |
| 528 | + attn_hidden_dims = tuple(head * attn_hidden_dim_mult for head in heads) |
| 529 | + |
| 530 | + intermediate_fiber = tuple_set_at_index(hidden_fiber, 0, sum(attn_hidden_dims) + type0_dim + htype_dims) |
| 531 | + self.intermediate_type0_split = [*attn_hidden_dims, type0_dim + htype_dims] |
| 532 | + |
| 533 | + # main branch tensor product |
| 534 | + |
| 535 | + self.to_attn_and_v = Conv(fiber, intermediate_fiber, edge_dim = edge_dim, pool = False, self_interaction = False, splits = splits) |
| 536 | + |
| 537 | + # non-linear projection of attention branch into the attention logits |
| 538 | + |
| 539 | + self.to_attn_logits = nn.ModuleList([ |
| 540 | + nn.Sequential( |
| 541 | + nn.LeakyReLU(attn_leakyrelu_slope), |
| 542 | + nn.Linear(attn_hidden_dim, h, bias = False) |
| 543 | + ) for attn_hidden_dim, h in zip(attn_hidden_dims, self.heads) |
| 544 | + ]) |
| 545 | + |
| 546 | + # non-linear transform of the value branch |
| 547 | + # todo - needs a DTP here? |
| 548 | + |
| 549 | + self.to_values = nn.Sequential( |
| 550 | + Gate(value_gate_fiber) |
| 551 | + ) |
| 552 | + |
| 553 | + # combining heads and projection out |
| 554 | + |
| 555 | + self.to_out = Linear(hidden_fiber, fiber) |
| 556 | + |
| 557 | + def forward( |
| 558 | + self, |
| 559 | + features, |
| 560 | + edge_info: EdgeInfo, |
| 561 | + rel_dist, |
| 562 | + basis, |
| 563 | + mask = None |
| 564 | + ): |
| 565 | + features = self.prenorm(features) |
| 566 | + |
| 567 | + intermediate = self.to_attn_and_v(features, edge_info, rel_dist, basis) |
| 568 | + |
| 569 | + *attn_branch_type0, value_branch_type0 = intermediate[0].split(self.intermediate_type0_split, dim = -2) |
| 570 | + |
| 571 | + intermediate[0] = value_branch_type0 |
| 572 | + |
| 573 | + # process the attention branch |
| 574 | + |
| 575 | + attentions = [] |
| 576 | + |
| 577 | + for fn, attn_intermediate, scale in zip(self.to_attn_logits, attn_branch_type0, self.scale): |
| 578 | + attn_intermediate = rearrange(attn_intermediate, '... 1 -> ...') |
| 579 | + attn_logits = fn(attn_intermediate) |
| 580 | + attn_logits = attn_logits * scale |
| 581 | + attn = attn_logits.softmax(dim = -1) # (batch, source, target, heads) |
| 582 | + attentions.append(attn) |
| 583 | + |
| 584 | + # process values branch |
| 585 | + |
| 586 | + values = self.to_values(intermediate) |
| 587 | + |
| 588 | + # aggregate values with attention matrix |
| 589 | + |
| 590 | + outputs = {} |
| 591 | + |
| 592 | + for degree, (attn, value, h) in enumerate(zip(attentions, values.values(), self.heads)): |
| 593 | + value = rearrange(value, 'b i j (h d) m -> b i j h d m', h = h) |
| 594 | + out = einsum('b i j h, b i j h d m -> b i h d m', attn, value) |
| 595 | + out = rearrange(out, 'b i h d m -> b i (h d) m') |
| 596 | + |
| 597 | + outputs[degree] = out |
| 598 | + |
| 599 | + return self.to_out(outputs) |
493 | 600 |
|
494 | 601 | # main class |
495 | 602 |
|
@@ -518,7 +625,8 @@ def __init__( |
518 | 625 | linear_out = True, |
519 | 626 | embedding_grad_frac = 0.5, |
520 | 627 | single_headed_kv = False, # whether to do single headed key/values for dot product attention, to save on memory and compute |
521 | | - ff_include_htype_norms = False # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms |
| 628 | + ff_include_htype_norms = False, # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms |
| 629 | + dot_product_attention = True |
522 | 630 | ): |
523 | 631 | super().__init__() |
524 | 632 |
|
@@ -587,9 +695,11 @@ def __init__( |
587 | 695 |
|
588 | 696 | self.layers = nn.ModuleList([]) |
589 | 697 |
|
| 698 | + attention_klass = DotProductAttention if dot_product_attention else MLPAttention |
| 699 | + |
590 | 700 | for ind in range(depth): |
591 | 701 | self.layers.append(nn.ModuleList([ |
592 | | - DotProductAttention(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv), |
| 702 | + attention_klass(self.dim, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, splits = splits, single_headed_kv = single_headed_kv), |
593 | 703 | FeedForward(self.dim, include_htype_norms = ff_include_htype_norms) |
594 | 704 | ])) |
595 | 705 |
|
@@ -709,7 +819,8 @@ def forward( |
709 | 819 |
|
710 | 820 | # main logic |
711 | 821 |
|
712 | | - edge_info = (neighbor_indices, neighbor_mask, edges) |
| 822 | + edge_info = EdgeInfo(neighbor_indices, neighbor_mask, edges) |
| 823 | + |
713 | 824 | x = feats |
714 | 825 |
|
715 | 826 | # project in |
|
0 commit comments