Skip to content

Commit f64f415

Browse files
committed
generalize attention utils
1 parent 14ff57c commit f64f415

File tree

6 files changed

+75
-59
lines changed

6 files changed

+75
-59
lines changed

llama3/llama3_jax/attention_cache_utils.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
_pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)])
1919

2020

21+
def safe_zip(*args):
22+
if len(args) == 0:
23+
return []
24+
assert all(len(arg) == len(args[0]) for arg in args)
25+
return zip(*args)
26+
27+
2128
def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
2229
"From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list."
2330

@@ -28,7 +35,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
2835
for i, c in enumerate(kv_list[0]):
2936
els = [[_split(z) for z in jax.tree.leaves(kv[i])] for kv in kv_list] # [B, R_flat, L]
3037
els = jax.tree.map(lambda *xs: jnp.concatenate(xs, axis=0), *els) # [R_flat, L]
31-
leaves_list = list(zip(*els)) # [L, R_flat]
38+
leaves_list = list(safe_zip(*els)) # [L, R_flat]
3239
out[i] = [jax.tree.unflatten(jax.tree.structure(c), leaves) for leaves in leaves_list] # [L, R]
3340
return tuple(out), max_seq_len
3441

@@ -41,7 +48,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
4148
@partial(jax.jit, donate_argnames=("cache",))
4249
def _kvcache_update_cache(
4350
cache: KVCache,
44-
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
51+
kvs: list[tuple[jax.Array | QuantArray, ...]],
4552
batch_idxs: list[jax.Array],
4653
actual_lens: list[jax.Array],
4754
update_mask: list[bool] | None = None,
@@ -62,15 +69,17 @@ def _update_element(x, u):
6269
# update_permute = [batch_dim, time_dim] + update_permute
6370
return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop")
6471

65-
cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs)
72+
cache_kvs = jax.tree.map(_update_element, cache.buffers, kvs)
6673
cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop")
6774
cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter)
68-
return dataclasses.replace(cache, k=cache_k, v=cache_v, iter=cache_iter, starts=cache_starts)
75+
76+
buffer_names = [field.name for field in dataclasses.fields(cache)][:len(cache_kvs)]
77+
return dataclasses.replace(cache, **dict(safe_zip(buffer_names, cache_kvs)), iter=cache_iter, starts=cache_starts)
6978

7079

7180
def kvcache_update_cache(
7281
cache: KVCache,
73-
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
82+
kvs: list[tuple[jax.Array | QuantArray, ...]],
7483
batch_idxs: list[jax.Array],
7584
actual_lens: list[jax.Array],
7685
):
@@ -85,7 +94,7 @@ def kvcache_update_cache(
8594
def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array):
8695
shift = -cache.starts[batch_idx]
8796
assert cache.time_axis > 0
88-
kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), (cache.k, cache.v))
97+
kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), cache.buffers)
8998
kvs = (jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[0]), jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[1]))
9099
true_len = cache.fill_len()[batch_idx]
91100
return kvs, true_len
@@ -109,13 +118,13 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array |
109118
return jax.lax.top_k(free_pages, k)[1]
110119

111120

112-
def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int):
113-
key_heads = cache.k[layer_idx].shape[0]
114-
assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1)
121+
def _paged_update_slice(cache: PagedKVCache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int):
122+
#key_heads = cache.buffers[0][layer_idx].shape[0]
123+
#assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) # TODO write this generically
115124
needs_next_page = (cache.lengths % cache.page_size) == 0
116125
page_table_idx = cache.lengths // cache.page_size
117126
current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0]
118-
avg_pages_per_batch_entry = round(cache.k[layer_idx].shape[0] / cache.batch_size)
127+
avg_pages_per_batch_entry = round(cache.buffers[0][layer_idx].shape[0] / cache.batch_size)
119128
even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry
120129
proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1)
121130
free_pages = _find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages)
@@ -127,27 +136,28 @@ def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.A
127136
# for batch index update the target slice is (heads, i, j, head_dim)
128137
# so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim)
129138
_update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1))
130-
cache.k[layer_idx], cache.v[layer_idx] = jax.tree.map(_update, (cache.k[layer_idx], cache.v[layer_idx]), (k, v))
139+
for buffer, new_buffer in safe_zip(cache.buffers, kv):
140+
buffer[layer_idx] = jax.tree.map(_update, buffer[layer_idx], new_buffer)
131141

