Skip to content

Commit 14ff57c

Browse files
committed
serving draft
1 parent e566f95 commit 14ff57c

File tree

11 files changed

+3025
-93
lines changed

11 files changed

+3025
-93
lines changed

llama3/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.txt
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import dataclasses
2+
from functools import partial
3+
import math
4+
from typing import Any
5+
6+
import jax
7+
import jax.numpy as jnp
8+
9+
try:
10+
from jax.experimental.shard import auto_axes
11+
except ModuleNotFoundError:
12+
from jax.sharding import auto_axes
13+
14+
QuantArray, PyTree = Any, Any
15+
16+
KVCache = Any
17+
next_power_of_2 = lambda x: 2 ** math.ceil(math.log2(max(x, 1)))
18+
_pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)])
19+
20+
21+
def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
22+
"From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list."
23+
24+
_split = lambda x: jnp.split(x, x.shape[0], axis=0)
25+
max_seq_len = max([jax.tree.leaves(kv)[0].shape[time_axis] for kv in kv_list])
26+
kv_list = [jax.tree.map(lambda x: _pad_after(x, max_seq_len, time_axis), kv) for kv in kv_list]
27+
out = [None for _ in kv_list[0]]
28+
for i, c in enumerate(kv_list[0]):
29+
els = [[_split(z) for z in jax.tree.leaves(kv[i])] for kv in kv_list] # [B, R_flat, L]
30+
els = jax.tree.map(lambda *xs: jnp.concatenate(xs, axis=0), *els) # [R_flat, L]
31+
leaves_list = list(zip(*els)) # [L, R_flat]
32+
out[i] = [jax.tree.unflatten(jax.tree.structure(c), leaves) for leaves in leaves_list] # [L, R]
33+
return tuple(out), max_seq_len
34+
35+
36+
########################################################################################################################
37+
# KV cache utils #######################################################################################################
38+
########################################################################################################################
39+
40+
41+
@partial(jax.jit, donate_argnames=("cache",))
42+
def _kvcache_update_cache(
43+
cache: KVCache,
44+
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
45+
batch_idxs: list[jax.Array],
46+
actual_lens: list[jax.Array],
47+
update_mask: list[bool] | None = None,
48+
):
49+
assert len(kvs) == len(batch_idxs) == len(actual_lens)
50+
batch_idxs, actual_lens, update_mask = jnp.array(batch_idxs), jnp.array(actual_lens), jnp.array(update_mask)
51+
uninitialized_cache = cache.iter < 0
52+
start_time = jnp.where(
53+
uninitialized_cache, jnp.max(actual_lens) - actual_lens, (cache.iter - actual_lens) % cache.size
54+
)
55+
batch_idxs = jnp.where(update_mask, batch_idxs, 2**30) # send masked to nowhere
56+
kvs, max_seq_len = _transpose_attention_tree(kvs, time_axis=cache.time_axis)
57+
time_indices = (jnp.arange(max_seq_len)[None, :] + start_time[:, None]) % cache.size
58+
59+
def _update_element(x, u):
60+
update_permute = [0, cache.time_axis] + [i for i in range(u.ndim) if i not in (0, cache.time_axis)]
61+
# time_dim, batch_dim = update_permute.pop(cache.time_axis), update_permute.pop(0) # first pop time_axis
62+
# update_permute = [batch_dim, time_dim] + update_permute
63+
return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop")
64+
65+
cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs)
66+
cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop")
67+
cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter)
68+
return dataclasses.replace(cache, k=cache_k, v=cache_v, iter=cache_iter, starts=cache_starts)
69+
70+
71+
def kvcache_update_cache(
72+
cache: KVCache,
73+
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
74+
batch_idxs: list[jax.Array],
75+
actual_lens: list[jax.Array],
76+
):
77+
pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4
78+
update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)]
79+
kvs = kvs + [kvs[-1]] * pad_len
80+
batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len
81+
return _kvcache_update_cache(cache, kvs, batch_idxs, actual_lens, update_mask)
82+
83+
84+
@jax.jit
85+
def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array):
86+
shift = -cache.starts[batch_idx]
87+
assert cache.time_axis > 0
88+
kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), (cache.k, cache.v))
89+
kvs = (jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[0]), jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[1]))
90+
true_len = cache.fill_len()[batch_idx]
91+
return kvs, true_len
92+
93+
94+
########################################################################################################################
95+
# Paged KV cache utils #################################################################################################
96+
########################################################################################################################
97+
98+
PagedKVCache = Any
99+
100+
101+
def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | None = None):
102+
if proposal_pages is not None:
103+
assert proposal_pages.size == k
104+
proposal_mask = free_pages[proposal_pages]
105+
indicies = jnp.where(~proposal_mask, jnp.cumsum(~proposal_mask, axis=-1) - 1, k - 1)
106+
newly_free_pages = free_pages.at[jnp.where(proposal_mask, proposal_pages, 2**30)].set(False, mode="drop")
107+
return jnp.where(proposal_mask, proposal_pages, jax.lax.top_k(newly_free_pages, k)[1][indicies])
108+
else:
109+
return jax.lax.top_k(free_pages, k)[1]
110+
111+
112+
def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int):
113+
key_heads = cache.k[layer_idx].shape[0]
114+
assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1)
115+
needs_next_page = (cache.lengths % cache.page_size) == 0
116+
page_table_idx = cache.lengths // cache.page_size
117+
current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0]
118+
avg_pages_per_batch_entry = round(cache.k[layer_idx].shape[0] / cache.batch_size)
119+
even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry
120+
proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1)
121+
free_pages = _find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages)
122+
page_cursor = jnp.where(needs_next_page, free_pages, current_page_cursor)
123+
124+
inpage_cursor = cache.lengths % cache.page_size
125+
126+
new_lengths = cache.lengths + 1
127+
# for batch index update the target slice is (heads, i, j, head_dim)
128+
# so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim)
129+
_update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1))
130+
cache.k[layer_idx], cache.v[layer_idx] = jax.tree.map(_update, (cache.k[layer_idx], cache.v[layer_idx]), (k, v))
131+
132+
batch_idx = jnp.arange(cache.batch_size)
133+
new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor)
134+
135+
new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop")
136+
new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages)
137+
return cache.k[layer_idx], cache.v[layer_idx], new_state
138+
139+
140+
def paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int):
141+
repl_sharding = jax.typeof(cache.lengths).sharding
142+
kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, (cache.k[layer_idx], cache.v[layer_idx]))
143+
sharding = (*kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding))
144+
return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, k, v)
145+
146+
147+
@partial(jax.jit, donate_argnames=("cache",))
148+
def _batch_paged_update_sequences(
149+
cache: PagedKVCache,
150+
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
151+
batch_idxs: list[jax.Array],
152+
actual_lens: list[jax.Array],
153+
update_mask: list[bool] | None = None,
154+
) -> PagedKVCache:
155+
update_mask = jnp.array(update_mask)
156+
batch_idxs = jnp.where(update_mask, jnp.array(batch_idxs), 2**30) # send masked to nowhere
157+
actual_lens = jnp.minimum(jnp.array(actual_lens), jnp.array([jax.tree.leaves(kv)[0].shape[2] for kv in kvs]))
158+
159+
kvs, max_seq_len = _transpose_attention_tree(
160+
kvs, time_axis=2
161+
) # undo stacking along the layer dimension for transit
162+
163+
# clear existing pages
164+
actual_page_num = jnp.rint(jnp.ceil(cache.lengths[batch_idxs] / cache.page_size)).astype(jnp.int32)
165+
occupied_mask = jnp.arange(cache.block_tables.shape[-1])[None, :] < actual_page_num[:, None]
166+
indices_to_free = jnp.where(occupied_mask & update_mask[:, None], cache.block_tables[batch_idxs, :], 2**30)
167+
new_free_pages = cache.free_pages.at[indices_to_free.reshape(-1)].set(True, mode="drop")
168+
169+
# get the length of the new sequence and find empty pages for the new sequence ideally contiguous
170+
upper_bound_page_num = math.ceil(max_seq_len / cache.page_size)
171+
actual_page_num = jnp.rint(jnp.ceil(actual_lens / cache.page_size)).astype(jnp.int32)
172+
avg_pages_per_batch_entry = round(jax.tree.leaves(cache)[0].shape[1] / cache.batch_size)
173+
proposal_pages = batch_idxs[:, None] * avg_pages_per_batch_entry + jnp.arange(upper_bound_page_num)[None, :]
174+
pages_idx = _find_empty_pages(
175+
new_free_pages, upper_bound_page_num * batch_idxs.size, proposal_pages=proposal_pages.reshape(-1)
176+
).reshape(proposal_pages.shape)
177+
pages_arange = jnp.arange(upper_bound_page_num)
178+
pages_idx = jnp.where(update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_idx, 2**30)
179+
180+
# reshape the new pages for insertion and possibly quantize
181+
b, h, s, e = jax.tree.leaves(kvs)[0].shape
182+
kvs = jax.tree.map(lambda x: x.reshape((b, h, s // cache.page_size, cache.page_size) + x.shape[3:]), kvs)
183+
184+
def _update_element(x, u):
185+
# we're updating (batch, page_entries) with (BATCH, heads, PAGE, page_size, head_dim), so (BATCH, PAGE) go first
186+
update_permute = [1, 0, 2] + [i for i in range(u.ndim) if i not in (0, 1, 2)]
187+
return x.at[:, pages_idx, ...].set(u.transpose(update_permute), mode="drop")
188+
189+
cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs)
190+
block_tables_idx = jnp.where(
191+
update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_arange[None, :], 2**30
192+
)
193+
new_block_tables = cache.block_tables.at[batch_idxs[:, None], block_tables_idx].set(pages_idx, mode="drop")
194+
new_free_pages = new_free_pages.at[pages_idx.reshape(-1)].set(False, mode="drop")
195+
new_lengths = cache.lengths.at[batch_idxs].set(actual_lens, mode="drop")
196+
return dataclasses.replace(
197+
cache, k=cache_k, v=cache_v, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages
198+
)
199+
200+
201+
def batch_paged_update_sequences(
202+
cache: KVCache,
203+
kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]],
204+
batch_idxs: list[jax.Array],
205+
actual_lens: list[jax.Array],
206+
):
207+
pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4
208+
update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)]
209+
kvs = kvs + [kvs[-1]] * pad_len
210+
batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len
211+
return _batch_paged_update_sequences(cache, kvs, batch_idxs, actual_lens, update_mask)
212+
213+
214+
@partial(jax.jit, static_argnames=("max_seq_len",))
215+
def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len: int = -1):
216+
true_len = cache.fill_len()[batch_idx]
217+
max_seq_len = max_seq_len if max_seq_len > 0 else cache.page_size * cache.block_tables.shape[-1]
218+
max_seq_len = min(max_seq_len, cache.page_size * cache.block_tables.shape[-1]) # cache capacity
219+
page_indices = cache.block_tables[batch_idx, : round(math.ceil(max_seq_len / cache.page_size))]
220+
_reshape_out = lambda x: x.reshape((x.shape[0], max_seq_len) + x.shape[3:])
221+
mask = jnp.arange(max_seq_len) < true_len
222+
_get = lambda x: jnp.where(mask[None, :, *([None] * (x.ndim - 3))], _reshape_out(x[:, page_indices, ...]), 0)
223+
224+
# stack along layer dimensions for transit
225+
kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, (cache.k, cache.v)))
226+
return kvs, true_len

0 commit comments

Comments
 (0)