18
18
_pad_after = lambda x , l , axis : jnp .pad (x , [(0 , 0 ) if i != axis else (0 , l - x .shape [i ]) for i in range (x .ndim )])
19
19
20
20
21
+ def safe_zip (* args ):
22
+ if len (args ) == 0 :
23
+ return []
24
+ assert all (len (arg ) == len (args [0 ]) for arg in args )
25
+ return zip (* args )
26
+
27
+
21
28
def _transpose_attention_tree (kv_list : list [PyTree ], time_axis : int ):
22
29
"From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list."
23
30
@@ -28,7 +35,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
28
35
for i , c in enumerate (kv_list [0 ]):
29
36
els = [[_split (z ) for z in jax .tree .leaves (kv [i ])] for kv in kv_list ] # [B, R_flat, L]
30
37
els = jax .tree .map (lambda * xs : jnp .concatenate (xs , axis = 0 ), * els ) # [R_flat, L]
31
- leaves_list = list (zip (* els )) # [L, R_flat]
38
+ leaves_list = list (safe_zip (* els )) # [L, R_flat]
32
39
out [i ] = [jax .tree .unflatten (jax .tree .structure (c ), leaves ) for leaves in leaves_list ] # [L, R]
33
40
return tuple (out ), max_seq_len
34
41
@@ -41,7 +48,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int):
41
48
@partial (jax .jit , donate_argnames = ("cache" ,))
42
49
def _kvcache_update_cache (
43
50
cache : KVCache ,
44
- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
51
+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
45
52
batch_idxs : list [jax .Array ],
46
53
actual_lens : list [jax .Array ],
47
54
update_mask : list [bool ] | None = None ,
@@ -62,15 +69,17 @@ def _update_element(x, u):
62
69
# update_permute = [batch_dim, time_dim] + update_permute
63
70
return x .at [batch_idxs [:, None ], :, time_indices , ...].set (u .transpose (update_permute ), mode = "drop" )
64
71
65
- cache_k , cache_v = jax .tree .map (_update_element , ( cache .k , cache . v ) , kvs )
72
+ cache_kvs = jax .tree .map (_update_element , cache .buffers , kvs )
66
73
cache_starts = cache .starts .at [batch_idxs ].set (start_time , mode = "drop" )
67
74
cache_iter = jnp .where (uninitialized_cache , jnp .max (actual_lens ), cache .iter )
68
- return dataclasses .replace (cache , k = cache_k , v = cache_v , iter = cache_iter , starts = cache_starts )
75
+
76
+ buffer_names = [field .name for field in dataclasses .fields (cache )][:len (cache_kvs )]
77
+ return dataclasses .replace (cache , ** dict (safe_zip (buffer_names , cache_kvs )), iter = cache_iter , starts = cache_starts )
69
78
70
79
71
80
def kvcache_update_cache (
72
81
cache : KVCache ,
73
- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
82
+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
74
83
batch_idxs : list [jax .Array ],
75
84
actual_lens : list [jax .Array ],
76
85
):
@@ -85,7 +94,7 @@ def kvcache_update_cache(
85
94
def kvcache_get_entry (cache : KVCache , batch_idx : jax .Array ):
86
95
shift = - cache .starts [batch_idx ]
87
96
assert cache .time_axis > 0
88
- kvs = jax .tree .map (lambda x : jnp .roll (x [batch_idx , ...], shift = shift , axis = cache .time_axis - 1 ), ( cache .k , cache . v ) )
97
+ kvs = jax .tree .map (lambda x : jnp .roll (x [batch_idx , ...], shift = shift , axis = cache .time_axis - 1 ), cache .buffers )
89
98
kvs = (jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), kvs [0 ]), jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), kvs [1 ]))
90
99
true_len = cache .fill_len ()[batch_idx ]
91
100
return kvs , true_len
@@ -109,13 +118,13 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array |
109
118
return jax .lax .top_k (free_pages , k )[1 ]
110
119
111
120
112
- def _paged_update_slice (cache : PagedKVCache , k : jax .Array | QuantArray , v : jax . Array | QuantArray , * , layer_idx : int ):
113
- key_heads = cache .k [layer_idx ].shape [0 ]
114
- assert v .shape [:- 1 ] == k .shape [:- 1 ] == (cache .batch_size , key_heads , 1 )
121
+ def _paged_update_slice (cache : PagedKVCache , kv : tuple [ jax .Array | QuantArray , ...] , * , layer_idx : int ):
122
+ # key_heads = cache.buffers[0] [layer_idx].shape[0]
123
+ # assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) # TODO write this generically
115
124
needs_next_page = (cache .lengths % cache .page_size ) == 0
116
125
page_table_idx = cache .lengths // cache .page_size
117
126
current_page_cursor = jnp .take_along_axis (cache .block_tables , page_table_idx [:, None ], axis = - 1 )[..., 0 ]
118
- avg_pages_per_batch_entry = round (cache .k [layer_idx ].shape [0 ] / cache .batch_size )
127
+ avg_pages_per_batch_entry = round (cache .buffers [ 0 ] [layer_idx ].shape [0 ] / cache .batch_size )
119
128
even_batch_spread = jnp .arange (cache .batch_size ) * avg_pages_per_batch_entry
120
129
proposal_pages = jnp .where (cache .lengths == 0 , even_batch_spread , current_page_cursor + 1 )
121
130
free_pages = _find_empty_pages (cache .free_pages , cache .batch_size , proposal_pages = proposal_pages )
@@ -127,27 +136,28 @@ def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.A
127
136
# for batch index update the target slice is (heads, i, j, head_dim)
128
137
# so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim)
129
138
_update = lambda dest , src : dest .at [:, page_cursor , inpage_cursor , ...].set (src .squeeze (2 ).swapaxes (0 , 1 ))
130
- cache .k [layer_idx ], cache .v [layer_idx ] = jax .tree .map (_update , (cache .k [layer_idx ], cache .v [layer_idx ]), (k , v ))
139
+ for buffer , new_buffer in safe_zip (cache .buffers , kv ):
140
+ buffer [layer_idx ] = jax .tree .map (_update , buffer [layer_idx ], new_buffer )
131
141
132
142
batch_idx = jnp .arange (cache .batch_size )
133
143
new_block_tables = cache .block_tables .at [batch_idx , new_lengths // cache .page_size ].set (page_cursor )
134
144
135
145
new_free_pages = cache .free_pages .at [page_cursor ].set (False , mode = "drop" )
136
146
new_state = dict (lengths = new_lengths , block_tables = new_block_tables , free_pages = new_free_pages )
137
- return cache . k [layer_idx ], cache .v [ layer_idx ] , new_state
147
+ return tuple ( buffer [layer_idx ] for buffer in cache .buffers ) , new_state
138
148
139
149
140
- def paged_update_slice (cache : PagedKVCache , k : jax .Array | QuantArray , v : jax . Array | QuantArray , * , layer_idx : int ):
150
+ def paged_update_slice (cache : PagedKVCache , kv : tuple [ jax .Array | QuantArray , ...] , * , layer_idx : int ):
141
151
repl_sharding = jax .typeof (cache .lengths ).sharding
142
- kv_sharding = jax .tree .map (lambda x : jax .typeof (x ).sharding , ( cache . k [layer_idx ], cache .v [ layer_idx ] ))
143
- sharding = (* kv_sharding , dict (lengths = repl_sharding , block_tables = repl_sharding , free_pages = repl_sharding ))
144
- return auto_axes (partial (_paged_update_slice , layer_idx = layer_idx ), out_sharding = sharding )(cache , k , v )
152
+ kv_sharding = jax .tree .map (lambda x : jax .typeof (x ).sharding , tuple ( buffer [layer_idx ] for buffer in cache .buffers ))
153
+ sharding = (kv_sharding , dict (lengths = repl_sharding , block_tables = repl_sharding , free_pages = repl_sharding ))
154
+ return auto_axes (partial (_paged_update_slice , layer_idx = layer_idx ), out_sharding = sharding )(cache , kv )
145
155
146
156
147
157
@partial (jax .jit , donate_argnames = ("cache" ,))
148
158
def _batch_paged_update_sequences (
149
159
cache : PagedKVCache ,
150
- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
160
+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
151
161
batch_idxs : list [jax .Array ],
152
162
actual_lens : list [jax .Array ],
153
163
update_mask : list [bool ] | None = None ,
@@ -156,9 +166,7 @@ def _batch_paged_update_sequences(
156
166
batch_idxs = jnp .where (update_mask , jnp .array (batch_idxs ), 2 ** 30 ) # send masked to nowhere
157
167
actual_lens = jnp .minimum (jnp .array (actual_lens ), jnp .array ([jax .tree .leaves (kv )[0 ].shape [2 ] for kv in kvs ]))
158
168
159
- kvs , max_seq_len = _transpose_attention_tree (
160
- kvs , time_axis = 2
161
- ) # undo stacking along the layer dimension for transit
169
+ kvs , max_seq_len = _transpose_attention_tree (kvs , time_axis = 2 ) # undo stack along layer dimension in transit
162
170
163
171
# clear existing pages
164
172
actual_page_num = jnp .rint (jnp .ceil (cache .lengths [batch_idxs ] / cache .page_size )).astype (jnp .int32 )
@@ -186,21 +194,23 @@ def _update_element(x, u):
186
194
update_permute = [1 , 0 , 2 ] + [i for i in range (u .ndim ) if i not in (0 , 1 , 2 )]
187
195
return x .at [:, pages_idx , ...].set (u .transpose (update_permute ), mode = "drop" )
188
196
189
- cache_k , cache_v = jax .tree .map (_update_element , ( cache .k , cache . v ) , kvs )
197
+ new_buffers = jax .tree .map (_update_element , cache .buffers , kvs )
190
198
block_tables_idx = jnp .where (
191
199
update_mask [:, None ] & (pages_arange [None , :] < actual_page_num [:, None ]), pages_arange [None , :], 2 ** 30
192
200
)
193
201
new_block_tables = cache .block_tables .at [batch_idxs [:, None ], block_tables_idx ].set (pages_idx , mode = "drop" )
194
202
new_free_pages = new_free_pages .at [pages_idx .reshape (- 1 )].set (False , mode = "drop" )
195
203
new_lengths = cache .lengths .at [batch_idxs ].set (actual_lens , mode = "drop" )
204
+
205
+ named_buffers = dict (zip ([field .name for field in dataclasses .fields (cache )][:len (new_buffers )], new_buffers ))
196
206
return dataclasses .replace (
197
- cache , k = cache_k , v = cache_v , lengths = new_lengths , block_tables = new_block_tables , free_pages = new_free_pages
207
+ cache , ** named_buffers , lengths = new_lengths , block_tables = new_block_tables , free_pages = new_free_pages
198
208
)
199
209
200
210
201
211
def batch_paged_update_sequences (
202
212
cache : KVCache ,
203
- kvs : list [tuple [list [ jax .Array | QuantArray ], list [ jax . Array | QuantArray ] ]],
213
+ kvs : list [tuple [jax .Array | QuantArray , ... ]],
204
214
batch_idxs : list [jax .Array ],
205
215
actual_lens : list [jax .Array ],
206
216
):
@@ -222,5 +232,5 @@ def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len
222
232
_get = lambda x : jnp .where (mask [None , :, * ([None ] * (x .ndim - 3 ))], _reshape_out (x [:, page_indices , ...]), 0 )
223
233
224
234
# stack along layer dimensions for transit
225
- kvs = tuple (jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), * z ) for z in jax .tree .map (_get , ( cache .k , cache . v ) ))
235
+ kvs = tuple (jax .tree .map (lambda * xs : jnp .stack (xs , 0 ), * z ) for z in jax .tree .map (_get , cache .buffers ))
226
236
return kvs , true_len
0 commit comments