132142
batch_idx = jnp.arange(cache.batch_size)
133143
new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor)
134144

135145
new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop")
136146
new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages)
137-
return cache.k[layer_idx], cache.v[layer_idx], new_state
147+
return tuple(buffer[layer_idx] for buffer in cache.buffers), new_state
138148

139149

140-
def paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int):
150+
def paged_update_slice(cache: PagedKVCache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int):
141151
repl_sharding = jax.typeof(cache.lengths).sharding
142-
kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, (cache.k[layer_idx], cache.v[layer_idx]))
143-
sharding = (*kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding))
144-
return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, k, v)
152+
kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, tuple(buffer[layer_idx] for buffer in cache.buffers))
153+
sharding = (kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding))
154+
return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, kv)
145155

146156

147157
@partial(jax.jit, donate_argnames=("cache",))
148158
def _batch_paged_update_sequences(
149159
cache: PagedKVCache,
150-
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
160+
kvs: list[tuple[jax.Array | QuantArray, ...]],
151161
batch_idxs: list[jax.Array],
152162
actual_lens: list[jax.Array],
153163
update_mask: list[bool] | None = None,
@@ -156,9 +166,7 @@ def _batch_paged_update_sequences(
156166
batch_idxs = jnp.where(update_mask, jnp.array(batch_idxs), 2**30) # send masked to nowhere
157167
actual_lens = jnp.minimum(jnp.array(actual_lens), jnp.array([jax.tree.leaves(kv)[0].shape[2] for kv in kvs]))
158168

159-
kvs, max_seq_len = _transpose_attention_tree(
160-
kvs, time_axis=2
161-
) # undo stacking along the layer dimension for transit
169+
kvs, max_seq_len = _transpose_attention_tree(kvs, time_axis=2) # undo stack along layer dimension in transit
162170

163171
# clear existing pages
164172
actual_page_num = jnp.rint(jnp.ceil(cache.lengths[batch_idxs] / cache.page_size)).astype(jnp.int32)
@@ -186,21 +194,23 @@ def _update_element(x, u):
186194
update_permute = [1, 0, 2] + [i for i in range(u.ndim) if i not in (0, 1, 2)]
187195
return x.at[:, pages_idx, ...].set(u.transpose(update_permute), mode="drop")
188196

189-
cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs)
197+
new_buffers = jax.tree.map(_update_element, cache.buffers, kvs)
190198
block_tables_idx = jnp.where(
191199
update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_arange[None, :], 2**30
192200
)
193201
new_block_tables = cache.block_tables.at[batch_idxs[:, None], block_tables_idx].set(pages_idx, mode="drop")
194202
new_free_pages = new_free_pages.at[pages_idx.reshape(-1)].set(False, mode="drop")
195203
new_lengths = cache.lengths.at[batch_idxs].set(actual_lens, mode="drop")
204+
205+
named_buffers = dict(zip([field.name for field in dataclasses.fields(cache)][:len(new_buffers)], new_buffers))
196206
return dataclasses.replace(
197-
cache, k=cache_k, v=cache_v, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages
207+
cache, **named_buffers, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages
198208
)
199209

200210

201211
def batch_paged_update_sequences(
202212
cache: KVCache,
203-
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
213+
kvs: list[tuple[jax.Array | QuantArray, ...]],
204214
batch_idxs: list[jax.Array],
205215
actual_lens: list[jax.Array],
206216
):
@@ -222,5 +232,5 @@ def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len
222232
_get = lambda x: jnp.where(mask[None, :, *([None] * (x.ndim - 3))], _reshape_out(x[:, page_indices, ...]), 0)
223233

224234
# stack along layer dimensions for transit
225-
kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, (cache.k, cache.v)))
235+
kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, cache.buffers))
226236
return kvs, true_len

llama3/llama3_jax/model.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
except ModuleNotFoundError:
3838
from jax.sharding import auto_axes as _auto_axes, reshard
3939
from jax.experimental.pallas.ops.gpu import paged_attention
40+
from etils import epath
4041

