Skip to content

Commit dbc16f4

Browse files
committed
buy into new research that attention needs to be able to attend to nothing
1 parent d18cddd commit dbc16f4

File tree

3 files changed

+73
-4
lines changed

3 files changed

+73
-4
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,14 @@ $ python denoise.py
267267
year = {2017}
268268
}
269269
```
270+
271+
```bibtex
272+
@article{Bondarenko2023QuantizableTR,
273+
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
274+
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
275+
journal = {ArXiv},
276+
year = {2023},
277+
volume = {abs/2306.12929},
278+
url = {https://api.semanticscholar.org/CorpusID:259224568}
279+
}
280+
```

equiformer_pytorch/equiformer_pytorch.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,8 @@ def __init__(
551551
radial_hidden_dim = 64,
552552
splits = 4,
553553
num_linear_attn_heads = 0,
554-
init_out_zero = True
554+
init_out_zero = True,
555+
gate_attn_head_outputs = True
555556
):
556557
super().__init__()
557558
num_degrees = len(fiber)
@@ -587,6 +588,21 @@ def __init__(
587588
self.linear_attn = LinearAttention(degree_zero_dim, dim_head = dim_head[0], heads = num_linear_attn_heads)
588589
hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + dim_head[0] * num_linear_attn_heads)
589590

591+
# gating heads across all degree outputs
592+
# to allow for attending to nothing
593+
594+
self.attn_head_gates = None
595+
596+
if gate_attn_head_outputs:
597+
self.attn_head_gates = nn.Sequential(
598+
Rearrange('... d 1 -> ... d'),
599+
nn.Linear(fiber[0], sum(heads)),
600+
nn.Sigmoid(),
601+
Rearrange('... n h -> ... h n 1 1')
602+
)
603+
604+
# combine heads
605+
590606
self.to_out = Linear(hidden_fiber, fiber)
591607

592608
if init_out_zero:
@@ -615,6 +631,8 @@ def forward(
615631

616632
features = self.prenorm(features)
617633

634+
# generate queries, keys, values
635+
618636
queries = self.to_q(features)
619637

620638
keyvalues = self.to_kv(
@@ -625,11 +643,20 @@ def forward(
625643
D = D
626644
)
627645

646+
# create gates
647+
648+
gates = (None,) * len(self.heads)
649+
650+
if exists(self.attn_head_gates):
651+
gates = self.attn_head_gates(features[0]).split(self.heads, dim = -4)
652+
653+
# single headed vs not
654+
628655
kv_einsum_eq = 'b h i j d m' if not one_head_kv else 'b i j d m'
629656

630657
outputs = {}
631658

632-
for degree, h, scale in zip(features.keys(), self.heads, self.scale):
659+
for degree, gate, h, scale in zip(features.keys(), gates, self.heads, self.scale):
633660
is_degree_zero = degree == 0
634661

635662
q, kv = map(lambda t: t[degree], (queries, keyvalues))
@@ -657,6 +684,10 @@ def forward(
657684

658685
attn = sim.softmax(dim = -1)
659686
out = einsum(attn, v, f'b h i j, {kv_einsum_eq} -> b h i d m')
687+
688+
if exists(gate):
689+
out = out * gate
690+
660691
outputs[degree] = rearrange(out, 'b h n d m -> b n (h d) m')
661692

662693
if self.has_linear_attn:
@@ -681,6 +712,7 @@ def __init__(
681712
radial_hidden_dim = 16,
682713
num_linear_attn_heads = 0,
683714
init_out_zero = True,
715+
gate_attn_head_outputs = True,
684716
**kwargs
685717
):
686718
super().__init__()
@@ -748,6 +780,19 @@ def __init__(
748780
self.linear_attn = LinearAttention(degree_zero_dim, dim_head = dim_head[0], heads = num_linear_attn_heads)
749781
hidden_fiber = tuple_set_at_index(hidden_fiber, 0, hidden_fiber[0] + dim_head[0] * num_linear_attn_heads)
750782

783+
# gating heads across all degree outputs
784+
# to allow for attending to nothing
785+
786+
self.attn_head_gates = None
787+
788+
if gate_attn_head_outputs:
789+
self.attn_head_gates = nn.Sequential(
790+
Rearrange('... d 1 -> ... d'),
791+
nn.Linear(fiber[0], sum(heads)),
792+
nn.Sigmoid(),
793+
Rearrange('... h -> ... h 1 1')
794+
)
795+
751796
# combining heads and projection out
752797

753798
self.to_out = Linear(hidden_fiber, fiber)
@@ -789,6 +834,13 @@ def forward(
789834

790835
intermediate[0] = value_branch_type0
791836

837+
# create gates
838+
839+
gates = (None,) * len(self.heads)
840+
841+
if exists(self.attn_head_gates):
842+
gates = self.attn_head_gates(features[0]).split(self.heads, dim = -3)
843+
792844
# process the attention branch
793845

794846
attentions = []
@@ -814,11 +866,15 @@ def forward(
814866

815867
value_einsum_eq = 'b i j h d m' if not one_headed_kv else 'b i j d m'
816868

817-
for degree, (attn, value, h) in enumerate(zip(attentions, values.values(), self.heads)):
869+
for degree, (attn, value, gate, h) in enumerate(zip(attentions, values.values(), gates, self.heads)):
818870
if not one_headed_kv:
819871
value = rearrange(value, 'b i j (h d) m -> b i j h d m', h = h)
820872

821873
out = einsum(attn, value, f'b i j h, {value_einsum_eq} -> b i h d m')
874+
875+
if exists(gate):
876+
out = out * gate
877+
822878
out = rearrange(out, 'b i h d m -> b i (h d) m')
823879
outputs[degree] = out
824880

@@ -863,6 +919,7 @@ def __init__(
863919
l2_dist_attention = True, # turn to False to use MLP attention as proposed in paper, but dot product attention with -cdist similarity is still far better, and i haven't even rotated distances (rotary embeddings) into the type 0 features yet
864920
reversible = False, # turns on reversible networks, to scale depth without incurring depth times memory cost
865921
attend_sparse_neighbors = False, # ability to accept an adjacency matrix
922+
gate_attn_head_outputs = True, # gate each attention head output, to allow for attending to nothing
866923
num_adj_degrees_embed = None,
867924
adj_dim = 0,
868925
max_sparse_neighbors = float('inf'),
@@ -958,6 +1015,7 @@ def __init__(
9581015
edge_dim = edge_dim,
9591016
single_headed_kv = single_headed_kv,
9601017
radial_hidden_dim = radial_hidden_dim,
1018+
gate_attn_head_outputs = gate_attn_head_outputs,
9611019
**kwargs
9621020
),
9631021
FeedForward(self.dim, include_htype_norms = ff_include_htype_norms)

equiformer_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.10'
1+
__version__ = '0.4.0'

0 commit comments

Comments
 (0)