You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix a bug with mlp attention (attending on wrong dimension), also make sure single headed values work with mlp attention. experiments still dont show a big gain over dot product attention
value=rearrange(value, 'b i j (h d) m -> b i j h d m', h=h)
594
-
out=einsum('b i j h, b i j h d m -> b i h d m', attn, value)
595
-
out=rearrange(out, 'b i h d m -> b i (h d) m')
601
+
ifnotone_headed_kv:
602
+
value=rearrange(value, 'b i j (h d) m -> b i j h d m', h=h)
596
603
604
+
out=einsum(f'b i j h, {value_einsum_eq} -> b i h d m', attn, value)
605
+
out=rearrange(out, 'b i h d m -> b i (h d) m')
597
606
outputs[degree] =out
598
607
608
+
# combine heads out
609
+
599
610
returnself.to_out(outputs)
600
611
601
612
# main class
@@ -626,7 +637,8 @@ def __init__(
626
637
embedding_grad_frac=0.5,
627
638
single_headed_kv=False, # whether to do single headed key/values for dot product attention, to save on memory and compute
628
639
ff_include_htype_norms=False, # whether for type0 projection to also involve norms of all higher types, in feedforward first projection. this allows for all higher types to be gated by other type norms
0 commit comments