|
| 1 | +import dataclasses |
| 2 | +from functools import partial |
| 3 | +import math |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | + |
| 9 | +try: |
| 10 | + from jax.experimental.shard import auto_axes |
| 11 | +except ModuleNotFoundError: |
| 12 | + from jax.sharding import auto_axes |
| 13 | + |
| 14 | +QuantArray, PyTree = Any, Any |
| 15 | + |
| 16 | +KVCache = Any |
| 17 | +next_power_of_2 = lambda x: 2 ** math.ceil(math.log2(max(x, 1))) |
| 18 | +_pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)]) |
| 19 | + |
| 20 | + |
| 21 | +def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): |
| 22 | + "From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list." |
| 23 | + |
| 24 | + _split = lambda x: jnp.split(x, x.shape[0], axis=0) |
| 25 | + max_seq_len = max([jax.tree.leaves(kv)[0].shape[time_axis] for kv in kv_list]) |
| 26 | + kv_list = [jax.tree.map(lambda x: _pad_after(x, max_seq_len, time_axis), kv) for kv in kv_list] |
| 27 | + out = [None for _ in kv_list[0]] |
| 28 | + for i, c in enumerate(kv_list[0]): |
| 29 | + els = [[_split(z) for z in jax.tree.leaves(kv[i])] for kv in kv_list] # [B, R_flat, L] |
| 30 | + els = jax.tree.map(lambda *xs: jnp.concatenate(xs, axis=0), *els) # [R_flat, L] |
| 31 | + leaves_list = list(zip(*els)) # [L, R_flat] |
| 32 | + out[i] = [jax.tree.unflatten(jax.tree.structure(c), leaves) for leaves in leaves_list] # [L, R] |
| 33 | + return tuple(out), max_seq_len |
| 34 | + |
| 35 | + |
| 36 | +######################################################################################################################## |
| 37 | +# KV cache utils ####################################################################################################### |
| 38 | +######################################################################################################################## |
| 39 | + |
| 40 | + |
| 41 | +@partial(jax.jit, donate_argnames=("cache",)) |
| 42 | +def _kvcache_update_cache( |
| 43 | + cache: KVCache, |
| 44 | + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], |
| 45 | + batch_idxs: list[jax.Array], |
| 46 | + actual_lens: list[jax.Array], |
| 47 | + update_mask: list[bool] | None = None, |
| 48 | +): |
| 49 | + assert len(kvs) == len(batch_idxs) == len(actual_lens) |
| 50 | + batch_idxs, actual_lens, update_mask = jnp.array(batch_idxs), jnp.array(actual_lens), jnp.array(update_mask) |
| 51 | + uninitialized_cache = cache.iter < 0 |
| 52 | + start_time = jnp.where( |
| 53 | + uninitialized_cache, jnp.max(actual_lens) - actual_lens, (cache.iter - actual_lens) % cache.size |
| 54 | + ) |
| 55 | + batch_idxs = jnp.where(update_mask, batch_idxs, 2**30) # send masked to nowhere |
| 56 | + kvs, max_seq_len = _transpose_attention_tree(kvs, time_axis=cache.time_axis) |
| 57 | + time_indices = (jnp.arange(max_seq_len)[None, :] + start_time[:, None]) % cache.size |
| 58 | + |
| 59 | + def _update_element(x, u): |
| 60 | + update_permute = [0, cache.time_axis] + [i for i in range(u.ndim) if i not in (0, cache.time_axis)] |
| 61 | + # time_dim, batch_dim = update_permute.pop(cache.time_axis), update_permute.pop(0) # first pop time_axis |
| 62 | + # update_permute = [batch_dim, time_dim] + update_permute |
| 63 | + return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop") |
| 64 | + |
| 65 | + cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs) |
| 66 | + cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop") |
| 67 | + cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter) |
| 68 | + return dataclasses.replace(cache, k=cache_k, v=cache_v, iter=cache_iter, starts=cache_starts) |
| 69 | + |
| 70 | + |
| 71 | +def kvcache_update_cache( |
| 72 | + cache: KVCache, |
| 73 | + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], |
| 74 | + batch_idxs: list[jax.Array], |
| 75 | + actual_lens: list[jax.Array], |
| 76 | +): |
| 77 | + pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 |
| 78 | + update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] |
| 79 | + kvs = kvs + [kvs[-1]] * pad_len |
| 80 | + batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len |
| 81 | + return _kvcache_update_cache(cache, kvs, batch_idxs, actual_lens, update_mask) |
| 82 | + |
| 83 | + |
| 84 | +@jax.jit |
| 85 | +def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array): |
| 86 | + shift = -cache.starts[batch_idx] |
| 87 | + assert cache.time_axis > 0 |
| 88 | + kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), (cache.k, cache.v)) |
| 89 | + kvs = (jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[0]), jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[1])) |
| 90 | + true_len = cache.fill_len()[batch_idx] |
| 91 | + return kvs, true_len |
| 92 | + |
| 93 | + |
| 94 | +######################################################################################################################## |
| 95 | +# Paged KV cache utils ################################################################################################# |
| 96 | +######################################################################################################################## |
| 97 | + |
| 98 | +PagedKVCache = Any |
| 99 | + |
| 100 | + |
| 101 | +def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | None = None): |
| 102 | + if proposal_pages is not None: |
| 103 | + assert proposal_pages.size == k |
| 104 | + proposal_mask = free_pages[proposal_pages] |
| 105 | + indicies = jnp.where(~proposal_mask, jnp.cumsum(~proposal_mask, axis=-1) - 1, k - 1) |
| 106 | + newly_free_pages = free_pages.at[jnp.where(proposal_mask, proposal_pages, 2**30)].set(False, mode="drop") |
| 107 | + return jnp.where(proposal_mask, proposal_pages, jax.lax.top_k(newly_free_pages, k)[1][indicies]) |
| 108 | + else: |
| 109 | + return jax.lax.top_k(free_pages, k)[1] |
| 110 | + |
| 111 | + |
| 112 | +def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int): |
| 113 | + key_heads = cache.k[layer_idx].shape[0] |
| 114 | + assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) |
| 115 | + needs_next_page = (cache.lengths % cache.page_size) == 0 |
| 116 | + page_table_idx = cache.lengths // cache.page_size |
| 117 | + current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0] |
| 118 | + avg_pages_per_batch_entry = round(cache.k[layer_idx].shape[0] / cache.batch_size) |
| 119 | + even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry |
| 120 | + proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1) |
| 121 | + free_pages = _find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages) |
| 122 | + page_cursor = jnp.where(needs_next_page, free_pages, current_page_cursor) |
| 123 | + |
| 124 | + inpage_cursor = cache.lengths % cache.page_size |
| 125 | + |
| 126 | + new_lengths = cache.lengths + 1 |
| 127 | + # for batch index update the target slice is (heads, i, j, head_dim) |
| 128 | + # so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim) |
| 129 | + _update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1)) |
| 130 | + cache.k[layer_idx], cache.v[layer_idx] = jax.tree.map(_update, (cache.k[layer_idx], cache.v[layer_idx]), (k, v)) |
| 131 | + |
| 132 | + batch_idx = jnp.arange(cache.batch_size) |
| 133 | + new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor) |
| 134 | + |
| 135 | + new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop") |
| 136 | + new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages) |
| 137 | + return cache.k[layer_idx], cache.v[layer_idx], new_state |
| 138 | + |
| 139 | + |
| 140 | +def paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int): |
| 141 | + repl_sharding = jax.typeof(cache.lengths).sharding |
| 142 | + kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, (cache.k[layer_idx], cache.v[layer_idx])) |
| 143 | + sharding = (*kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding)) |
| 144 | + return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, k, v) |
| 145 | + |
| 146 | + |
| 147 | +@partial(jax.jit, donate_argnames=("cache",)) |
| 148 | +def _batch_paged_update_sequences( |
| 149 | + cache: PagedKVCache, |
| 150 | + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], |
| 151 | + batch_idxs: list[jax.Array], |
| 152 | + actual_lens: list[jax.Array], |
| 153 | + update_mask: list[bool] | None = None, |
| 154 | +) -> PagedKVCache: |
| 155 | + update_mask = jnp.array(update_mask) |
| 156 | + batch_idxs = jnp.where(update_mask, jnp.array(batch_idxs), 2**30) # send masked to nowhere |
| 157 | + actual_lens = jnp.minimum(jnp.array(actual_lens), jnp.array([jax.tree.leaves(kv)[0].shape[2] for kv in kvs])) |
| 158 | + |
| 159 | + kvs, max_seq_len = _transpose_attention_tree( |
| 160 | + kvs, time_axis=2 |
| 161 | + ) # undo stacking along the layer dimension for transit |
| 162 | + |
| 163 | + # clear existing pages |
| 164 | + actual_page_num = jnp.rint(jnp.ceil(cache.lengths[batch_idxs] / cache.page_size)).astype(jnp.int32) |
| 165 | + occupied_mask = jnp.arange(cache.block_tables.shape[-1])[None, :] < actual_page_num[:, None] |
| 166 | + indices_to_free = jnp.where(occupied_mask & update_mask[:, None], cache.block_tables[batch_idxs, :], 2**30) |
| 167 | + new_free_pages = cache.free_pages.at[indices_to_free.reshape(-1)].set(True, mode="drop") |
| 168 | + |
| 169 | + # get the length of the new sequence and find empty pages for the new sequence ideally contiguous |
| 170 | + upper_bound_page_num = math.ceil(max_seq_len / cache.page_size) |
| 171 | + actual_page_num = jnp.rint(jnp.ceil(actual_lens / cache.page_size)).astype(jnp.int32) |
| 172 | + avg_pages_per_batch_entry = round(jax.tree.leaves(cache)[0].shape[1] / cache.batch_size) |
| 173 | + proposal_pages = batch_idxs[:, None] * avg_pages_per_batch_entry + jnp.arange(upper_bound_page_num)[None, :] |
| 174 | + pages_idx = _find_empty_pages( |
| 175 | + new_free_pages, upper_bound_page_num * batch_idxs.size, proposal_pages=proposal_pages.reshape(-1) |
| 176 | + ).reshape(proposal_pages.shape) |
| 177 | + pages_arange = jnp.arange(upper_bound_page_num) |
| 178 | + pages_idx = jnp.where(update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_idx, 2**30) |
| 179 | + |
| 180 | + # reshape the new pages for insertion and possibly quantize |
| 181 | + b, h, s, e = jax.tree.leaves(kvs)[0].shape |
| 182 | + kvs = jax.tree.map(lambda x: x.reshape((b, h, s // cache.page_size, cache.page_size) + x.shape[3:]), kvs) |
| 183 | + |
| 184 | + def _update_element(x, u): |
| 185 | + # we're updating (batch, page_entries) with (BATCH, heads, PAGE, page_size, head_dim), so (BATCH, PAGE) go first |
| 186 | + update_permute = [1, 0, 2] + [i for i in range(u.ndim) if i not in (0, 1, 2)] |
| 187 | + return x.at[:, pages_idx, ...].set(u.transpose(update_permute), mode="drop") |
| 188 | + |
| 189 | + cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs) |
| 190 | + block_tables_idx = jnp.where( |
| 191 | + update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_arange[None, :], 2**30 |
| 192 | + ) |
| 193 | + new_block_tables = cache.block_tables.at[batch_idxs[:, None], block_tables_idx].set(pages_idx, mode="drop") |
| 194 | + new_free_pages = new_free_pages.at[pages_idx.reshape(-1)].set(False, mode="drop") |
| 195 | + new_lengths = cache.lengths.at[batch_idxs].set(actual_lens, mode="drop") |
| 196 | + return dataclasses.replace( |
| 197 | + cache, k=cache_k, v=cache_v, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages |
| 198 | + ) |
| 199 | + |
| 200 | + |
| 201 | +def batch_paged_update_sequences( |
| 202 | + cache: KVCache, |
| 203 | + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], |
| 204 | + batch_idxs: list[jax.Array], |
| 205 | + actual_lens: list[jax.Array], |
| 206 | +): |
| 207 | + pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 |
| 208 | + update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] |
| 209 | + kvs = kvs + [kvs[-1]] * pad_len |
| 210 | + batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len |
| 211 | + return _batch_paged_update_sequences(cache, kvs, batch_idxs, actual_lens, update_mask) |
| 212 | + |
| 213 | + |
| 214 | +@partial(jax.jit, static_argnames=("max_seq_len",)) |
| 215 | +def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len: int = -1): |
| 216 | + true_len = cache.fill_len()[batch_idx] |
| 217 | + max_seq_len = max_seq_len if max_seq_len > 0 else cache.page_size * cache.block_tables.shape[-1] |
| 218 | + max_seq_len = min(max_seq_len, cache.page_size * cache.block_tables.shape[-1]) # cache capacity |
| 219 | + page_indices = cache.block_tables[batch_idx, : round(math.ceil(max_seq_len / cache.page_size))] |
| 220 | + _reshape_out = lambda x: x.reshape((x.shape[0], max_seq_len) + x.shape[3:]) |
| 221 | + mask = jnp.arange(max_seq_len) < true_len |
| 222 | + _get = lambda x: jnp.where(mask[None, :, *([None] * (x.ndim - 3))], _reshape_out(x[:, page_indices, ...]), 0) |
| 223 | + |
| 224 | + # stack along layer dimensions for transit |
| 225 | + kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, (cache.k, cache.v))) |
| 226 | + return kvs, true_len |
0 commit comments