4142
from . import ragged_attention
4243
from . import attention_cache_utils
@@ -216,7 +217,7 @@ class ArrayInfo:
216217
_count_left_padding = lambda ids, pad_id=0: auto_axes(
217218
lambda ids: jnp.sum(jnp.cumsum(ids != pad_id, axis=-1) == 0, axis=-1), out_sharding=P(None)
218219
)(ids)
219-
_length_minus_padding = lambda segment_ids: auto_axes(
220+
_length_minus_right_padding = lambda segment_ids: auto_axes(
220221
lambda segment_ids: jnp.sum(jnp.cumsum(jnp.flip(segment_ids != 0, -1), axis=-1) > 0, -1), out_sharding=P(None)
221222
)(segment_ids)
222223

@@ -411,7 +412,7 @@ class KVCache(_Init):
411412
iter: jax.Array # [] # sequences are right-aligned for slice update performance
412413
starts: jax.Array # [batch_size] # sequences are right-aligned, we need start indices
413414
batch_size: int = 0
414-
size: int = 0
415+
size: int = 2 ** 30
415416
time_axis: int = 2
416417

417418
@classmethod
@@ -428,6 +429,7 @@ def abstract(cls, cfg: Config, batch_size: int):
428429
# -1 means unintialized since iter (cursor) must be 0 <= iter < len - 1
429430
iter=ArrayInfo((), jnp.int32, (), jax.nn.initializers.constant(-1)),
430431
starts=ArrayInfo((batch_size,), jnp.int32, ("batch",), jax.nn.initializers.zeros),
432+
size=cfg.max_seq_len,
431433
)
432434
if cfg.quant_cache:
433435
_quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype, zero_init=True)
@@ -447,8 +449,11 @@ def abstract(cls, cfg: Config, batch_size: int):
447449
return cache
448450

449451
def fill_len(self) -> jax.Array:
450-
length = jnp.where(self.iter > self.starts, self.iter - self.starts, self.size + self.iter - self.starts)
451-
return jnp.where(self.iter >= 0, length, 0)
452+
return jnp.where(self.iter >= 0, (self.iter - self.starts) % self.size, 0)
453+
454+
@property
455+
def buffers(self) -> tuple[jax.Array, ...]:
456+
return (self.k, self.v)
452457

453458
update_slice = None
454459
insert_sequences = staticmethod(attention_cache_utils.kvcache_update_cache)
@@ -463,7 +468,7 @@ class PagedKVCache(_Init):
463468
block_tables: jax.Array # [batch_size, pages_per_seq]
464469
free_pages: jax.Array # [total_num_pages]
465470
batch_size: int = 0
466-
size: int = 2**31 - 1
471+
size: int = 2**30
467472
page_size: int = 0
468473

469474
@classmethod
@@ -501,6 +506,10 @@ def abstract(cls, cfg: "Config", batch_size: int, total_num_pages: int, page_siz
501506
def fill_len(self) -> jax.Array:
502507
return self.lengths
503508

509+
@property
510+
def buffers(self) -> tuple[jax.Array, ...]:
511+
return (self.k, self.v)
512+
504513
update_slice = staticmethod(attention_cache_utils.paged_update_slice)
505514
insert_sequences = staticmethod(attention_cache_utils.batch_paged_update_sequences)
506515
get_sequence = staticmethod(attention_cache_utils.batch_paged_get_entry)
@@ -807,12 +816,9 @@ def attention_block(
807816
q, k = apply_rotary_embedding(q, sin, cos), apply_rotary_embedding(k, sin, cos)
808817

809818
if cfg.quant_cache:
810-
k = QuantArray(
811-
*quantize(k, -1, scale_dtype=cfg.quant_scale_dtype), out_scaling=True, scale_expand_dims=(-2, -3)
812-
)
813-
v = QuantArray(
814-
*quantize(v, -1, scale_dtype=cfg.quant_scale_dtype), out_scaling=False, scale_expand_dims=(-2, -3)
815-
)
819+
_quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype)
820+
k = QuantArray(*_quantize(k), out_scaling=True, scale_expand_dims=(-2, -3))
821+
v = QuantArray(*_quantize(v), out_scaling=False, scale_expand_dims=(-2, -3))
816822

817823
with jax.named_scope("cache_update"):
818824
paged_state, starts = None, None
@@ -825,23 +831,21 @@ def attention_block(
825831
) % cache.size # [B, T]
826832

827833
q_segment_ids = jnp.where(segment_ids != 0, 1, 0)
828-
incremental_position = jnp.max(_length_minus_padding(segment_ids))
834+
incremental_position = jnp.max(_length_minus_right_padding(segment_ids))
829835
# i.e. valid below where we've written things [B, T]
830-
kv_segment_ids = (
831-
(time_indices >= 0) & (time_indices < cache.fill_len()[:, None] + incremental_position)
832-
).astype(jnp.int32)
833-
q_offset = cache.fill_len() - _count_left_padding(segment_ids)
836+
kv_segment_ids = (time_indices >= 0) & (time_indices < cache.fill_len()[:, None] + incremental_position)
837+
q_offset = cache.fill_len() - _count_left_padding(segment_ids, 0) # 0 is the pad "token" for segment_ids
834838
starts, lengths = cache.starts, cache.fill_len()
835839
cache_updates = (k, v)
836840
elif is_type(cache, PagedKVCache):
837841
cache: PagedKVCache
838-
k, v, paged_state = PagedKVCache.update_slice(cache, k=k, v=v, layer_idx=idx)
842+
(k, v), paged_state = PagedKVCache.update_slice(cache, (k, v), layer_idx=idx)
839843
cache_updates = (k, v, paged_state)
840844
else:
841845
# this supports prefill only; no support for a ring cache buffer here
842846
q_segment_ids, kv_segment_ids = segment_ids, segment_ids
843847
q_offset = jnp.zeros(x.shape[0], dtype=jnp.int32)
844-
starts, lengths = _count_left_padding(segment_ids, 0), _length_minus_padding(kv_segment_ids)
848+
starts, lengths = _count_left_padding(segment_ids, 0), _length_minus_right_padding(kv_segment_ids)
845849
cache_updates = (k, v)
846850

847851
# Compute attention
@@ -931,15 +935,12 @@ def forward(
931935
x, cache_updates = forward_layer(x, segment_ids, layer, sin, cos, idx, cfg, cache)
932936
all_cache_updates.append(cache_updates)
933937

934-
# Final layer norm.
935-
x = rms_norm(x, weights.gamma_final)
936-
937-
# Project to vocabulary size
938-
logits = einsum("btd,dv->btv", x, weights.lm_head)
938+
x = rms_norm(x, weights.gamma_final) # Final layer norm.
939+
logits = einsum("btd,dv->btv", x, weights.lm_head) # Project to vocabulary size
939940

940941
if is_type(cache, KVCache):
941942
cache.k, cache.v = [z[0] for z in all_cache_updates], [z[1] for z in all_cache_updates]
942-
new_iter = (jnp.maximum(0, cache.iter) + jnp.max(_length_minus_padding(segment_ids))) % cache.size
943+
new_iter = (jnp.maximum(0, cache.iter) + jnp.max(_length_minus_right_padding(segment_ids))) % cache.size
943944
cache = dataclasses.replace(cache, iter=new_iter)
944945
return logits, cache
945946
elif is_type(cache, PagedKVCache):

llama3/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ dependencies = [
1919
#"datasets",
2020
"gcsfs",
2121
"etils",
22+
"importlib_resources",
23+
"absl-py",
2224
]
2325

2426
# we don't need CUDA torch

llama3/tests/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ def test_model_init(self, quant):
5454
@parameterized.product(quant=[False, True])
5555
def test_cache_init(self, quant):
5656
cfg = dataclasses.replace(self.small_cfg, quant_cache=quant)
57-
cache = l3jax.KVCache.init(random.key(0), cfg, 2, cfg.max_seq_len)
57+
cache = l3jax.KVCache.init(random.key(0), cfg, 2)
5858
del cache
5959

6060
@parameterized.product(quant_weights=[False, True], quant_cache=[True, False])
6161
def test_prefill_decode(self, quant_weights, quant_cache):
6262
cfg = dataclasses.replace(self.small_cfg, quant_layer=quant_weights, quant_cache=quant_cache)
6363
tokens = jnp.ones((1, 32), dtype=jnp.int32)
6464
weights = l3jax.Weights.init(random.key(0), cfg)
65-
cache = l3jax.KVCache.init(random.key(0), cfg, tokens.shape[0], cfg.max_seq_len)
65+
cache = l3jax.KVCache.init(random.key(0), cfg, tokens.shape[0])
6666
with use_mesh(cfg.mesh):
6767
max_tokens, _, cache = l3jax.prefill(tokens, weights, cache, cfg)
6868
next_tokens = max_tokens[:, :-1]

0 commit comments

Comments
 (0)