From 14ff57c981efa3cc702f1ec1a67f5e529d407222 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 11 Mar 2025 18:24:05 -0700 Subject: [PATCH 01/11] serving draft --- llama3/.gitignore | 1 + llama3/llama3_jax/attention_cache_utils.py | 226 +++ llama3/llama3_jax/model.py | 307 +++- llama3/llama3_jax/ragged_attention.py | 13 +- llama3/pyproject.toml | 8 +- serving/README.md | 3 + serving/client_demo.py | 1511 ++++++++++++++++++++ serving/main_serving.py | 224 +++ serving/pyproject.toml | 30 + serving/serving_jax/__init__.py | 731 ++++++++++ serving/serving_jax/cross_host.py | 64 + 11 files changed, 3025 insertions(+), 93 deletions(-) create mode 100644 llama3/.gitignore create mode 100644 llama3/llama3_jax/attention_cache_utils.py create mode 100644 serving/README.md create mode 100644 serving/client_demo.py create mode 100644 serving/main_serving.py create mode 100644 serving/pyproject.toml create mode 100644 serving/serving_jax/__init__.py create mode 100644 serving/serving_jax/cross_host.py diff --git a/llama3/.gitignore b/llama3/.gitignore new file mode 100644 index 0000000..2211df6 --- /dev/null +++ b/llama3/.gitignore @@ -0,0 +1 @@ +*.txt diff --git a/llama3/llama3_jax/attention_cache_utils.py b/llama3/llama3_jax/attention_cache_utils.py new file mode 100644 index 0000000..abfcbe0 --- /dev/null +++ b/llama3/llama3_jax/attention_cache_utils.py @@ -0,0 +1,226 @@ +import dataclasses +from functools import partial +import math +from typing import Any + +import jax +import jax.numpy as jnp + +try: + from jax.experimental.shard import auto_axes +except ModuleNotFoundError: + from jax.sharding import auto_axes + +QuantArray, PyTree = Any, Any + +KVCache = Any +next_power_of_2 = lambda x: 2 ** math.ceil(math.log2(max(x, 1))) +_pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)]) + + +def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): + "From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list." + + _split = lambda x: jnp.split(x, x.shape[0], axis=0) + max_seq_len = max([jax.tree.leaves(kv)[0].shape[time_axis] for kv in kv_list]) + kv_list = [jax.tree.map(lambda x: _pad_after(x, max_seq_len, time_axis), kv) for kv in kv_list] + out = [None for _ in kv_list[0]] + for i, c in enumerate(kv_list[0]): + els = [[_split(z) for z in jax.tree.leaves(kv[i])] for kv in kv_list] # [B, R_flat, L] + els = jax.tree.map(lambda *xs: jnp.concatenate(xs, axis=0), *els) # [R_flat, L] + leaves_list = list(zip(*els)) # [L, R_flat] + out[i] = [jax.tree.unflatten(jax.tree.structure(c), leaves) for leaves in leaves_list] # [L, R] + return tuple(out), max_seq_len + + +######################################################################################################################## +# KV cache utils ####################################################################################################### +######################################################################################################################## + + +@partial(jax.jit, donate_argnames=("cache",)) +def _kvcache_update_cache( + cache: KVCache, + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + batch_idxs: list[jax.Array], + actual_lens: list[jax.Array], + update_mask: list[bool] | None = None, +): + assert len(kvs) == len(batch_idxs) == len(actual_lens) + batch_idxs, actual_lens, update_mask = jnp.array(batch_idxs), jnp.array(actual_lens), jnp.array(update_mask) + uninitialized_cache = cache.iter < 0 + start_time = jnp.where( + uninitialized_cache, jnp.max(actual_lens) - actual_lens, (cache.iter - actual_lens) % cache.size + ) + batch_idxs = jnp.where(update_mask, batch_idxs, 2**30) # send masked to nowhere + kvs, max_seq_len = _transpose_attention_tree(kvs, time_axis=cache.time_axis) + time_indices = (jnp.arange(max_seq_len)[None, :] + start_time[:, None]) % cache.size + + def _update_element(x, u): + update_permute = [0, cache.time_axis] + [i for i in range(u.ndim) if i not in (0, cache.time_axis)] + # time_dim, batch_dim = update_permute.pop(cache.time_axis), update_permute.pop(0) # first pop time_axis + # update_permute = [batch_dim, time_dim] + update_permute + return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop") + + cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs) + cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop") + cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter) + return dataclasses.replace(cache, k=cache_k, v=cache_v, iter=cache_iter, starts=cache_starts) + + +def kvcache_update_cache( + cache: KVCache, + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + batch_idxs: list[jax.Array], + actual_lens: list[jax.Array], +): + pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 + update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] + kvs = kvs + [kvs[-1]] * pad_len + batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len + return _kvcache_update_cache(cache, kvs, batch_idxs, actual_lens, update_mask) + + +@jax.jit +def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array): + shift = -cache.starts[batch_idx] + assert cache.time_axis > 0 + kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), (cache.k, cache.v)) + kvs = (jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[0]), jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[1])) + true_len = cache.fill_len()[batch_idx] + return kvs, true_len + + +######################################################################################################################## +# Paged KV cache utils ################################################################################################# +######################################################################################################################## + +PagedKVCache = Any + + +def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | None = None): + if proposal_pages is not None: + assert proposal_pages.size == k + proposal_mask = free_pages[proposal_pages] + indicies = jnp.where(~proposal_mask, jnp.cumsum(~proposal_mask, axis=-1) - 1, k - 1) + newly_free_pages = free_pages.at[jnp.where(proposal_mask, proposal_pages, 2**30)].set(False, mode="drop") + return jnp.where(proposal_mask, proposal_pages, jax.lax.top_k(newly_free_pages, k)[1][indicies]) + else: + return jax.lax.top_k(free_pages, k)[1] + + +def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int): + key_heads = cache.k[layer_idx].shape[0] + assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) + needs_next_page = (cache.lengths % cache.page_size) == 0 + page_table_idx = cache.lengths // cache.page_size + current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0] + avg_pages_per_batch_entry = round(cache.k[layer_idx].shape[0] / cache.batch_size) + even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry + proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1) + free_pages = _find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages) + page_cursor = jnp.where(needs_next_page, free_pages, current_page_cursor) + + inpage_cursor = cache.lengths % cache.page_size + + new_lengths = cache.lengths + 1 + # for batch index update the target slice is (heads, i, j, head_dim) + # so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim) + _update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1)) + cache.k[layer_idx], cache.v[layer_idx] = jax.tree.map(_update, (cache.k[layer_idx], cache.v[layer_idx]), (k, v)) + + batch_idx = jnp.arange(cache.batch_size) + new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor) + + new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop") + new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages) + return cache.k[layer_idx], cache.v[layer_idx], new_state + + +def paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int): + repl_sharding = jax.typeof(cache.lengths).sharding + kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, (cache.k[layer_idx], cache.v[layer_idx])) + sharding = (*kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding)) + return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, k, v) + + +@partial(jax.jit, donate_argnames=("cache",)) +def _batch_paged_update_sequences( + cache: PagedKVCache, + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + batch_idxs: list[jax.Array], + actual_lens: list[jax.Array], + update_mask: list[bool] | None = None, +) -> PagedKVCache: + update_mask = jnp.array(update_mask) + batch_idxs = jnp.where(update_mask, jnp.array(batch_idxs), 2**30) # send masked to nowhere + actual_lens = jnp.minimum(jnp.array(actual_lens), jnp.array([jax.tree.leaves(kv)[0].shape[2] for kv in kvs])) + + kvs, max_seq_len = _transpose_attention_tree( + kvs, time_axis=2 + ) # undo stacking along the layer dimension for transit + + # clear existing pages + actual_page_num = jnp.rint(jnp.ceil(cache.lengths[batch_idxs] / cache.page_size)).astype(jnp.int32) + occupied_mask = jnp.arange(cache.block_tables.shape[-1])[None, :] < actual_page_num[:, None] + indices_to_free = jnp.where(occupied_mask & update_mask[:, None], cache.block_tables[batch_idxs, :], 2**30) + new_free_pages = cache.free_pages.at[indices_to_free.reshape(-1)].set(True, mode="drop") + + # get the length of the new sequence and find empty pages for the new sequence ideally contiguous + upper_bound_page_num = math.ceil(max_seq_len / cache.page_size) + actual_page_num = jnp.rint(jnp.ceil(actual_lens / cache.page_size)).astype(jnp.int32) + avg_pages_per_batch_entry = round(jax.tree.leaves(cache)[0].shape[1] / cache.batch_size) + proposal_pages = batch_idxs[:, None] * avg_pages_per_batch_entry + jnp.arange(upper_bound_page_num)[None, :] + pages_idx = _find_empty_pages( + new_free_pages, upper_bound_page_num * batch_idxs.size, proposal_pages=proposal_pages.reshape(-1) + ).reshape(proposal_pages.shape) + pages_arange = jnp.arange(upper_bound_page_num) + pages_idx = jnp.where(update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_idx, 2**30) + + # reshape the new pages for insertion and possibly quantize + b, h, s, e = jax.tree.leaves(kvs)[0].shape + kvs = jax.tree.map(lambda x: x.reshape((b, h, s // cache.page_size, cache.page_size) + x.shape[3:]), kvs) + + def _update_element(x, u): + # we're updating (batch, page_entries) with (BATCH, heads, PAGE, page_size, head_dim), so (BATCH, PAGE) go first + update_permute = [1, 0, 2] + [i for i in range(u.ndim) if i not in (0, 1, 2)] + return x.at[:, pages_idx, ...].set(u.transpose(update_permute), mode="drop") + + cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs) + block_tables_idx = jnp.where( + update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_arange[None, :], 2**30 + ) + new_block_tables = cache.block_tables.at[batch_idxs[:, None], block_tables_idx].set(pages_idx, mode="drop") + new_free_pages = new_free_pages.at[pages_idx.reshape(-1)].set(False, mode="drop") + new_lengths = cache.lengths.at[batch_idxs].set(actual_lens, mode="drop") + return dataclasses.replace( + cache, k=cache_k, v=cache_v, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages + ) + + +def batch_paged_update_sequences( + cache: KVCache, + kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + batch_idxs: list[jax.Array], + actual_lens: list[jax.Array], +): + pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 + update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] + kvs = kvs + [kvs[-1]] * pad_len + batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len + return _batch_paged_update_sequences(cache, kvs, batch_idxs, actual_lens, update_mask) + + +@partial(jax.jit, static_argnames=("max_seq_len",)) +def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len: int = -1): + true_len = cache.fill_len()[batch_idx] + max_seq_len = max_seq_len if max_seq_len > 0 else cache.page_size * cache.block_tables.shape[-1] + max_seq_len = min(max_seq_len, cache.page_size * cache.block_tables.shape[-1]) # cache capacity + page_indices = cache.block_tables[batch_idx, : round(math.ceil(max_seq_len / cache.page_size))] + _reshape_out = lambda x: x.reshape((x.shape[0], max_seq_len) + x.shape[3:]) + mask = jnp.arange(max_seq_len) < true_len + _get = lambda x: jnp.where(mask[None, :, *([None] * (x.ndim - 3))], _reshape_out(x[:, page_indices, ...]), 0) + + # stack along layer dimensions for transit + kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, (cache.k, cache.v))) + return kvs, true_len diff --git a/llama3/llama3_jax/model.py b/llama3/llama3_jax/model.py index 8d38a34..3d07be9 100644 --- a/llama3/llama3_jax/model.py +++ b/llama3/llama3_jax/model.py @@ -21,6 +21,7 @@ import math from functools import partial from typing import Callable, Any, TypeVar +from types import ModuleType from inspect import signature import jax @@ -35,9 +36,10 @@ from jax.experimental.shard import auto_axes as _auto_axes, reshard except ModuleNotFoundError: from jax.sharding import auto_axes as _auto_axes, reshard -from etils import epath +from jax.experimental.pallas.ops.gpu import paged_attention from . import ragged_attention +from . import attention_cache_utils AxisName = str | tuple[str, ...] | None Axes = tuple[AxisName, ...] @@ -94,7 +96,7 @@ def auto_axes(x, out_sharding): # TOOD(rdyro): remove once in JAX >= 0.7.0 def logical_to_physical(logical: Axes, rules: ShardingRules) -> jax.sharding.PartitionSpec: - """Returns how to physically shard a given sequence of logical array dimensions (i.e. the logical shape of an array).""" + """Translate logical to physically sharding.""" spec = [getattr(rules, axis) if axis is not None else None for axis in logical] # `spec` may contain tuples, flatten to check that `spec` maps each physical mesh axis to at most one logical array # axis. @@ -116,7 +118,9 @@ def jax_pytree_struct(cls, meta_fields: tuple = ()): cls = dataclasses.dataclass(cls) all_fields = tuple(f.name for f in dataclasses.fields(cls) if f.init) data_fields = tuple(f for f in all_fields if f not in meta_fields) - return tree_util.register_dataclass(cls, data_fields=data_fields, meta_fields=meta_fields) + # return register_dataclass_serialization( + return tree_util.register_dataclass(cls, data_fields=data_fields, meta_fields=meta_fields) # , + # serialize_auxdata=lambda *args: b"", deserialize_auxdata=lambda *args: ()) jax_static = lambda cls: tree_util.register_static(dataclasses.dataclass(cls)) @@ -178,8 +182,10 @@ def llama_to_jax_config(llama_config: Any | dict[str, Any]) -> "Config": def load_config(config_path: str | os.PathLike[str] | Path) -> "Config": return llama_to_jax_config(json.loads(Path(config_path).read_text())) + PreTrainedTokenizerFast = TypeVar("PreTrainedTokenizerFast") + def load_tokenizer( tokenizer_path: str | os.PathLike[str] | Path, tokenizer_config_path: str | os.PathLike[str] | Path ) -> PreTrainedTokenizerFast: @@ -305,19 +311,19 @@ def quantize(x: jax.Array | ArrayInfo, axis: int | tuple[int, ...], scale_dtype= raise ValueError(f"quantize got unexpected type: {type(x)}") -def update_slice(x: jax.Array | QuantArray, y: jax.Array, pos: int, update_axis: int, quant_axis: int = -1): +def update_slice( + x: jax.Array | QuantArray, y: jax.Array | QuantArray, pos: int, update_axis: int, quant_axis: int = -1 +): """dynamic_update_slice wrapper that handles regular arrays and QuantArrays""" if is_type(x, QuantArray): assert x.quant.ndim == y.ndim quant_axis, update_axis = quant_axis % x.quant.ndim, update_axis % x.quant.ndim # normalize axis numbers - y_quant, y_scale = quantize(y, axis=quant_axis, scale_dtype=x.scale.dtype) # quantize rhs + y_quant, y_scale = y.quant, y.scale y_quant = reshard(y_quant.astype(x.quant.dtype), jax.typeof(x.quant).sharding.spec) y_scale = reshard(y_scale.astype(x.scale.dtype), jax.typeof(x.scale).sharding.spec) new_quant = jax.lax.dynamic_update_slice_in_dim(x.quant, y_quant, pos, axis=update_axis) scale_update_axis = [ax for ax in range(x.quant.ndim) if ax != quant_axis][update_axis] - new_scale = jax.lax.dynamic_update_slice_in_dim( - x.scale, y_scale, pos, axis=scale_update_axis - ) # update axis in `scale` + new_scale = jax.lax.dynamic_update_slice_in_dim(x.scale, y_scale, pos, axis=scale_update_axis) return dataclasses.replace(x, quant=new_quant, scale=new_scale) else: assert x.ndim == y.ndim @@ -398,17 +404,20 @@ def abstract(cls, cfg: Config): ) -@jax_pytree_struct +@partial(jax_pytree_struct, meta_fields=("batch_size", "size", "time_axis")) class KVCache(_Init): - k: list[jax.Array] # (batch_size, key_heads, max_seq_len, head_dim) - v: list[jax.Array] # (batch_size, key_heads, max_seq_len, head_dim) - length: jax.Array # [] # sequences are right-aligned for slice udpate performance + k: list[tuple[jax.Array | QuantArray, ...]] # (batch_size, key_heads, max_seq_len, head_dim) + v: list[tuple[jax.Array | QuantArray, ...]] # (batch_size, key_heads, max_seq_len, head_dim) + iter: jax.Array # [] # sequences are right-aligned for slice update performance starts: jax.Array # [batch_size] # sequences are right-aligned, we need start indices + batch_size: int = 0 + size: int = 0 + time_axis: int = 2 @classmethod - def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int): + def abstract(cls, cfg: Config, batch_size: int): val_info = ArrayInfo( - (batch_size, cfg.kv_heads, max_seq_len, cfg.head_dim), + (batch_size, cfg.kv_heads, cfg.max_seq_len, cfg.head_dim), cfg.dtype, ("batch", "kv_heads", "sequence", "head_dim"), jax.nn.initializers.zeros, @@ -416,11 +425,65 @@ def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int): cache = KVCache( k=[val_info for _ in range(cfg.num_layers)], v=[val_info for _ in range(cfg.num_layers)], - length=ArrayInfo((), jnp.int32, (), jax.nn.initializers.zeros), + # -1 means unintialized since iter (cursor) must be 0 <= iter < len - 1 + iter=ArrayInfo((), jnp.int32, (), jax.nn.initializers.constant(-1)), starts=ArrayInfo((batch_size,), jnp.int32, ("batch",), jax.nn.initializers.zeros), ) if cfg.quant_cache: _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype, zero_init=True) + cache = dataclasses.replace( + cache, + k=[ + QuantArray(*_quantize(cache.k[idx]), out_scaling=True, scale_expand_dims=(-2, -3)) + for idx in range(cfg.num_layers) + ], + v=[ + QuantArray(*_quantize(cache.v[idx]), out_scaling=False, scale_expand_dims=(-2, -3)) + for idx in range(cfg.num_layers) + ], + ) + + cache.batch_size, cache.size = batch_size, cfg.max_seq_len + return cache + + def fill_len(self) -> jax.Array: + length = jnp.where(self.iter > self.starts, self.iter - self.starts, self.size + self.iter - self.starts) + return jnp.where(self.iter >= 0, length, 0) + + update_slice = None + insert_sequences = staticmethod(attention_cache_utils.kvcache_update_cache) + get_sequence = staticmethod(attention_cache_utils.kvcache_get_entry) + + +@partial(jax_pytree_struct, meta_fields=("batch_size", "size", "page_size")) +class PagedKVCache(_Init): + k: list[jax.Array | QuantArray] # [key_heads, total_num_pages, page_size, head_dim] + v: list[jax.Array | QuantArray] # [key_heads, total_num_pages, page_size, head_dim] + lengths: jax.Array # [batch_size] # true length of the cache entries + block_tables: jax.Array # [batch_size, pages_per_seq] + free_pages: jax.Array # [total_num_pages] + batch_size: int = 0 + size: int = 2**31 - 1 + page_size: int = 0 + + @classmethod + def abstract(cls, cfg: "Config", batch_size: int, total_num_pages: int, page_size: int): + pages_per_seq = math.ceil(cfg.max_seq_len / page_size) + val_info = ArrayInfo( + (cfg.kv_heads, total_num_pages, page_size, cfg.head_dim), + cfg.dtype, + ("kv_heads", None, None, "head_dim"), + jax.nn.initializers.zeros, + ) + cache = PagedKVCache( + k=[val_info for _ in range(cfg.num_layers)], + v=[val_info for _ in range(cfg.num_layers)], + lengths=ArrayInfo((batch_size,), jnp.int32, (), jax.nn.initializers.constant(0)), + block_tables=ArrayInfo((batch_size, pages_per_seq), jnp.int32, (), jax.nn.initializers.constant(0)), + free_pages=ArrayInfo((total_num_pages,), jnp.bool, (), jax.nn.initializers.constant(1)), + ) + if cfg.quant_cache: + _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype) cache = dataclasses.replace( cache, k=[ @@ -432,11 +495,15 @@ def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int): for idx in range(len(cache.v)) ], ) + cache.batch_size, cache.page_size = batch_size, page_size return cache - @property - def time_axis(self) -> int: - return 2 + def fill_len(self) -> jax.Array: + return self.lengths + + update_slice = staticmethod(attention_cache_utils.paged_update_slice) + insert_sequences = staticmethod(attention_cache_utils.batch_paged_update_sequences) + get_sequence = staticmethod(attention_cache_utils.batch_paged_get_entry) def segment_ids_to_positions(segment_ids): @@ -525,32 +592,42 @@ def apply_rotary_embedding(x, sin, cos): return jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) -def make_attention_mask(q_len, k_len, q_segment_ids, k_segment_ids, q_offset, causal: bool): +def rms_norm(x: jax.Array, gamma: jax.Array) -> jax.Array: + """Apply RMS normalization.""" + rms = jnp.sqrt(jnp.mean(jnp.astype(x, jnp.float32) ** 2, axis=-1, keepdims=True) + 1e-6) + return jnp.astype(gamma * x / rms, jnp.bfloat16) + + +def make_attention_mask(q_len, k_len, q_segment_ids, kv_segment_ids, q_offset, causal: bool, cache_starts: jax.Array): + cache_size = kv_segment_ids.shape[-1] # [B, t, T] - segment_mask = q_segment_ids[:, :, None] == k_segment_ids[:, None, :] + segment_mask = q_segment_ids[:, :, None] == kv_segment_ids[:, None, :] # [B, t, T] -> [B, 1, t, T] segment_mask = segment_mask[:, None, :, :] if causal: # [b, h, t, T] qk = (1, 1, q_len, k_len) - q_iota = jax.lax.broadcasted_iota(jnp.int32, qk, 2) - k_iota = jax.lax.broadcasted_iota(jnp.int32, qk, 3) - q_positions = q_iota + q_offset[:, None, None, None] - causal_mask = q_positions >= k_iota + q_positions = jax.lax.broadcasted_iota(jnp.int32, qk, 2) + q_offset[:, None, None, None] + k_positions = ( + jax.lax.broadcasted_iota(jnp.int32, qk, 3) + (-1 * cache_starts)[:, None, None, None] + ) % cache_size + causal_mask = q_positions >= k_positions combined_mask = jnp.logical_and(segment_mask, causal_mask) return combined_mask else: return segment_mask -def _attention( +@partial(auto_axes, out_sharding=P(BATCH_AXIS_NAME, TENSOR_AXIS_NAME, None, None)) +def attention( q: jax.Array, k: jax.Array | tuple[jax.Array, jax.Array], v: jax.Array | tuple[jax.Array, jax.Array], q_segment_ids: jax.Array, - k_segment_ids: jax.Array, + kv_segment_ids: jax.Array, q_offset: jax.Array, + cache_starts: jax.Array | None, cfg: Config, ) -> jax.Array: """ @@ -561,7 +638,7 @@ def _attention( k: Key tensor of shape (batch_size, num_heads, k_len, head_dim) v: Value tensor of shape (batch_size, num_heads, k_len, head_dim) q_segment_ids: Query segment IDs of shape (batch_size, q_len) - k_segment_ids: Key segment IDs of shape (batch_size, k_len) + kv_segment_ids: Key segment IDs of shape (batch_size, k_len) q_offset: Query offset of shape (batch_size,) cfg: Configuration object @@ -579,7 +656,7 @@ def _attention( qk = einsum("bhgtd,bhTd->bhgtT", q_, k) * scale qk = qk.reshape((b, qh, t, T)) - mask = make_attention_mask(t, T, q_segment_ids, k_segment_ids, q_offset, cfg.causal) + mask = make_attention_mask(t, T, q_segment_ids, kv_segment_ids, q_offset, cfg.causal, cache_starts) # Apply the combined mask qk = jnp.where(mask, qk, -1e30) # jax softmax impl includes max subtraction for numerical stability, no need to do it outside. @@ -591,9 +668,6 @@ def _attention( return qkv.reshape((b, qh, t, d)) -attention = auto_axes(_attention, out_sharding=P(BATCH_AXIS_NAME, TENSOR_AXIS_NAME, None, None)) - - def attention_kernel(q, k, v, q_segment_ids, kv_segment_ids, q_offset, starts, lengths, cfg: Config): """Flash attention kernel!""" @@ -665,10 +739,48 @@ def _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale return jax.lax.reshape(ret, q_shape__, out_sharding=l2p("batch", "q_heads", "sequence", "head_dim")) -def rms_norm(x: jax.Array, gamma: jax.Array) -> jax.Array: - """Apply RMS normalization.""" - rms = jnp.sqrt(jnp.mean(jnp.astype(x, jnp.float32) ** 2, axis=-1, keepdims=True) + 1e-6) - return jnp.astype(gamma * x / rms, jnp.bfloat16) +def paged_attention_kernel(q, k, v, block_tables, lengths, cfg: Config): + k, k_scale = (k.quant, k.scale) if is_type(k, QuantArray) else (k, None) + v, v_scale = (v.quant, v.scale) if is_type(v, QuantArray) else (v, None) + + # handle grouped query attention + assert q.shape[-3] % cfg.kv_heads == 0 and k.shape[0] == cfg.kv_heads + scale = q.shape[-1] ** -0.5 + + l2p = lambda *logical: logical_to_physical(logical, cfg.rules) + + kv_repeats = q.shape[-3] // cfg.kv_heads + q_spec = P( + *(l2p("batch", "kv_heads") + tuple(set(*l2p("q_heads")) - set(*l2p("kv_heads"))) + l2p("sequence", "head_dim")) + ) + q_shape__ = q.shape + q = jax.lax.reshape(q, (q.shape[:-3] + (cfg.kv_heads, kv_repeats, q.shape[-2], q.shape[-1])), out_sharding=q_spec) + + # shard_map + in_specs = ( + q_spec, # q + l2p("kv_heads", None, "sequence", "head_dim"), # k / k_quant + None if k_scale is None else l2p("kv_heads", None, "sequence"), # k_scale or None + l2p("kv_heads", None, "sequence", "head_dim"), # v / v_quant + None if v_scale is None else l2p("kv_heads", None, "sequence"), # v_scale or None + l2p("batch", None), # block_tables + l2p("batch"), # lengths + ) + out_specs = q_spec + + @partial(shard_map, mesh=cfg.mesh, in_specs=in_specs, out_specs=out_specs, check_rep=False) + def _f(q, k, k_scale, v, v_scale, block_tables, lengths): + # q in [batch_size, kv_heads_local, kv_repeats, 1, head_dim] + if k_scale is not None: + k = (k * k_scale[..., None]).astype(jnp.bfloat16) + if v_scale is not None: + v = (v * v_scale[..., None]).astype(jnp.bfloat16) + q_ = q[..., 0, :].reshape((q.shape[0], -1, q.shape[-1])) + ret = paged_attention.paged_attention(q_ * scale, k, v, block_tables, lengths) + return ret.reshape(q.shape) + + ret = _f(q, k, k_scale, v, v_scale, block_tables, lengths).astype(jnp.bfloat16) + return jax.lax.reshape(ret, q_shape__, out_sharding=l2p("batch", "q_heads", "sequence", "head_dim")) def attention_block( @@ -678,7 +790,7 @@ def attention_block( sin: jax.Array, cos: jax.Array, cfg: Config, - cache: KVCache | None = None, + cache: KVCache | PagedKVCache | None = None, idx: int | None = None, ): l2p = lambda *specs: logical_to_physical(specs, cfg.rules) @@ -694,41 +806,61 @@ def attention_block( with jax.named_scope("rope"): q, k = apply_rotary_embedding(q, sin, cos), apply_rotary_embedding(k, sin, cos) + if cfg.quant_cache: + k = QuantArray( + *quantize(k, -1, scale_dtype=cfg.quant_scale_dtype), out_scaling=True, scale_expand_dims=(-2, -3) + ) + v = QuantArray( + *quantize(v, -1, scale_dtype=cfg.quant_scale_dtype), out_scaling=False, scale_expand_dims=(-2, -3) + ) + with jax.named_scope("cache_update"): - if cache is not None: - k = update_slice(cache.k[idx], k, cache.length, update_axis=cache.time_axis, quant_axis=-1) - v = update_slice(cache.v[idx], v, cache.length, update_axis=cache.time_axis, quant_axis=-1) - time_indices = jnp.arange(0, v.shape[cache.time_axis])[None, :] # [1, T] + paged_state, starts = None, None + if is_type(cache, KVCache): + it = jnp.maximum(cache.iter, 0) + k = update_slice(cache.k[idx], k, it, update_axis=cache.time_axis, quant_axis=-1) + v = update_slice(cache.v[idx], v, it, update_axis=cache.time_axis, quant_axis=-1) + time_indices = ( + jnp.arange(0, v.shape[cache.time_axis])[None, :] - cache.starts[:, None] + ) % cache.size # [B, T] q_segment_ids = jnp.where(segment_ids != 0, 1, 0) incremental_position = jnp.max(_length_minus_padding(segment_ids)) # i.e. valid below where we've written things [B, T] - k_segment_ids = ( - (time_indices >= cache.starts[:, None]) & (time_indices < (cache.length + incremental_position)) + kv_segment_ids = ( + (time_indices >= 0) & (time_indices < cache.fill_len()[:, None] + incremental_position) ).astype(jnp.int32) - - q_offset = cache.length[None] - starts, lengths = cache.starts, (cache.length + incremental_position)[None] + q_offset = cache.fill_len() - _count_left_padding(segment_ids) + starts, lengths = cache.starts, cache.fill_len() + cache_updates = (k, v) + elif is_type(cache, PagedKVCache): + cache: PagedKVCache + k, v, paged_state = PagedKVCache.update_slice(cache, k=k, v=v, layer_idx=idx) + cache_updates = (k, v, paged_state) else: - q_segment_ids, k_segment_ids = segment_ids, segment_ids + # this supports prefill only; no support for a ring cache buffer here + q_segment_ids, kv_segment_ids = segment_ids, segment_ids q_offset = jnp.zeros(x.shape[0], dtype=jnp.int32) - starts, lengths = _count_left_padding(k_segment_ids, 0), _length_minus_padding(k_segment_ids) + starts, lengths = _count_left_padding(segment_ids, 0), _length_minus_padding(kv_segment_ids) + cache_updates = (k, v) # Compute attention with jax.named_scope("attention"): - if (cfg.use_prefill_attn_kernel and q.shape[-2] != 1) or (cfg.use_decode_attn_kernel and q.shape[-2] == 1): + if is_type(cache, PagedKVCache): + attn_out = paged_attention_kernel(q, k, v, paged_state["block_tables"], paged_state["lengths"], cfg) + elif (cfg.use_prefill_attn_kernel and q.shape[-2] != 1) or (cfg.use_decode_attn_kernel and q.shape[-2] == 1): attn_out = attention_kernel( - q, k, v, q_segment_ids, k_segment_ids, q_offset, starts=starts, lengths=lengths, cfg=cfg + q, k, v, q_segment_ids, kv_segment_ids, q_offset, starts=starts, lengths=lengths, cfg=cfg ) else: - attn_out = attention(q, k, v, q_segment_ids, k_segment_ids, q_offset, cfg) + attn_out = attention(q, k, v, q_segment_ids, kv_segment_ids, q_offset, starts, cfg) # Project attention output with jax.named_scope("projection"): attn_out = einsum( "bhtq,hqd->btd", attn_out, layer.o, out_sharding=l2p("batch", "sequence", "act_embed") ).astype(cfg.dtype) - return attn_out, k, v + return attn_out, cache_updates def ffn_block(x: jax.Array, layer: Layer, cfg: Config): @@ -753,14 +885,14 @@ def forward_layer( cos: jax.Array, idx: int, cfg: Config, - cache: KVCache | None = None, + cache: KVCache | PagedKVCache | None = None, ) -> tuple[jax.Array, jax.Array, jax.Array]: x = x.astype(cfg.dtype) # Attention block with jax.named_scope("attn_pre_norm"): attn_in = rms_norm(x, layer.attn_pre_gamma) - attn_out, k, v = attention_block(attn_in, segment_ids, layer, sin, cos, cfg, cache, idx) + attn_out, cache_updates = attention_block(attn_in, segment_ids, layer, sin, cos, cfg, cache, idx) with jax.named_scope("residual"): x = x + attn_out.astype(cfg.dtype) @@ -772,7 +904,7 @@ def forward_layer( with jax.named_scope("residual"): x = x + ff_out.astype(cfg.dtype) - return x, k, v + return x, cache_updates def forward( @@ -780,38 +912,42 @@ def forward( segment_ids: jax.Array, weights: Weights, cfg: Config, - cache: KVCache | None = None, + cache: KVCache | PagedKVCache | None = None, ): l2p = lambda *args: logical_to_physical(args, cfg.rules) # Embed input tokens [B, T] -> [B, T D] x = weights.embedding.at[x, :].get(out_sharding=l2p("batch", "sequence", "act_embed")) - batch = x.shape[0] - positions = segment_ids_to_positions(segment_ids) + positions = segment_ids_to_positions(segment_ids) # already shifted by padding # Apply rotary embeddings: [B, T, head_dim] if cache is not None: # For inference with cache, we need to index the positional embeddings - start_indices = jnp.where(cache.length != 0, cache.length - cache.starts, 0) - else: - start_indices = jnp.zeros((batch,), dtype=jnp.int32) + positions = cache.fill_len()[:, None] + positions # NOTE: At inference time this only works for UNPACKED sequences. - positions = start_indices[:, None] + positions - # [B, T, head_dim] - sin, cos = _generate_pos_embeddings(positions, cfg.head_dim, cfg) + sin, cos = _generate_pos_embeddings(positions, cfg.head_dim, cfg) # [B, T, head_dim] sin, cos = sin.astype(cfg.dtype), cos.astype(cfg.dtype) + all_cache_updates = [] for idx, layer in enumerate(weights.layers): - x, k, v = forward_layer(x, segment_ids, layer, sin, cos, idx, cfg, cache) - cache.k[idx], cache.v[idx] = k, v + x, cache_updates = forward_layer(x, segment_ids, layer, sin, cos, idx, cfg, cache) + all_cache_updates.append(cache_updates) # Final layer norm. x = rms_norm(x, weights.gamma_final) + # Project to vocabulary size logits = einsum("btd,dv->btv", x, weights.lm_head) - if cache is not None: - # Sum where there is a valid segment id (i.e. non padding tokens) [B, T] -> [B,] - cache = dataclasses.replace(cache, length=cache.length + jnp.max(_length_minus_padding(segment_ids))) + + if is_type(cache, KVCache): + cache.k, cache.v = [z[0] for z in all_cache_updates], [z[1] for z in all_cache_updates] + new_iter = (jnp.maximum(0, cache.iter) + jnp.max(_length_minus_padding(segment_ids))) % cache.size + cache = dataclasses.replace(cache, iter=new_iter) + return logits, cache + elif is_type(cache, PagedKVCache): + kv, new_state = tuple(map(list, zip(*[z[:2] for z in all_cache_updates]))), all_cache_updates[0][2] + cache = dataclasses.replace(cache, k=kv[0], v=kv[1], **new_state) return logits, cache - return logits + else: + return logits, all_cache_updates # serialization @@ -844,26 +980,41 @@ def prepare_chunk(chunk, pad_to: int, pad_id: int): return chunk, segment_ids -def prefill(tokens: jax.Array, weights: Weights, cache: KVCache, cfg: Config, pad_id: int = 0): +def prefill(tokens: jax.Array, weights: Weights, cache: KVCache | None, cfg: Config, pad_id: int = 0): """Samples from a prompt.""" # Calculate the next power of 2 for padding, up to cfg.max_seq. assert tokens.shape[-1] <= cfg.max_seq_len pad_to = 2 ** math.ceil(math.log2((tokens.shape[-1]))) - with use_mesh(cfg.mesh): - prompt, prompt_segment_ids = prepare_chunk(tokens, pad_to=pad_to, pad_id=pad_id) - assert prompt.ndim == 2 - cache_shardings = KVCache.shardings(cfg, cache.k[0].shape[0], cache.k[0].shape[cache.time_axis]) - logits_shardings = jax.sharding.NamedSharding(cfg.mesh, P(BATCH_AXIS_NAME, None, TENSOR_AXIS_NAME)) + prompt, prompt_segment_ids = prepare_chunk(tokens, pad_to=pad_to, pad_id=pad_id) + assert prompt.ndim == 2 + + logits_shardings = jax.sharding.NamedSharding(cfg.mesh, P(BATCH_AXIS_NAME, None, TENSOR_AXIS_NAME)) + cache_shardings = KVCache.shardings(cfg, cache.batch_size if cache is not None else tokens.shape[0]) + if is_type(cache, KVCache): cache = dataclasses.replace( - cache, length=jnp.zeros_like(cache.length), starts=_count_left_padding(tokens, pad_id=pad_id) + cache, + iter=-jnp.ones_like(cache.iter), + starts=_count_left_padding(tokens, pad_id=pad_id), ) logits, cache = jax.jit(forward, donate_argnums=(4,), out_shardings=(logits_shardings, cache_shardings))( prompt, prompt_segment_ids, weights, cfg, cache ) - next_tokens = jax.jit(jnp.argmax, static_argnames=("axis",))(logits, axis=-1) - return next_tokens, logits, cache + elif is_type(cache, PagedKVCache): + raise ValueError("Prefill with Paged KV Cache is not currently supported.") + else: + cache_shardings = KVCache.shardings(dataclasses.replace(cfg), tokens.shape[0]) + kv_sharding = [(cache_shardings.k[idx], cache_shardings.v[idx]) for idx in range(cfg.num_layers)] + logits, kv_list = jax.jit(forward, out_shardings=(logits_shardings, kv_sharding))( + prompt, prompt_segment_ids, weights, cfg, None + ) + cache = kv_list + next_tokens = jax.jit(jnp.argmax, static_argnames=("axis",))(logits, axis=-1) + return next_tokens, logits, cache + + +prefill.forward = forward @partial(jax.jit, donate_argnames=("cache",)) diff --git a/llama3/llama3_jax/ragged_attention.py b/llama3/llama3_jax/ragged_attention.py index 4a0b6e3..c5eb83d 100644 --- a/llama3/llama3_jax/ragged_attention.py +++ b/llama3/llama3_jax/ragged_attention.py @@ -316,17 +316,7 @@ def test_main(interpret=False): mesh = jax.make_mesh((jax.device_count(),), ("x",)) @partial(jax.jit, static_argnames=("which", "block_kv", "block_bs")) - def fn( - q, - k, - v, - starts, - lengths, - qk_prev=None, - which: str = "pallas", - block_kv: int = 128, - block_bs: int = 8 - ): + def fn(q, k, v, starts, lengths, qk_prev=None, which: str = "pallas", block_kv: int = 128, block_bs: int = 8): k, k_scale = k if isinstance(k, tuple) else (k, None) v, v_scale = v if isinstance(v, tuple) else (v, None) kv_heads = k.shape[1] @@ -412,5 +402,6 @@ def _fn(q, k, v, starts, lengths, k_scale, v_scale, qk_prev): err = jnp.mean(err, -1) print(f"{err = }") + if __name__ == "__main__": test_main() diff --git a/llama3/pyproject.toml b/llama3/pyproject.toml index 116963a..5fa6905 100644 --- a/llama3/pyproject.toml +++ b/llama3/pyproject.toml @@ -6,17 +6,17 @@ authors = [ { name = "Robert Dyro" }, ] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "Apache-2.0" } dependencies = [ "jax", "torch", - "transformers", # for the model config and the tokenizer + #"transformers", # for the model config and the tokenizer "tqdm", "numpy", - "orbax-checkpoint", - "datasets", + #"orbax-checkpoint", + #"datasets", "gcsfs", "etils", ] diff --git a/serving/README.md b/serving/README.md new file mode 100644 index 0000000..a1e6fb1 --- /dev/null +++ b/serving/README.md @@ -0,0 +1,3 @@ +# Minimal Serving in JAX + +Work in progress. diff --git a/serving/client_demo.py b/serving/client_demo.py new file mode 100644 index 0000000..431702b --- /dev/null +++ b/serving/client_demo.py @@ -0,0 +1,1511 @@ +#!/usr/bin/env python3 +import threading +import time +import requests +from typing import List +import textwrap +import sys +import gzip +import json +import base64 +from pathlib import Path + +from rich.live import Live +from rich.panel import Panel +from rich.layout import Layout +from rich.console import Console +from rich.text import Text + +import numpy as np + +# --- Configuration --- +SERVER_URL = "http://localhost:8081" + +def fetch_stream(request_id: int, prompt_text: str): + """ + Worker function to fetch a single stream from the server in a thread. + """ + payload = {"id": request_id, "text": prompt_text} + headers = {"accept": "application/json", "Content-Type": "application/json"} + global responses, responses_lock, responses_done + + try: + t_first, t_start = None, time.perf_counter() + with requests.get( + SERVER_URL + "/stream", + headers=headers, + json=payload, + stream=True, + timeout=5 * 60, # 5 minutes + ) as response: + response.raise_for_status() + for i, chunk in enumerate(response.iter_content(chunk_size=None, decode_unicode=True)): + if i == 0: + t_first = time.perf_counter() + with open("log.txt", "a") as f: + f.write(f"Time to first response: {t_first - t_start:.4e} s\n") + if chunk: + with responses_lock: + responses[request_id]["response"] += chunk + t_stop = time.perf_counter() + with open("log.txt", "a") as f: + f.write(f"Decode time: {t_stop - t_first:.4e} s\n") + f.write(f"Total time: {t_stop - t_start:.4e} s\n") + + except requests.exceptions.RequestException as e: + with responses_lock: + responses[request_id]["status"] = f"[bold red]Error: {e}[/]" + finally: + with responses_lock: + if responses[request_id]["status"] == "Running...": + responses[request_id]["status"] = "[bold green]Done[/]" + responses_done[request_id] = True + + +def generate_layout() -> Layout: + """Creates the Rich Layout for the TUI.""" + layout, panels = Layout(name="root"), [] + panels = [] + HEIGHT, WIDTH = 20, 80 + global responses_lock, responses + + with responses_lock: + for i, data in responses.items(): + # --- Scrolling/Truncation Logic --- + full_response_text = data["response"] + display_text = full_response_text + new_text = textwrap.wrap(display_text, width=WIDTH - 4) + new_text = "\n".join(new_text[-(HEIGHT - 2) :]) + + # Use Text to prevent re-parsing of Rich markup and handle wrapping + # response_content = Text(full_response_text, no_wrap=True, overflow="fold") + response_content = Text(new_text, no_wrap=True, overflow="crop") + + # Determine border style based on status + border_style = "red" if "Error" in data["status"] else "yellow" + border_style = "green" if "Done" in data["status"] else border_style + + opts = dict(subtitle_align="left", border_style=border_style, height=HEIGHT, width=WIDTH) + panels.append( + Panel(response_content, title=f"Request #{i}", subtitle=f"Prompt: {data['prompt'][:50]}...", **opts) + ) + layout.split_column(*panels) + return layout + + +def profile_issue(): + headers = {"accept": "application/json", "Content-Type": "application/json"} + server_url = SERVER_URL + "/profile" + requests.get(server_url, headers=headers) + + +def set_generation_length(length: int): + headers = {"accept": "application/json", "Content-Type": "application/json"} + server_url = SERVER_URL + "/set_generation_length" + requests.get(server_url, headers=headers, params={"length": length}) + + +def retrieve(id: int): + headers = {"accept": "application/json", "Content-Type": "application/json"} + server_url = SERVER_URL + "/retrieve" + print(requests.get(server_url, headers=headers, params={"id": id})) + + +def investigate(id: int): + headers = {"accept": "application/json", "Content-Type": "application/json"} + server_url = SERVER_URL + "/investigate" + requests.get(server_url, headers=headers, params={"id": id}) + + +def main(): + global responses, MAX_PANEL_LINES, responses_lock, responses_done + all_prompts = get_prompts() + prompts_num = 18 + idxs = np.random.randint(0, len(all_prompts) - 1, prompts_num) + PROMPTS = [all_prompts[idx] for idx in idxs] + + # This controls the "scrolling" effect. It's the max number of lines + # displayed in a panel. When text exceeds this, only the latest lines are shown. + MAX_PANEL_LINES = 15 + # --------------------- + + # A thread-safe dictionary to store the state of each streaming response. + # The structure will be: { request_id: {"prompt": str, "response": str, "status": str} } + GLOBAL_ID = time.time_ns() % 2**30 + responses = { + (GLOBAL_ID + i): {"prompt": prompt, "response": "", "status": "Running..."} for i, prompt in enumerate(PROMPTS) + } + responses_done = {k: False for k in responses.keys()} + responses_lock = threading.Lock() + console = Console() + + if len(sys.argv) > 1 and sys.argv[1] == "profile": + profile_issue() + return + + if len(sys.argv) > 2 and sys.argv[1] == "investigate": + investigate(int(sys.argv[2])) + return + + set_generation_length(64 if len(sys.argv) <= 1 else int(sys.argv[1])) + threads = [] + console.print("[bold cyan]Starting streaming requests... Press Ctrl+C to exit.[/]") + time.sleep(1) # Give user time to read the message + + for i, prompt in enumerate(PROMPTS): + thread = threading.Thread(target=fetch_stream, args=(GLOBAL_ID + i, prompt)) + threads.append(thread) + thread.start() + + layout = None + with Live(generate_layout(), screen=True, redirect_stderr=False, transient=True) as live: + try: + # This loop now runs indefinitely until the user presses Ctrl+C. + # It will continue to update the TUI even after all threads are done. + # while True: + while not all(v for v in responses_done.values()): + layout = generate_layout() + live.update(layout) + time.sleep(0.1) # Refresh rate + with open("log.txt", "a") as f: + for id, response in responses.items(): + f.write(f"RESPONSE {id}: ------{response['prompt']}\n") + f.write(response["response"]) + f.write("\n") + + layout = generate_layout() + live.update(layout) + time.sleep(0.1) # Refresh rate + except KeyboardInterrupt: + # This block allows for a graceful exit. + console.print("\n[bold orange3]Exiting gracefully...[/]") + + # Optional: Join threads on exit to ensure clean shutdown, though the program + # will exit anyway. This is good practice if threads were daemonic. + for thread in threads: + thread.join(timeout=0.1) + + with open("log.txt", "a") as f: + f.write("-" * 80 + "\n") + + console.print("[bold cyan]Application terminated.[/]") + console.print(layout) + + +def get_prompts(): + return [ + "Analyze the ethical implications of AI-generated art. Discuss copyright.", + "Summarize the plot of '1984' by George Orwell concisely.", + "Explain quantum entanglement for a high school student. Keep it simple.", + "Describe the characteristics of a healthy forest ecosystem. Mention biodiversity.", + "Write a haiku about a blooming cherry blossom tree.", + "Compare and contrast renewable and non-renewable energy sources.", + "List three major causes of the French Revolution. Be brief.", + "Define 'epistemology' and provide a simple example of its application.", + "Imagine a world without internet. How would daily life change?", + "Discuss the benefits of regular exercise on mental health.", + "Outline the basic steps of brewing a cup of coffee.", + "What is blockchain technology? Explain it in layman's terms.", + "Describe the typical diet of a prehistoric Neanderthal. Be speculative.", + "Write a short, optimistic sentence about the future of humanity.", + "Explain the concept of 'net neutrality' and its importance.", + "List two advantages and two disadvantages of remote work.", + "Summarize the Socratic method of questioning. Keep it concise.", + "Describe the feeling of awe when gazing at the night sky.", + "What is the Turing Test? How does it evaluate AI?", + "Write a short, humorous sentence about a self-aware toaster.", + "Explain the difference between weather and climate. Give examples.", + "List three common cybersecurity threats. How can they be mitigated?", + "Describe the sound of rain falling on a tin roof.", + "What are the primary functions of the human liver? Briefly explain.", + "Imagine a conversation between a dog and a cat. What do they discuss?", + "Explain the concept of 'opportunity cost' in economics.", + "List three famous landmarks in Rome. Name their historical significance.", + "Describe the taste of fresh-baked bread right out of the oven.", + "What is photosynthesis? Explain its importance to life.", + "Write a short, inspiring sentence about overcoming adversity.", + "Explain the difference between a simile and a metaphor.", + "List three essential tools for a basic woodworking project.", + "Describe the feeling of sand between your toes at the beach.", + "What is the scientific method? Outline its core steps.", + "Imagine a new color. Describe its appearance and emotional impact.", + "Explain the concept of 'supply and demand' in a market.", + "List three benefits of learning a new language.", + "Describe the smell of a pine forest after a summer rain.", + "What are black holes? Explain their formation briefly.", + "Write a short, whimsical sentence about a cloud shaped like an animal.", + "Explain the concept of 'biodiversity hotspots' and their importance.", + "List three major inventions from the Industrial Revolution.", + "Describe the sound of wind chimes on a breezy afternoon.", + "What is natural selection? Explain its role in evolution.", + "Imagine a city powered entirely by renewable energy. Describe it.", + "Explain the 'butterfly effect' in chaos theory. Give an example.", + "List three key features of a well-written short story.", + "Describe the feeling of excitement before a big adventure.", + "What is the difference between a virus and bacteria?", + "Write a short, thought-provoking question about the nature of reality.", + "Explain the concept of 'cognitive bias.' Provide an example.", + "List three common ingredients in traditional Italian pasta sauce.", + "Describe the view from the top of a very tall mountain.", + "What is the internet of things (IoT)? Give a simple example.", + "Imagine a world where dreams are shared. How would society change?", + "Explain the 'tragedy of the commons' in environmental science.", + "List three benefits of reading fiction regularly.", + "Describe the smell of freshly cut grass on a summer day.", + "What is artificial intelligence (AI)? Briefly explain its goals.", + "Write a short, funny sentence about a talking squirrel.", + "Explain the difference between a fact and an opinion.", + "List three different types of clouds and their typical weather.", + "Describe the feeling of warmth from a crackling fireplace.", + "What is climate change? List two of its major impacts.", + "Imagine a new musical instrument. Describe its sound and appearance.", + "Explain the concept of 'confirmation bias' and its dangers.", + "List three important figures from ancient Greek philosophy.", + "Describe the sound of waves crashing gently on a sandy shore.", + "What is a supernova? Briefly explain its significance.", + "Write a short, inspiring sentence about the power of knowledge.", + "Explain the difference between a planet and a dwarf planet.", + "List three common symptoms of a cold or flu.", + "Describe the taste of a ripe, juicy strawberry.", + "What is cryptography? Explain its basic purpose.", + "Imagine a day when gravity briefly disappears. What happens?", + "Explain the concept of 'digital footprint' and its implications.", + "List three advantages of living in a small town.", + "Describe the smell of freshly brewed coffee on a cold morning.", + "What is a galaxy? Briefly describe its components.", + "Write a short, whimsical sentence about a mischievous elf.", + "Explain the difference between 'weather' and 'climate' in simple terms.", + "List three common types of trees found in temperate forests.", + "Describe the feeling of peacefulness in a quiet library.", + "What is the ozone layer? Why is it important?", + "Imagine a machine that translates animal thoughts into human language.", + "Explain the concept of 'sustainable development.' Give an example.", + "List three benefits of volunteering in your community.", + "Describe the sound of a distant train whistle at night.", + "What is DNA? Briefly explain its function.", + "Write a short, humorous sentence about a clumsy robot.", + "Explain the difference between a democratic and autocratic government.", + "List three healthy food choices for a balanced breakfast.", + "Describe the smell of rain on hot pavement.", + "What is a neuron? Briefly explain its role.", + "Imagine a device that allows you to instantly learn any skill.", + "Explain the concept of 'cultural relativism' concisely.", + "List three major rivers in the world and their continents.", + "Describe the feeling of cool water on a hot day.", + "What is quantum computing? Explain it simply.", + "Write a short, optimistic sentence about technological progress.", + "Explain the difference between a hypothesis and a theory.", + "List three benefits of meditation for mental well-being.", + "Describe the sound of birds chirping at dawn.", + "What is the Big Bang theory? Explain its core idea.", + "Imagine a plant that glows in the dark. Describe its appearance.", + "Explain the concept of 'echo chambers' in online communities.", + "List three famous explorers and their discoveries.", + "Describe the taste of dark chocolate melting in your mouth.", + "What is a volcano? Briefly explain its eruption process.", + "Write a short, inspiring sentence about embracing change.", + "Explain the difference between empathy and sympathy.", + "List three common ingredients in a classic Margherita pizza.", + "Describe the feeling of soft grass under bare feet.", + "What is the human microbiome? Briefly explain its importance.", + "Imagine a world where everyone can fly. How does society adapt?", + "Explain the concept of 'confirmation bias' with an example.", + "List three major events in the American Civil War.", + "Describe the sound of laughter echoing through a room.", + "What is gravity? Briefly explain its effect.", + "Write a short, whimsical sentence about a talking teapot.", + "Explain the difference between 'data' and 'information.'", + "List three common symptoms of a common cold.", + "Describe the smell of freshly baked bread. Delicious!", + "What is a satellite? Briefly explain its purpose.", + "Imagine a personal AI assistant with a sense of humor.", + "Explain the 'Pareto principle' (80/20 rule) simply.", + "List three benefits of regular physical activity.", + "Describe the feeling of warmth from a cozy blanket.", + "What is virtual reality (VR)? Give an example.", + "Write a short, inspiring sentence about perseverance.", + "Explain the difference between 'fact' and 'opinion' clearly.", + "List three major types of renewable energy.", + "Describe the sound of ocean waves gently crashing.", + "What is a black hole? Explain its basic nature.", + "Imagine a tree that grows money. What would happen?", + "Explain the concept of 'cognitive dissonance' with an example.", + "List three benefits of mindful eating.", + "Describe the smell of rain on a hot summer night.", + "What is the Internet of Things? Give a simple example.", + "Write a short, funny sentence about a robot failing.", + "Explain the difference between 'argument' and 'debate.'", + "List three common spices used in Indian cuisine.", + "Describe the taste of a crisp, green apple.", + "What is a galaxy? Briefly describe its components.", + "Imagine a world without money. How would transactions occur?", + "Explain the concept of 'herd immunity' simply.", + "List three famous landmarks in Paris.", + "Describe the sound of crickets chirping on a summer evening.", + "What is a neuron? Explain its basic function.", + "Write a short, optimistic sentence about human ingenuity.", + "Explain the difference between 'privacy' and 'security.'", + "List three benefits of learning to play a musical instrument.", + "Describe the feeling of soft fur on a friendly pet.", + "What is a fossil? Briefly explain its formation.", + "Imagine a portal that leads to any past time period.", + "Explain the concept of 'supply chain' concisely.", + "List three common types of clouds.", + "Describe the smell of freshly cut grass.", + "What is climate change? List two causes.", + "Write a short, inspiring sentence about embracing challenges.", + "Explain the difference between 'equity' and 'equality.'", + "List three major historical events in the 20th century.", + "Describe the sound of children laughing joyfully.", + "What is photosynthesis? Briefly explain its importance.", + "Imagine a pet that can talk. What would it say?", + "Explain the 'halo effect' in social psychology.", + "List three benefits of getting enough sleep.", + "Describe the taste of sweet, ripe mango.", + "What is DNA? Explain its primary role.", + "Write a short, whimsical sentence about a flying pig.", + "Explain the difference between a 'virus' and a 'worm.'", + "List three common ingredients in a Caesar salad.", + "Describe the feeling of cold snow on your skin.", + "What is the Big Bang theory? Briefly explain it.", + "Imagine a robot designed for companionship. Describe it.", + "Explain the concept of 'confirmation bias' simply.", + "List three major world religions.", + "Describe the sound of gentle waves on a lake.", + "What is a black hole? Explain its basic properties.", + "Write a short, optimistic sentence about space exploration.", + "Explain the difference between 'qualitative' and 'quantitative' data.", + "List three benefits of spending time in nature.", + "Describe the smell of pine needles in a forest.", + "What is artificial intelligence? Briefly define it.", + "Imagine a device that lets you visit your own dreams.", + "Explain the 'Dunning-Kruger effect' with an example.", + "List three essential items for a hiking trip.", + "Describe the taste of bitter dark coffee.", + "What is blockchain technology? Explain its core principle.", + "Write a short, funny sentence about a clumsy wizard.", + "Explain the difference between 'responsibility' and 'accountability.'", + "List three common uses of renewable energy.", + "Describe the sound of a babbling brook.", + "What is the atmosphere? Briefly explain its layers.", + "Imagine a world without borders. How would it function?", + "Explain the concept of 'cognitive load' simply.", + "List three major components of a healthy diet.", + "Describe the feeling of warm sunshine on your face.", + "What is photosynthesis? Explain its main purpose.", + "Write a short, inspiring sentence about collective action.", + "Explain the difference between 'anarchy' and 'democracy.'", + "List three famous landmarks in New York City.", + "Describe the sound of sizzling bacon.", + "What is a galaxy? Briefly describe its formation.", + "Imagine a plant that can communicate. What would it say?", + "Explain the 'observer effect' in quantum physics (simplified).", + "List three benefits of learning basic first aid.", + "Describe the smell of cinnamon rolls baking.", + "What is the ozone layer? Briefly explain its function.", + "Write a short, whimsical sentence about a talking hat.", + "Explain the difference between 'data mining' and 'machine learning.'", + "List three common types of exercise.", + "Describe the feeling of cool silk on your skin.", + "What is gravity? Briefly explain its discovery.", + "Imagine a book that changes its story based on your mood.", + "Explain the concept of 'gamification' with an example.", + "List three major innovations of the Renaissance.", + "Describe the sound of rain tapping on a window.", + "What is a neuron? Explain its basic communication.", + "Write a short, optimistic sentence about peace.", + "Explain the difference between 'capitalism' and 'socialism.'", + "List three common materials used in construction.", + "Describe the taste of a sour lemon.", + "What is the Internet? Briefly explain its origins.", + "Imagine a world where colors have sounds. Describe a symphony.", + "Explain the 'butterfly effect' in daily life examples.", + "List three benefits of being a good listener.", + "Describe the smell of a campfire in the woods.", + "What is a volcano? Explain its dangers briefly.", + "Write a short, funny sentence about a dog driving.", + "Explain the difference between 'prejudice' and 'discrimination.'", + "List three major events of World War II.", + "Describe the sound of wind rustling through leaves.", + "What is DNA? Briefly explain its structure.", + "Imagine a future where robots do all housework.", + "Explain the concept of 'confirmation bias' in media.", + "List three benefits of daily journaling.", + "Describe the taste of sweet, creamy ice cream.", + "What is AI? Briefly define its subfields.", + "Write a short, inspiring sentence about embracing failure.", + "Explain the difference between 'gross' and 'net' income.", + "List three common types of human emotions.", + "Describe the feeling of walking barefoot on dewy grass.", + "What is a satellite? Explain its orbital mechanics simply.", + "Imagine a device that records and plays back smells.", + "Explain the 'fundamental attribution error' simply.", + "List three major characteristics of a desert ecosystem.", + "Describe the sound of church bells ringing.", + "What is the Big Bang theory? Explain its evidence.", + "Write a short, whimsical sentence about a dragon librarian.", + "Explain the difference between 'ethics' and 'morality.'", + "List three benefits of learning to code.", + "Describe the smell of baking cookies.", + "What is climate change? List its main human causes.", + "Imagine a world where plants can move. What happens?", + "Explain the concept of 'cognitive load' in learning.", + "List three essential tools for a gardener.", + "Describe the taste of a juicy, grilled steak.", + "What is photosynthesis? Briefly explain its chemical process.", + "Write a short, optimistic sentence about global cooperation.", + "Explain the difference between 'data privacy' and 'data security.'", + "List three common types of birds in your region.", + "Describe the feeling of deep relaxation after a long day.", + "What is a black hole? Explain its event horizon.", + "Imagine a pet rock that responds to your thoughts.", + "Explain the 'Pareto principle' in time management.", + "List three benefits of practicing gratitude.", + "Describe the sound of a distant thunderstorm approaching.", + "What is the Internet of Things? Give a practical example.", + "Write a short, funny sentence about a shy ghost.", + "Explain the difference between 'deductive' and 'inductive' reasoning.", + "List three common spices used in Mexican cuisine.", + "Describe the taste of fresh, sweet watermelon.", + "What is a galaxy? Briefly explain different types.", + "Imagine a world where gravity is optional. How would sports change?", + "Explain the concept of 'herd mentality' in groups.", + "List three famous scientists and their contributions.", + "Describe the sound of a crackling fire on a cold night.", + "What is a neuron? Explain its parts simply.", + "Write a short, inspiring sentence about the power of imagination.", + "Explain the difference between 'conscious' and 'subconscious.'", + "List three benefits of eating vegetables daily.", + "Describe the smell of a new book.", + "What is the ozone layer? Why is it depleting?", + "Imagine a personal assistant that anticipates your needs.", + "Explain the 'bystander effect' with a simple example.", + "List three major challenges facing humanity today.", + "Describe the taste of a tangy, juicy orange.", + "What is DNA? Explain its role in heredity.", + "Write a short, whimsical sentence about a talking fish.", + "Explain the difference between 'stress' and 'anxiety.'", + "List three common ingredients in a vegetable stir-fry.", + "Describe the feeling of warm sand on a sunny beach.", + "What is the Big Bang theory? Explain its expansion.", + "Imagine a device that allows you to control your dreams.", + "Explain the concept of 'confirmation bias' in political views.", + "List three benefits of mindfulness exercises.", + "Describe the sound of laughter filling a room.", + "What is gravity? Briefly explain its influence on tides.", + "Write a short, optimistic sentence about scientific discovery.", + "Explain the difference between 'inflation' and 'deflation.'", + "List three common types of trees in urban areas.", + "Describe the smell of freshly brewed coffee beans.", + "What is artificial intelligence? Give a real-world example.", + "Imagine a world where emotions are visible colors.", + "Explain the 'Dunning-Kruger effect' in everyday skills.", + "List three essential items for a picnic.", + "Describe the taste of a freshly baked cookie.", + "What is blockchain technology? Explain its security.", + "Write a short, funny sentence about a confused alien.", + "Explain the difference between 'theory' and 'law' in science.", + "List three common uses of solar energy.", + "Describe the sound of a distant ambulance siren.", + "What is the atmosphere? Briefly explain its composition.", + "Imagine a society where kindness is currency.", + "Explain the concept of 'cognitive bias' in decision-making.", + "List three major challenges of space travel.", + "Describe the feeling of fresh air after a storm.", + "What is photosynthesis? Explain its byproduct.", + "Write a short, inspiring sentence about courage.", + "Explain the difference between 'democracy' and 'republic.'", + "List three famous structures from ancient Egypt.", + "Describe the sound of ocean waves during a storm.", + "What is a black hole? Explain its singularity.", + "Imagine a pet that can change its form.", + "Explain the 'Pareto principle' in customer service.", + "List three benefits of outdoor activities.", + "Describe the smell of a damp forest floor.", + "What is the Internet of Things? Explain its connectivity.", + "Write a short, whimsical sentence about a unicorn scientist.", + "Explain the difference between 'objective' and 'subjective.'", + "List three common types of pollution.", + "Describe the taste of a hot, savory soup.", + "What is a galaxy? Briefly explain its life cycle.", + "Imagine a world where animals can vote.", + "Explain the concept of 'groupthink' in decision making.", + "List three famous composers of classical music.", + "Describe the sound of wind chimes in a gentle breeze.", + "What is a neuron? Briefly explain its signal transmission.", + "Write a short, optimistic sentence about human potential.", + "Explain the difference between 'copyright' and 'trademark.'", + "List three benefits of learning history.", + "Describe the feeling of accomplishment after finishing a task.", + "What is DNA? Explain its code for life.", + "Imagine a mirror that shows your past lives.", + "Explain the 'halo effect' in job interviews.", + "List three major historical figures from the American Revolution.", + "Describe the sound of leaves crunching underfoot.", + "What is climate change? List two solutions.", + "Write a short, funny sentence about a very slow snail.", + "Explain the difference between 'data encryption' and 'data backup.'", + "List three common types of vegetables.", + "Describe the smell of a freshly mown lawn.", + "What is a satellite? Explain its uses in communication.", + "Imagine a city built entirely underground.", + "Explain the 'bystander effect' in emergency situations.", + "List three benefits of having a positive attitude.", + "Describe the taste of a warm, gooey brownie.", + "What is AI? Briefly explain its learning process.", + "Write a short, inspiring sentence about unity.", + "Explain the difference between 'analogy' and 'metaphor.'", + "List three common types of fruits.", + "Describe the feeling of calm in a quiet forest.", + "What is the Big Bang theory? Explain its implications.", + "Imagine a world where thoughts are telepathic.", + "Explain the concept of 'confirmation bias' in science.", + "List three benefits of practicing mindfulness daily.", + "Describe the sound of a distant train passing by.", + "What is gravity? Briefly explain its effect on space-time.", + "Write a short, optimistic sentence about the future of education.", + "Explain the difference between 'explicit' and 'implicit' memory.", + "List three common forms of transportation.", + "Describe the smell of a warm, cozy bakery.", + "What is photosynthesis? Explain its role in ecosystems.", + "Imagine a pet that changes color with its mood.", + "Explain the 'Pareto principle' in sales.", + "List three benefits of critical thinking.", + "Describe the sound of a busy city street.", + "What is the Internet of Things? Explain its impact.", + "Write a short, whimsical sentence about a talking flower.", + "Explain the difference between 'absolute' and 'relative' poverty.", + "List three common types of renewable energy sources.", + "Describe the taste of a refreshing glass of water.", + "What is a black hole? Explain its event horizon simply.", + "Imagine a device that allows you to relive memories.", + "Explain the concept of 'cognitive dissonance' in choices.", + "List three famous rivers in South America.", + "Describe the sound of a gentle rainfall.", + "What is a neuron? Briefly explain its function in the brain.", + "Write a short, inspiring sentence about overcoming fear.", + "Explain the difference between 'passion' and 'hobby.'", + "List three common ingredients in a homemade sandwich.", + "Describe the feeling of warmth from a sunny window.", + "What is DNA? Explain its role in evolution.", + "Imagine a world where animals wear clothes.", + "Explain the 'halo effect' in consumer behavior.", + "List three major events of the Cold War.", + "Describe the sound of leaves rustling in a gentle breeze.", + "What is climate change? List its environmental impacts.", + "Write a short, funny sentence about a robot chef.", + "Explain the difference between 'router' and 'modem.'", + "List three common types of birds.", + "Describe the smell of fresh-cut pine wood.", + "What is a satellite? Explain its use in navigation.", + "Imagine a house that cleans itself automatically.", + "Explain the 'bystander effect' in a workplace setting.", + "List three benefits of practicing gratitude daily.", + "Describe the taste of sweet, juicy berries.", + "What is AI? Briefly explain its impact on jobs.", + "Write a short, optimistic sentence about innovation.", + "Explain the difference between 'simile' and 'analogy' concisely.", + "List three common types of desserts.", + "Describe the feeling of a cool breeze on a hot day.", + "What is the Big Bang theory? Explain its core idea simply.", + "Imagine a device that can read minds. What happens?", + "Explain the concept of 'confirmation bias' in social media.", + "List three benefits of learning a new skill.", + "Describe the sound of wind whistling through trees.", + "What is gravity? Briefly explain its role in planetary orbits.", + "Write a short, whimsical sentence about a playful cloud.", + "Explain the difference between 'qualitative' and 'quantitative' research.", + "List three common types of rocks.", + "Describe the smell of fresh coffee brewing.", + "What is photosynthesis? Explain its significance to plants.", + "Imagine a world where plants can sing.", + "Explain the 'Pareto principle' in personal finance.", + "List three benefits of engaging in creative activities.", + "Describe the sound of a distant waterfall.", + "What is the Internet of Things? Give a practical application.", + "Write a short, inspiring sentence about reaching goals.", + "Explain the difference between 'ethical' and 'legal.'", + "List three common forms of renewable energy.", + "Describe the taste of crunchy, salty popcorn.", + "What is a black hole? Briefly explain its formation process.", + "Imagine a pet robot that understands emotions.", + "Explain the concept of 'cognitive dissonance' in beliefs.", + "List three major mountain ranges in the world.", + "Describe the sound of soft piano music.", + "What is a neuron? Explain its role in thought.", + "Write a short, optimistic sentence about the power of community.", + "Explain the difference between 'accuracy' and 'precision.'", + "List three common ingredients in a traditional pizza.", + "Describe the feeling of warmth from a cozy fireplace.", + "What is DNA? Briefly explain its double helix.", + "Imagine a world where animals can build cities.", + "Explain the 'halo effect' in product marketing.", + "List three major historical periods.", + "Describe the sound of footsteps on a wooden floor.", + "What is climate change? List its economic impacts.", + "Write a short, funny sentence about a grumpy cat.", + "Explain the difference between 'phishing' and 'malware.'", + "List three common types of trees in forests.", + "Describe the smell of fresh laundry.", + "What is a satellite? Briefly explain its orbit.", + "Imagine a world where thoughts are visible.", + "Explain the 'bystander effect' in everyday life.", + "List three benefits of spending time alone.", + "Describe the taste of a hot, crispy french fry.", + "What is AI? Briefly explain its future potential.", + "Write a short, inspiring sentence about self-discovery.", + "Explain the difference between 'explicit' and 'implicit' bias.", + "List three common types of musical instruments.", + "Describe the feeling of coolness from a gentle breeze.", + "What is the Big Bang theory? Explain its cosmic implications.", + "Imagine a device that makes you invisible.", + "Explain the concept of 'confirmation bias' in relationships.", + "List three benefits of practicing gratitude daily.", + "Describe the sound of a bubbling stream.", + "What is gravity? Briefly explain its role in galaxies.", + "Write a short, optimistic sentence about human resilience.", + "Explain the difference between 'weather' and 'climate' in detail.", + "List three common types of birds in your neighborhood.", + "Describe the smell of baking bread from a distance.", + "What is photosynthesis? Explain its role in the food chain.", + "Imagine a plant that can cure any disease.", + "Explain the 'Pareto principle' in project management.", + "List three benefits of daily exercise.", + "Describe the sound of a distant, ringing bell.", + "What is the Internet of Things? Explain its security risks.", + "Write a short, whimsical sentence about a talking moon.", + "Explain the difference between 'analogue' and 'digital.'", + "List three common types of pasta.", + "Describe the taste of sweet, ripe berries.", + "What is a black hole? Explain its gravitational pull.", + "Imagine a pet that can grant wishes.", + "Explain the concept of 'cognitive dissonance' in consumerism.", + "List three major deserts in the world.", + "Describe the sound of laughter echoing through an empty hall.", + "What is a neuron? Briefly explain its electrical impulses.", + "Write a short, inspiring sentence about overcoming challenges.", + "Explain the difference between 'empathy' and 'sympathy' with examples.", + "List three common ingredients in a basic omelette.", + "Describe the feeling of relief after a stressful event.", + "What is DNA? Explain its role in cloning.", + "Imagine a world where animals can read.", + "Explain the 'halo effect' in education.", + "List three major events of the Roman Empire.", + "Describe the sound of gentle rain falling.", + "What is climate change? List its social impacts.", + "Write a short, funny sentence about a cat playing piano.", + "Explain the difference between 'software' and 'hardware.'", + "List three common types of fish.", + "Describe the smell of fresh linen.", + "What is a satellite? Explain its use in weather forecasting.", + "Imagine a device that can instantly transport you anywhere.", + "Explain the 'bystander effect' in online spaces.", + "List three benefits of having a strong support system.", + "Describe the taste of a warm, comforting bowl of soup.", + "What is AI? Briefly explain its ethical considerations.", + "Write a short, optimistic sentence about future energy.", + "Explain the difference between 'fact' and 'hypothesis.'", + "List three common types of cheese.", + "Describe the feeling of warmth from a sunny window.", + "What is the Big Bang theory? Briefly explain its origin.", + "Imagine a world where plants glow in response to music.", + "Explain the concept of 'confirmation bias' in political discussions.", + "List three benefits of regular reading.", + "Describe the sound of a distant church bell.", + "What is gravity? Briefly explain its impact on stars.", + "Write a short, whimsical sentence about a dancing teapot.", + "Explain the difference between 'democracy' and 'autocracy' in detail.", + "List three common types of forests.", + "Describe the smell of fresh cut flowers.", + "What is photosynthesis? Explain its role in oxygen production.", + "Imagine a pet that can talk to other animals.", + "Explain the 'Pareto principle' in personal development.", + "List three benefits of learning new languages.", + "Describe the sound of a chirping bird outside your window.", + "What is the Internet of Things? Briefly explain its privacy concerns.", + "Write a short, inspiring sentence about embracing diversity.", + "Explain the difference between 'supply' and 'demand' with examples.", + "List three common types of grains.", + "Describe the taste of a sweet, ripe peach.", + "What is a black hole? Explain its effect on light.", + "Imagine a robot designed for artistic creation.", + "Explain the concept of 'cognitive dissonance' in political affiliation.", + "List three major rivers in Asia.", + "Describe the sound of a gentle flowing river.", + "What is a neuron? Briefly explain its connections.", + "Write a short, optimistic sentence about space travel.", + "Explain the difference between 'efficiency' and 'effectiveness.'", + "List three common ingredients in a salad.", + "Describe the feeling of contentment after a good meal.", + "What is DNA? Explain its role in genetic engineering.", + "Imagine a world where trees can walk.", + "Explain the 'halo effect' in personal relationships.", + "List three major achievements of the space race.", + "Describe the sound of a distant foghorn.", + "What is climate change? List its potential solutions.", + "Write a short, funny sentence about a talking shoe.", + "Explain the difference between 'cloud computing' and 'local storage.'", + "List three common types of transportation in cities.", + "Describe the smell of fresh-baked cookies right out of the oven.", + "What is a satellite? Briefly explain its role in GPS.", + "Imagine a building that can reconfigure itself.", + "Explain the 'bystander effect' in online forums.", + "List three benefits of learning a musical instrument.", + "Describe the taste of a warm, sweet waffle.", + "What is AI? Briefly explain its societal impacts.", + "Write a short, inspiring sentence about collaboration.", + "Explain the difference between 'cybersecurity' and 'information security.'", + "List three common types of fruits in a tropical climate.", + "Describe the feeling of joy when seeing a loved one.", + "What is the Big Bang theory? Explain its timeline.", + "Imagine a device that lets you experience other people's memories.", + "Explain the concept of 'confirmation bias' in everyday life.", + "List three benefits of engaging in hobbies.", + "Describe the sound of a cat purring contentedly.", + "What is gravity? Briefly explain its influence on tides.", + "Write a short, whimsical sentence about a mischievous gnome.", + "Explain the difference between 'weather' and 'climate' in research.", + "List three common types of insects.", + "Describe the smell of a fresh rain shower.", + "What is photosynthesis? Explain its importance to animal life.", + "Imagine a plant that can clean all pollution.", + "Explain the 'Pareto principle' in customer satisfaction.", + "List three benefits of spending time in nature regularly.", + "Describe the sound of gentle waves on a pebble beach.", + "What is the Internet of Things? Explain its future implications.", + "Write a short, inspiring sentence about environmental protection.", + "Explain the difference between 'qualitative' and 'quantitative' research methods.", + "List three common types of metals.", + "Describe the taste of a sweet and tangy raspberry.", + "What is a black hole? Briefly explain its spaghettification.", + "Imagine a pet that can change its size.", + "Explain the concept of 'cognitive dissonance' in moral dilemmas.", + "List three major volcanoes in the world.", + "Describe the sound of wind rustling through tall grass.", + "What is a neuron? Explain its dendrites and axons.", + "Write a short, optimistic sentence about global peace.", + "Explain the difference between 'privacy' and 'anonymity.'", + "List three common ingredients in a basic soup.", + "Describe the feeling of warmth from a campfire.", + "What is DNA? Explain its role in cloning simply.", + "Imagine a world where dreams are prophecies.", + "Explain the 'halo effect' in brand perception.", + "List three major events of the Renaissance.", + "Describe the sound of a distant, echoing laugh.", + "What is climate change? List its long-term effects.", + "Write a short, funny sentence about a very smart dog.", + "Explain the difference between 'virus' and 'worm' concisely.", + "List three common types of trees in cold climates.", + "Describe the smell of freshly brewed tea.", + "What is a satellite? Explain its use in communication clearly.", + "Imagine a city powered by human thought.", + "Explain the 'bystander effect' in a public setting.", + "List three benefits of learning a new language as an adult.", + "Describe the taste of a hot, crispy pizza crust.", + "What is AI? Briefly explain its potential societal benefits.", + "Write a short, inspiring sentence about hope.", + "Explain the difference between 'empathy' and 'pity.'", + "List three common types of vegetables in a garden.", + "Describe the feeling of soft sand between your toes.", + "What is the Big Bang theory? Explain its implications for matter.", + "Imagine a device that lets you talk to animals.", + "Explain the concept of 'confirmation bias' in news consumption.", + "List three benefits of regular exercise on physical health.", + "Describe the sound of a gentle waterfall in a forest.", + "What is gravity? Briefly explain its impact on stars.", + "Write a short, whimsical sentence about a talking tree.", + "Explain the difference between 'fact' and 'belief.'", + "List three common types of flowers.", + "Describe the smell of a pine forest in winter.", + "What is photosynthesis? Explain its role in carbon cycle.", + "Imagine a plant that can grow any object.", + "Explain the 'Pareto principle' in personal productivity.", + "List three benefits of daily meditation.", + "Describe the sound of wind chimes on a calm evening.", + "What is the Internet of Things? Explain its impact on homes.", + "Write a short, inspiring sentence about sustainability.", + "Explain the difference between 'ethics' and 'law' concisely.", + "List three common types of birds in a city.", + "Describe the taste of sweet, creamy yogurt.", + "What is a black hole? Briefly explain its powerful gravity.", + "Imagine a pet robot that teaches you new skills.", + "Explain the concept of 'cognitive dissonance' in relationships.", + "List three major seas in the world.", + "Describe the sound of gentle waves on a lake shore.", + "What is a neuron? Explain its role in memory.", + "Write a short, optimistic sentence about the future of medicine.", + "Explain the difference between 'accuracy' and 'validity' in research.", + "List three common ingredients in a healthy smoothie.", + "Describe the feeling of warmth from a hot bath.", + "What is DNA? Explain its role in evolution simply.", + "Imagine a world where animals can create art.", + "Explain the 'halo effect' in social interactions.", + "List three major events of the industrial revolution.", + "Describe the sound of gentle rain on a rooftop.", + "What is climate change? List its impact on biodiversity.", + "Write a short, funny sentence about a clumsy robot chef.", + "Explain the difference between 'URL' and 'IP address.'", + "List three common types of landforms.", + "Describe the smell of freshly baked bread from a bakery.", + "What is a satellite? Explain its use in remote sensing.", + "Imagine a device that can translate any language instantly.", + "Explain the 'bystander effect' in online gaming.", + "List three benefits of practicing positive affirmations.", + "Describe the taste of a juicy, sweet apple.", + "What is AI? Briefly explain its impact on industries.", + "Write a short, inspiring sentence about overcoming obstacles.", + "Explain the difference between 'debt' and 'equity.'", + "List three common types of natural disasters.", + "Describe the feeling of softness of a fluffy cloud.", + "What is the Big Bang theory? Explain its implications for life.", + "Imagine a world where time travel is possible.", + "Explain the concept of 'confirmation bias' in scientific research.", + "List three benefits of daily gratitude journaling.", + "Describe the sound of wind howling in a storm.", + "What is gravity? Briefly explain its role in black holes.", + "Write a short, whimsical sentence about a talking cloud.", + "Explain the difference between 'weather' and 'climate' in terms of scale.", + "List three common types of reptiles.", + "Describe the smell of freshly ground coffee.", + "What is photosynthesis? Explain its process in plants.", + "Imagine a plant that can generate electricity.", + "Explain the 'Pareto principle' in customer service (briefly).", + "List three benefits of volunteering for a cause.", + "Describe the sound of ocean waves crashing against rocks.", + "What is the Internet of Things? Briefly explain its impact on cities.", + "Write a short, inspiring sentence about unity and progress.", + "Explain the difference between 'ethical' and 'moral' choices.", + "List three common types of birds of prey.", + "Describe the taste of a warm, comforting cup of tea.", + "What is a black hole? Briefly explain its properties.", + "Imagine a pet robot that cleans your house.", + "Explain the concept of 'cognitive dissonance' in voting.", + "List three major lakes in North America.", + "Describe the sound of a distant, mournful owl hoot.", + "What is a neuron? Briefly explain its role in movement.", + "Write a short, optimistic sentence about human kindness.", + "Explain the difference between 'supply chain' and 'logistics.'", + "List three common ingredients in a homemade pizza.", + "Describe the feeling of exhilaration after achieving a goal.", + "What is DNA? Explain its role in forensic science.", + "Imagine a world where trees can talk.", + "Explain the 'halo effect' in job interviews again.", + "List three major battles of the American Civil War.", + "Describe the sound of gentle rain on a forest floor.", + "What is climate change? List its impact on oceans.", + "Write a short, funny sentence about a very sleepy cat.", + "Explain the difference between 'CPU' and 'GPU.'", + "List three common types of amphibians.", + "Describe the smell of warm, fresh cookies.", + "What is a satellite? Briefly explain its use in disaster relief.", + "Imagine a device that lets you control your dreams.", + "Explain the 'bystander effect' in online harassment.", + "List three benefits of learning critical thinking skills.", + "Describe the taste of sweet, ripe banana.", + "What is AI? Briefly explain its challenges.", + "Write a short, inspiring sentence about embracing challenges.", + "Explain the difference between 'data' and 'information' again.", + "List three common types of desserts in American cuisine.", + "Describe the feeling of coolness from an air conditioner.", + "What is the Big Bang theory? Explain its evidence simply.", + "Imagine a robot designed for medical assistance.", + "Explain the concept of 'confirmation bias' in personal beliefs.", + "List three benefits of engaging in artistic endeavors.", + "Describe the sound of a distant thunderstorm.", + "What is gravity? Briefly explain its role in planetary formation.", + "Write a short, whimsical sentence about a flying dog.", + "Explain the difference between 'analogy' and 'simile' in rhetoric.", + "List three common types of fish in freshwater.", + "Describe the smell of fresh laundry drying outside.", + "What is photosynthesis? Explain its importance to life concisely.", + "Imagine a plant that grows diamonds.", + "Explain the 'Pareto principle' in software development.", + "List three benefits of self-care practices.", + "Describe the sound of a distant train horn at night.", + "What is the Internet of Things? Explain its impact on industries.", + "Write a short, inspiring sentence about building a better future.", + "Explain the difference between 'risk' and 'uncertainty.'", + "List three common types of cereals.", + "Describe the taste of a crispy, salty potato chip.", + "What is a black hole? Briefly explain its escape velocity.", + "Imagine a pet that can teleport.", + "Explain the concept of 'cognitive dissonance' in political actions.", + "List three major rivers in Europe.", + "Describe the sound of gentle classical music.", + "What is a neuron? Explain its basic communication in a network.", + "Write a short, optimistic sentence about technological advancements.", + "Explain the difference between 'privacy policy' and 'terms of service.'", + "List three common ingredients in a traditional curry.", + "Describe the feeling of peacefulness in a quiet garden.", + "What is DNA? Explain its role in genetic diseases.", + "Imagine a world where animals have human-like intelligence.", + "Explain the 'halo effect' in customer reviews.", + "List three major historical figures from the American West.", + "Describe the sound of rain tapping on a window pane.", + "What is climate change? List its impact on human health.", + "Write a short, funny sentence about a talking cat detective.", + "Explain the difference between 'URL' and 'domain name.'", + "List three common types of birds that migrate.", + "Describe the smell of warm, fresh baked pie.", + "What is a satellite? Briefly explain its function in space.", + "Imagine a device that lets you fast-forward time.", + "Explain the 'bystander effect' in online forums, briefly.", + "List three benefits of practicing gratitude daily.", + "Describe the taste of a sweet, juicy pear.", + "What is AI? Briefly explain its role in automation.", + "Write a short, inspiring sentence about the power of belief.", + "Explain the difference between 'ethical' and 'legal' implications.", + "List three common types of flowers that bloom in spring.", + "Describe the feeling of warmth from a sunny beach.", + "What is the Big Bang theory? Explain its expansion and cooling.", + "Imagine a robot designed for artistic creation and expression.", + "Explain the concept of 'confirmation bias' in political news.", + "List three benefits of engaging in community service.", + "Describe the sound of a distant dog barking.", + "What is gravity? Briefly explain its effect on space and time.", + "Write a short, whimsical sentence about a magical library.", + "Explain the difference between 'gross' and 'net' profit.", + "List three common types of insects found in gardens.", + "Describe the smell of fresh cut grass on a summer morning.", + "What is photosynthesis? Explain its role in ecosystems concisely.", + "Imagine a plant that can talk.", + "Explain the 'Pareto principle' in customer support.", + "List three benefits of practicing self-compassion.", + "Describe the sound of waves gently lapping at the shore.", + "What is the Internet of Things? Explain its impact on privacy.", + "Write a short, inspiring sentence about resilience in adversity.", + "Explain the difference between 'cybersecurity' and 'network security.'", + "List three common types of fish found in oceans.", + "Describe the taste of a hot, savory slice of pizza.", + "What is a black hole? Briefly explain its effect on light.", + "Imagine a pet that can change its color at will.", + "Explain the concept of 'cognitive dissonance' in personal choices.", + "List three major deserts in Africa.", + "Describe the sound of a gentle lullaby.", + "What is a neuron? Explain its basic function in the brain.", + "Write a short, optimistic sentence about the future of humanity.", + "Explain the difference between 'empathy' and 'sympathy' clearly.", + "List three common ingredients in a classic sandwich.", + "Describe the feeling of calm in a quiet, secluded spot.", + "What is DNA? Explain its role in inheritance briefly.", + "Imagine a world where animals have super powers.", + "Explain the 'halo effect' in social perception.", + "List three major events of the Middle Ages.", + "Describe the sound of a distant, lonely owl hoot.", + "What is climate change? List its impact on polar regions.", + "Write a short, funny sentence about a robot losing its keys.", + "Explain the difference between 'hard drive' and 'RAM.'", + "List three common types of vegetables you can grow at home.", + "Describe the smell of warm, fresh rain on pavement.", + "What is a satellite? Briefly explain its use in climate monitoring.", + "Imagine a device that lets you explore other planets.", + "Explain the 'bystander effect' in online forums, with context.", + "List three benefits of learning basic cooking skills.", + "Describe the taste of a crispy, fresh salad.", + "What is AI? Briefly explain its role in scientific research.", + "Write a short, inspiring sentence about the power of imagination.", + "Explain the difference between 'fact' and 'opinion' clearly.", + "List three common types of nuts.", + "Describe the feeling of warmth from a cozy sweater.", + "What is the Big Bang theory? Explain its implications for matter.", + "Imagine a robot designed for companionship and learning.", + "Explain the concept of 'confirmation bias' in news consumption.", + "List three benefits of engaging in artistic endeavors regularly.", + "Describe the sound of a gentle breeze through tree leaves.", + "What is gravity? Briefly explain its effect on celestial bodies.", + "Write a short, whimsical sentence about a talking book.", + "Explain the difference between 'qualitative' and 'quantitative' data analysis.", + "List three common types of rocks found in sedimentary layers.", + "Describe the smell of fresh-baked bread from a bustling bakery.", + "What is photosynthesis? Explain its significance to all life.", + "Imagine a plant that can sing, emitting beautiful melodies.", + "Explain the 'Pareto principle' in personal finance with examples.", + "List three benefits of self-care and relaxation.", + "Describe the sound of gentle waves on a sandy beach.", + "What is the Internet of Things? Explain its impact on health.", + "Write a short, inspiring sentence about embracing new beginnings.", + "Explain the difference between 'ethical' and 'moral' decision making.", + "List three common types of birds that sing beautifully.", + "Describe the taste of sweet, juicy watermelon on a hot day.", + "What is a black hole? Briefly explain its event horizon again.", + "Imagine a pet robot that understands and responds to emotions.", + "Explain the concept of 'cognitive dissonance' in consumer choices.", + "List three major mountain ranges in North America.", + "Describe the sound of distant thunder rumbling.", + "What is a neuron? Explain its role in sensory perception.", + "Write a short, optimistic sentence about the future of clean energy.", + "Explain the difference between 'efficiency' and 'productivity.'", + "List three common ingredients in a traditional stir-fry.", + "Describe the feeling of accomplishment after a challenging workout.", + "What is DNA? Explain its role in genetic engineering simply.", + "Imagine a world where trees can walk and talk.", + "Explain the 'halo effect' in job candidate evaluation.", + "List three major historical figures from the American Civil War.", + "Describe the sound of gentle rain on a canvas tent.", + "What is climate change? List its impact on water resources.", + "Write a short, funny sentence about a very clumsy wizard.", + "Explain the difference between 'URL' and 'web address.'", + "List three common types of mammals in your region.", + "Describe the smell of a warm, inviting fireplace.", + "What is a satellite? Briefly explain its role in communication networks.", + "Imagine a device that allows you to control the weather.", + "Explain the 'bystander effect' in emergency situations, simply.", + "List three benefits of regular physical activity for health.", + "Describe the taste of a sweet, tangy orange.", + "What is AI? Briefly explain its role in autonomous vehicles.", + "Write a short, inspiring sentence about finding inner strength.", + "Explain the difference between 'empathy' and 'compassion.'", + "List three common types of fruits in temperate climates.", + "Describe the feeling of coolness from a mountain breeze.", + "What is the Big Bang theory? Explain its core concepts.", + "Imagine a robot designed for artistic expression.", + "Explain the concept of 'confirmation bias' in political discussions.", + "List three benefits of daily mindfulness meditation.", + "Describe the sound of a distant bird singing at dawn.", + "What is gravity? Briefly explain its effect on planetary motion.", + "Write a short, whimsical sentence about a talking shoe.", + "Explain the difference between 'gross' and 'net' income for taxes.", + "List three common types of trees in deciduous forests.", + "Describe the smell of freshly brewed coffee in the morning.", + "What is photosynthesis? Explain its importance to food production.", + "Imagine a plant that grows currency.", + "Explain the 'Pareto principle' in software testing.", + "List three benefits of practicing gratitude for well-being.", + "Describe the sound of a distant, gentle wind chimes.", + "What is the Internet of Things? Explain its smart home applications.", + "Write a short, inspiring sentence about hope for the future.", + "Explain the difference between 'ethics' and 'morality' in practice.", + "List three common types of birds that live near water.", + "Describe the taste of a warm, comforting bowl of oatmeal.", + "What is a black hole? Briefly explain its accretion disk.", + "Imagine a pet robot that can play any musical instrument.", + "Explain the concept of 'cognitive dissonance' in purchasing decisions.", + "List three major rivers in Asia and their countries.", + "Describe the sound of a gentle stream flowing over rocks.", + "What is a neuron? Explain its role in learning and memory.", + "Write a short, optimistic sentence about global collaboration.", + "Explain the difference between 'efficiency' and 'effectiveness' in projects.", + "List three common ingredients in a classic pasta dish.", + "Describe the feeling of contentment after a satisfying meal.", + "What is DNA? Explain its role in genetic counseling.", + "Imagine a world where animals have a universal language.", + "Explain the 'halo effect' in celebrity endorsements.", + "List three major wars in human history.", + "Describe the sound of gentle rain falling on leaves.", + "What is climate change? List its impact on ecosystems.", + "Write a short, funny sentence about a robot that loves disco.", + "Explain the difference between 'HTTPS' and 'HTTP.'", + "List three common types of wildflowers.", + "Describe the smell of freshly cut grass on a summer evening.", + "What is a satellite? Briefly explain its use in defense.", + "Imagine a device that allows you to visit alternate realities.", + "Explain the 'bystander effect' in a school setting.", + "List three benefits of engaging in creative writing.", + "Describe the taste of a sweet, juicy pineapple.", + "What is AI? Briefly explain its impact on creativity.", + "Write a short, inspiring sentence about the power of kindness.", + "Explain the difference between 'fact' and 'theory' in science.", + "List three common types of desserts around the world.", + "Describe the feeling of cool water on your feet.", + "What is the Big Bang theory? Explain its expansion of space.", + "Imagine a robot designed for exploring deep space.", + "Explain the concept of 'confirmation bias' in scientific discovery.", + "List three benefits of daily gratitude practice.", + "Describe the sound of a distant train whistle blowing.", + "What is gravity? Briefly explain its role in planet formation.", + "Write a short, whimsical sentence about a talking pillow.", + "Explain the difference between 'qualitative' and 'quantitative' data collection.", + "List three common types of metamorphic rocks.", + "Describe the smell of fresh-baked apple pie.", + "What is photosynthesis? Explain its byproduct, oxygen.", + "Imagine a plant that provides infinite energy.", + "Explain the 'Pareto principle' in personal relationships.", + "List three benefits of practicing positive self-talk.", + "Describe the sound of gentle waves lapping on a wooden dock.", + "What is the Internet of Things? Explain its impact on industry.", + "Write a short, inspiring sentence about pursuing your dreams.", + "Explain the difference between 'ethical' and 'legal' boundaries.", + "List three common types of birds that sing at night.", + "Describe the taste of warm, creamy mashed potatoes.", + "What is a black hole? Briefly explain its intense gravity.", + "Imagine a pet robot that helps with homework.", + "Explain the concept of 'cognitive dissonance' in health choices.", + "List three major lakes in Africa.", + "Describe the sound of a gentle breeze through tall trees.", + "What is a neuron? Explain its basic function in the nervous system.", + "Write a short, optimistic sentence about the power of learning.", + "Explain the difference between 'privacy' and 'confidentiality.'", + "List three common ingredients in a basic chili.", + "Describe the feeling of warmth from a sunny winter day.", + "What is DNA? Explain its role in inherited traits.", + "Imagine a world where animals build complex structures.", + "Explain the 'halo effect' in leadership.", + "List three major historical figures from World War II.", + "Describe the sound of gentle rain falling on a metal roof.", + "What is climate change? List its impact on agriculture.", + "Write a short, funny sentence about a sleepy, talking cloud.", + "Explain the difference between 'broadband' and 'dial-up.'", + "List three common types of mammals in cold climates.", + "Describe the smell of freshly laundered clothes.", + "What is a satellite? Briefly explain its use in scientific research.", + "Imagine a device that allows you to communicate with plants.", + "Explain the 'bystander effect' in online forums, with context.", + "List three benefits of learning basic coding skills.", + "Describe the taste of a tangy, sweet grapefruit.", + "What is AI? Briefly explain its role in healthcare.", + "Write a short, inspiring sentence about the joy of discovery.", + "Explain the difference between 'simile' and 'analogy' in detail.", + "List three common types of pastries.", + "Describe the feeling of cool water splashing on your face.", + "What is the Big Bang theory? Explain its implications for energy.", + "Imagine a robot designed for space exploration and construction.", + "Explain the concept of 'confirmation bias' in personal decisions.", + "List three benefits of regular meditation practice.", + "Describe the sound of a distant owl hooting softly.", + "What is gravity? Briefly explain its role in stellar evolution.", + "Write a short, whimsical sentence about a talking telescope.", + "Explain the difference between 'accuracy' and 'precision' in measurement.", + "List three common types of sedimentary rocks.", + "Describe the smell of fresh baked cinnamon rolls.", + "What is photosynthesis? Explain its light-dependent reactions.", + "Imagine a plant that can predict the future.", + "Explain the 'Pareto principle' in time management with examples.", + "List three benefits of engaging in outdoor recreation.", + "Describe the sound of ocean waves receding from the shore.", + "What is the Internet of Things? Explain its impact on transportation.", + "Write a short, inspiring sentence about environmental stewardship.", + "Explain the difference between 'ethics' and 'values.'", + "List three common types of trees found in tropical regions.", + "Describe the taste of warm, creamy chocolate pudding.", + "What is a black hole? Briefly explain its formation from stars.", + "Imagine a pet robot that helps you stay organized.", + "Explain the concept of 'cognitive dissonance' in societal norms.", + "List three major rivers in South America and their countries.", + "Describe the sound of a gentle wind whispering through leaves.", + "What is a neuron? Explain its role in complex thought.", + "Write a short, optimistic sentence about human ingenuity.", + "Explain the difference between 'supply' and 'demand' curves.", + "List three common ingredients in a classic omelette.", + "Describe the feeling of peacefulness in a quiet, secluded forest.", + "What is DNA? Explain its role in biotechnology.", + "Imagine a world where animals can invent things.", + "Explain the 'halo effect' in political campaigns.", + "List three major battles of World War I.", + "Describe the sound of gentle rain falling on an umbrella.", + "What is climate change? List its impact on marine life.", + "Write a short, funny sentence about a robot afraid of dust.", + "Explain the difference between 'HTTP' and 'HTTPS' simply.", + "List three common types of amphibians in rainforests.", + "Describe the smell of fresh baked apple crisp.", + "What is a satellite? Briefly explain its use in disaster monitoring.", + "Imagine a device that lets you explore historical events.", + "Explain the 'bystander effect' in a digital community.", + "List three benefits of continuous learning.", + "Describe the taste of a sweet, juicy plum.", + "What is AI? Briefly explain its role in education.", + "Write a short, inspiring sentence about the power of unity.", + "Explain the difference between 'debt' and 'credit.'", + "List three common types of natural resources.", + "Describe the feeling of softness of a warm blanket.", + "What is the Big Bang theory? Explain its earliest moments.", + "Imagine a robot designed for emergency rescue.", + "Explain the concept of 'confirmation bias' in legal judgments.", + "List three benefits of daily exercise for mental health.", + "Describe the sound of a distant bell tolling.", + "What is gravity? Briefly explain its role in galaxy clusters.", + "Write a short, whimsical sentence about a talking flowerpot.", + "Explain the difference between 'explicit' and 'implicit' memory briefly.", + "List three common types of musical genres.", + "Describe the smell of freshly brewed chamomile tea.", + "What is photosynthesis? Explain its carbon dioxide absorption.", + "Imagine a plant that can heal wounds instantly.", + "Explain the 'Pareto principle' in sales performance.", + "List three benefits of practicing self-awareness.", + "Describe the sound of gentle waves on a rocky shore.", + "What is the Internet of Things? Explain its impact on healthcare.", + "Write a short, inspiring sentence about overcoming adversity.", + "Explain the difference between 'ethical' and 'moral' reasoning.", + "List three common types of birds that live in deserts.", + "Describe the taste of a warm, comforting bowl of soup on a cold day.", + "What is a black hole? Briefly explain its gravitational lensing.", + "Imagine a pet robot that helps with cooking.", + "Explain the concept of 'cognitive dissonance' in personal values.", + "List three major lakes in Europe.", + "Describe the sound of a gentle breeze through an open window.", + "What is a neuron? Explain its role in motor control.", + "Write a short, optimistic sentence about finding solutions.", + "Explain the difference between 'supply chain management' and 'logistics.'", + "List three common ingredients in a basic vegetable soup.", + "Describe the feeling of excitement before a new journey.", + "What is DNA? Explain its role in gene therapy.", + "Imagine a world where animals can design buildings.", + "Explain the 'halo effect' in classroom settings.", + "List three major figures from ancient Roman history.", + "Describe the sound of gentle rain on a tent roof.", + "What is climate change? List its impact on coastal areas.", + "Write a short, funny sentence about a robot that loves dancing.", + "Explain the difference between 'Wi-Fi' and 'Bluetooth.'", + "List three common types of reptiles in tropical climates.", + "Describe the smell of warm, freshly baked cookies.", + "What is a satellite? Briefly explain its use in global positioning.", + "Imagine a device that allows you to experience other people's dreams.", + "Explain the 'bystander effect' in online gaming forums.", + "List three benefits of learning mindfulness techniques.", + "Describe the taste of a hot, crispy churro.", + "What is AI? Briefly explain its role in cybersecurity.", + "Write a short, inspiring sentence about courage and determination.", + "Explain the difference between 'fact' and 'misinformation.'", + "List three common types of berries.", + "Describe the feeling of warmth from a sunny afternoon.", + "What is the Big Bang theory? Explain its implications for the universe.", + "Imagine a robot designed for companionship and emotional support.", + "Explain the concept of 'confirmation bias' in legal cases.", + "List three benefits of daily gratitude journaling, concisely.", + "Describe the sound of wind whispering through tall grass.", + "What is gravity? Briefly explain its role in galaxy formation.", + "Write a short, whimsical sentence about a mischievous talking cat.", + "Explain the difference between 'qualitative' and 'quantitative' data analysis in research.", + "List three common types of igneous rocks.", + "Describe the smell of fresh, damp earth after rain.", + "What is photosynthesis? Explain its dependence on sunlight.", + "Imagine a plant that can communicate telepathically.", + "Explain the 'Pareto principle' in customer relationships.", + "List three benefits of spending time in nature for mental health.", + "Describe the sound of gentle waves on a sandy beach.", + "What is the Internet of Things? Explain its impact on cities.", + "Write a short, inspiring sentence about global cooperation and peace.", + "Explain the difference between 'ethical' and 'moral' dilemmas.", + "List three common types of birds that live in mountains.", + "Describe the taste of warm, sweet maple syrup.", + "What is a black hole? Briefly explain its powerful gravitational pull.", + "Imagine a pet robot that can play any sport.", + "Explain the concept of 'cognitive dissonance' in ethical decisions.", + "List three major deserts in South America.", + "Describe the sound of a distant, joyful song.", + "What is a neuron? Explain its role in sensory processing.", + "Write a short, optimistic sentence about the future of exploration.", + "Explain the difference between 'privacy' and 'confidentiality' in data.", + "List three common ingredients in a basic sandwich.", + "Describe the feeling of peace in a quiet, secluded forest.", + "What is DNA? Explain its role in genetic engineering processes.", + "Imagine a world where animals have their own technology.", + "Explain the 'halo effect' in product branding.", + "List three major historical figures from the Civil Rights Movement.", + "Describe the sound of gentle rain on a windowpane.", + "What is climate change? List its impact on animal habitats.", + "Write a short, funny sentence about a robot trying to bake.", + "Explain the difference between 'cloud storage' and 'local storage.'", + "List three common types of mammals found in forests.", + "Describe the smell of warm, fresh bread from a bakery.", + "What is a satellite? Briefly explain its use in environmental monitoring.", + "Imagine a device that allows you to experience the future.", + "Explain the 'bystander effect' in online harassment, with nuance.", + "List three benefits of learning basic programming skills.", + "Describe the taste of a crisp, refreshing cucumber.", + "What is AI? Briefly explain its role in art generation.", + "Write a short, inspiring sentence about finding your purpose.", + "Explain the difference between 'deductive' and 'inductive' reasoning clearly.", + "List three common types of musical instruments in an orchestra.", + "Describe the feeling of lightness after a good laugh.", + "What is the Big Bang theory? Explain its implications for life's origin.", + "Imagine a robot designed for creative writing.", + "Explain the concept of 'confirmation bias' in personal interactions.", + "List three benefits of engaging in team sports.", + "Describe the sound of a distant bird singing melodiously.", + "What is gravity? Briefly explain its effect on planetary orbits.", + "Write a short, whimsical sentence about a talking chessboard.", + "Explain the difference between 'accuracy' and 'reliability' in research.", + "List three common types of metamorphic rocks, briefly.", + "Describe the smell of fresh soil after rain.", + "What is photosynthesis? Explain its light-independent reactions.", + "Imagine a plant that can grow any material.", + "Explain the 'Pareto principle' in business operations.", + "List three benefits of practicing self-care consistently.", + "Describe the sound of gentle waves lapping on a tranquil lake.", + "What is the Internet of Things? Explain its impact on smart cities.", + "Write a short, inspiring sentence about embracing positive change.", + "Explain the difference between 'ethical' and 'legal' obligations.", + "List three common types of birds that are migratory.", + "Describe the taste of warm, gooey cheese on pizza.", + "What is a black hole? Briefly explain its gravitational singularity.", + "Imagine a pet robot that can teach you languages.", + "Explain the concept of 'cognitive dissonance' in moral choices.", + "List three major mountain ranges in Europe.", + "Describe the sound of distant, echoing laughter.", + "What is a neuron? Explain its role in decision-making.", + "Write a short, optimistic sentence about the power of imagination.", + "Explain the difference between 'supply' and 'demand' in economics.", + "List three common ingredients in a simple salad dressing.", + "Describe the feeling of warmth from a gentle embrace.", + "What is DNA? Explain its role in genetic diseases, briefly.", + "Imagine a world where animals can build civilizations.", + "Explain the 'halo effect' in social interactions (briefly).", + "List three major historical periods of artistic movements.", + "Describe the sound of gentle rain falling on a forest floor.", + "What is climate change? List its impact on human migration.", + "Write a short, funny sentence about a robot trying to dance.", + "Explain the difference between 'LAN' and 'WAN.'", + "List three common types of fish in saltwater.", + "Describe the smell of warm, fresh cinnamon buns.", + "What is a satellite? Briefly explain its use in reconnaissance.", + "Imagine a device that lets you explore fictional worlds.", + "Explain the 'bystander effect' in online communities, with examples.", + "List three benefits of learning public speaking skills.", + "Describe the taste of a sweet, ripe mango.", + "What is AI? Briefly explain its role in scientific discovery.", + "Write a short, inspiring sentence about the pursuit of knowledge.", + "Explain the difference between 'fact' and 'assumption.'", + "List three common types of vegetables that grow underground.", + "Describe the feeling of cool water flowing over your hands.", + "What is the Big Bang theory? Explain its cosmic evolution.", + "Imagine a robot designed for environmental cleanup.", + "Explain the concept of 'confirmation bias' in political views.", + "List three benefits of daily journaling for mental clarity.", + "Describe the sound of wind rustling through dry leaves.", + "What is gravity? Briefly explain its effect on galaxy formation.", + "Write a short, whimsical sentence about a talking pencil.", + "Explain the difference between 'weather' and 'climate' in scientific terms.", + "List three common types of reptiles found in deserts.", + "Describe the smell of fresh cut lumber.", + "What is photosynthesis? Explain its efficiency in different plants.", + "Imagine a plant that can telekinetically move objects.", + "Explain the 'Pareto principle' in personal productivity, concisely.", + "List three benefits of spending time in nature for overall health.", + "Describe the sound of gentle waves on a pebble beach, peacefully.", + "What is the Internet of Things? Explain its impact on smart homes.", + "Write a short, inspiring sentence about the importance of kindness.", + "Explain the difference between 'ethical' and 'moral' choices in life.", + "List three common types of birds that live in urban areas.", + "Describe the taste of warm, comforting apple pie.", + "What is a black hole? Briefly explain its properties again.", + "Imagine a pet robot that helps with personal fitness.", + "Explain the concept of 'cognitive dissonance' in a workplace.", + "List three major lakes in Asia.", + "Describe the sound of a distant, happy dog barking.", + "What is a neuron? Explain its role in emotions.", + "Write a short, optimistic sentence about overcoming challenges.", + "Explain the difference between 'supply chain' and 'logistics' in detail.", + "List three common ingredients in a traditional chicken soup.", + "Describe the feeling of excitement before an unknown adventure.", + "What is DNA? Explain its role in gene editing.", + "Imagine a world where animals can communicate telepathically.", + "Explain the 'halo effect' in consumer decisions.", + "List three major battles of the American Revolutionary War.", + "Describe the sound of gentle rain on a metal roof.", + "What is climate change? List its impact on weather patterns.", + "Write a short, funny sentence about a robot trying to sing.", + "Explain the difference between 'HTML' and 'CSS.'", + "List three common types of mammals in temperate climates.", + "Describe the smell of fresh baked bread on a cold day.", + "What is a satellite? Briefly explain its use in education.", + "Imagine a device that lets you revisit your past.", + "Explain the 'bystander effect' in a digital social setting.", + "List three benefits of practicing self-awareness regularly.", + "Describe the taste of a hot, crispy donut.", + "What is AI? Briefly explain its role in gaming.", + "Write a short, inspiring sentence about perseverance and growth.", + "Explain the difference between 'assets' and 'liabilities.'", + "List three common types of grains used in baking.", + "Describe the feeling of warmth from a cozy fireplace in winter.", + "What is the Big Bang theory? Explain its implications for the universe's size.", + "Imagine a robot designed for artistic expression and healing.", + "Explain the concept of 'confirmation bias' in personal growth.", + "List three benefits of daily gratitude practice, with details.", + "Describe the sound of wind whispering through pine trees.", + "What is gravity? Briefly explain its role in stellar life cycles.", + "Write a short, whimsical sentence about a talking teacup.", + "Explain the difference between 'qualitative' and 'quantitative' research design.", + "List three common types of metamorphic rocks, simply.", + "Describe the smell of fresh, damp earth after a light rain.", + "What is photosynthesis? Explain its role in the global carbon cycle.", + "Imagine a plant that can heal any illness.", + "Explain the 'Pareto principle' in customer retention.", + "List three benefits of engaging in outdoor activities consistently.", + "Describe the sound of gentle waves lapping on a quiet lake.", + "What is the Internet of Things? Explain its impact on retail.", + "Write a short, inspiring sentence about community building.", + "Explain the difference between 'ethical' and 'moral' choices (briefly).", + "List three common types of birds found in forests.", + "Describe the taste of warm, comforting chicken noodle soup.", + "What is a black hole? Briefly explain its event horizon and singularity.", + "Imagine a pet robot that can play any sport with you.", + "Explain the concept of 'cognitive dissonance' in political views.", + "List three major deserts in North America.", + "Describe the sound of a distant, peaceful choir singing.", + "What is a neuron? Explain its role in muscle control.", + "Write a short, optimistic sentence about the power of learning.", + "Explain the difference between 'supply chain' and 'value chain.'", + "List three common ingredients in a traditional chili recipe.", + "Describe the feeling of warmth from a sunny window seat.", + "What is DNA? Explain its role in ancestry tracing.", + "Imagine a world where animals can build cities and roads.", + "Explain the 'halo effect' in political elections.", + "List three major historical figures from ancient Greek history.", + "Describe the sound of gentle rain on a tent, calmly.", + "What is climate change? List its impact on extreme weather events.", + "Write a short, funny sentence about a robot trying to paint.", + "Explain the difference between 'HTML' and 'JavaScript.'", + "List three common types of amphibians in temperate climates.", + "Describe the smell of fresh baked apple pie, warm and sweet.", + "What is a satellite? Briefly explain its use in disaster prediction.", + "Imagine a device that allows you to communicate with animals.", + "Explain the 'bystander effect' in online emergencies.", + "List three benefits of learning critical thinking skills (briefly).", + "Describe the taste of a crispy, salty piece of bacon.", + "What is AI? Briefly explain its role in art generation processes.", + "Write a short, inspiring sentence about the power of dreams.", + "Explain the difference between 'debt' and 'equity' in finance.", + "List three common types of grains used in breakfast foods.", + "Describe the feeling of softness of a warm, fluffy blanket.", + "What is the Big Bang theory? Explain its implications for cosmic expansion.", + "Imagine a robot designed for creative writing and storytelling.", + "Explain the concept of 'confirmation bias' in scientific fields.", + "List three benefits of daily gratitude journaling for well-being.", + "Describe the sound of wind whispering through dry grass fields.", + "What is gravity? Briefly explain its role in stellar formation.", + "Write a short, whimsical sentence about a talking broom.", + "Explain the difference between 'weather' and 'climate' in academic terms.", + "List three common types of reptiles found in rainforests.", + "Describe the smell of fresh pine needles on a forest floor.", + "What is photosynthesis? Explain its process in detail, simply.", + "Imagine a plant that can change its form at will.", + "Explain the 'Pareto principle' in customer service and satisfaction.", + "List three benefits of spending time in nature for mental clarity.", + "Describe the sound of gentle waves lapping on a quiet lake shore.", + "What is the Internet of Things? Explain its impact on manufacturing.", + "Write a short, inspiring sentence about collective human effort.", + "Explain the difference between 'ethical' and 'moral' choices in society.", + "List three common types of birds that sing beautiful melodies.", + "Describe the taste of warm, creamy tomato soup.", + "What is a black hole? Briefly explain its event horizon and gravity.", + "Imagine a pet robot that helps with personal health monitoring.", + "Explain the concept of 'cognitive dissonance' in consumer decisions.", + "List three major mountain ranges in South America.", + "Describe the sound of a distant, peaceful melody.", + "What is a neuron? Explain its role in learning new things.", + "Write a short, optimistic sentence about future possibilities.", + "Explain the difference between 'efficiency' and 'effectiveness' in work.", + "List three common ingredients in a homemade pizza sauce.", + "Describe the feeling of joy when reunited with loved ones.", + "What is DNA? Explain its role in personalized medicine.", + "Imagine a world where animals have their own social networks.", + "Explain the 'halo effect' in job interviews, providing examples.", + "List three major historical figures from the Renaissance.", + "Describe the sound of gentle rain on a canvas tent, softly.", + "What is climate change? List its impact on global temperatures.", + "Write a short, funny sentence about a robot trying to cook.", + "Explain the difference between 'Wi-Fi' and 'Ethernet.'", + "List three common types of mammals in your local area.", + "Describe the smell of warm, freshly baked cookies from a nearby kitchen.", + "What is a satellite? Briefly explain its use in disaster management.", + "Imagine a device that lets you explore different dimensions.", + "Explain the 'bystander effect' in online communities, with greater detail.", + "List three benefits of learning basic first aid skills.", + "Describe the taste of a crispy, fresh apple.", + "What is AI? Briefly explain its role in personal assistants.", + "Write a short, inspiring sentence about the human spirit.", + ] + + +if __name__ == "__main__": + main() diff --git a/serving/main_serving.py b/serving/main_serving.py new file mode 100644 index 0000000..1ee265c --- /dev/null +++ b/serving/main_serving.py @@ -0,0 +1,224 @@ +import dataclasses +import time +from pathlib import Path +import threading +import asyncio +import socket +import signal +import time +from typing import AsyncGenerator +from contextlib import asynccontextmanager +import os + +import jax +from jax import random +from jax.sharding import PartitionSpec as P, AxisType, NamedSharding +from llama3_jax import model as l3jax +import serving_jax as serving +import numpy as np + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse, Response +from pydantic import BaseModel +import uvicorn + + +TOKENIZER, SERVE_LOOP, SERVING_THREAD = None, None, None + +jax.config.update("jax_explain_cache_misses", True) +jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) +jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) +jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) +jax.config.update("jax_enable_empty_arrays", True) + +try: # newer JAX only + my_id = int(socket.gethostname().split("-")[-1]) - 1 + my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] + jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") + jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) +except: # noqa: E722 + pass + +shutdown_signal = threading.Event() + + +def encode_input(tokenizer, texts, pad_id: int = 0): + assert isinstance(texts, list) + inputs = [ + tokenizer.apply_chat_template([{"role": "user", "content": text}], add_generation_prompt=True) for text in texts + ] + max_len = max([len(x) for x in inputs]) + return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) + + +def _place_local(tree, sharding: NamedSharding, present: bool): + return jax.tree.map( + lambda z, s: jax.make_array_from_single_device_arrays( + z.shape, s, [] if not present else [y.data for y in z.addressable_shards], dtype=z.dtype + ), + tree, + sharding, + ) + + +def load_model(): + global SERVE_LOOP, SERVING_THREAD, TOKENIZER + process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) + jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) + print(jax.devices()) + print("-" * 80) + print(jax.local_devices()) + + model_name = "Llama-3.1-8B-Instruct" + ckpt_path = Path(f"~/{model_name}").expanduser() + cfg = l3jax.load_config(ckpt_path / "config.json") + TOKENIZER = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") + assert ckpt_path.is_dir() + print("---> Model config loaded") + + # two hosts, different device and host meshes + local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) + decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) + prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3) + + # single host, same decode and prefill meshes + #local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) + #decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) + #prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) + + # single host, separate decode and prefill meshes + #local_mesh = jax.make_mesh((1, 4, 1), P("x", "y", "z"), devices=jax.local_devices()[:4], axis_types=(AxisType.Explicit,) * 3) + #decode_mesh = jax.make_mesh((1, 4, 1), P("x", "y", "z"), devices=jax.devices()[:4], axis_types=(AxisType.Explicit,) * 3) + #prefill_mesh = jax.make_mesh((1, 4, 1), P("x", "y", "z"), devices=jax.devices()[4:], axis_types=(AxisType.Explicit,) * 3) + + cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True) + cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=8192) + cfg = dataclasses.replace(cfg, quant_layer=False, quant_cache=False) + + weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=local_mesh))) + + # multi-host: until orbax update + decode_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh)), present=jax.process_index() == 0) + prefill_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh)), present=jax.process_index() == 1) + + # single-host: until orbax update + #decode_weights = serving.device_put(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh))) + #prefill_weights = serving.device_put(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh))) + + print("---> Weights loaded") + + serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64) + # decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) + decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) + SERVE_LOOP = serving.ServingLoop( + serve_cfg, cfg, l3jax.prefill, prefill_weights, l3jax.decode_step, decode_weights, decode_cache + ) + print("---> Created the serving loop") + + def serve_forever(): + try: + while not shutdown_signal.is_set(): + SERVE_LOOP.serving_step() + finally: + print("Received a shutdown signal") + time.sleep(0.1) + signal.raise_signal(signal.SIGINT) # shut down the web server + print("Exiting the serving loop") + + SERVING_THREAD = threading.Thread(target=serve_forever) + SERVING_THREAD.start() + + +######################################################################################################################## + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + shutdown_signal.set() + + +_ = load_model() +APP = FastAPI(lifespan=lifespan) + + +class GenerateRequest(BaseModel): + id: int + text: str + + +#async def generate_generator(params: GenerateRequest, request: Request) -> AsyncGenerator[str, None]: +async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: + if id in SERVE_LOOP.results: + del SERVE_LOOP.results[id] + + input = encode_input(TOKENIZER, [text])[0].tolist() + iter = len(input) + SERVE_LOOP.add_request(serving.UserRequestPrompt(id, input)) + while id not in SERVE_LOOP.results: + await asyncio.sleep(0.1) + try: + result: serving.DecodeResult = SERVE_LOOP.results[id] + while not result.done: + if await request.is_disconnected(): # Check if client disconnected + print("Client disconnected.") + break + if len(result.token_list) > iter: + new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) + yield f"{new_segment}" + await asyncio.sleep(0.1) # Stream a new message every 1 second + if len(result.token_list) > iter: + new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) + yield f"{new_segment}" + except asyncio.CancelledError: + pass + finally: + pass + + +@APP.get("/stream") +async def stream_response(params: GenerateRequest, request: Request): + return StreamingResponse(generate_generator(params.id, params.text, request), media_type="text/event-stream") + + +@APP.get("/generate") +async def generate(id: int, text: str): # generate without output + print(f"Input text: {text}") + SERVE_LOOP.add_request(serving.UserRequestPrompt(id, encode_input(TOKENIZER, [text])[0].tolist())) + return Response("OK") + + +@APP.get("/retrieve") +async def retrieve(id: int): + if id in SERVE_LOOP.results: + return Response(TOKENIZER.decode(SERVE_LOOP.results[id].token_list)) + return Response("NO TEXT") + + +@APP.get("/set_generation_length") +async def set_generation_length(length: int): + SERVE_LOOP.serve_cfg.max_decode_length = max(length, 32) + return Response("OK") + + +@APP.get("/profile") +async def profile(request: Request): + del request + SERVE_LOOP.profile_start_time = time.perf_counter() + return Response("OK") + + +@APP.get("/") +async def root(): + return {"message": "Welcome! Try the /stream-text endpoint."} + + +if __name__ == "__main__": + if jax.process_index() == 0: + uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) + else: + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + shutdown_signal.set() diff --git a/serving/pyproject.toml b/serving/pyproject.toml new file mode 100644 index 0000000..0df6079 --- /dev/null +++ b/serving/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "serving_jax" +version = "0.1.0" +description = "" +authors = [ + { name = "Robert Dyro" }, +] +readme = "README.md" +requires-python = ">=3.11" +license = { text = "Apache-2.0" } + +dependencies = [ + "jax", + "tqdm", + "numpy", + #"orbax-checkpoint", + #"datasets", + "gcsfs", + "etils", +] + +[build-system] +requires = ["setuptools>=61.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["serving_jax"] + +[tool.setuptools.dynamic] +dependencies = { file = ["pyproject.toml"] } diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py new file mode 100644 index 0000000..f17ccf8 --- /dev/null +++ b/serving/serving_jax/__init__.py @@ -0,0 +1,731 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from functools import partial +from typing import Any, Callable +import math +from concurrent.futures import ThreadPoolExecutor, Future +import threading +import time +import json +from typing import Any + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding, use_mesh + +try: + from jax.experimental.shard import auto_axes +except ModuleNotFoundError: + from jax.sharding import auto_axes +from jax._src import distributed + +from jax._src.lib import xla_client as xc +import numpy as np + +from .cross_host import transfer_tree_A2B + + +KVCache, Weights, Config = Any, Any, Any +PyTree, PyTreeStruct = Any, Any + +TIME_AXIS = 2 +USE_PREFIX_CACHE = True # the eviction mechanism is extremely simple right now +is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) + +######################################################################################################################## +# device put for cross-process/hosts transfers ######################################################################### +######################################################################################################################## + + +def unsafe_device_put(xs: PyTree, spec: PyTree, dest_mesh: Mesh): + """Fastest, but local single-process JAX only for now.""" + xs_flat, xs_struct = jax.tree.flatten(xs) + shardings_list = [NamedSharding(dest_mesh, s) for s in jax.tree.leaves(spec)] + devices_list = [s._internal_device_list for s in shardings_list] + copy_semantics = [xc.ArrayCopySemantics.ALWAYS_COPY] * len(devices_list) + out = xc.batched_copy_array_to_devices_with_sharding(xs_flat, devices_list, shardings_list, copy_semantics) + return jax.tree.unflatten(xs_struct, out) + + +def jax_device_put(xs: PyTree, sharding: PyTree): + """Async, available in future JAX.""" + is_source = len(getattr(jax.tree.leaves(xs)[0], "addressable_shards", [])) > 0 + if is_source: + return jax.device_put(xs, sharding) + else: + empty_arrays = jax.tree.map( + lambda x: jax.make_array_from_single_device_arrays(x.shape, x.sharding, [], dtype=x.dtype), xs + ) + return jax.device_put(empty_arrays, sharding) + + +def jit_device_put(xs: PyTree, sharding: PyTree): + """Most compatabile, uses jit, so requires blocking dispatch.""" + jax.sharding.set_mesh(None) # not compatible with context mesh + meshA, meshB = jax.tree.leaves(xs)[0].sharding.mesh, jax.tree.leaves(sharding)[0].mesh + return transfer_tree_A2B(xs, meshA, meshB) + + +device_put = jit_device_put # the most compatible options currently, but NOT async, need + + +def _ensure_all_args_on_mesh(*args, mesh: Mesh): + args_len = len(args) + if not all(jax.tree.leaves(arg)[0].sharding.mesh == mesh for arg in args): + _correct_mesh = lambda value: jax.tree.leaves(value)[0].sharding.mesh == mesh + _args = {i: arg for i, arg in enumerate(args) if not _correct_mesh(arg)} + if len(_args) > 0: + args = dict(enumerate(args)) | device_put(_args, like_shard(_args, mesh)) + args = tuple(args[i] for i in range(len(args))) + return args if args_len > 1 else args[0] + + +######################################################################################################################## +# trie utils ########################################################################################################### +######################################################################################################################## + +_GLOBAL_NODE_ID = 0 + + +@dataclasses.dataclass +class OffloadedValue: + ref: str | np.ndarray + spec: Any + shape_dtypes: Any + + +@dataclasses.dataclass +class TrieNode: + id: int + key: jax.Array + value: PyTree | OffloadedValue + children: list["TrieNode"] = dataclasses.field(default_factory=list) + child_keys: jax.Array | None = None + lock: "threading.Lock | None" = None + usage: int = 1 + + def __repr__(self, indent: int = 0): + lines = [" " * indent + "TrieNode("] + lines.append((" " * indent) + f" key={str(self.key.tolist() if hasattr(self.key, 'tolist') else self.key)},") + lines.append((" " * indent) + f" usage={self.usage},") + if is_type(self.value, OffloadedValue): + lines.append((" " * indent) + f" value={self.value.ref},") + else: + lines.append( + (" " * indent) + + f" value={jax.tree.map(jax.typeof, self.value) if self.value is not None else 'None'}," + ) + lines.append( + (" " * indent) + f" child_keys={jax.typeof(self.child_keys) if self.child_keys is not None else 'None'}," + ) + lines.append((" " * indent) + " children=[") + if self.children: + for child in self.children: + lines.append(f"{child.__repr__(indent + 2)},") + lines.append(" " * indent + " ],") + else: + lines[-1] += "]," + lines.append(" " * indent + ")") + return "\n".join(lines) + + @staticmethod + def new_id(): + global _GLOBAL_NODE_ID + _GLOBAL_NODE_ID += 1 + return _GLOBAL_NODE_ID - 1 + + @staticmethod + def _dist_to_key(key, keys, mask, pad_idx: int): + invalid_rows = np.all(keys == pad_idx, axis=-1) + return np.where(invalid_rows, 2**30, np.sum(mask * np.abs(key - keys), axis=-1)) + + @staticmethod + def _append_key(keys, new_key, keys_len: int, pad_idx: int): + if keys is None: + return new_key[None, ...] # 2 ** 0 power of 2 + if keys_len == keys.shape[0]: # need to double the keys buffer + new_buf = np.pad( + new_key[None, ...], ((0, keys.shape[0] - 1), (0, 0)), mode="constant", constant_values=pad_idx + ) + return np.concatenate([keys, new_buf], 0) + else: + keys[keys_len, ...] = new_key + return keys + + @staticmethod + def _pad_to_multiple_of(sequence: jax.Array, chunk_size: int, pad_idx: int): + sequence_pad_len = math.ceil(sequence.size / chunk_size) * chunk_size + return np.pad(sequence, ((0, sequence_pad_len - sequence.shape[-1])), mode="constant", constant_values=pad_idx) + + @staticmethod + def _overlap_dist(key1, key2, mask): + return np.sum(np.cumsum(np.logical_not(mask & (key1 == key2)), axis=-1) == 0, axis=-1) + + +@partial(jax.jit, static_argnames=("axis", "chunk_size", "ns")) +def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]: + spec = jax.tree.map(lambda x: [x] * ns, like_spec(val)) + + def _fn(val): + axis_ = axis % val.ndim + size = val.shape[axis_] + if size < chunk_size * ns: + min_len = chunk_size * ns + val = jnp.pad(val, [(0, 0) if i != axis_ else (0, min_len - val.shape[axis_]) for i in range(val.ndim)]) + index = [slice(None) if i != axis_ else slice(0, ns * chunk_size) for i in range(val.ndim)] + return jnp.split(val[*index], ns, axis=axis_)[:ns] + + return auto_axes(lambda vals: jax.tree.map(_fn, vals), out_sharding=spec)(val) + + +@partial(jax.jit, static_argnames=("split_axis",)) +def _concat(values, split_axis: int): + _fn = lambda vals: jax.tree.map(lambda *args: jnp.concatenate(args, axis=split_axis), *vals) + return auto_axes(_fn, out_sharding=like_spec(values[0]))(values) + + +def insert_prefix( + prefix_cache: TrieNode, + sequence: jax.Array, + value: PyTree, + *, + chunk_size: int, + split_axis: int, + pad_idx: int = 2**30, + executor: ThreadPoolExecutor | None = None, + mesh: Any | None = None, +): + del executor + sequence = np.array(sequence) + assert sequence.ndim == 1 + sequence = TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) + ns = sequence.shape[-1] // chunk_size + sequence_chunks = np.split(sequence, ns) + + # split the value, but only if it's needed for non-cache hit + value_chunks = None + + def lazy_get_value(idx): + nonlocal value_chunks + if value_chunks is None: + value_leaves, value_struct = jax.tree.flatten(value) + with use_mesh(mesh): + split_leaves = _split(value_leaves, axis=split_axis, chunk_size=chunk_size, ns=ns) + value_chunks = [jax.tree.unflatten(value_struct, [x[i] for x in split_leaves]) for i in range(ns)] + return value_chunks[idx] + + # walk the prefix cache tree + with prefix_cache.lock: + node = prefix_cache + for seq_idx, seq in enumerate(sequence_chunks): + if len(node.children) == 0: + node.child_keys = TrieNode._append_key(node.child_keys, seq, len(node.children), pad_idx=pad_idx) + node.children.append(TrieNode(TrieNode.new_id(), seq, lazy_get_value(seq_idx))) + node = node.children[-1] + continue + left_mask, right_mask = (seq != pad_idx), (node.child_keys != pad_idx) + left_dist = TrieNode._dist_to_key(seq, node.child_keys, left_mask, pad_idx=pad_idx) + right_dist = TrieNode._dist_to_key(seq, node.child_keys, right_mask, pad_idx=pad_idx) + left_idx, right_idx = np.argmin(left_dist), np.argmin(right_dist) + if node.children and right_dist[right_idx] == 0: # this sequence is longer + if left_dist[right_idx] > 0: + node.children[right_idx].key = seq + node.children[right_idx].value = lazy_get_value(seq_idx) + node.child_keys[right_idx, :] = seq + else: # exact sequence exists + node.children[right_idx].usage += 1 + pass + node = node.children[right_idx] + elif left_dist[left_idx] == 0: # longer sequence already exists + node.children[left_idx].usage += 1 + assert seq_idx == len(sequence_chunks) - 1 + return + else: # no exact match + node.child_keys = TrieNode._append_key(node.child_keys, seq, len(node.children), pad_idx=pad_idx) + node.children.append(TrieNode(TrieNode.new_id(), seq, lazy_get_value(seq_idx))) + node = node.children[-1] + + +def retrieve_prefix( + prefix_cache: TrieNode, + sequence: jax.Array, + *, + chunk_size: int, + split_axis: int, + pad_idx: int = 2**30, + executor: ThreadPoolExecutor | None = None, + mesh: Any | None = None, +): + sequence, total_match = np.array(sequence), 0 + assert sequence.ndim == 1 + sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) + ns = sequence.shape[-1] // chunk_size + values, sequence_chunks = [], np.split(sequence, ns) + + def _construct_output(): + if sequence_len != total_match: + return None, total_match + for i, value in enumerate(values): + if is_type(value, OffloadedValue): + _load = lambda value: jax.block_until_ready(device_put(value.ref, like_shard(value.spec, mesh))) + values[i] = _load(value) if executor is None else executor.submit(_load, value) + + values_future = lambda: [value.result() if hasattr(value, "result") else value for value in values] + return (executor.submit(values_future) if executor is not None else values_future()), total_match + + node = prefix_cache + for seq in sequence_chunks: + if len(node.children) == 0: # cache ran out of node + return _construct_output() + left_mask = seq != pad_idx + overlaps = TrieNode._overlap_dist(node.child_keys, seq, left_mask) + max_idx = np.argmax(overlaps) + max_overlap = overlaps[max_idx] + if max_overlap == 0: + return _construct_output() + with prefix_cache.lock: + node.children[max_idx].usage += 1 + values.append(node.children[max_idx].value) + node, total_match = node.children[max_idx], total_match + max_overlap + # exit early if the entire chunk wasn't found + if max_overlap != np.sum(left_mask): + break + return _construct_output() + + +def offload_nodes(prefix_cache: TrieNode, how_many: int = 3): + # work in progress, not tested, will probably not work + # TODO: switch to [memories](https://docs.jax.dev/en/latest/notebooks/host-offloading.html) + node_queue, all_nodes = [prefix_cache], [] + with prefix_cache.lock: + while len(node_queue) > 0: + node = node_queue.pop(0) + for child in node.children: + node_queue.append(child) + all_nodes.append(child) + sorted_nodes = sorted(all_nodes, key=lambda x: x.usage) + offloaded = 0 + for i, node in enumerate(sorted_nodes): + if offloaded >= how_many: + break + if is_type(node.value, OffloadedValue): + continue + value = jax.tree.map(partial(np.asarray, copy=False), jax.device_put(node.value, jax.devices("cpu")[0])) + node.value = OffloadedValue(value, like_spec(node.value), jax.tree.map(jax.typeof, node.value)) + + +######################################################################################################################## +# serving loop ######################################################################################################### +######################################################################################################################## + +next_power_of_2 = lambda x: 2 ** round(math.ceil(math.log2(x))) +like_spec = lambda z: jax.tree.map(lambda x: jax.typeof(x).sharding.spec, z) +like_shard = lambda z, mesh: jax.tree.map(lambda x: NamedSharding(mesh, jax.typeof(x).sharding.spec), z) + + +@dataclasses.dataclass +class ServingConfig: + decode_steps: int = 10 + decode_batch_size: int = 16 + prefill_batch_size: int = 4 + prefix_chunk_size: int = 512 + eos_tokens: tuple[int, ...] | jax.Array = () + token_pad_idx: int = 0 + max_decode_length: int = 64 + + +@dataclasses.dataclass +class UserRequestPrompt: + id: int + text: str + + +@dataclasses.dataclass +class DecodeResult: + id: int + token_list: list[int] + tokens_decoded: int = 0 + done: bool = False + + +@dataclasses.dataclass +class PrefillResult: + id: int + input: np.ndarray + next_token: jax.Array + cache_entry: Any + len: int + + +@dataclasses.dataclass +class DecodeWork: + curr_tokens: jax.Array # [B, 1] to conform with the general forward fn expecting a sequence dimension + cache: KVCache + active_results: list[DecodeResult | None] + + +@dataclasses.dataclass +class PrefillWork: + requests: list[UserRequestPrompt] + to_prefill: list[UserRequestPrompt] + to_decode: list[PrefillResult] + pending_prefill: Future | None = None + pending_cache_retrievals: list[tuple[UserRequestPrompt, Future]] = dataclasses.field(default_factory=list) + + +def return_request(resp: DecodeResult): + # an optional callback called with results available on decode nodes only + # something happens here to output the response to the global queue + # print(f"Finished request: {resp.id}") + pass + + +class SyncServer: + """A regular local network server for syncing between JAX processes in the multi-process JAX setup.""" + + CLIENT = None + TIMEOUT_SEC = 60 + + @staticmethod + def _get_client(): + if SyncServer.CLIENT is None: + SyncServer.CLIENT = distributed.global_state.client + return SyncServer.CLIENT + + @staticmethod + def barrier(key: str, current_it: int) -> None: + client = SyncServer._get_client() + if client is None: + return + client.wait_at_barrier(key + str(current_it), timeout_in_ms=SyncServer.TIMEOUT_SEC * 1000) + + @staticmethod + def broadcast(key: str, current_it: int, value: Any, is_source: bool = False, jsonify: bool = True) -> None: + client = SyncServer._get_client() + if client is None: + return value + if is_source: + client.key_value_set(key + str(current_it), json.dumps(value) if jsonify else value) + return value + else: + value = client.blocking_key_value_get(key + str(current_it), SyncServer.TIMEOUT_SEC * 1000) + return json.loads(value) if jsonify else value + + +def _make_multistep_decode_fn(decode_fn): + @partial(jax.jit, static_argnames=("steps",), donate_argnames=("cache",)) + def multistep_decode_fn(curr_tokens, decode_weights, cache, cfg, steps: int = 32): + def body(carry, _): + curr_tokens, cache = carry + next_tokens, cache = decode_fn(curr_tokens, decode_weights, cache, cfg) + return (next_tokens, cache), next_tokens + + (curr_tokens, cache), output_tokens = jax.lax.scan(body, (curr_tokens, cache), length=steps) + return (curr_tokens, cache), output_tokens[..., 0].T + + return multistep_decode_fn + + +def _make_stacked_prefill(prefill_fn): + def _numpy_pad_tokens(tokens): + opts = dict(mode="constant", constant_values=0) + return np.pad(tokens, [(0, 0), (0, next_power_of_2(tokens.shape[-1]) - tokens.shape[-1])], **opts) + + @jax.jit + def stacked_prefill(inputs, weights, cfg): + next_tokens, logits, kv_list = prefill_fn(inputs, weights, None, cfg) + assert len(kv_list) == cfg.num_layers, "The output kv values have to be in a list kv pairs." + stacked_kv = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *kv_list) + return next_tokens, logits, stacked_kv + + return lambda inputs, weights, cfg: stacked_prefill(_numpy_pad_tokens(inputs), weights, cfg) + + +class ServingLoop: + def __init__( + self, + serve_cfg: ServingConfig, + cfg: Config, + prefill_fn: Callable, + prefill_weights: Weights, + decode_fn: Callable, + decode_weights: Weights, + decode_cache: KVCache, + ): + self.serve_cfg, self.cfg = serve_cfg, cfg + + # setup decode + self.decode_fn, self.decode_weights = decode_fn, decode_weights + self.decode_mesh = [x for x in jax.tree.leaves(decode_weights) if hasattr(x, "sharding")][0].sharding.mesh + with use_mesh(self.decode_mesh): + self.decode_work = DecodeWork(None, decode_cache, [None for _ in range(serve_cfg.decode_batch_size)]) + self.decode_work.curr_tokens = jax.device_put( + jnp.zeros((serve_cfg.decode_batch_size, 1), dtype=jnp.int32), P() + ) + self.multistep_decode_fn = _make_multistep_decode_fn(self.decode_fn) + self._update_index = jax.jit(lambda x, i, new: x.at[i, ...].set(new[:, None], mode="drop")) + + def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, kvs, batch_idxs, actual_lens): + length_sort = sorted( + range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2] + ) # sort to minimize variants num + new_cache = decode_cache.insert_sequences( + cache, *[[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] + ) + with use_mesh(self.decode_mesh): + new_curr_tokens = self._update_index(curr_tokens, np.array(batch_idxs), new_tokens) + return new_cache, new_curr_tokens + + self._update_cache_and_index = _update_cache_and_index + self.decode_output = (None, None) + + # setup prefill + self.prefill_fn = staticmethod(_make_stacked_prefill(prefill_fn)) + self.prefill_weights = prefill_weights + self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh + self.prefill_work = PrefillWork([], [], []) + self.prefix_cache = TrieNode(TrieNode.new_id(), None, None, lock=threading.Lock()) + self._get_index = jax.jit(lambda z, idx: jax.tree.map(lambda x: x[:, idx, ...], z)) + self._get_cache_entry = jax.jit(self.decode_work.cache.get_sequence) + + # setup misc + self.pending_requests, self.requests_lock, self.results = [], threading.Lock(), {} + self.pad_id, self.eos_tokens, self.time_axis = 0, np.array(serve_cfg.eos_tokens), TIME_AXIS + self._background = ThreadPoolExecutor(max_workers=1024) + + # setup profiling + self.profile_start_time, self.profiling = -1, False + + # setup cache management + # -1 for missing batch dimensiona and + 1 for layers being stacked + self.prefix_cache, self._retrieve_prefix, self._insert_prefix = None, None, None + self.new_prefix_cache() + + # setup the sync server for multi-host + self._it, self.roles = 0, (("server",) if jax.process_index() == 0 else ()) # main server + if any(d.id in [d_.id for d_ in self.decode_mesh.devices.reshape(-1)] for d in jax.local_devices()): + self.roles += ("decode",) # any node which has decode mesh devices + if any(d.id in [d_.id for d_ in self.prefill_mesh.devices.reshape(-1)] for d in jax.local_devices()): + self.roles += ("prefill",) # any node which has prefill devices + if any(d.id == min([d_.id for d_ in self.decode_mesh.devices.reshape(-1)]) for d in jax.local_devices()): + self.roles += ("decode_coordinator",) # the decode node which holds the smallest decode mesh device + if any(d.id == min([d_.id for d_ in self.prefill_mesh.devices.reshape(-1)]) for d in jax.local_devices()): + self.roles += ("prefill_coordinator",) # the prefill node which holds the smallest prefill mesh device + self.total_requests = 0 + + def decode_step(self): + # TODO: a more intelligent decision between decode and prefill (adaptive strategies, prefill queue size) + + # 1. add outstanding ready to decode prefill result to the active decode + # - some cache entries require some computation, so they're a callable + # - some cache entries are not on the correct decode_mesh + if len(self.prefill_work.to_decode) > 0: + batch_cache_updates = [] + for i, active_result in enumerate(self.decode_work.active_results): + if active_result is not None: + continue + if len(self.prefill_work.to_decode) == 0: + break + result: PrefillResult = self.prefill_work.to_decode.pop(0) + self.decode_work.active_results[i] = DecodeResult(result.id, result.input.tolist()) + with use_mesh(self.decode_mesh): + result.cache_entry = result.cache_entry() if callable(result.cache_entry) else result.cache_entry + result.cache_entry = _ensure_all_args_on_mesh(result.cache_entry, mesh=self.decode_mesh) + self.results[result.id] = self.decode_work.active_results[i] + batch_cache_updates.append((result.cache_entry, i, result.len, result.next_token)) + if len(self.prefill_work.to_decode) == 0: + break + if "decode" in self.roles and len(batch_cache_updates) > 0: # batch cache update + entries, batch_idxs, lens, next_tokens = map(list, zip(*batch_cache_updates)) + entries = [entry.result() if hasattr(entry, "result") else entry for entry in entries] # maybe collect + _control_args = (np.array(next_tokens), entries, batch_idxs, lens) + self.decode_work.cache, self.decode_work.curr_tokens = self._update_cache_and_index( + self.decode_work.cache, self.decode_work.curr_tokens, *_control_args + ) + + if all(x is None for x in self.decode_work.active_results): + return # skip decoding if no decoding tasks are present + + # 2. run N decode steps + output_tokens, output_mapping = [], [] + if "decode" in self.roles: # cut a corner, don't issue the decode call on non-participating machines + with use_mesh(self.decode_mesh): + config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) + (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn( + self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config + ) + output_mapping = [ + [getattr(result, "id", -1) for result in self.decode_work.active_results] + ] * self.serve_cfg.decode_steps + output_mapping = np.array(output_mapping).T + print( + f"Decoding with fill rate of {np.mean([result is not None for result in self.decode_work.active_results])}" + ) + + # 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection) + self.decode_output, (output_tokens, output_mapping) = (output_tokens, output_mapping), self.decode_output + if output_tokens is not None: + SyncServer.barrier("output_tokens", self._it) + if "decode" in self.roles: + output_tokens = np.array(output_tokens) + done = np.any(output_tokens[..., None] == self.eos_tokens, (-1, -2)).tolist() # check for done + done = [ + d or getattr(result, "tokens_decoded", 0) >= self.serve_cfg.max_decode_length + for d, result in zip(done, self.decode_work.active_results) + ] + else: + output_tokens, done = None, None + done = SyncServer.broadcast("done_sync", self._it, done, is_source="decode" in self.roles) + if "server" in self.roles: + for token, id in zip(output_tokens.reshape(-1).tolist(), output_mapping.reshape(-1).tolist()): + if id > 0: + self.results[id].token_list.append(token) + self.results[id].tokens_decoded += 1 + with use_mesh(self.decode_mesh): + for i, result in enumerate(self.decode_work.active_results): + if result is None: + continue + # 2. check for done sequences; evict them if done and return them + if done[i]: + if USE_PREFIX_CACHE: + sequence = np.array(result.token_list) + with use_mesh(self.decode_mesh): + cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) + self._background.submit(self._insert_prefix, sequence, cache_entry, mesh=self.decode_mesh) + return_request(result) + result.done, self.decode_work.active_results[i] = True, None + + def prefill_step(self): + # 1. check on any finished prefill jobs + if self.prefill_work.pending_prefill is not None: + prefill_is_done, is_source = self.prefill_work.pending_prefill.done(), "prefill_coordinator" in self.roles + prefill_is_done = SyncServer.broadcast("prefill_done", self._it, prefill_is_done, is_source=is_source) + if prefill_is_done: + prefill_input, prefill_results = self.prefill_work.pending_prefill.result() + for i, request in enumerate(prefill_input): + with use_mesh(self.prefill_mesh): + kv_list = self._get_index(prefill_results, i) + id, input = request.id, np.array(request.text) + self.prefill_work.to_decode.append(PrefillResult(id, input, input[-1], kv_list, len(input) - 1)) + self.prefill_work.pending_prefill = None + + # 2. triage requests queue into cached (-> decode) and not-cached work (-> prefill) + new_pending_retrievals = [] + for request, cache_entry_fut in self.prefill_work.pending_cache_retrievals: + if len(self.prefill_work.to_decode) < self.serve_cfg.decode_batch_size and cache_entry_fut.done(): + with use_mesh(self.decode_mesh): + # batch missing (-1) layers concatenated (+1) + cache_entry = partial(_concat, cache_entry_fut.result(), self.time_axis - 1 + 1) # jit work future + new_decode = PrefillResult( + request.id, np.array(request.text), request.text[-1], cache_entry, len(request.text) - 1 + ) + self.prefill_work.to_decode.append(new_decode) + else: + new_pending_retrievals.append((request, cache_entry_fut)) # not yet ready + self.prefill_work.pending_cache_retrievals = new_pending_retrievals + + # 3. check if prefixes are in the cache + retrieval_results = self._background.map( + lambda request: (self._retrieve_prefix(np.array(request.text[:-1])), request), self.prefill_work.requests + ) + for (cache_entry_fut, length), request in retrieval_results: + if length == len(request.text) - 1: + self.prefill_work.pending_cache_retrievals.append((request, cache_entry_fut)) + print(f"Found full prefill match in the cache") + else: + print(f"Need to prefill the request, only found a match for length {length / (len(request.text) - 1)}") + self.prefill_work.to_prefill.append(request) + self.prefill_work.requests.clear() + + if self.prefill_work.pending_prefill is not None: # a current prefill is still running, skip scheduling another + return + + # 4. prefill requests to be prefilled + prefill_input = self.prefill_work.to_prefill[: self.serve_cfg.prefill_batch_size] + self.prefill_work.to_prefill = self.prefill_work.to_prefill[len(prefill_input) :] + if len(prefill_input) > 0: + # disaggregated server via async on a subset of devices + def _prefill_job(): + max_len = max([len(request.text) for request in prefill_input]) + inputs = [[self.pad_id] * (max_len - len(request.text)) + request.text for request in prefill_input] + inputs = np.stack([np.array(input) for input in inputs], 0) + row_pad = self.serve_cfg.prefill_batch_size - inputs.shape[0] + col_pad = next_power_of_2(inputs.shape[-1]) - inputs.shape[-1] + inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) + cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) + with use_mesh(self.prefill_mesh): + _, _, prefill_results = self.prefill_fn(inputs, self.prefill_weights, cfg) + prefill_results = jax.block_until_ready(prefill_results) + return prefill_input, prefill_results + + self.prefill_work.pending_prefill = self._background.submit(_prefill_job) + + def serving_step(self): + # this event loop relies on determinism for issuing computation to multiple processes (multi-process JAX) + # frequent barriers should keep it in sync + + # potentially profile when received the request to ######################################### + should_start_profile = self.profile_start_time > 0 and not self.profiling + should_start_profile = SyncServer.broadcast( + "profile", self._it, should_start_profile, is_source="server" in self.roles + ) + if should_start_profile: + self.profile_start_time, self.profiling = time.perf_counter(), True + jax.profiler.start_trace("/tmp/online") + print("STARTING TRACE") + should_stop_profile = self.profile_start_time > 0 and time.perf_counter() - self.profile_start_time > 5.0 + should_stop_profile = SyncServer.broadcast( + "stop_profile", self._it, should_stop_profile, is_source="server" in self.roles + ) + if should_stop_profile: + self.profile_start_time, self.profiling = -1, False + print("STOPPING TRACE") + jax.profiler.stop_trace() + # potentially profile when received the request to ######################################### + + # sync on the server requests received ##################################################### + SyncServer.barrier("serving_step", self._it) + self._it, requests = self._it + 1, None + if "server" in self.roles: + with self.requests_lock: + self.pending_requests, requests = [], list(self.pending_requests) + requests = SyncServer.broadcast("requests", self._it, requests, is_source="server" in self.roles) + for request in requests: + self.total_requests += 1 + self.prefill_work.requests.append(UserRequestPrompt(**request)) + # sync on the server requests received ##################################################### + + # main event loop work ##################################################################### + self.decode_step() + self.prefill_step() + # main event loop work ##################################################################### + + # manage cache ############################################################################# + # TODO: test and configure host offloading for the cache + if USE_PREFIX_CACHE and len(self.prefix_cache.children) > 100: # clear the cache after 100 root children + self.new_prefix_cache() + # manage cache ############################################################################# + + def add_request(self, request: UserRequestPrompt): + with self.requests_lock: + self.pending_requests.append(dataclasses.asdict(request)) + + def new_prefix_cache(self): + self.prefix_cache = TrieNode(TrieNode.new_id(), None, None, lock=threading.Lock()) + _prefix_opts = dict(chunk_size=self.serve_cfg.prefix_chunk_size) + _prefix_opts |= dict(split_axis=self.time_axis - 1 + 1, mesh=self.decode_mesh, executor=self._background) + self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, **_prefix_opts) + self._insert_prefix = partial(insert_prefix, self.prefix_cache, **_prefix_opts) diff --git a/serving/serving_jax/cross_host.py b/serving/serving_jax/cross_host.py new file mode 100644 index 0000000..7c59fce --- /dev/null +++ b/serving/serving_jax/cross_host.py @@ -0,0 +1,64 @@ +"""This file implements cross-host device_put in multi-process JAX - a temporary workaround until jax.device_put update.""" + +from functools import lru_cache +from typing import Any + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + +jax.config.update("jax_enable_empty_arrays", True) +PyTree = Any + + +@lru_cache +def combine_meshes(meshA, meshB): + if not (meshA.devices.shape == meshB.devices.shape and meshA.axis_names == meshB.axis_names): + raise ValueError("Meshes shapes and specs must match") + devices = np.stack([meshA.devices, meshB.devices], axis=0) + axis_names = ("cross_mesh",) + tuple(meshA.axis_names) + axis_types = tuple(meshA.axis_types)[:1] + tuple(meshA.axis_types) + return Mesh(devices, axis_names, axis_types=axis_types) + + +@jax.jit +def _prepare_arrays(xs: list[jax.Array]): + return jax.tree.map(lambda x: x[None, ...], xs) + + +@lru_cache +def _make_zeros(sds, shardings): + new_shardings = tuple(NamedSharding(sd.mesh, P(None, *sd.spec)) for sd in shardings) + new_sds = tuple(jax.ShapeDtypeStruct((1,) + sd.shape, sd.dtype) for sd in sds) + return jax.jit( + lambda: jax.tree.map(lambda s: jnp.zeros(s.shape, dtype=s.dtype), new_sds), out_shardings=new_shardings + )() + + +@jax.jit +def _combine(xs: list[jax.Array]): + return jax.tree.map(lambda x: jnp.sum(x, axis=0).astype(x.dtype), xs) + + +def transfer_tree_A2B(xs: PyTree, meshA, meshB): + if meshA == meshB: + return xs + meshC = combine_meshes(meshA, meshB) + xs, xs_struct = jax.tree.flatten(xs) + combined_sharding = [NamedSharding(meshC, P("cross_mesh", *x.sharding.spec)) for x in xs] + dest_sharding = [NamedSharding(meshB, x.sharding.spec) for x in xs] + dest_arrays = _make_zeros(tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in xs), tuple(dest_sharding)) + all_arrays = [x_src._arrays + x_dest._arrays for x_src, x_dest in zip(_prepare_arrays(xs), dest_arrays)] + xs_combined = [ + jax.make_array_from_single_device_arrays((2,) + x.shape, sharding, arrays, dtype=x.dtype) + for (x, arrays, sharding) in zip(xs, all_arrays, combined_sharding) + ] + xs_repl = _combine(xs_combined) # issue collectives under jit + xs_new = [ + jax.make_array_from_single_device_arrays( + x_src.shape, sharding, x_new._arrays[len(x_src._arrays) :], dtype=x_src.dtype + ) + for x_new, x_src, sharding, x_dest in zip(xs_repl, xs, dest_sharding, dest_arrays) + ] + return jax.tree.unflatten(xs_struct, xs_new) From f64f4152da84d927fe8c811b284b83b3643ac88b Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 24 Jul 2025 17:18:13 -0700 Subject: [PATCH 02/11] generalize attention utils --- llama3/llama3_jax/attention_cache_utils.py | 58 +++++++++++++--------- llama3/llama3_jax/model.py | 49 +++++++++--------- llama3/pyproject.toml | 2 + llama3/tests/test_model.py | 4 +- serving/main_serving.py | 19 ++++--- serving/serving_jax/__init__.py | 2 +- 6 files changed, 75 insertions(+), 59 deletions(-) diff --git a/llama3/llama3_jax/attention_cache_utils.py b/llama3/llama3_jax/attention_cache_utils.py index abfcbe0..5367cd1 100644 --- a/llama3/llama3_jax/attention_cache_utils.py +++ b/llama3/llama3_jax/attention_cache_utils.py @@ -18,6 +18,13 @@ _pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)]) +def safe_zip(*args): + if len(args) == 0: + return [] + assert all(len(arg) == len(args[0]) for arg in args) + return zip(*args) + + def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): "From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list." @@ -28,7 +35,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): for i, c in enumerate(kv_list[0]): els = [[_split(z) for z in jax.tree.leaves(kv[i])] for kv in kv_list] # [B, R_flat, L] els = jax.tree.map(lambda *xs: jnp.concatenate(xs, axis=0), *els) # [R_flat, L] - leaves_list = list(zip(*els)) # [L, R_flat] + leaves_list = list(safe_zip(*els)) # [L, R_flat] out[i] = [jax.tree.unflatten(jax.tree.structure(c), leaves) for leaves in leaves_list] # [L, R] return tuple(out), max_seq_len @@ -41,7 +48,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): @partial(jax.jit, donate_argnames=("cache",)) def _kvcache_update_cache( cache: KVCache, - kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], update_mask: list[bool] | None = None, @@ -62,15 +69,17 @@ def _update_element(x, u): # update_permute = [batch_dim, time_dim] + update_permute return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop") - cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs) + cache_kvs = jax.tree.map(_update_element, cache.buffers, kvs) cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop") cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter) - return dataclasses.replace(cache, k=cache_k, v=cache_v, iter=cache_iter, starts=cache_starts) + + buffer_names = [field.name for field in dataclasses.fields(cache)][:len(cache_kvs)] + return dataclasses.replace(cache, **dict(safe_zip(buffer_names, cache_kvs)), iter=cache_iter, starts=cache_starts) def kvcache_update_cache( cache: KVCache, - kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], ): @@ -85,7 +94,7 @@ def kvcache_update_cache( def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array): shift = -cache.starts[batch_idx] assert cache.time_axis > 0 - kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), (cache.k, cache.v)) + kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), cache.buffers) kvs = (jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[0]), jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[1])) true_len = cache.fill_len()[batch_idx] return kvs, true_len @@ -109,13 +118,13 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | return jax.lax.top_k(free_pages, k)[1] -def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int): - key_heads = cache.k[layer_idx].shape[0] - assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) +def _paged_update_slice(cache: PagedKVCache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int): + #key_heads = cache.buffers[0][layer_idx].shape[0] + #assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) # TODO write this generically needs_next_page = (cache.lengths % cache.page_size) == 0 page_table_idx = cache.lengths // cache.page_size current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0] - avg_pages_per_batch_entry = round(cache.k[layer_idx].shape[0] / cache.batch_size) + avg_pages_per_batch_entry = round(cache.buffers[0][layer_idx].shape[0] / cache.batch_size) even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1) free_pages = _find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages) @@ -127,27 +136,28 @@ def _paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.A # for batch index update the target slice is (heads, i, j, head_dim) # so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim) _update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1)) - cache.k[layer_idx], cache.v[layer_idx] = jax.tree.map(_update, (cache.k[layer_idx], cache.v[layer_idx]), (k, v)) + for buffer, new_buffer in safe_zip(cache.buffers, kv): + buffer[layer_idx] = jax.tree.map(_update, buffer[layer_idx], new_buffer) batch_idx = jnp.arange(cache.batch_size) new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor) new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop") new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages) - return cache.k[layer_idx], cache.v[layer_idx], new_state + return tuple(buffer[layer_idx] for buffer in cache.buffers), new_state -def paged_update_slice(cache: PagedKVCache, k: jax.Array | QuantArray, v: jax.Array | QuantArray, *, layer_idx: int): +def paged_update_slice(cache: PagedKVCache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int): repl_sharding = jax.typeof(cache.lengths).sharding - kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, (cache.k[layer_idx], cache.v[layer_idx])) - sharding = (*kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding)) - return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, k, v) + kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, tuple(buffer[layer_idx] for buffer in cache.buffers)) + sharding = (kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding)) + return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, kv) @partial(jax.jit, donate_argnames=("cache",)) def _batch_paged_update_sequences( cache: PagedKVCache, - kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], update_mask: list[bool] | None = None, @@ -156,9 +166,7 @@ def _batch_paged_update_sequences( batch_idxs = jnp.where(update_mask, jnp.array(batch_idxs), 2**30) # send masked to nowhere actual_lens = jnp.minimum(jnp.array(actual_lens), jnp.array([jax.tree.leaves(kv)[0].shape[2] for kv in kvs])) - kvs, max_seq_len = _transpose_attention_tree( - kvs, time_axis=2 - ) # undo stacking along the layer dimension for transit + kvs, max_seq_len = _transpose_attention_tree(kvs, time_axis=2) # undo stack along layer dimension in transit # clear existing pages actual_page_num = jnp.rint(jnp.ceil(cache.lengths[batch_idxs] / cache.page_size)).astype(jnp.int32) @@ -186,21 +194,23 @@ def _update_element(x, u): update_permute = [1, 0, 2] + [i for i in range(u.ndim) if i not in (0, 1, 2)] return x.at[:, pages_idx, ...].set(u.transpose(update_permute), mode="drop") - cache_k, cache_v = jax.tree.map(_update_element, (cache.k, cache.v), kvs) + new_buffers = jax.tree.map(_update_element, cache.buffers, kvs) block_tables_idx = jnp.where( update_mask[:, None] & (pages_arange[None, :] < actual_page_num[:, None]), pages_arange[None, :], 2**30 ) new_block_tables = cache.block_tables.at[batch_idxs[:, None], block_tables_idx].set(pages_idx, mode="drop") new_free_pages = new_free_pages.at[pages_idx.reshape(-1)].set(False, mode="drop") new_lengths = cache.lengths.at[batch_idxs].set(actual_lens, mode="drop") + + named_buffers = dict(zip([field.name for field in dataclasses.fields(cache)][:len(new_buffers)], new_buffers)) return dataclasses.replace( - cache, k=cache_k, v=cache_v, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages + cache, **named_buffers, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages ) def batch_paged_update_sequences( cache: KVCache, - kvs: list[tuple[list[jax.Array | QuantArray], list[jax.Array | QuantArray]]], + kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], ): @@ -222,5 +232,5 @@ def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len _get = lambda x: jnp.where(mask[None, :, *([None] * (x.ndim - 3))], _reshape_out(x[:, page_indices, ...]), 0) # stack along layer dimensions for transit - kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, (cache.k, cache.v))) + kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in jax.tree.map(_get, cache.buffers)) return kvs, true_len diff --git a/llama3/llama3_jax/model.py b/llama3/llama3_jax/model.py index 3d07be9..2866156 100644 --- a/llama3/llama3_jax/model.py +++ b/llama3/llama3_jax/model.py @@ -37,6 +37,7 @@ except ModuleNotFoundError: from jax.sharding import auto_axes as _auto_axes, reshard from jax.experimental.pallas.ops.gpu import paged_attention +from etils import epath from . import ragged_attention from . import attention_cache_utils @@ -216,7 +217,7 @@ class ArrayInfo: _count_left_padding = lambda ids, pad_id=0: auto_axes( lambda ids: jnp.sum(jnp.cumsum(ids != pad_id, axis=-1) == 0, axis=-1), out_sharding=P(None) )(ids) -_length_minus_padding = lambda segment_ids: auto_axes( +_length_minus_right_padding = lambda segment_ids: auto_axes( lambda segment_ids: jnp.sum(jnp.cumsum(jnp.flip(segment_ids != 0, -1), axis=-1) > 0, -1), out_sharding=P(None) )(segment_ids) @@ -411,7 +412,7 @@ class KVCache(_Init): iter: jax.Array # [] # sequences are right-aligned for slice update performance starts: jax.Array # [batch_size] # sequences are right-aligned, we need start indices batch_size: int = 0 - size: int = 0 + size: int = 2 ** 30 time_axis: int = 2 @classmethod @@ -428,6 +429,7 @@ def abstract(cls, cfg: Config, batch_size: int): # -1 means unintialized since iter (cursor) must be 0 <= iter < len - 1 iter=ArrayInfo((), jnp.int32, (), jax.nn.initializers.constant(-1)), starts=ArrayInfo((batch_size,), jnp.int32, ("batch",), jax.nn.initializers.zeros), + size=cfg.max_seq_len, ) if cfg.quant_cache: _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype, zero_init=True) @@ -447,8 +449,11 @@ def abstract(cls, cfg: Config, batch_size: int): return cache def fill_len(self) -> jax.Array: - length = jnp.where(self.iter > self.starts, self.iter - self.starts, self.size + self.iter - self.starts) - return jnp.where(self.iter >= 0, length, 0) + return jnp.where(self.iter >= 0, (self.iter - self.starts) % self.size, 0) + + @property + def buffers(self) -> tuple[jax.Array, ...]: + return (self.k, self.v) update_slice = None insert_sequences = staticmethod(attention_cache_utils.kvcache_update_cache) @@ -463,7 +468,7 @@ class PagedKVCache(_Init): block_tables: jax.Array # [batch_size, pages_per_seq] free_pages: jax.Array # [total_num_pages] batch_size: int = 0 - size: int = 2**31 - 1 + size: int = 2**30 page_size: int = 0 @classmethod @@ -501,6 +506,10 @@ def abstract(cls, cfg: "Config", batch_size: int, total_num_pages: int, page_siz def fill_len(self) -> jax.Array: return self.lengths + @property + def buffers(self) -> tuple[jax.Array, ...]: + return (self.k, self.v) + update_slice = staticmethod(attention_cache_utils.paged_update_slice) insert_sequences = staticmethod(attention_cache_utils.batch_paged_update_sequences) get_sequence = staticmethod(attention_cache_utils.batch_paged_get_entry) @@ -807,12 +816,9 @@ def attention_block( q, k = apply_rotary_embedding(q, sin, cos), apply_rotary_embedding(k, sin, cos) if cfg.quant_cache: - k = QuantArray( - *quantize(k, -1, scale_dtype=cfg.quant_scale_dtype), out_scaling=True, scale_expand_dims=(-2, -3) - ) - v = QuantArray( - *quantize(v, -1, scale_dtype=cfg.quant_scale_dtype), out_scaling=False, scale_expand_dims=(-2, -3) - ) + _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype) + k = QuantArray(*_quantize(k), out_scaling=True, scale_expand_dims=(-2, -3)) + v = QuantArray(*_quantize(v), out_scaling=False, scale_expand_dims=(-2, -3)) with jax.named_scope("cache_update"): paged_state, starts = None, None @@ -825,23 +831,21 @@ def attention_block( ) % cache.size # [B, T] q_segment_ids = jnp.where(segment_ids != 0, 1, 0) - incremental_position = jnp.max(_length_minus_padding(segment_ids)) + incremental_position = jnp.max(_length_minus_right_padding(segment_ids)) # i.e. valid below where we've written things [B, T] - kv_segment_ids = ( - (time_indices >= 0) & (time_indices < cache.fill_len()[:, None] + incremental_position) - ).astype(jnp.int32) - q_offset = cache.fill_len() - _count_left_padding(segment_ids) + kv_segment_ids = (time_indices >= 0) & (time_indices < cache.fill_len()[:, None] + incremental_position) + q_offset = cache.fill_len() - _count_left_padding(segment_ids, 0) # 0 is the pad "token" for segment_ids starts, lengths = cache.starts, cache.fill_len() cache_updates = (k, v) elif is_type(cache, PagedKVCache): cache: PagedKVCache - k, v, paged_state = PagedKVCache.update_slice(cache, k=k, v=v, layer_idx=idx) + (k, v), paged_state = PagedKVCache.update_slice(cache, (k, v), layer_idx=idx) cache_updates = (k, v, paged_state) else: # this supports prefill only; no support for a ring cache buffer here q_segment_ids, kv_segment_ids = segment_ids, segment_ids q_offset = jnp.zeros(x.shape[0], dtype=jnp.int32) - starts, lengths = _count_left_padding(segment_ids, 0), _length_minus_padding(kv_segment_ids) + starts, lengths = _count_left_padding(segment_ids, 0), _length_minus_right_padding(kv_segment_ids) cache_updates = (k, v) # Compute attention @@ -931,15 +935,12 @@ def forward( x, cache_updates = forward_layer(x, segment_ids, layer, sin, cos, idx, cfg, cache) all_cache_updates.append(cache_updates) - # Final layer norm. - x = rms_norm(x, weights.gamma_final) - - # Project to vocabulary size - logits = einsum("btd,dv->btv", x, weights.lm_head) + x = rms_norm(x, weights.gamma_final) # Final layer norm. + logits = einsum("btd,dv->btv", x, weights.lm_head) # Project to vocabulary size if is_type(cache, KVCache): cache.k, cache.v = [z[0] for z in all_cache_updates], [z[1] for z in all_cache_updates] - new_iter = (jnp.maximum(0, cache.iter) + jnp.max(_length_minus_padding(segment_ids))) % cache.size + new_iter = (jnp.maximum(0, cache.iter) + jnp.max(_length_minus_right_padding(segment_ids))) % cache.size cache = dataclasses.replace(cache, iter=new_iter) return logits, cache elif is_type(cache, PagedKVCache): diff --git a/llama3/pyproject.toml b/llama3/pyproject.toml index 5fa6905..89e5261 100644 --- a/llama3/pyproject.toml +++ b/llama3/pyproject.toml @@ -19,6 +19,8 @@ dependencies = [ #"datasets", "gcsfs", "etils", + "importlib_resources", + "absl-py", ] # we don't need CUDA torch diff --git a/llama3/tests/test_model.py b/llama3/tests/test_model.py index c1f749e..51ad44e 100644 --- a/llama3/tests/test_model.py +++ b/llama3/tests/test_model.py @@ -54,7 +54,7 @@ def test_model_init(self, quant): @parameterized.product(quant=[False, True]) def test_cache_init(self, quant): cfg = dataclasses.replace(self.small_cfg, quant_cache=quant) - cache = l3jax.KVCache.init(random.key(0), cfg, 2, cfg.max_seq_len) + cache = l3jax.KVCache.init(random.key(0), cfg, 2) del cache @parameterized.product(quant_weights=[False, True], quant_cache=[True, False]) @@ -62,7 +62,7 @@ def test_prefill_decode(self, quant_weights, quant_cache): cfg = dataclasses.replace(self.small_cfg, quant_layer=quant_weights, quant_cache=quant_cache) tokens = jnp.ones((1, 32), dtype=jnp.int32) weights = l3jax.Weights.init(random.key(0), cfg) - cache = l3jax.KVCache.init(random.key(0), cfg, tokens.shape[0], cfg.max_seq_len) + cache = l3jax.KVCache.init(random.key(0), cfg, tokens.shape[0]) with use_mesh(cfg.mesh): max_tokens, _, cache = l3jax.prefill(tokens, weights, cache, cfg) next_tokens = max_tokens[:, :-1] diff --git a/serving/main_serving.py b/serving/main_serving.py index 1ee265c..ebaf1ca 100644 --- a/serving/main_serving.py +++ b/serving/main_serving.py @@ -63,8 +63,8 @@ def _place_local(tree, sharding: NamedSharding, present: bool): def load_model(): global SERVE_LOOP, SERVING_THREAD, TOKENIZER - process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) - jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) + #process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) + #jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) print(jax.devices()) print("-" * 80) print(jax.local_devices()) @@ -78,8 +78,9 @@ def load_model(): # two hosts, different device and host meshes local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) - decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) - prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3) + decode_mesh, prefill_mesh = local_mesh, local_mesh + #decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) + #prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3) # single host, same decode and prefill meshes #local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) @@ -94,12 +95,14 @@ def load_model(): cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True) cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=8192) cfg = dataclasses.replace(cfg, quant_layer=False, quant_cache=False) + cfg.quant_cache = True weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=local_mesh))) # multi-host: until orbax update - decode_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh)), present=jax.process_index() == 0) - prefill_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh)), present=jax.process_index() == 1) + decode_weights, prefill_weights = weights, weights + #decode_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh)), present=jax.process_index() == 0) + #prefill_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh)), present=jax.process_index() == 1) # single-host: until orbax update #decode_weights = serving.device_put(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh))) @@ -108,7 +111,7 @@ def load_model(): print("---> Weights loaded") serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64) - # decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) + #decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) SERVE_LOOP = serving.ServingLoop( serve_cfg, cfg, l3jax.prefill, prefill_weights, l3jax.decode_step, decode_weights, decode_cache @@ -122,7 +125,7 @@ def serve_forever(): finally: print("Received a shutdown signal") time.sleep(0.1) - signal.raise_signal(signal.SIGINT) # shut down the web server + signal.raise_signal(signal.SIGKILL) # shut down the web server print("Exiting the serving loop") SERVING_THREAD = threading.Thread(target=serve_forever) diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index f17ccf8..3f16965 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -588,7 +588,7 @@ def decode_step(self): ] else: output_tokens, done = None, None - done = SyncServer.broadcast("done_sync", self._it, done, is_source="decode" in self.roles) + done = SyncServer.broadcast("done_sync", self._it, done, is_source="decode_coordinator" in self.roles) if "server" in self.roles: for token, id in zip(output_tokens.reshape(-1).tolist(), output_mapping.reshape(-1).tolist()): if id > 0: From 9eaf9710cb3f0119aa550e46e440eea7f6b027fe Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Fri, 25 Jul 2025 22:39:53 -0700 Subject: [PATCH 03/11] serving bug fixes, testing with deepseek --- deepseek_r1_jax/deepseek_r1_jax/model.py | 32 ++- llama3/llama3_jax/model.py | 63 +++++- serving/main_serving_ds_r1.py | 209 ++++++++++++++++++ serving/serving_jax/__init__.py | 48 ++-- .../serving_jax}/attention_cache_utils.py | 43 +--- 5 files changed, 323 insertions(+), 72 deletions(-) create mode 100644 serving/main_serving_ds_r1.py rename {llama3/llama3_jax => serving/serving_jax}/attention_cache_utils.py (79%) diff --git a/deepseek_r1_jax/deepseek_r1_jax/model.py b/deepseek_r1_jax/deepseek_r1_jax/model.py index 32c017c..93dc5e7 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/model.py +++ b/deepseek_r1_jax/deepseek_r1_jax/model.py @@ -327,7 +327,8 @@ def quantize(x: jax.Array | ArrayInfo, axis: int | tuple[int, ...], scale_dtype= def quantize_update_slice(x: QuantArray, y: jax.Array, pos: int, update_axis: int, quant_axis: int): assert x.quant.ndim == y.ndim quant_axis, update_axis = quant_axis % x.quant.ndim, update_axis % x.quant.ndim # normalize axis numbers - y_quant, y_scale = quantize(y, axis=quant_axis, scale_dtype=x.scale.dtype) # quantize rhs + #y_quant, y_scale = quantize(y, axis=quant_axis, scale_dtype=x.scale.dtype) # quantize rhs + y_quant, y_scale = y.quant, y.scale scale_update_axis = [ax for ax in range(x.quant.ndim) if ax != quant_axis][update_axis] # update axis in `scale` z_quant = jax.lax.dynamic_update_slice_in_dim(x.quant, y_quant.astype(x.quant.dtype), pos, axis=update_axis) z_scale = jax.lax.dynamic_update_slice_in_dim(x.scale, y_scale.astype(x.scale.dtype), pos, axis=scale_update_axis) @@ -587,9 +588,9 @@ def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int, dtype: int = j _init, ) k_pe_info = ArrayInfo( - (batch_size, max_seq_len, cfg.qk_rope_head_dim), + (batch_size, 1, max_seq_len, cfg.qk_rope_head_dim), dtype, - ("batch", "sequence", "head_dim"), + ("batch", None, "sequence", "head_dim"), _init, ) v_info = ArrayInfo( @@ -613,7 +614,7 @@ def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int, dtype: int = j for k_nope in cache.k_nope ] cache.k_pe = [ - QuantArray(*_quantize(k_pe), out_scaling=True, scale_expand_dims=(-2, -3)) + QuantArray(*_quantize(k_pe), out_scaling=True, scale_expand_dims=-2) for k_pe in cache.k_pe ] cache.v = [ @@ -772,7 +773,8 @@ def attention( _, h, T, _ = k_nope.shape qk = einsum("bhtd,bhTd->bhtT", q_nope, k_nope) - qk = qk + einsum("bhtd,bTd->bhtT", q_pe, k_pe) + #qk = qk + einsum("bhtd,bTd->bhtT", q_pe, k_pe) + qk = qk + einsum("bhtd,b1Td->bhtT", q_pe, k_pe) qk = qk * scale # [b, h, t, T] mask = make_attention_mask(t, T, q_segment_ids, kv_segment_ids, q_offset, kv_offset, cfg.causal) @@ -871,18 +873,27 @@ def mla_attention_block( with jax.named_scope("kv_compressed_embed"): kv_compressed = einsum("btd,dr->btr", x, attn_layer.kv_a).astype(dtype) kv_compressed = rms_norm(kv_compressed, attn_layer.kv_gamma).astype(dtype) - k_pe = einsum("btd,dq->btq", x, attn_layer.k_pe) - k_pe = apply_rotary_embedding(k_pe[..., None, :, :], sin, cos)[..., 0, :, :].astype(dtype) + #k_pe = einsum("btd,dq->btq", x, attn_layer.k_pe) + #k_pe = apply_rotary_embedding(k_pe[..., None, :, :], sin, cos)[..., 0, :, :].astype(dtype) + k_pe = einsum("btd,dq->btq", x, attn_layer.k_pe)[..., None, :, :] + k_pe = apply_rotary_embedding(k_pe, sin, cos).astype(dtype) with jax.named_scope("kv_embed"): k_nope = einsum("btr,rhq->bhtq", kv_compressed, attn_layer.k_b) v = einsum("btr,rhv->bhtv", kv_compressed, attn_layer.v_b) + if cfg.quantize_cache: + _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype) + k_nope = QuantArray(*_quantize(k_nope), out_scaling=True, scale_expand_dims=-2) + k_pe = QuantArray(*_quantize(k_pe), out_scaling=True, scale_expand_dims=-2) + v = QuantArray(*_quantize(v), out_scaling=False, scale_expand_dims=-2) + with jax.named_scope("full_cache_update"): if is_type(cache, KVCache): it = jnp.maximum(cache.iter, 0) k_nope = update_slice(cache.k_nope[idx], k_nope, it, update_axis=cache.time_axis) - k_pe = update_slice(cache.k_pe[idx], k_pe, it, update_axis=cache.time_axis - 1) + #k_pe = update_slice(cache.k_pe[idx], k_pe, it, update_axis=cache.time_axis - 1) + k_pe = update_slice(cache.k_pe[idx], k_pe, it, update_axis=cache.time_axis) v = update_slice(cache.v[idx], v, it, update_axis=cache.time_axis) cache_updates = (k_nope, k_pe, v) @@ -905,7 +916,8 @@ def mla_attention_block( lsc = partial(logical_sharding_constraint, mesh=cfg.mesh, rules=cfg.rules) spec = ("batch", "act_heads", "sequence", "head_dim") q_nope, q_pe = lsc(q_nope, spec), lsc(q_pe, spec) - k_nope, k_pe, v = lsc(k_nope, spec), lsc(k_pe, ("batch", "sequence", "head_dim")), lsc(v, spec) + #k_nope, k_pe, v = lsc(k_nope, spec), lsc(k_pe, ("batch", "sequence", "head_dim")), lsc(v, spec) + k_nope, k_pe, v = lsc(k_nope, spec), lsc(k_pe, ("batch", None, "sequence", "head_dim")), lsc(v, spec) # Compute attention with jax.named_scope("attention"): @@ -1251,7 +1263,7 @@ def prefill(tokens: jax.Array, weights: Weights, cache: KVCache, cfg: Config, pa uninitialized_iter = -jnp.ones_like(cache.iter) cache = dataclasses.replace(cache, starts=_count_left_padding(prompt, pad_id=pad_id), iter=uninitialized_iter) else: - cache_shardings = tuple([z[idx] for idx in range(cfg.num_layers)] for z in cache_shardings) + cache_shardings = [tuple(z[idx] for z in cache_shardings.buffers) for idx in range(cfg.num_layers)] logits_shardings = logical_to_sharding(("batch", "sequence", "act_embed"), cfg.mesh, cfg.rules) logits, cache = jax.jit(forward, donate_argnums=(4,), out_shardings=(logits_shardings, cache_shardings))( prompt, prompt_segment_ids, weights, cfg, cache diff --git a/llama3/llama3_jax/model.py b/llama3/llama3_jax/model.py index 2866156..61cceed 100644 --- a/llama3/llama3_jax/model.py +++ b/llama3/llama3_jax/model.py @@ -40,7 +40,6 @@ from etils import epath from . import ragged_attention -from . import attention_cache_utils AxisName = str | tuple[str, ...] | None Axes = tuple[AxisName, ...] @@ -455,9 +454,9 @@ def fill_len(self) -> jax.Array: def buffers(self) -> tuple[jax.Array, ...]: return (self.k, self.v) - update_slice = None - insert_sequences = staticmethod(attention_cache_utils.kvcache_update_cache) - get_sequence = staticmethod(attention_cache_utils.kvcache_get_entry) + #update_slice = None + #insert_sequences = staticmethod(attention_cache_utils.kvcache_update_cache) + #get_sequence = staticmethod(attention_cache_utils.kvcache_get_entry) @partial(jax_pytree_struct, meta_fields=("batch_size", "size", "page_size")) @@ -510,9 +509,59 @@ def fill_len(self) -> jax.Array: def buffers(self) -> tuple[jax.Array, ...]: return (self.k, self.v) - update_slice = staticmethod(attention_cache_utils.paged_update_slice) - insert_sequences = staticmethod(attention_cache_utils.batch_paged_update_sequences) - get_sequence = staticmethod(attention_cache_utils.batch_paged_get_entry) + #update_slice = staticmethod(paged_update_slice) + #insert_sequences = staticmethod(attention_cache_utils.batch_paged_update_sequences) + #get_sequence = staticmethod(attention_cache_utils.batch_paged_get_entry) + + @staticmethod + def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | None = None): + if proposal_pages is not None: + assert proposal_pages.size == k + proposal_mask = free_pages[proposal_pages] + indicies = jnp.where(~proposal_mask, jnp.cumsum(~proposal_mask, axis=-1) - 1, k - 1) + newly_free_pages = free_pages.at[jnp.where(proposal_mask, proposal_pages, 2**30)].set(False, mode="drop") + return jnp.where(proposal_mask, proposal_pages, jax.lax.top_k(newly_free_pages, k)[1][indicies]) + else: + return jax.lax.top_k(free_pages, k)[1] + + + @staticmethod + def _paged_update_slice(cache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int): + #key_heads = cache.buffers[0][layer_idx].shape[0] + #assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) # TODO write this generically + needs_next_page = (cache.lengths % cache.page_size) == 0 + page_table_idx = cache.lengths // cache.page_size + current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0] + avg_pages_per_batch_entry = round(cache.buffers[0][layer_idx].shape[0] / cache.batch_size) + even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry + proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1) + free_pages = PagedKVCache._find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages) + page_cursor = jnp.where(needs_next_page, free_pages, current_page_cursor) + + inpage_cursor = cache.lengths % cache.page_size + + new_lengths = cache.lengths + 1 + # for batch index update the target slice is (heads, i, j, head_dim) + # so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim) + _update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1)) + for buffer, new_buffer in zip(cache.buffers, kv): + buffer[layer_idx] = jax.tree.map(_update, buffer[layer_idx], new_buffer) + + batch_idx = jnp.arange(cache.batch_size) + new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor) + + new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop") + new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages) + return tuple(buffer[layer_idx] for buffer in cache.buffers), new_state + + + @staticmethod + def update_slice(cache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int): + repl_sharding = jax.typeof(cache.lengths).sharding + kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, tuple(buffer[layer_idx] for buffer in cache.buffers)) + sharding = (kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding)) + return auto_axes(partial(PagedKVCache._paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, kv) + def segment_ids_to_positions(segment_ids): diff --git a/serving/main_serving_ds_r1.py b/serving/main_serving_ds_r1.py new file mode 100644 index 0000000..340a95d --- /dev/null +++ b/serving/main_serving_ds_r1.py @@ -0,0 +1,209 @@ +import sys +import dataclasses +import time +from pathlib import Path +import threading +import asyncio +import socket +import signal +import time +from typing import AsyncGenerator +from contextlib import asynccontextmanager +import os + +import jax +from jax import random +from jax.sharding import PartitionSpec as P, AxisType, NamedSharding, auto_axes +#from llama3_jax import model as l3jax +from deepseek_r1_jax import model as dsjax +from deepseek_r1_jax import chkpt_utils as dsjax_utils +import serving_jax as serving +from serving_jax import attention_cache_utils +import numpy as np + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse, Response +from pydantic import BaseModel +import uvicorn + +IS_SERVER = len(sys.argv) > 1 and sys.argv[1] == "server" +TOKENIZER, SERVE_LOOP, SERVING_THREAD = None, None, None + +jax.config.update("jax_explain_cache_misses", True) +#jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) +#jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) +#jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) +jax.config.update("jax_enable_empty_arrays", True) + +try: # newer JAX only + assert False + my_id = int(socket.gethostname().split("-")[-1]) - 1 + my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] + jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") + jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) +except: # noqa: E722 + pass + +shutdown_signal = threading.Event() + + +def encode_input(tokenizer, texts, pad_id: int = 0): + assert isinstance(texts, list) + inputs = [ + tokenizer.apply_chat_template([{"role": "user", "content": text}], add_generation_prompt=True) for text in texts + ] + max_len = max([len(x) for x in inputs]) + return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) + + +def _place_local(tree, sharding: NamedSharding, present: bool): + return jax.tree.map( + lambda z, s: jax.make_array_from_single_device_arrays( + z.shape, s, [] if not present else [y.data for y in z.addressable_shards], dtype=z.dtype + ), + tree, + sharding, + ) + + +def load_model(): + global SERVE_LOOP, SERVING_THREAD, TOKENIZER + jax.distributed.initialize() + print(jax.devices()) + print("-" * 80) + print(jax.local_devices()) + + ckpt_path = Path(f"~/bucket/deepseek-r1-jax-chkpt").expanduser() + TOKENIZER = dsjax.load_tokenizer() + assert ckpt_path.is_dir() + print("---> Model config loaded") + + mesh = jax.make_mesh((1, 8, jax.device_count() // 8), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Auto,) * 3) + decode_mesh, prefill_mesh = mesh, mesh + cfg = dataclasses.replace(dsjax.Config(), mesh=mesh) + weights = dsjax_utils.load_model(ckpt_path, cfg) + decode_weights, prefill_weights = weights, weights + + print("---> Weights loaded") + serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64, + decode_batch_size=8, prefill_batch_size=1, prefix_chunk_size=64) + decode_cache = dsjax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, cfg.max_seq_len) + decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry + decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache + SERVE_LOOP = serving.ServingLoop( + serve_cfg, cfg, dsjax.prefill, prefill_weights, dsjax.decode_step, decode_weights, decode_cache, is_server=IS_SERVER + ) + print("---> Created the serving loop") + + def serve_forever(): + try: + while not shutdown_signal.is_set(): + SERVE_LOOP.serving_step() + except Exception as e: + import traceback + print(traceback.format_exc(), flush=True) + print(f"Exception {e}", flush=True) + finally: + print("Received a shutdown signal") + time.sleep(0.1) + signal.raise_signal(signal.SIGKILL) # shut down the web server + print("Exiting the serving loop") + + SERVING_THREAD = threading.Thread(target=serve_forever) + SERVING_THREAD.start() + + +######################################################################################################################## + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + shutdown_signal.set() + + +_ = load_model() +APP = FastAPI(lifespan=lifespan) + + +class GenerateRequest(BaseModel): + id: int + text: str + + +#async def generate_generator(params: GenerateRequest, request: Request) -> AsyncGenerator[str, None]: +async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: + if id in SERVE_LOOP.results: + del SERVE_LOOP.results[id] + + input = encode_input(TOKENIZER, [text])[0].tolist() + iter = len(input) + SERVE_LOOP.add_request(serving.UserRequestPrompt(id, input)) + while id not in SERVE_LOOP.results: + await asyncio.sleep(0.1) + try: + result: serving.DecodeResult = SERVE_LOOP.results[id] + while not result.done: + if await request.is_disconnected(): # Check if client disconnected + print("Client disconnected.") + break + if len(result.token_list) > iter: + new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) + yield f"{new_segment}" + await asyncio.sleep(0.1) # Stream a new message every 1 second + if len(result.token_list) > iter: + new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) + yield f"{new_segment}" + except asyncio.CancelledError: + pass + finally: + pass + + +@APP.get("/stream") +async def stream_response(params: GenerateRequest, request: Request): + return StreamingResponse(generate_generator(params.id, params.text, request), media_type="text/event-stream") + + +@APP.get("/generate") +async def generate(id: int, text: str): # generate without output + print(f"Input text: {text}") + SERVE_LOOP.add_request(serving.UserRequestPrompt(id, encode_input(TOKENIZER, [text])[0].tolist())) + return Response("OK") + + +@APP.get("/retrieve") +async def retrieve(id: int): + if id in SERVE_LOOP.results: + return Response(TOKENIZER.decode(SERVE_LOOP.results[id].token_list)) + return Response("NO TEXT") + + +@APP.get("/set_generation_length") +async def set_generation_length(length: int): + SERVE_LOOP.update_params({"max_decode_length": max(length, 32)}) + return Response("OK") + + +@APP.get("/profile") +async def profile(request: Request): + del request + SERVE_LOOP.profile_start_time = time.perf_counter() + return Response("OK") + + +@APP.get("/") +async def root(): + return {"message": "Welcome! Try the /stream-text endpoint."} + + +if __name__ == "__main__": + if IS_SERVER: + print(f"jax.process_idx() == {jax.process_index()}") + uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) + else: + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + shutdown_signal.set() diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index 3f16965..474e8bf 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -397,7 +397,7 @@ class SyncServer: """A regular local network server for syncing between JAX processes in the multi-process JAX setup.""" CLIENT = None - TIMEOUT_SEC = 60 + TIMEOUT_SEC = 600 @staticmethod def _get_client(): @@ -464,6 +464,7 @@ def __init__( decode_fn: Callable, decode_weights: Weights, decode_cache: KVCache, + is_server: bool = False, ): self.serve_cfg, self.cfg = serve_cfg, cfg @@ -503,6 +504,7 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, # setup misc self.pending_requests, self.requests_lock, self.results = [], threading.Lock(), {} + self.params_lock = threading.Lock() self.pad_id, self.eos_tokens, self.time_axis = 0, np.array(serve_cfg.eos_tokens), TIME_AXIS self._background = ThreadPoolExecutor(max_workers=1024) @@ -515,7 +517,7 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, self.new_prefix_cache() # setup the sync server for multi-host - self._it, self.roles = 0, (("server",) if jax.process_index() == 0 else ()) # main server + self._it, self.roles = 0, (("server",) if is_server else ()) # main server if any(d.id in [d_.id for d_ in self.decode_mesh.devices.reshape(-1)] for d in jax.local_devices()): self.roles += ("decode",) # any node which has decode mesh devices if any(d.id in [d_.id for d_ in self.prefill_mesh.devices.reshape(-1)] for d in jax.local_devices()): @@ -586,14 +588,21 @@ def decode_step(self): d or getattr(result, "tokens_decoded", 0) >= self.serve_cfg.max_decode_length for d, result in zip(done, self.decode_work.active_results) ] + output_tokens_flat = output_tokens.reshape(-1).tolist() + output_mapping_flat = output_mapping.reshape(-1).tolist() else: - output_tokens, done = None, None - done = SyncServer.broadcast("done_sync", self._it, done, is_source="decode_coordinator" in self.roles) - if "server" in self.roles: - for token, id in zip(output_tokens.reshape(-1).tolist(), output_mapping.reshape(-1).tolist()): - if id > 0: - self.results[id].token_list.append(token) - self.results[id].tokens_decoded += 1 + output_tokens, done, output_tokens_flat, output_mapping_flat = None, None, None, None + output_tokens_flat, output_mapping_flat, done = SyncServer.broadcast( + "decode_output", + self._it, + (output_tokens_flat, output_mapping_flat, done), + is_source="decode_coordinator" in self.roles, + ) + #if "server" in self.roles or "decode_coordinator" in self.roles: + for token, id in zip(output_tokens.reshape(-1).tolist(), output_mapping.reshape(-1).tolist()): + if id > 0: + self.results[id].token_list.append(token) + self.results[id].tokens_decoded += 1 with use_mesh(self.decode_mesh): for i, result in enumerate(self.decode_work.active_results): if result is None: @@ -604,7 +613,8 @@ def decode_step(self): sequence = np.array(result.token_list) with use_mesh(self.decode_mesh): cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) - self._background.submit(self._insert_prefix, sequence, cache_entry, mesh=self.decode_mesh) + # self._background.submit(self._insert_prefix, sequence, cache_entry, mesh=self.decode_mesh) + self._insert_prefix(sequence, cache_entry, mesh=self.decode_mesh) return_request(result) result.done, self.decode_work.active_results[i] = True, None @@ -624,8 +634,11 @@ def prefill_step(self): # 2. triage requests queue into cached (-> decode) and not-cached work (-> prefill) new_pending_retrievals = [] - for request, cache_entry_fut in self.prefill_work.pending_cache_retrievals: - if len(self.prefill_work.to_decode) < self.serve_cfg.decode_batch_size and cache_entry_fut.done(): + done_mask = [cache_entry_fut.done() for (_, cache_entry_fut) in self.prefill_work.pending_cache_retrievals] + done_mask = SyncServer.broadcast("retrievals_done", self._it, done_mask, is_source="prefill_coordinator" in self.roles) + for i, (request, cache_entry_fut) in enumerate(self.prefill_work.pending_cache_retrievals): + #if len(self.prefill_work.to_decode) < self.serve_cfg.decode_batch_size and cache_entry_fut.done(): + if len(self.prefill_work.to_decode) < self.serve_cfg.decode_batch_size and done_mask[i]: with use_mesh(self.decode_mesh): # batch missing (-1) layers concatenated (+1) cache_entry = partial(_concat, cache_entry_fut.result(), self.time_axis - 1 + 1) # jit work future @@ -641,6 +654,7 @@ def prefill_step(self): retrieval_results = self._background.map( lambda request: (self._retrieve_prefix(np.array(request.text[:-1])), request), self.prefill_work.requests ) + #retrieval_results = [[(None, -100), request] for request in self.prefill_work.requests] for (cache_entry_fut, length), request in retrieval_results: if length == len(request.text) - 1: self.prefill_work.pending_cache_retrievals.append((request, cache_entry_fut)) @@ -702,7 +716,11 @@ def serving_step(self): if "server" in self.roles: with self.requests_lock: self.pending_requests, requests = [], list(self.pending_requests) - requests = SyncServer.broadcast("requests", self._it, requests, is_source="server" in self.roles) + with self.params_lock: + serve_cfg, requests = SyncServer.broadcast( + "requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles + ) + self.serve_cfg = dataclasses.replace(self.serve_cfg, **serve_cfg) for request in requests: self.total_requests += 1 self.prefill_work.requests.append(UserRequestPrompt(**request)) @@ -723,6 +741,10 @@ def add_request(self, request: UserRequestPrompt): with self.requests_lock: self.pending_requests.append(dataclasses.asdict(request)) + def update_params(self, params: dict[str, Any]): + with self.params_lock: + self.serve_cfg = dataclasses.replace(self.serve_cfg, **params) + def new_prefix_cache(self): self.prefix_cache = TrieNode(TrieNode.new_id(), None, None, lock=threading.Lock()) _prefix_opts = dict(chunk_size=self.serve_cfg.prefix_chunk_size) diff --git a/llama3/llama3_jax/attention_cache_utils.py b/serving/serving_jax/attention_cache_utils.py similarity index 79% rename from llama3/llama3_jax/attention_cache_utils.py rename to serving/serving_jax/attention_cache_utils.py index 5367cd1..854be86 100644 --- a/llama3/llama3_jax/attention_cache_utils.py +++ b/serving/serving_jax/attention_cache_utils.py @@ -6,11 +6,6 @@ import jax import jax.numpy as jnp -try: - from jax.experimental.shard import auto_axes -except ModuleNotFoundError: - from jax.sharding import auto_axes - QuantArray, PyTree = Any, Any KVCache = Any @@ -95,7 +90,7 @@ def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array): shift = -cache.starts[batch_idx] assert cache.time_axis > 0 kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), cache.buffers) - kvs = (jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[0]), jax.tree.map(lambda *xs: jnp.stack(xs, 0), kvs[1])) + kvs = tuple(jax.tree.map(lambda *xs: jnp.stack(xs, 0), *z) for z in kvs) true_len = cache.fill_len()[batch_idx] return kvs, true_len @@ -118,42 +113,6 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | return jax.lax.top_k(free_pages, k)[1] -def _paged_update_slice(cache: PagedKVCache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int): - #key_heads = cache.buffers[0][layer_idx].shape[0] - #assert v.shape[:-1] == k.shape[:-1] == (cache.batch_size, key_heads, 1) # TODO write this generically - needs_next_page = (cache.lengths % cache.page_size) == 0 - page_table_idx = cache.lengths // cache.page_size - current_page_cursor = jnp.take_along_axis(cache.block_tables, page_table_idx[:, None], axis=-1)[..., 0] - avg_pages_per_batch_entry = round(cache.buffers[0][layer_idx].shape[0] / cache.batch_size) - even_batch_spread = jnp.arange(cache.batch_size) * avg_pages_per_batch_entry - proposal_pages = jnp.where(cache.lengths == 0, even_batch_spread, current_page_cursor + 1) - free_pages = _find_empty_pages(cache.free_pages, cache.batch_size, proposal_pages=proposal_pages) - page_cursor = jnp.where(needs_next_page, free_pages, current_page_cursor) - - inpage_cursor = cache.lengths % cache.page_size - - new_lengths = cache.lengths + 1 - # for batch index update the target slice is (heads, i, j, head_dim) - # so transpose update (batch, heads, seq, head_dim) -> (batch, heads, head_dim) -> (heads, batch, head_dim) - _update = lambda dest, src: dest.at[:, page_cursor, inpage_cursor, ...].set(src.squeeze(2).swapaxes(0, 1)) - for buffer, new_buffer in safe_zip(cache.buffers, kv): - buffer[layer_idx] = jax.tree.map(_update, buffer[layer_idx], new_buffer) - - batch_idx = jnp.arange(cache.batch_size) - new_block_tables = cache.block_tables.at[batch_idx, new_lengths // cache.page_size].set(page_cursor) - - new_free_pages = cache.free_pages.at[page_cursor].set(False, mode="drop") - new_state = dict(lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages) - return tuple(buffer[layer_idx] for buffer in cache.buffers), new_state - - -def paged_update_slice(cache: PagedKVCache, kv: tuple[jax.Array | QuantArray, ...], *, layer_idx: int): - repl_sharding = jax.typeof(cache.lengths).sharding - kv_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, tuple(buffer[layer_idx] for buffer in cache.buffers)) - sharding = (kv_sharding, dict(lengths=repl_sharding, block_tables=repl_sharding, free_pages=repl_sharding)) - return auto_axes(partial(_paged_update_slice, layer_idx=layer_idx), out_sharding=sharding)(cache, kv) - - @partial(jax.jit, donate_argnames=("cache",)) def _batch_paged_update_sequences( cache: PagedKVCache, From 499497427769f670d3a04a3c4387fa59002ae47f Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 29 Jul 2025 18:56:25 -0700 Subject: [PATCH 04/11] Improved prefix cache with host offloading Deepseek and llama3 should work with serving now --- serving/client_demo.py | 2 +- serving/main_serving.py | 65 ++--- serving/main_serving_ds_r1.py | 43 +-- serving/serving_jax/__init__.py | 480 ++++++++++++++++---------------- 4 files changed, 279 insertions(+), 311 deletions(-) diff --git a/serving/client_demo.py b/serving/client_demo.py index 431702b..c83ad6c 100644 --- a/serving/client_demo.py +++ b/serving/client_demo.py @@ -121,7 +121,7 @@ def main(): global responses, MAX_PANEL_LINES, responses_lock, responses_done all_prompts = get_prompts() prompts_num = 18 - idxs = np.random.randint(0, len(all_prompts) - 1, prompts_num) + idxs = np.random.randint(0, len(all_prompts), prompts_num) PROMPTS = [all_prompts[idx] for idx in idxs] # This controls the "scrolling" effect. It's the max number of lines diff --git a/serving/main_serving.py b/serving/main_serving.py index ebaf1ca..e79c543 100644 --- a/serving/main_serving.py +++ b/serving/main_serving.py @@ -9,29 +9,30 @@ from typing import AsyncGenerator from contextlib import asynccontextmanager import os +from argparse import ArgumentParser import jax from jax import random from jax.sharding import PartitionSpec as P, AxisType, NamedSharding -from llama3_jax import model as l3jax -import serving_jax as serving import numpy as np - from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse, Response from pydantic import BaseModel import uvicorn +from llama3_jax import model as l3jax +import serving_jax as serving +from serving_jax import attention_cache_utils + -TOKENIZER, SERVE_LOOP, SERVING_THREAD = None, None, None +TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None jax.config.update("jax_explain_cache_misses", True) jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) -jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) -jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) jax.config.update("jax_enable_empty_arrays", True) try: # newer JAX only + assert False my_id = int(socket.gethostname().split("-")[-1]) - 1 my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") @@ -41,7 +42,6 @@ shutdown_signal = threading.Event() - def encode_input(tokenizer, texts, pad_id: int = 0): assert isinstance(texts, list) inputs = [ @@ -51,18 +51,13 @@ def encode_input(tokenizer, texts, pad_id: int = 0): return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) -def _place_local(tree, sharding: NamedSharding, present: bool): - return jax.tree.map( - lambda z, s: jax.make_array_from_single_device_arrays( - z.shape, s, [] if not present else [y.data for y in z.addressable_shards], dtype=z.dtype - ), - tree, - sharding, - ) +def load_model(): + global SERVE_LOOP, SERVING_THREAD, TOKENIZER, ARGS + parser = ArgumentParser() + parser.add_argument("--server", action="store_true", help="Make this node the main server.", default=False) + ARGS = parser.parse_args() -def load_model(): - global SERVE_LOOP, SERVING_THREAD, TOKENIZER #process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) #jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) print(jax.devices()) @@ -79,42 +74,25 @@ def load_model(): # two hosts, different device and host meshes local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) decode_mesh, prefill_mesh = local_mesh, local_mesh - #decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) - #prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3) - - # single host, same decode and prefill meshes - #local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) - #decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) - #prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) - - # single host, separate decode and prefill meshes - #local_mesh = jax.make_mesh((1, 4, 1), P("x", "y", "z"), devices=jax.local_devices()[:4], axis_types=(AxisType.Explicit,) * 3) - #decode_mesh = jax.make_mesh((1, 4, 1), P("x", "y", "z"), devices=jax.devices()[:4], axis_types=(AxisType.Explicit,) * 3) - #prefill_mesh = jax.make_mesh((1, 4, 1), P("x", "y", "z"), devices=jax.devices()[4:], axis_types=(AxisType.Explicit,) * 3) - cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True) cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=8192) cfg = dataclasses.replace(cfg, quant_layer=False, quant_cache=False) cfg.quant_cache = True - weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=local_mesh))) - - # multi-host: until orbax update - decode_weights, prefill_weights = weights, weights - #decode_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh)), present=jax.process_index() == 0) - #prefill_weights = _place_local(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh)), present=jax.process_index() == 1) - - # single-host: until orbax update - #decode_weights = serving.device_put(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh))) - #prefill_weights = serving.device_put(weights, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh))) + decode_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh))) + prefill_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh))) print("---> Weights loaded") serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64) #decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) + #decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry + #decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) + decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry + decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences SERVE_LOOP = serving.ServingLoop( - serve_cfg, cfg, l3jax.prefill, prefill_weights, l3jax.decode_step, decode_weights, decode_cache + serve_cfg, cfg, l3jax.prefill, prefill_weights, l3jax.decode_step, decode_weights, decode_cache, ARGS.server ) print("---> Created the serving loop") @@ -122,6 +100,9 @@ def serve_forever(): try: while not shutdown_signal.is_set(): SERVE_LOOP.serving_step() + except: # noqa: E722 + import traceback + print(traceback.format_exc(), flush=True) finally: print("Received a shutdown signal") time.sleep(0.1) @@ -217,7 +198,7 @@ async def root(): if __name__ == "__main__": - if jax.process_index() == 0: + if ARGS.server: uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) else: try: diff --git a/serving/main_serving_ds_r1.py b/serving/main_serving_ds_r1.py index 340a95d..663fece 100644 --- a/serving/main_serving_ds_r1.py +++ b/serving/main_serving_ds_r1.py @@ -9,12 +9,11 @@ import time from typing import AsyncGenerator from contextlib import asynccontextmanager -import os +from argparse import ArgumentParser import jax from jax import random -from jax.sharding import PartitionSpec as P, AxisType, NamedSharding, auto_axes -#from llama3_jax import model as l3jax +from jax.sharding import PartitionSpec as P, AxisType from deepseek_r1_jax import model as dsjax from deepseek_r1_jax import chkpt_utils as dsjax_utils import serving_jax as serving @@ -26,27 +25,16 @@ from pydantic import BaseModel import uvicorn -IS_SERVER = len(sys.argv) > 1 and sys.argv[1] == "server" -TOKENIZER, SERVE_LOOP, SERVING_THREAD = None, None, None +TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None jax.config.update("jax_explain_cache_misses", True) -#jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) +jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) #jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) #jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) jax.config.update("jax_enable_empty_arrays", True) -try: # newer JAX only - assert False - my_id = int(socket.gethostname().split("-")[-1]) - 1 - my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] - jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") - jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) -except: # noqa: E722 - pass - shutdown_signal = threading.Event() - def encode_input(tokenizer, texts, pad_id: int = 0): assert isinstance(texts, list) inputs = [ @@ -55,19 +43,13 @@ def encode_input(tokenizer, texts, pad_id: int = 0): max_len = max([len(x) for x in inputs]) return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) +def load_model(): + global SERVE_LOOP, SERVING_THREAD, TOKENIZER, ARGS -def _place_local(tree, sharding: NamedSharding, present: bool): - return jax.tree.map( - lambda z, s: jax.make_array_from_single_device_arrays( - z.shape, s, [] if not present else [y.data for y in z.addressable_shards], dtype=z.dtype - ), - tree, - sharding, - ) - + parser = ArgumentParser() + parser.add_argument("--server", action="store_true", help="Make this node the main server.", default=False) + ARGS = parser.parse_args() -def load_model(): - global SERVE_LOOP, SERVING_THREAD, TOKENIZER jax.distributed.initialize() print(jax.devices()) print("-" * 80) @@ -79,8 +61,7 @@ def load_model(): print("---> Model config loaded") mesh = jax.make_mesh((1, 8, jax.device_count() // 8), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Auto,) * 3) - decode_mesh, prefill_mesh = mesh, mesh - cfg = dataclasses.replace(dsjax.Config(), mesh=mesh) + cfg = dataclasses.replace(dsjax.Config(), mesh=mesh)#, num_layers=4) weights = dsjax_utils.load_model(ckpt_path, cfg) decode_weights, prefill_weights = weights, weights @@ -91,7 +72,7 @@ def load_model(): decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache SERVE_LOOP = serving.ServingLoop( - serve_cfg, cfg, dsjax.prefill, prefill_weights, dsjax.decode_step, decode_weights, decode_cache, is_server=IS_SERVER + serve_cfg, cfg, dsjax.prefill, prefill_weights, dsjax.decode_step, decode_weights, decode_cache, ARGS.server ) print("---> Created the serving loop") @@ -198,7 +179,7 @@ async def root(): if __name__ == "__main__": - if IS_SERVER: + if ARGS.server: print(f"jax.process_idx() == {jax.process_index()}") uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) else: diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index 474e8bf..60eede8 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -43,6 +43,7 @@ TIME_AXIS = 2 USE_PREFIX_CACHE = True # the eviction mechanism is extremely simple right now +#USE_PREFIX_CACHE = False is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) ######################################################################################################################## @@ -94,237 +95,246 @@ def _ensure_all_args_on_mesh(*args, mesh: Mesh): ######################################################################################################################## -# trie utils ########################################################################################################### +# kv cache buffer management ########################################################################################### ######################################################################################################################## -_GLOBAL_NODE_ID = 0 +@partial(jax.jit, static_argnames=("axis", "chunk_size", "ns")) +def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]: + def _fn(val): + axis_ = axis % val.ndim + size = val.shape[axis_] + if size < chunk_size * ns: + min_len = chunk_size * ns + val = jnp.pad(val, [(0, 0) if i != axis_ else (0, min_len - val.shape[axis_]) for i in range(val.ndim)]) + index = [slice(None) if i != axis_ else slice(0, ns * chunk_size) for i in range(val.ndim)] + return jnp.split(val[*index], ns, axis=axis_)[:ns] + + val_leaves, val_structure = jax.tree.flatten(val) + spec = [[x] * ns for x in like_spec(val_leaves)] + split_leaves = auto_axes(lambda vals: [_fn(val) for val in vals], out_sharding=spec)(val_leaves) + return [jax.tree.unflatten(val_structure, [x[i] for x in split_leaves]) for i in range(ns)] + + +@partial(jax.jit, static_argnames=("split_axis",)) +def _concat(values, split_axis: int): + _fn = lambda vals: jax.tree.map(lambda *args: jnp.concatenate(args, axis=split_axis), *vals) + return auto_axes(_fn, out_sharding=like_spec(values[0]))(values) + +class KVBufferStore: + def __init__(self): + self.usecount, self.ondevice, self._store, self.unique_id, self.livecount = {}, {}, {}, 18, 0 + + def _get_unique_buffer_ids(self, n: int): + ids = list(range(self.unique_id, self.unique_id + n)) + self.unique_id += n + return ids + + def offload_buffers(self, how_many: int): + if how_many == 0: + return + candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2 ** 60) + for i in candidates[:how_many]: + if self.ondevice[i]: + shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i]) + self._store[i] = jax.device_put(self._store[i], shrd) + self.livecount -= 1 + + def load(self, id: int): + if isinstance(id, (tuple, list)): + return [self.load(i) for i in id] + if self.ondevice[id]: + return self._store[id] + self.ondevice[id] = True + self.livecount += 1 + shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("device"), self._store[id]) + self._store[id] = jax.device_put(self._store[id], shrd) + return self._store[id] + + def delete(self, id: int): + if isinstance(id, (list, tuple)): + return [self.delete(i) for i in id] + self.livecount -= self.ondevice[id] + del self.usecount[id], self.ondevice[id], self._store[id] + + def store(self, id: int, val: Any): + if isinstance(id, (tuple, list)): + return [self.store(i, v) for i, v in zip(id, val)] + self.livecount += 1 + self.usecount[id], self.ondevice[id], self._store[id] = 1, True, val + + def mark_visited(self, id: int): + if isinstance(id, (list, tuple)): + return [self.mark_visited(i) for i in id] + self.usecount[id] += 1 + +BUFFER_STORE = KVBufferStore() + +######################################################################################################################## +# trie utils ########################################################################################################### +######################################################################################################################## + +EMPTY, HASH_BITWIDTH = -1, 1 @dataclasses.dataclass -class OffloadedValue: - ref: str | np.ndarray - spec: Any - shape_dtypes: Any +class ChildKeys: + keys: np.ndarray + keys_hash: np.ndarray + keys_hash_mask: np.ndarray + key_lens: np.ndarray + num: int = 0 + + +def _hash_encode(v: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int = EMPTY): + v, last_dim = v.astype(np.int64), min(64 // hash_bitwidth, v.shape[-1]) + v_, el_mask = v.reshape(v.shape[:-1] + (-1, last_dim)), (1 << hash_bitwidth) - 1 + mask = np.bitwise_or.reduce(((v_ != pad_idx) * el_mask) << (hash_bitwidth * np.arange(v_.shape[-1])), axis=-1) + h = np.bitwise_or.reduce((v_ & el_mask) << (hash_bitwidth * np.arange(v_.shape[-1])), axis=-1) + return h, mask + + +def _prefilter_on_hash( + w: np.ndarray, keys: np.ndarray, vh: np.ndarray, vm: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int = EMPTY +): + wh, wm = _hash_encode(w, hash_bitwidth=hash_bitwidth, pad_idx=pad_idx) + inv_match = (wh ^ vh) & vm & wm + # count full hash chunk matches, but don't miss sequences not matching at least one full hash + match_len = np.sum(np.cumsum(inv_match, axis=-1) == 0, axis=-1) + (w[0] == keys[:, 0]) + max_match_len = max(np.max(match_len), 1) + return np.where(match_len == max_match_len)[0] + +def _fast_pad(x, size, axis, pad_val=0): + new_buf = pad_val * np.ones([size - s if i == axis else s for i, s in enumerate(x.shape)], dtype=x.dtype) + return np.concat([x, new_buf], axis) @dataclasses.dataclass class TrieNode: - id: int - key: jax.Array - value: PyTree | OffloadedValue + value: int children: list["TrieNode"] = dataclasses.field(default_factory=list) - child_keys: jax.Array | None = None + child_keys: ChildKeys | None = None lock: "threading.Lock | None" = None usage: int = 1 def __repr__(self, indent: int = 0): - lines = [" " * indent + "TrieNode("] - lines.append((" " * indent) + f" key={str(self.key.tolist() if hasattr(self.key, 'tolist') else self.key)},") - lines.append((" " * indent) + f" usage={self.usage},") - if is_type(self.value, OffloadedValue): - lines.append((" " * indent) + f" value={self.value.ref},") + lines = [f"TrieNode(value={self.value}, usage={self.usage}, children={{"] + if len(self.children) == 0: + lines[-1] = lines[-1][:-1] + "})" else: - lines.append( - (" " * indent) - + f" value={jax.tree.map(jax.typeof, self.value) if self.value is not None else 'None'}," - ) - lines.append( - (" " * indent) + f" child_keys={jax.typeof(self.child_keys) if self.child_keys is not None else 'None'}," - ) - lines.append((" " * indent) + " children=[") - if self.children: - for child in self.children: - lines.append(f"{child.__repr__(indent + 2)},") - lines.append(" " * indent + " ],") - else: - lines[-1] += "]," - lines.append(" " * indent + ")") - return "\n".join(lines) - - @staticmethod - def new_id(): - global _GLOBAL_NODE_ID - _GLOBAL_NODE_ID += 1 - return _GLOBAL_NODE_ID - 1 + for i, child in enumerate(self.children): + child_key = self.child_keys.keys[i, : self.child_keys.key_lens[i]].tolist() + lines.append(f"{' ' * indent} {child_key}: {child.__repr__(indent + 2).strip()},") + lines.append(")") + return "\n".join([(" " * indent) + line for line in lines]) @staticmethod - def _dist_to_key(key, keys, mask, pad_idx: int): - invalid_rows = np.all(keys == pad_idx, axis=-1) - return np.where(invalid_rows, 2**30, np.sum(mask * np.abs(key - keys), axis=-1)) + def _overlap(child_keys: ChildKeys, key, key_len, pad_idx: int = EMPTY): + keys = child_keys.keys[: child_keys.num, :] + keys_hash = child_keys.keys_hash[: child_keys.num, :] + keys_hash_mask = child_keys.keys_hash_mask[: child_keys.num, :] + + # pre-filter sequences + relevant_idx = _prefilter_on_hash(key, keys, keys_hash, keys_hash_mask, pad_idx=pad_idx) + if len(relevant_idx) == 0: + return np.zeros((child_keys.num,), dtype=np.int32), np.zeros((child_keys.num,), dtype=np.int32) + keys = keys[relevant_idx, :] + + mask = np.cumsum((key == keys) | (key == pad_idx) | (keys == pad_idx), -1) == np.arange(1, key.shape[-1] + 1) + overlap = np.zeros((child_keys.num,), dtype=np.int32) + overlap[relevant_idx] = np.sum(mask, axis=-1) + return np.minimum(overlap, key_len), np.minimum(overlap, child_keys.key_lens[: child_keys.num]) @staticmethod - def _append_key(keys, new_key, keys_len: int, pad_idx: int): + def _append_key(keys: ChildKeys | None, new_key: np.ndarray, key_len: int, pad_idx: int = EMPTY): if keys is None: - return new_key[None, ...] # 2 ** 0 power of 2 - if keys_len == keys.shape[0]: # need to double the keys buffer - new_buf = np.pad( - new_key[None, ...], ((0, keys.shape[0] - 1), (0, 0)), mode="constant", constant_values=pad_idx - ) - return np.concatenate([keys, new_buf], 0) - else: - keys[keys_len, ...] = new_key - return keys + key_hash, key_hash_mask = _hash_encode(new_key[None, :], pad_idx=pad_idx) + return ChildKeys(new_key[None, :], key_hash, key_hash_mask, np.array([key_len], dtype=np.int32), 1) + if keys.num == keys.keys.shape[0]: # need to double the keys buffer + keys.keys = _fast_pad(keys.keys, 2 * keys.num, 0, 0) + keys.key_lens = _fast_pad(keys.key_lens, 2 * keys.num, 0) + keys.keys_hash = _fast_pad(keys.keys_hash, 2 * keys.num, 0, 0) + keys.keys_hash_mask = _fast_pad(keys.keys_hash_mask, 2 * keys.num, 0, 0) + keys.keys[keys.num, :], keys.key_lens[keys.num] = new_key, key_len + keys.keys_hash[keys.num, :], keys.keys_hash_mask[keys.num, :] = _hash_encode(new_key, pad_idx=pad_idx) + keys.num += 1 + return keys @staticmethod - def _pad_to_multiple_of(sequence: jax.Array, chunk_size: int, pad_idx: int): + def _pad_to_multiple_of(sequence: np.ndarray, chunk_size: int, pad_idx: int = EMPTY): sequence_pad_len = math.ceil(sequence.size / chunk_size) * chunk_size - return np.pad(sequence, ((0, sequence_pad_len - sequence.shape[-1])), mode="constant", constant_values=pad_idx) - - @staticmethod - def _overlap_dist(key1, key2, mask): - return np.sum(np.cumsum(np.logical_not(mask & (key1 == key2)), axis=-1) == 0, axis=-1) - - -@partial(jax.jit, static_argnames=("axis", "chunk_size", "ns")) -def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]: - spec = jax.tree.map(lambda x: [x] * ns, like_spec(val)) - - def _fn(val): - axis_ = axis % val.ndim - size = val.shape[axis_] - if size < chunk_size * ns: - min_len = chunk_size * ns - val = jnp.pad(val, [(0, 0) if i != axis_ else (0, min_len - val.shape[axis_]) for i in range(val.ndim)]) - index = [slice(None) if i != axis_ else slice(0, ns * chunk_size) for i in range(val.ndim)] - return jnp.split(val[*index], ns, axis=axis_)[:ns] - - return auto_axes(lambda vals: jax.tree.map(_fn, vals), out_sharding=spec)(val) + return _fast_pad(sequence, sequence_pad_len, 0, pad_idx) -@partial(jax.jit, static_argnames=("split_axis",)) -def _concat(values, split_axis: int): - _fn = lambda vals: jax.tree.map(lambda *args: jnp.concatenate(args, axis=split_axis), *vals) - return auto_axes(_fn, out_sharding=like_spec(values[0]))(values) - - -def insert_prefix( - prefix_cache: TrieNode, - sequence: jax.Array, - value: PyTree, - *, - chunk_size: int, - split_axis: int, - pad_idx: int = 2**30, - executor: ThreadPoolExecutor | None = None, - mesh: Any | None = None, -): - del executor +def insert_prefix(root: TrieNode, sequence: np.ndarray, ref_vals: list[int], *, chunk_size: int, pad_idx: int = 2**30): + if len(sequence) == 0: + return [], [], [] sequence = np.array(sequence) assert sequence.ndim == 1 - sequence = TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) + sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) ns = sequence.shape[-1] // chunk_size + seq_actual_lens = [(chunk_size if i != ns - 1 else (sequence_len - (ns - 1) * chunk_size)) for i in range(ns)] sequence_chunks = np.split(sequence, ns) - - # split the value, but only if it's needed for non-cache hit - value_chunks = None - - def lazy_get_value(idx): - nonlocal value_chunks - if value_chunks is None: - value_leaves, value_struct = jax.tree.flatten(value) - with use_mesh(mesh): - split_leaves = _split(value_leaves, axis=split_axis, chunk_size=chunk_size, ns=ns) - value_chunks = [jax.tree.unflatten(value_struct, [x[i] for x in split_leaves]) for i in range(ns)] - return value_chunks[idx] + if len(ref_vals) < ns: + msg = f"Pass at least as many references as there are chunks (size={chunk_size}) in the sequence " + msg += f" (size={sequence_len}), so expected at least {ns} references, got {len(ref_vals)=} instead." + raise ValueError(msg) + visited_refs, store_refs, delete_refs = [], [], [] # which refs to retain and which to delete # walk the prefix cache tree - with prefix_cache.lock: - node = prefix_cache - for seq_idx, seq in enumerate(sequence_chunks): - if len(node.children) == 0: - node.child_keys = TrieNode._append_key(node.child_keys, seq, len(node.children), pad_idx=pad_idx) - node.children.append(TrieNode(TrieNode.new_id(), seq, lazy_get_value(seq_idx))) - node = node.children[-1] - continue - left_mask, right_mask = (seq != pad_idx), (node.child_keys != pad_idx) - left_dist = TrieNode._dist_to_key(seq, node.child_keys, left_mask, pad_idx=pad_idx) - right_dist = TrieNode._dist_to_key(seq, node.child_keys, right_mask, pad_idx=pad_idx) - left_idx, right_idx = np.argmin(left_dist), np.argmin(right_dist) - if node.children and right_dist[right_idx] == 0: # this sequence is longer - if left_dist[right_idx] > 0: - node.children[right_idx].key = seq - node.children[right_idx].value = lazy_get_value(seq_idx) - node.child_keys[right_idx, :] = seq - else: # exact sequence exists - node.children[right_idx].usage += 1 - pass - node = node.children[right_idx] - elif left_dist[left_idx] == 0: # longer sequence already exists - node.children[left_idx].usage += 1 - assert seq_idx == len(sequence_chunks) - 1 - return - else: # no exact match - node.child_keys = TrieNode._append_key(node.child_keys, seq, len(node.children), pad_idx=pad_idx) - node.children.append(TrieNode(TrieNode.new_id(), seq, lazy_get_value(seq_idx))) + with root.lock: + node = root + for seq_idx, (seq, seq_len) in enumerate(zip(sequence_chunks, seq_actual_lens)): + if len(node.children) > 0: + left_match, right_match = TrieNode._overlap(node.child_keys, seq, seq_len, pad_idx=pad_idx) + best_idx = np.argmax(left_match) + left_match, right_match = left_match[best_idx], right_match[best_idx] + else: + left_match, right_match, best_idx = 0, 0, 2**30 # case 0: no children, add new child + if left_match != seq_len: # append new node + node.child_keys = TrieNode._append_key(node.child_keys, seq, seq_len, pad_idx=pad_idx) + node.children.append(TrieNode(int(ref_vals[seq_idx]))) + store_refs.append(int(ref_vals[seq_idx])) node = node.children[-1] + elif right_match < left_match: # replace the node + delete_refs.append(node.children[best_idx].value) + node.children[best_idx] = TrieNode(int(ref_vals[seq_idx])) + node.child_keys.keys[best_idx, :], node.child_keys.key_lens[best_idx] = seq, seq_len + store_refs.append(int(ref_vals[seq_idx])) + node = node.children[best_idx] + else: # full match, do nothing + if best_idx > len(node.children): + break + visited_refs.append(int(node.children[best_idx].value)) + node = node.children[best_idx] + visited_refs = list(set(visited_refs) | set(store_refs)) + return visited_refs, store_refs, delete_refs -def retrieve_prefix( - prefix_cache: TrieNode, - sequence: jax.Array, - *, - chunk_size: int, - split_axis: int, - pad_idx: int = 2**30, - executor: ThreadPoolExecutor | None = None, - mesh: Any | None = None, -): - sequence, total_match = np.array(sequence), 0 +def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pad_idx: int = 2**30): + sequence, total_match, ref_vals = np.array(sequence), 0, [] assert sequence.ndim == 1 sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) ns = sequence.shape[-1] // chunk_size - values, sequence_chunks = [], np.split(sequence, ns) - - def _construct_output(): - if sequence_len != total_match: - return None, total_match - for i, value in enumerate(values): - if is_type(value, OffloadedValue): - _load = lambda value: jax.block_until_ready(device_put(value.ref, like_shard(value.spec, mesh))) - values[i] = _load(value) if executor is None else executor.submit(_load, value) - - values_future = lambda: [value.result() if hasattr(value, "result") else value for value in values] - return (executor.submit(values_future) if executor is not None else values_future()), total_match - - node = prefix_cache - for seq in sequence_chunks: - if len(node.children) == 0: # cache ran out of node - return _construct_output() - left_mask = seq != pad_idx - overlaps = TrieNode._overlap_dist(node.child_keys, seq, left_mask) - max_idx = np.argmax(overlaps) - max_overlap = overlaps[max_idx] - if max_overlap == 0: - return _construct_output() - with prefix_cache.lock: - node.children[max_idx].usage += 1 - values.append(node.children[max_idx].value) - node, total_match = node.children[max_idx], total_match + max_overlap - # exit early if the entire chunk wasn't found - if max_overlap != np.sum(left_mask): - break - return _construct_output() - - -def offload_nodes(prefix_cache: TrieNode, how_many: int = 3): - # work in progress, not tested, will probably not work - # TODO: switch to [memories](https://docs.jax.dev/en/latest/notebooks/host-offloading.html) - node_queue, all_nodes = [prefix_cache], [] - with prefix_cache.lock: - while len(node_queue) > 0: - node = node_queue.pop(0) - for child in node.children: - node_queue.append(child) - all_nodes.append(child) - sorted_nodes = sorted(all_nodes, key=lambda x: x.usage) - offloaded = 0 - for i, node in enumerate(sorted_nodes): - if offloaded >= how_many: + seq_actual_lens = [(chunk_size if i != ns - 1 else (sequence_len - (ns - 1) * chunk_size)) for i in range(ns)] + visited_refs = [] + + with root.lock: + node = root + for seq, seq_len in zip(np.split(sequence, ns), seq_actual_lens): + if len(node.children) == 0: # cache ran out of node + return (total_match, ref_vals), visited_refs + left_match, right_match = TrieNode._overlap(node.child_keys, seq, seq_len, pad_idx=pad_idx) + exact_match = np.minimum(left_match, right_match) + best_idx = np.argmax(exact_match) + match_length = exact_match[best_idx] + if match_length > 0: + visited_refs.append(int(node.children[best_idx].value)) + node = node.children[best_idx] + total_match += int(match_length) + if match_length != seq_len: break - if is_type(node.value, OffloadedValue): - continue - value = jax.tree.map(partial(np.asarray, copy=False), jax.device_put(node.value, jax.devices("cpu")[0])) - node.value = OffloadedValue(value, like_spec(node.value), jax.tree.map(jax.typeof, node.value)) + ref_vals.append(node.value) + return (total_match, ref_vals), visited_refs ######################################################################################################################## @@ -466,6 +476,8 @@ def __init__( decode_cache: KVCache, is_server: bool = False, ): + if not (SyncServer.broadcast("welcome", 0, True, is_server) if jax.process_count() > 1 else is_server): + raise ValueError("No processes registered as the main server, at least one process must.") self.serve_cfg, self.cfg = serve_cfg, cfg # setup decode @@ -498,13 +510,11 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, self.prefill_weights = prefill_weights self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh self.prefill_work = PrefillWork([], [], []) - self.prefix_cache = TrieNode(TrieNode.new_id(), None, None, lock=threading.Lock()) self._get_index = jax.jit(lambda z, idx: jax.tree.map(lambda x: x[:, idx, ...], z)) self._get_cache_entry = jax.jit(self.decode_work.cache.get_sequence) # setup misc - self.pending_requests, self.requests_lock, self.results = [], threading.Lock(), {} - self.params_lock = threading.Lock() + self.pending_requests, self.state_lock, self.results = [], threading.Lock(), {} self.pad_id, self.eos_tokens, self.time_axis = 0, np.array(serve_cfg.eos_tokens), TIME_AXIS self._background = ThreadPoolExecutor(max_workers=1024) @@ -512,7 +522,6 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, self.profile_start_time, self.profiling = -1, False # setup cache management - # -1 for missing batch dimensiona and + 1 for layers being stacked self.prefix_cache, self._retrieve_prefix, self._insert_prefix = None, None, None self.new_prefix_cache() @@ -563,7 +572,7 @@ def decode_step(self): # 2. run N decode steps output_tokens, output_mapping = [], [] - if "decode" in self.roles: # cut a corner, don't issue the decode call on non-participating machines + if "decode" in self.roles: # TODO(rdyro): revisit: don't issue the decode call on non-participating machines with use_mesh(self.decode_mesh): config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn( @@ -609,12 +618,21 @@ def decode_step(self): continue # 2. check for done sequences; evict them if done and return them if done[i]: - if USE_PREFIX_CACHE: + if USE_PREFIX_CACHE: # store the results in the prefix cache buffer store sequence = np.array(result.token_list) with use_mesh(self.decode_mesh): cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) # self._background.submit(self._insert_prefix, sequence, cache_entry, mesh=self.decode_mesh) - self._insert_prefix(sequence, cache_entry, mesh=self.decode_mesh) + ns = math.ceil(sequence.size / self.serve_cfg.prefix_chunk_size) + buffer_ids = BUFFER_STORE._get_unique_buffer_ids(ns) + visited_ids, store_ids, del_ids = self._insert_prefix(sequence, buffer_ids) + if len(store_ids) > 0: + axis = self.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) + chunked_cache_entry = _split(cache_entry, axis, self.serve_cfg.prefix_chunk_size, ns) + vals = [chunked_cache_entry[buffer_ids.index(id)] for id in store_ids] + BUFFER_STORE.store(store_ids, vals) + BUFFER_STORE.delete(del_ids) + BUFFER_STORE.mark_visited(visited_ids) return_request(result) result.done, self.decode_work.active_results[i] = True, None @@ -632,42 +650,27 @@ def prefill_step(self): self.prefill_work.to_decode.append(PrefillResult(id, input, input[-1], kv_list, len(input) - 1)) self.prefill_work.pending_prefill = None - # 2. triage requests queue into cached (-> decode) and not-cached work (-> prefill) - new_pending_retrievals = [] - done_mask = [cache_entry_fut.done() for (_, cache_entry_fut) in self.prefill_work.pending_cache_retrievals] - done_mask = SyncServer.broadcast("retrievals_done", self._it, done_mask, is_source="prefill_coordinator" in self.roles) - for i, (request, cache_entry_fut) in enumerate(self.prefill_work.pending_cache_retrievals): - #if len(self.prefill_work.to_decode) < self.serve_cfg.decode_batch_size and cache_entry_fut.done(): - if len(self.prefill_work.to_decode) < self.serve_cfg.decode_batch_size and done_mask[i]: + # 2. triage requests based on whether they need to go to prefill or there's a cache match, so decode directly + while len(self.prefill_work.requests) > 0: + request = self.prefill_work.requests.pop(0) + sequence = np.array(request.text) + (total_match, buffer_ids), visited_ids = self._retrieve_prefix(sequence) + BUFFER_STORE.mark_visited(visited_ids) + if total_match == sequence.size: with use_mesh(self.decode_mesh): - # batch missing (-1) layers concatenated (+1) - cache_entry = partial(_concat, cache_entry_fut.result(), self.time_axis - 1 + 1) # jit work future - new_decode = PrefillResult( - request.id, np.array(request.text), request.text[-1], cache_entry, len(request.text) - 1 - ) + time_axis = self.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) + cache_entry = partial(_concat, BUFFER_STORE.load(buffer_ids), time_axis) + new_decode = PrefillResult(request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1) self.prefill_work.to_decode.append(new_decode) + print(f"Found a full match") else: - new_pending_retrievals.append((request, cache_entry_fut)) # not yet ready - self.prefill_work.pending_cache_retrievals = new_pending_retrievals - - # 3. check if prefixes are in the cache - retrieval_results = self._background.map( - lambda request: (self._retrieve_prefix(np.array(request.text[:-1])), request), self.prefill_work.requests - ) - #retrieval_results = [[(None, -100), request] for request in self.prefill_work.requests] - for (cache_entry_fut, length), request in retrieval_results: - if length == len(request.text) - 1: - self.prefill_work.pending_cache_retrievals.append((request, cache_entry_fut)) - print(f"Found full prefill match in the cache") - else: - print(f"Need to prefill the request, only found a match for length {length / (len(request.text) - 1)}") + print(f"Need to prefill the request, only found a match for length {total_match / (len(request.text) - 1)}") self.prefill_work.to_prefill.append(request) - self.prefill_work.requests.clear() if self.prefill_work.pending_prefill is not None: # a current prefill is still running, skip scheduling another return - # 4. prefill requests to be prefilled + # 3. prefill requests to be prefilled prefill_input = self.prefill_work.to_prefill[: self.serve_cfg.prefill_batch_size] self.prefill_work.to_prefill = self.prefill_work.to_prefill[len(prefill_input) :] if len(prefill_input) > 0: @@ -714,9 +717,9 @@ def serving_step(self): SyncServer.barrier("serving_step", self._it) self._it, requests = self._it + 1, None if "server" in self.roles: - with self.requests_lock: + with self.state_lock: self.pending_requests, requests = [], list(self.pending_requests) - with self.params_lock: + with self.state_lock: serve_cfg, requests = SyncServer.broadcast( "requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles ) @@ -737,17 +740,20 @@ def serving_step(self): self.new_prefix_cache() # manage cache ############################################################################# + # offload buffers to keep a max of N ####################################################### + max_buffers = 100 + BUFFER_STORE.offload_buffers(max(0, BUFFER_STORE.livecount - max_buffers)) + # offload buffers to keep a max of N ####################################################### + def add_request(self, request: UserRequestPrompt): - with self.requests_lock: + with self.state_lock: self.pending_requests.append(dataclasses.asdict(request)) def update_params(self, params: dict[str, Any]): - with self.params_lock: + with self.state_lock: self.serve_cfg = dataclasses.replace(self.serve_cfg, **params) def new_prefix_cache(self): - self.prefix_cache = TrieNode(TrieNode.new_id(), None, None, lock=threading.Lock()) - _prefix_opts = dict(chunk_size=self.serve_cfg.prefix_chunk_size) - _prefix_opts |= dict(split_axis=self.time_axis - 1 + 1, mesh=self.decode_mesh, executor=self._background) - self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, **_prefix_opts) - self._insert_prefix = partial(insert_prefix, self.prefix_cache, **_prefix_opts) + self.prefix_cache = TrieNode(None, lock=threading.Lock()) + self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) + self._insert_prefix = partial(insert_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) From 365284920ce52e9a96d8c67ee9eb1627b0d2c4c9 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Wed, 30 Jul 2025 19:54:48 -0700 Subject: [PATCH 05/11] Splash attention for deepseek prefill on TPU --- deepseek_r1_jax/deepseek_r1_jax/model.py | 228 ++++++++++++----------- deepseek_r1_jax/main.ipynb | 147 --------------- 2 files changed, 118 insertions(+), 257 deletions(-) delete mode 100644 deepseek_r1_jax/main.ipynb diff --git a/deepseek_r1_jax/deepseek_r1_jax/model.py b/deepseek_r1_jax/deepseek_r1_jax/model.py index 93dc5e7..506965f 100644 --- a/deepseek_r1_jax/deepseek_r1_jax/model.py +++ b/deepseek_r1_jax/deepseek_r1_jax/model.py @@ -24,6 +24,7 @@ import gzip import json from pathlib import Path +from warnings import warn import jax import jax.numpy as jnp @@ -67,6 +68,7 @@ class ShardingRules: taking this mapping and eventually turning it into the correct JAX shardings and sharding contraints. """ + batch: AxisName = BATCH_AXIS_NAME sequence: AxisName = None head_dim: AxisName = None @@ -209,6 +211,7 @@ def load_tokenizer( # module reload friendly check for type(x) == cls is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) is_param = lambda x: is_type(x, ArrayInfo) +which_platform = lambda cfg: cfg.mesh.devices.reshape(-1)[0].platform _count_left_padding = lambda ids, pad_id=PAD_ID: jnp.sum(jnp.cumsum(ids != pad_id, axis=-1) == 0, axis=-1) _length_minus_right_padding = lambda segment_ids: jnp.sum(jnp.cumsum(jnp.flip(segment_ids != 0, -1), axis=-1) > 0, -1) @@ -225,6 +228,7 @@ class ArrayInfo: the ArrayInfo approach instead to decouple data and its sharding from the functions we'll apply the data to. """ + shape: tuple[int, ...] dtype: "jnp.dtype" logical_axes: tuple @@ -240,6 +244,7 @@ def _init_fn(key): return jax.tree.map( lambda info: info.initializer(next(key_iter), info.shape, info.dtype), abstract, is_leaf=is_param ) + return _init_fn(key) @@ -289,9 +294,11 @@ class QuantArray: shape = property(lambda self: self.quant.shape) ndim = property(lambda self: self.quant.ndim) + _int8_quant_init = lambda key, shape, dtype=jnp.int8: random.randint(key, shape, -128, 128, dtype=dtype) _int8_scale_init = lambda key, shape, dtype: random.normal(key, shape, dtype=dtype) / math.sqrt(math.prod(shape)) / 127 + def quantize(x: jax.Array | ArrayInfo, axis: int | tuple[int, ...], scale_dtype=jnp.float16, zero_init: bool = False): if is_type(x, QuantArray): raise ValueError("Attempting to quantize an already quantized QuantArray.") @@ -318,7 +325,7 @@ def quantize(x: jax.Array | ArrayInfo, axis: int | tuple[int, ...], scale_dtype= quant_init, scale_init = _int8_quant_init, _int8_scale_init return ( dataclasses.replace(x, shape=x.shape, dtype=jnp.int8, initializer=quant_init), - ArrayInfo(new_shape, scale_dtype, new_logical_axes, scale_init) + ArrayInfo(new_shape, scale_dtype, new_logical_axes, scale_init), ) raise ValueError(f"quantize got unexpected type: {type(x)}") @@ -327,7 +334,7 @@ def quantize(x: jax.Array | ArrayInfo, axis: int | tuple[int, ...], scale_dtype= def quantize_update_slice(x: QuantArray, y: jax.Array, pos: int, update_axis: int, quant_axis: int): assert x.quant.ndim == y.ndim quant_axis, update_axis = quant_axis % x.quant.ndim, update_axis % x.quant.ndim # normalize axis numbers - #y_quant, y_scale = quantize(y, axis=quant_axis, scale_dtype=x.scale.dtype) # quantize rhs + # y_quant, y_scale = quantize(y, axis=quant_axis, scale_dtype=x.scale.dtype) # quantize rhs y_quant, y_scale = y.quant, y.scale scale_update_axis = [ax for ax in range(x.quant.ndim) if ax != quant_axis][update_axis] # update axis in `scale` z_quant = jax.lax.dynamic_update_slice_in_dim(x.quant, y_quant.astype(x.quant.dtype), pos, axis=update_axis) @@ -386,39 +393,41 @@ def abstract(cls, cfg: Config): _sinit = jax.nn.initializers.he_normal(in_axis=0, out_axis=1) dtype = cfg.dtype layer = MoELayer( - w_router=ArrayInfo( - (cfg.embed, cfg.n_routed_experts), cfg.moe_gate_dtype, ("moe_e_up_embed", None), _sinit - ), - b_router=ArrayInfo( - (cfg.n_routed_experts,), cfg.moe_gate_dtype, (None,), jax.nn.initializers.constant(0.0) - ), + w_router=ArrayInfo((cfg.embed, cfg.n_routed_experts), cfg.moe_gate_dtype, ("moe_e_up_embed", None), _sinit), + b_router=ArrayInfo((cfg.n_routed_experts,), cfg.moe_gate_dtype, (None,), jax.nn.initializers.constant(0.0)), we_gate=ArrayInfo( - (cfg.n_routed_experts, cfg.embed, cfg.moe_ffw_size), dtype, + (cfg.n_routed_experts, cfg.embed, cfg.moe_ffw_size), + dtype, ("moe_e_experts", "moe_e_up_embed", "moe_e_up_ffw"), _einit, ), we_up=ArrayInfo( - (cfg.n_routed_experts, cfg.embed, cfg.moe_ffw_size), dtype, + (cfg.n_routed_experts, cfg.embed, cfg.moe_ffw_size), + dtype, ("moe_e_experts", "moe_e_up_embed", "moe_e_up_ffw"), _einit, ), we_down=ArrayInfo( - (cfg.n_routed_experts, cfg.moe_ffw_size, cfg.embed), dtype, + (cfg.n_routed_experts, cfg.moe_ffw_size, cfg.embed), + dtype, ("moe_e_experts", "moe_e_down_ffw", "moe_e_down_embed"), _einit, ), ws_gate=ArrayInfo( - (cfg.embed, cfg.n_shared_experts * cfg.moe_ffw_size), dtype, + (cfg.embed, cfg.n_shared_experts * cfg.moe_ffw_size), + dtype, ("moe_s_up_embed", "moe_s_up_ffw"), _sinit, ), ws_up=ArrayInfo( - (cfg.embed, cfg.n_shared_experts * cfg.moe_ffw_size), dtype, + (cfg.embed, cfg.n_shared_experts * cfg.moe_ffw_size), + dtype, ("moe_s_up_embed", "moe_s_up_ffw"), _sinit, ), ws_down=ArrayInfo( - (cfg.moe_ffw_size, cfg.n_shared_experts * cfg.embed), dtype, + (cfg.moe_ffw_size, cfg.n_shared_experts * cfg.embed), + dtype, ("moe_s_down_ffw", "moe_s_down_embed"), _sinit, ), @@ -464,7 +473,8 @@ def abstract(cls, cfg: Config): q_a=ArrayInfo((cfg.embed, cfg.q_lora_rank), dtype, ("qkv_embed", "q_lora"), _init(1)), q_gamma=ArrayInfo((cfg.q_lora_rank,), dtype, ("q_lora",), _ones_init), q_b=ArrayInfo( - (cfg.q_lora_rank, cfg.num_heads, q_head_dim), dtype, + (cfg.q_lora_rank, cfg.num_heads, q_head_dim), + dtype, ("q_lora", "qkv_heads", "head_dim"), _init(1, 2), ), @@ -472,7 +482,8 @@ def abstract(cls, cfg: Config): k_pe=ArrayInfo((cfg.embed, cfg.qk_rope_head_dim), dtype, ("qkv_embed", "head_dim"), _init(1)), kv_gamma=ArrayInfo((cfg.kv_lora_rank,), dtype, ("kv_lora",), _ones_init), k_b=ArrayInfo( - (cfg.kv_lora_rank, cfg.num_heads, cfg.qk_nope_head_dim), dtype, + (cfg.kv_lora_rank, cfg.num_heads, cfg.qk_nope_head_dim), + dtype, ("kv_lora", "qkv_heads", "head_dim"), _init(1, 2), ), @@ -610,17 +621,10 @@ def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int, dtype: int = j if cfg.quantize_cache: _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype, zero_init=True) cache.k_nope = [ - QuantArray(*_quantize(k_nope), out_scaling=True, scale_expand_dims=-2) - for k_nope in cache.k_nope - ] - cache.k_pe = [ - QuantArray(*_quantize(k_pe), out_scaling=True, scale_expand_dims=-2) - for k_pe in cache.k_pe - ] - cache.v = [ - QuantArray(*_quantize(v), out_scaling=False, scale_expand_dims=-2) - for v in cache.v + QuantArray(*_quantize(k_nope), out_scaling=True, scale_expand_dims=-2) for k_nope in cache.k_nope ] + cache.k_pe = [QuantArray(*_quantize(k_pe), out_scaling=True, scale_expand_dims=-2) for k_pe in cache.k_pe] + cache.v = [QuantArray(*_quantize(v), out_scaling=False, scale_expand_dims=-2) for v in cache.v] return cache def fill_len(self) -> jax.Array: @@ -651,7 +655,10 @@ def update_slice(x: jax.Array | QuantArray, y: jax.Array, pos: int, update_axis: else: return jax.lax.dynamic_update_slice_in_dim(x, y.astype(x.dtype), pos, axis=update_axis) -def logical_sharding_constraint(x: jax.Array | QuantArray, logical_axes: Axes, mesh: jax.sharding.Mesh, rules: ShardingRules): + +def logical_sharding_constraint( + x: jax.Array | QuantArray, logical_axes: Axes, mesh: jax.sharding.Mesh, rules: ShardingRules +): """Generate a sharding constraint for a regular or QuantArray given its logical axes.""" sharding = logical_to_sharding(logical_axes, mesh, rules) if is_type(x, QuantArray): @@ -718,8 +725,8 @@ def apply_rotary_embedding(x, sin, cos): def make_attention_mask(q_len, k_len, q_segment_ids, kv_segment_ids, q_offset, kv_offset, causal: bool): - segment_mask = q_segment_ids[:, :, None] == kv_segment_ids[:, None, :] # [B, t, T] - segment_mask = segment_mask[:, None, :, :] # [B, t, T] -> [B, 1, t, T] + segment_mask = q_segment_ids[:, :, None] == kv_segment_ids[:, None, :] # [B, t, T] + segment_mask = segment_mask[:, None, :, :] # [B, t, T] -> [B, 1, t, T] if causal: qk = (1, 1, q_len, k_len) # [b, h, t, T] @@ -773,7 +780,7 @@ def attention( _, h, T, _ = k_nope.shape qk = einsum("bhtd,bhTd->bhtT", q_nope, k_nope) - #qk = qk + einsum("bhtd,bTd->bhtT", q_pe, k_pe) + # qk = qk + einsum("bhtd,bTd->bhtT", q_pe, k_pe) qk = qk + einsum("bhtd,b1Td->bhtT", q_pe, k_pe) qk = qk * scale # [b, h, t, T] @@ -787,63 +794,75 @@ def attention( return qkv.reshape((b, h, t, v.shape[-1])) -def attention_kernel(q, k, v, q_segment_ids, kv_segment_ids, q_offset, starts, lengths, cfg: Config): +def attention_kernel( + q_nope: jax.Array, + q_pe: jax.Array, + k_nope: jax.Array | tuple[jax.Array, jax.Array], + k_pe: jax.Array | tuple[jax.Array, jax.Array], + v: jax.Array | tuple[jax.Array, jax.Array], + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + q_offset: jax.Array, + kv_offset: jax.Array, + cfg: Config, +) -> jax.Array: """Flash attention kernel!""" - k, k_scale = (k.quant, k.scale) if is_type(k, QuantArray) else (k, None) + k_nope, k_nope_scale = (k_nope.quant, k_nope.scale) if is_type(k_nope, QuantArray) else (k_nope, None) + k_pe, k_pe_scale = (k_pe.quant, k_pe.scale) if is_type(k_pe, QuantArray) else (k_pe, None) v, v_scale = (v.quant, v.scale) if is_type(v, QuantArray) else (v, None) + scale = _get_attn_scale(q_nope.shape[-1] + q_pe.shape[-1], cfg) - # handle grouped query attention - assert q.shape[-3] % k.shape[-3] == 0 - scale = _get_attn_scale(q.shape[-1], cfg) + l2p = lambda *logical: logical_to_physical(logical, cfg.rules) + q_spec = l2p("batch", "qkv_heads", "sequence", "head_dim") - l2p = lambda *xs: logical_to_physical(xs, cfg.rules) in_specs = ( - l2p("batch", "act_heads", "sequence", "head_dim"), - l2p("batch", "act_heads", "sequence", "head_dim"), - l2p("batch", "act_heads", "sequence", "head_dim"), - l2p("batch", "sequence"), - l2p("batch", "sequence"), - l2p("batch") if starts is not None else None, - l2p("batch") if lengths is not None else None, - l2p("batch", "act_heads", "sequence") if k_scale is not None else None, - l2p("batch", "act_heads", "sequence") if v_scale is not None else None, + q_spec, # q_nope + q_spec, # q_pe + l2p("batch", "qkv_heads", "sequence", "head_dim"), # k_nope + l2p("batch", None, "sequence", "head_dim"), # k_pe + l2p("batch", "qkv_heads", "sequence", "head_dim"), # v + l2p("batch", "sequence"), # q_segment_ids + l2p("batch", "sequence"), # kv_segment_ids + None if k_nope_scale is None else l2p("batch", "qkv_heads", "sequence"), # k_nope_scale + None if k_pe_scale is None else l2p("batch", None, "sequence"), # k_pe_scale + None if v_scale is None else l2p("batch", "qkv_heads", "sequence"), # v_scale ) - out_specs = l2p("batch", "act_heads", "sequence", "head_dim") + out_specs = q_spec @partial(shard_map, mesh=cfg.mesh, in_specs=in_specs, out_specs=out_specs, check_rep=False) - def _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale): - q_org_shape = q.shape - kv_repeats = q.shape[-3] // k.shape[-3] - q = q.reshape(q.shape[:-3] + (k.shape[-3], kv_repeats, q.shape[-2], q.shape[-1])) - - if q.shape[-2] != 1: - mask = mask_lib.MultiHeadMask([mask_lib.CausalMask((q.shape[-2], k.shape[-2])) for _ in range(q.shape[-3])]) - block_q, block_kv = min(q.shape[-2], 512), min(k.shape[-2], 1024) - block_sizes = splash.BlockSizes(block_q=block_q, block_kv=block_kv, block_kv_compute=block_kv) - attn_fn = splash.make_splash_mqa_single_device(mask=mask, block_sizes=block_sizes) - attn_fn = jax.vmap(jax.vmap(attn_fn, in_axes=(0, 0, 0, None)), in_axes=(0, 0, 0, 0)) - - segment_ids = splash.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) - if k_scale is not None: - k = (k * k_scale[..., None]).astype(jnp.bfloat16) - if v_scale is not None: - v = (v * v_scale[..., None]).astype(jnp.bfloat16) - ret = attn_fn(q * scale, k, v, segment_ids) - else: - raise NotImplementedError - assert q.shape[-2] == 1, "This is a decode kernel, q.shape[-2] must be 1" - q = q[..., 0, :] - in_axes = (1, 1, 1, None, None) - in_axes += ((None if k_scale is None else 1),) - in_axes += ((None if v_scale is None else 1),) - hyperparams = dict(scale=scale, block_kv=min(k.shape[-2], 8192)) - ret = jax.vmap(partial(ragged_attention.ragged_decode_fwd, **hyperparams), in_axes=in_axes, out_axes=1)( # noqa: F821 - q, k, v, starts, lengths, k_scale, v_scale - ) - return ret.reshape(q_org_shape[:-1] + (v.shape[-1],)) + def _f(q_nope, q_pe, k_nope, k_pe, v, q_segment_ids, kv_segment_ids, k_nope_scale, k_pe_scale, v_scale): + q_seq, kv_seq, heads = q_nope.shape[-2], v.shape[-2], v.shape[-3] + block_q, block_kv = min(q_seq, 512), min(kv_seq, 1024) + block_sizes = splash.BlockSizes(block_q=block_q, block_kv=block_kv, block_kv_compute=block_kv) + + mask = mask_lib.MultiHeadMask([mask_lib.CausalMask((q_seq, kv_seq)) for _ in range(heads)]) + attn_static_fn = splash.make_splash_mha_single_device(mask=mask, block_sizes=block_sizes) + attn_static_fn = jax.vmap(attn_static_fn, in_axes=(0, 0, 0, 0)) # for prefill with an empty cache + + def attn_dynamic_fn(q, k, v, segment_ids): # when the offsets are different (chunked prefill) + mask = make_attention_mask(q_seq, kv_seq, q_segment_ids, kv_segment_ids, q_offset, kv_offset, causal=True) + attn_fn = lambda q, k, v, segment_ids, mask: splash.make_splash_mha_single_device( + mask=mask, block_sizes=block_sizes + )(q, k, v, segment_ids) + return jax.vmap(attn_fn, in_axes=(0, 0, 0, 0, 0))(q, k, v, segment_ids, mask) + + segment_ids = splash.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + if k_nope_scale is not None: + k_nope = (k_nope * k_nope_scale[..., None]).astype(jnp.bfloat16) + if k_pe_scale is not None: + k_pe = (k_pe * k_pe_scale[..., None]).astype(jnp.bfloat16) + if v_scale is not None: + v = (v * v_scale[..., None]).astype(jnp.bfloat16) + k = jnp.concatenate([k_nope, jnp.broadcast_to(k_pe, k_nope.shape[:-1] + k_pe.shape[-1:])], -1) + q = jnp.concatenate([q_nope, q_pe], -1) + # jax.debug.print("Using a static mask: {}", jnp.all(q_offset == kv_offset)) + return jax.lax.cond( + jnp.all(q_offset == kv_offset), attn_static_fn, attn_dynamic_fn, q * scale, k, v, segment_ids + ) - lengths = jnp.broadcast_to(lengths, starts.shape) - return _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale).astype(jnp.bfloat16) + return _f(q_nope, q_pe, k_nope, k_pe, v, q_segment_ids, kv_segment_ids, k_nope_scale, k_pe_scale, v_scale).astype( + jnp.bfloat16 + ) def rms_norm(x: jax.Array, gamma: jax.Array) -> jax.Array: @@ -873,8 +892,8 @@ def mla_attention_block( with jax.named_scope("kv_compressed_embed"): kv_compressed = einsum("btd,dr->btr", x, attn_layer.kv_a).astype(dtype) kv_compressed = rms_norm(kv_compressed, attn_layer.kv_gamma).astype(dtype) - #k_pe = einsum("btd,dq->btq", x, attn_layer.k_pe) - #k_pe = apply_rotary_embedding(k_pe[..., None, :, :], sin, cos)[..., 0, :, :].astype(dtype) + # k_pe = einsum("btd,dq->btq", x, attn_layer.k_pe) + # k_pe = apply_rotary_embedding(k_pe[..., None, :, :], sin, cos)[..., 0, :, :].astype(dtype) k_pe = einsum("btd,dq->btq", x, attn_layer.k_pe)[..., None, :, :] k_pe = apply_rotary_embedding(k_pe, sin, cos).astype(dtype) @@ -892,7 +911,7 @@ def mla_attention_block( if is_type(cache, KVCache): it = jnp.maximum(cache.iter, 0) k_nope = update_slice(cache.k_nope[idx], k_nope, it, update_axis=cache.time_axis) - #k_pe = update_slice(cache.k_pe[idx], k_pe, it, update_axis=cache.time_axis - 1) + # k_pe = update_slice(cache.k_pe[idx], k_pe, it, update_axis=cache.time_axis - 1) k_pe = update_slice(cache.k_pe[idx], k_pe, it, update_axis=cache.time_axis) v = update_slice(cache.v[idx], v, it, update_axis=cache.time_axis) cache_updates = (k_nope, k_pe, v) @@ -916,25 +935,14 @@ def mla_attention_block( lsc = partial(logical_sharding_constraint, mesh=cfg.mesh, rules=cfg.rules) spec = ("batch", "act_heads", "sequence", "head_dim") q_nope, q_pe = lsc(q_nope, spec), lsc(q_pe, spec) - #k_nope, k_pe, v = lsc(k_nope, spec), lsc(k_pe, ("batch", "sequence", "head_dim")), lsc(v, spec) + # k_nope, k_pe, v = lsc(k_nope, spec), lsc(k_pe, ("batch", "sequence", "head_dim")), lsc(v, spec) k_nope, k_pe, v = lsc(k_nope, spec), lsc(k_pe, ("batch", None, "sequence", "head_dim")), lsc(v, spec) # Compute attention with jax.named_scope("attention"): - if (cfg.use_prefill_attn_kernel and q.shape[-2] != 1) or (cfg.use_decode_attn_kernel and q.shape[-2] == 1): - raise NotImplementedError + if which_platform(cfg) == "tpu" and cfg.use_prefill_attn_kernel and q.shape[-2] != 1: attn_out = attention_kernel( - q_nope, - q_pe, - k_nope, - k_pe, - v, - q_segment_ids, - kv_segment_ids, - q_offset, - starts=starts, - lengths=lengths, - cfg=cfg, + q_nope, q_pe, k_nope, k_pe, v, q_segment_ids, kv_segment_ids, q_offset, kv_offset, cfg=cfg ) else: attn_out = attention(q_nope, q_pe, k_nope, k_pe, v, q_segment_ids, kv_segment_ids, q_offset, kv_offset, cfg) @@ -1120,26 +1128,26 @@ def _expert_fn(x, we_gate, we_up, we_down, topk_weights, topk_idx): if expert_axname is not None: ff_out_expert = jax.lax.psum(ff_out_expert, expert_axname) else: - # collectives - if is_embedding_sharded: # activations are supposed to be sharded on out - with jax.named_scope("tp_e_psum_scatter"): - ff_out_expert = jax.lax.psum_scatter( - ff_out_expert, tensor_axname, scatter_dimension=1, tiled=True - ) - with jax.named_scope("ep_e_psum"): - if expert_axname is not None: - ff_out_expert = jax.lax.psum(ff_out_expert, expert_axname) - else: - psum_axes = tensor_axname if expert_axname is None else (expert_axname, tensor_axname) - ff_out_expert = jax.lax.psum(ff_out_expert, psum_axes) + # collectives + if is_embedding_sharded: # activations are supposed to be sharded on out + with jax.named_scope("tp_e_psum_scatter"): + ff_out_expert = jax.lax.psum_scatter( + ff_out_expert, tensor_axname, scatter_dimension=1, tiled=True + ) + with jax.named_scope("ep_e_psum"): + if expert_axname is not None: + ff_out_expert = jax.lax.psum(ff_out_expert, expert_axname) + else: + psum_axes = tensor_axname if expert_axname is None else (expert_axname, tensor_axname) + ff_out_expert = jax.lax.psum(ff_out_expert, psum_axes) ff_out_expert = ff_out_expert.reshape((b, s, ff_out_expert.shape[-1])) return ff_out_expert with jax.named_scope("moe_routed_expert"): x_ = psc(x, x_spec) - ff_out_expert = _expert_fn(x_, we_gate, we_up, we_down, topk_weights, topk_idx)[..., :x.shape[-1]] + ff_out_expert = _expert_fn(x_, we_gate, we_up, we_down, topk_weights, topk_idx)[..., : x.shape[-1]] with jax.named_scope("moe_shared_expert"): - ff_out_shared = mlp_block(x, MLPLayer(layer.ws_gate, layer.ws_up, layer.ws_down), cfg)[..., :x.shape[-1]] + ff_out_shared = mlp_block(x, MLPLayer(layer.ws_gate, layer.ws_up, layer.ws_down), cfg)[..., : x.shape[-1]] return psc(ff_out_expert + ff_out_shared, l2p("batch", "sequence", "act_embed")) @@ -1197,7 +1205,7 @@ def forward(x: jax.Array, segment_ids: jax.Array, weights: Weights, cfg: Config, positions = segment_ids_to_positions(segment_ids) if is_type(cache, KVCache): positions = positions + cache.fill_len()[:, None] - sin, cos = generate_pos_embeddings(positions, cfg.qk_rope_head_dim, cfg) # [B, T, head_dim] + sin, cos = generate_pos_embeddings(positions, cfg.qk_rope_head_dim, cfg) # [B, T, head_dim] sin, cos = sin.astype(cfg.dtype), cos.astype(cfg.dtype) all_cache_updates = [] diff --git a/deepseek_r1_jax/main.ipynb b/deepseek_r1_jax/main.ipynb deleted file mode 100644 index 77997c4..0000000 --- a/deepseek_r1_jax/main.ipynb +++ /dev/null @@ -1,147 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0c88bb62", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "import ipyparallel as ip\n", - "\n", - "NODES = 16 # adjust\n", - "client = ip.Client(connection_info=str(Path(\"~/nfs/security/ipcontroller-client.json\").expanduser()))\n", - "client.wait_for_engines(NODES - 1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b715596", - "metadata": {}, - "outputs": [], - "source": [ - "%%px --local\n", - "import dataclasses\n", - "from pprint import pformat\n", - "from pathlib import Path\n", - "\n", - "import jax\n", - "from jax import numpy as jnp\n", - "from jax import random\n", - "import numpy as np\n", - "from etils import epath\n", - "\n", - "jax.config.update(\"jax_compilation_cache_dir\", str(Path(\"~/.jax_cache\").expanduser()))\n", - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "print(jax.devices())\n", - "\n", - "from deepseek_r1_jax import model as dsjax\n", - "from deepseek_r1_jax import chkpt_utils as utils" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d991e2a0", - "metadata": {}, - "outputs": [], - "source": [ - "%%px --local\n", - "def encode_input(tokenizer, texts, pad_id: int = 0):\n", - " assert isinstance(texts, list)\n", - " inputs = [\n", - " tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": text}]) + tokenizer.encode(\"<|Assistant|>\")\n", - " for text in texts\n", - " ]\n", - " max_len = max([len(x) for x in inputs])\n", - " inputs = [(max_len - len(x)) * [pad_id] + x for x in inputs]\n", - " return np.array(inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "81a2400d", - "metadata": {}, - "outputs": [], - "source": [ - "%%px --local\n", - "ckpt_path = epath.Path(f\"~/bucket/deepseek-r1-jax-chkpt\").expanduser()\n", - "tokenizer = dsjax.load_tokenizer()\n", - "mesh = jax.make_mesh((1, 4, jax.device_count() // 4), (\"x\", \"y\", \"z\"), devices=jax.devices())\n", - "cfg = dataclasses.replace(dsjax.Config(), mesh=mesh)\n", - "weights = utils.load_model(epath.Path(ckpt_path).expanduser(), cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "882caef8", - "metadata": {}, - "outputs": [ - { - "ename": "", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", - "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", - "\u001b[1;31mClick here for more info. \n", - "\u001b[1;31mView Jupyter log for further details." - ] - } - ], - "source": [ - "%%px --local\n", - "input = encode_input(\n", - " tokenizer,\n", - " [\n", - " \"Tell me your name\",\n", - " \"What is the weather like expressed in long prose in Old English\",\n", - " \"Do you like ice cream, be extremely precise\",\n", - " ],\n", - ")\n", - "\n", - "zero_cache = dsjax.KVCache.init(random.key(1), cfg, input.shape[0], cfg.max_seq_len)\n", - "curr_tokens, logits, cache = dsjax.prefill(input, weights, zero_cache, cfg)\n", - "curr_tokens, tokens_list = curr_tokens[:, cache.length - 1 : cache.length], []\n", - "tokens_list = []\n", - "for _ in range(32):\n", - " tokens_list.append(curr_tokens)\n", - " curr_tokens, cache = dsjax.decode_step(curr_tokens, weights, cache, cfg)\n", - "tokens = np.array(jnp.concatenate(tokens_list, axis=-1))\n", - "responses = [tokenizer.decode(row) for row in tokens]\n", - "print(\"Responses:\\n\" + pformat(responses))" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "-all", - "main_language": "python", - "notebook_metadata_filter": "-all" - }, - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.15" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From bfc06656704503ee0f1824d5032260b2d4aaa1d5 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 31 Jul 2025 20:37:15 -0700 Subject: [PATCH 06/11] Skip non-participating hosts in computation --- serving/serving_jax/__init__.py | 84 ++++++++++++++++++++----------- serving/serving_jax/cross_host.py | 2 +- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index 60eede8..3cb1f17 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from functools import partial +from functools import partial, wraps from typing import Any, Callable import math from concurrent.futures import ThreadPoolExecutor, Future @@ -43,7 +43,7 @@ TIME_AXIS = 2 USE_PREFIX_CACHE = True # the eviction mechanism is extremely simple right now -#USE_PREFIX_CACHE = False +# USE_PREFIX_CACHE = False is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) ######################################################################################################################## @@ -98,6 +98,7 @@ def _ensure_all_args_on_mesh(*args, mesh: Mesh): # kv cache buffer management ########################################################################################### ######################################################################################################################## + @partial(jax.jit, static_argnames=("axis", "chunk_size", "ns")) def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]: def _fn(val): @@ -133,7 +134,7 @@ def _get_unique_buffer_ids(self, n: int): def offload_buffers(self, how_many: int): if how_many == 0: return - candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2 ** 60) + candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2**60) for i in candidates[:how_many]: if self.ondevice[i]: shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i]) @@ -168,6 +169,7 @@ def mark_visited(self, id: int): return [self.mark_visited(i) for i in id] self.usecount[id] += 1 + BUFFER_STORE = KVBufferStore() ######################################################################################################################## @@ -176,6 +178,7 @@ def mark_visited(self, id: int): EMPTY, HASH_BITWIDTH = -1, 1 + @dataclasses.dataclass class ChildKeys: keys: np.ndarray @@ -194,7 +197,12 @@ def _hash_encode(v: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int def _prefilter_on_hash( - w: np.ndarray, keys: np.ndarray, vh: np.ndarray, vm: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int = EMPTY + w: np.ndarray, + keys: np.ndarray, + vh: np.ndarray, + vm: np.ndarray, + hash_bitwidth: int = HASH_BITWIDTH, + pad_idx: int = EMPTY, ): wh, wm = _hash_encode(w, hash_bitwidth=hash_bitwidth, pad_idx=pad_idx) inv_match = (wh ^ vh) & vm & wm @@ -203,6 +211,7 @@ def _prefilter_on_hash( max_match_len = max(np.max(match_len), 1) return np.where(match_len == max_match_len)[0] + def _fast_pad(x, size, axis, pad_val=0): new_buf = pad_val * np.ones([size - s if i == axis else s for i, s in enumerate(x.shape)], dtype=x.dtype) return np.concat([x, new_buf], axis) @@ -344,6 +353,9 @@ def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pa next_power_of_2 = lambda x: 2 ** round(math.ceil(math.log2(x))) like_spec = lambda z: jax.tree.map(lambda x: jax.typeof(x).sharding.spec, z) like_shard = lambda z, mesh: jax.tree.map(lambda x: NamedSharding(mesh, jax.typeof(x).sharding.spec), z) +_make_empty = lambda x, mesh: jax.make_array_from_single_device_arrays( + x.shape, NamedSharding(mesh, x.sharding.spec), [], dtype=x.dtype +) @dataclasses.dataclass @@ -446,7 +458,15 @@ def body(carry, _): (curr_tokens, cache), output_tokens = jax.lax.scan(body, (curr_tokens, cache), length=steps) return (curr_tokens, cache), output_tokens[..., 0].T - return multistep_decode_fn + @wraps(multistep_decode_fn) + def wrapped(curr_tokens, decode_weights, cache, cfg, steps: int = 32, *, participate: bool = True): + if participate: + return multistep_decode_fn(curr_tokens, decode_weights, cache, cfg, steps=steps) + else: + _make_empty_, fn = partial(_make_empty, mesh=cfg.mesh), multistep_decode_fn + return jax.tree.map(_make_empty_, jax.eval_shape(fn, curr_tokens, decode_weights, cache, cfg, steps=steps)) + + return wrapped def _make_stacked_prefill(prefill_fn): @@ -461,7 +481,15 @@ def stacked_prefill(inputs, weights, cfg): stacked_kv = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *kv_list) return next_tokens, logits, stacked_kv - return lambda inputs, weights, cfg: stacked_prefill(_numpy_pad_tokens(inputs), weights, cfg) + @wraps(stacked_prefill) + def wrapped(inputs, weights, cfg, *, participate: bool = True): + if participate: + return stacked_prefill(_numpy_pad_tokens(inputs), weights, cfg) + else: + _make_empty_ = partial(_make_empty, mesh=cfg.mesh) + return jax.tree.map(_make_empty_, jax.eval_shape(stacked_prefill, _numpy_pad_tokens(inputs), weights, cfg)) + + return wrapped class ServingLoop: @@ -476,8 +504,8 @@ def __init__( decode_cache: KVCache, is_server: bool = False, ): - if not (SyncServer.broadcast("welcome", 0, True, is_server) if jax.process_count() > 1 else is_server): - raise ValueError("No processes registered as the main server, at least one process must.") + if not SyncServer.broadcast("welcome", 0, is_server, is_server): + raise ValueError("Neither this proccess nor any other processe is the main server, at least one must.") self.serve_cfg, self.cfg = serve_cfg, cfg # setup decode @@ -506,7 +534,7 @@ def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, self.decode_output = (None, None) # setup prefill - self.prefill_fn = staticmethod(_make_stacked_prefill(prefill_fn)) + self.prefill_fn = _make_stacked_prefill(prefill_fn) self.prefill_weights = prefill_weights self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh self.prefill_work = PrefillWork([], [], []) @@ -572,19 +600,16 @@ def decode_step(self): # 2. run N decode steps output_tokens, output_mapping = [], [] - if "decode" in self.roles: # TODO(rdyro): revisit: don't issue the decode call on non-participating machines - with use_mesh(self.decode_mesh): - config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) - (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn( - self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config - ) - output_mapping = [ - [getattr(result, "id", -1) for result in self.decode_work.active_results] - ] * self.serve_cfg.decode_steps - output_mapping = np.array(output_mapping).T - print( - f"Decoding with fill rate of {np.mean([result is not None for result in self.decode_work.active_results])}" + with use_mesh(self.decode_mesh): + config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps, participate="decode" in self.roles) + (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn( + self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config ) + output_mapping = [ + [getattr(result, "id", -1) for result in self.decode_work.active_results] + ] * self.serve_cfg.decode_steps + output_mapping = np.array(output_mapping).T + print(f"Decoding with fill rate: {np.mean([result is not None for result in self.decode_work.active_results])}") # 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection) self.decode_output, (output_tokens, output_mapping) = (output_tokens, output_mapping), self.decode_output @@ -607,8 +632,7 @@ def decode_step(self): (output_tokens_flat, output_mapping_flat, done), is_source="decode_coordinator" in self.roles, ) - #if "server" in self.roles or "decode_coordinator" in self.roles: - for token, id in zip(output_tokens.reshape(-1).tolist(), output_mapping.reshape(-1).tolist()): + for token, id in zip(output_tokens_flat, output_mapping_flat): if id > 0: self.results[id].token_list.append(token) self.results[id].tokens_decoded += 1 @@ -664,7 +688,9 @@ def prefill_step(self): self.prefill_work.to_decode.append(new_decode) print(f"Found a full match") else: - print(f"Need to prefill the request, only found a match for length {total_match / (len(request.text) - 1)}") + print( + f"Need to prefill the request, only found a match for length {total_match / (len(request.text) - 1)}" + ) self.prefill_work.to_prefill.append(request) if self.prefill_work.pending_prefill is not None: # a current prefill is still running, skip scheduling another @@ -684,7 +710,9 @@ def _prefill_job(): inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) with use_mesh(self.prefill_mesh): - _, _, prefill_results = self.prefill_fn(inputs, self.prefill_weights, cfg) + _, _, prefill_results = self.prefill_fn( + inputs, self.prefill_weights, cfg, participate="prefill" in self.roles + ) prefill_results = jax.block_until_ready(prefill_results) return prefill_input, prefill_results @@ -719,10 +747,10 @@ def serving_step(self): if "server" in self.roles: with self.state_lock: self.pending_requests, requests = [], list(self.pending_requests) + serve_cfg, requests = SyncServer.broadcast( + "requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles + ) with self.state_lock: - serve_cfg, requests = SyncServer.broadcast( - "requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles - ) self.serve_cfg = dataclasses.replace(self.serve_cfg, **serve_cfg) for request in requests: self.total_requests += 1 diff --git a/serving/serving_jax/cross_host.py b/serving/serving_jax/cross_host.py index 7c59fce..65f3bc4 100644 --- a/serving/serving_jax/cross_host.py +++ b/serving/serving_jax/cross_host.py @@ -8,7 +8,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import numpy as np -jax.config.update("jax_enable_empty_arrays", True) +#jax.config.update("jax_enable_empty_arrays", True) PyTree = Any From 13b52cca1369928e4a3f26129e7820cd0b3fa746 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Mon, 4 Aug 2025 19:36:17 -0700 Subject: [PATCH 07/11] Partial prefix match support via chunked prefill --- .../scripts/convert_hf_r1_checkpoint.py | 27 +- llama3/llama3_jax/model.py | 26 +- llama3/pyproject.toml | 2 + serving/main_serving.py | 60 ++-- serving/serving_jax/__init__.py | 275 +++++++++--------- serving/serving_jax/cross_host.py | 9 +- 6 files changed, 235 insertions(+), 164 deletions(-) diff --git a/deepseek_r1_jax/scripts/convert_hf_r1_checkpoint.py b/deepseek_r1_jax/scripts/convert_hf_r1_checkpoint.py index c6832d4..34cf846 100644 --- a/deepseek_r1_jax/scripts/convert_hf_r1_checkpoint.py +++ b/deepseek_r1_jax/scripts/convert_hf_r1_checkpoint.py @@ -19,13 +19,15 @@ import jax from jax.sharding import PartitionSpec as P +from argparse import ArgumentParser -from deepseek_r1_jax.model import ShardingRules, Config -from deepseek_r1_jax import chkpt_utils as utils -def main(): - root_path = Path("/mnt/storage/DeepSeek-R1") - dest_path = Path("/mnt/storage/deepseek-r1-jax-chkpt") +def main(root_path, dest_path): + from deepseek_r1_jax.model import ShardingRules, Config + from deepseek_r1_jax import chkpt_utils as utils + + root_path, dest_path = Path(root_path), Path(dest_path) + dest_path.mkdir(exist_ok=True, parents=True) cfg = Config() cfg.quantize_mlp = False @@ -39,4 +41,17 @@ def main(): utils.convert_hf_checkpoint(params_map, root_path, dest_path, cfg) if __name__ == "__main__": - main() + parser = ArgumentParser() + parser.add_argument( + "--source-path", default="/mnt/storage/DeepSeek-R1-weights-only", required=True, help="HF model directory path" + ) + parser.add_argument( + "--dest-path", + default="~/deepseek_r1_jax", + required=True, + help="JAX model model directory (to be created).", + ) + args = parser.parse_args() + main(args.source_path, args.dest_path) + + main(args) diff --git a/llama3/llama3_jax/model.py b/llama3/llama3_jax/model.py index 61cceed..a54539a 100644 --- a/llama3/llama3_jax/model.py +++ b/llama3/llama3_jax/model.py @@ -21,8 +21,8 @@ import math from functools import partial from typing import Callable, Any, TypeVar -from types import ModuleType from inspect import signature +from collections import OrderedDict as odict import jax import jax.numpy as jnp @@ -31,7 +31,8 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib from jax.experimental.shard_map import shard_map -from jax.sharding import PartitionSpec as P, use_mesh +from jax.sharding import PartitionSpec as P +from jax.experimental.array_serialization import pytree_serialization as ser try: from jax.experimental.shard import auto_axes as _auto_axes, reshard except ModuleNotFoundError: @@ -213,6 +214,7 @@ class ArrayInfo: # module reload friendly isinstance check is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) is_param = lambda x: is_type(x, ArrayInfo) +which_platform = lambda cfg: cfg.mesh.devices.reshape(-1)[0].platform _count_left_padding = lambda ids, pad_id=0: auto_axes( lambda ids: jnp.sum(jnp.cumsum(ids != pad_id, axis=-1) == 0, axis=-1), out_sharding=P(None) )(ids) @@ -404,15 +406,18 @@ def abstract(cls, cfg: Config): ) -@partial(jax_pytree_struct, meta_fields=("batch_size", "size", "time_axis")) +@partial(jax_pytree_struct, meta_fields=("batch_size", "size", "time_axis", "insert_sequences")) class KVCache(_Init): k: list[tuple[jax.Array | QuantArray, ...]] # (batch_size, key_heads, max_seq_len, head_dim) v: list[tuple[jax.Array | QuantArray, ...]] # (batch_size, key_heads, max_seq_len, head_dim) iter: jax.Array # [] # sequences are right-aligned for slice update performance starts: jax.Array # [batch_size] # sequences are right-aligned, we need start indices - batch_size: int = 0 + batch_size: int = 1 size: int = 2 ** 30 time_axis: int = 2 + #update_slice: Callable = None + insert_sequences: Callable = None + #get_sequence: Callable = None @classmethod def abstract(cls, cfg: Config, batch_size: int): @@ -798,6 +803,8 @@ def _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale def paged_attention_kernel(q, k, v, block_tables, lengths, cfg: Config): + if which_platform(cfg) not in ("gpu", "cuda"): + raise ValueError("Paged attention is only supported on GPU.") k, k_scale = (k.quant, k.scale) if is_type(k, QuantArray) else (k, None) v, v_scale = (v.quant, v.scale) if is_type(v, QuantArray) else (v, None) @@ -1030,6 +1037,17 @@ def prepare_chunk(chunk, pad_to: int, pad_id: int): return chunk, segment_ids +## serialization +#def save_pytree(data, path): +# flat_data = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(data)[0]) +# ser.save(flat_data, path) # save a flatten with path to avoid custom +# +# +#def load_pytree(path, sharding=None): +# flat_sharding = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(sharding)[0]) +# return jax.tree.unflatten(jax.tree.structure(sharding), jax.tree.leaves(ser.load(path, flat_sharding))) + + def prefill(tokens: jax.Array, weights: Weights, cache: KVCache | None, cfg: Config, pad_id: int = 0): """Samples from a prompt.""" # Calculate the next power of 2 for padding, up to cfg.max_seq. diff --git a/llama3/pyproject.toml b/llama3/pyproject.toml index 89e5261..f0bdeb5 100644 --- a/llama3/pyproject.toml +++ b/llama3/pyproject.toml @@ -15,6 +15,8 @@ dependencies = [ #"transformers", # for the model config and the tokenizer "tqdm", "numpy", + "flatbuffers", + "tensorstore", #"orbax-checkpoint", #"datasets", "gcsfs", diff --git a/serving/main_serving.py b/serving/main_serving.py index e79c543..d8d80cc 100644 --- a/serving/main_serving.py +++ b/serving/main_serving.py @@ -8,8 +8,8 @@ import time from typing import AsyncGenerator from contextlib import asynccontextmanager -import os from argparse import ArgumentParser +from typing import Any import jax from jax import random @@ -24,16 +24,15 @@ import serving_jax as serving from serving_jax import attention_cache_utils +Config = Any TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None jax.config.update("jax_explain_cache_misses", True) -jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) -jax.config.update("jax_enable_empty_arrays", True) +#jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) try: # newer JAX only - assert False - my_id = int(socket.gethostname().split("-")[-1]) - 1 + my_id = int(socket.gethostname().split("-")[-1]) my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) @@ -60,39 +59,60 @@ def load_model(): #process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) #jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) + jax.distributed.initialize() print(jax.devices()) print("-" * 80) print(jax.local_devices()) - model_name = "Llama-3.1-8B-Instruct" - ckpt_path = Path(f"~/{model_name}").expanduser() + #model_name = "Llama-3.1-8B-Instruct" + #ckpt_path = Path(f"~/{model_name}").expanduser() + #model_name = "Llama-3.1-8B-Instruct-quant" + model_name = "Llama-3.1-70B-Instruct-quant" + ckpt_path = Path(f"~/bucket/llama3_jax_old/{model_name}").expanduser() cfg = l3jax.load_config(ckpt_path / "config.json") TOKENIZER = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") assert ckpt_path.is_dir() print("---> Model config loaded") # two hosts, different device and host meshes - local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) - decode_mesh, prefill_mesh = local_mesh, local_mesh + #local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) + #local_mesh = jax.make_mesh((1, 1, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) + #decode_mesh, prefill_mesh = local_mesh, local_mesh + decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) + prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3) + #decode_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3) + #prefill_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3) cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True) - cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=8192) - cfg = dataclasses.replace(cfg, quant_layer=False, quant_cache=False) - cfg.quant_cache = True + cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=2048) + cfg.quant_cache = False decode_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh))) prefill_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh))) print("---> Weights loaded") - serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64) - #decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) - #decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry - #decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache - decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) - decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry - decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences + serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64, prefix_chunk_size=64) + decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) + decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry + decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache + #decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) + #decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry + #decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences + + def init_cache(cfg: Config, batch_size: int, actual_len: int): + cache = l3jax.KVCache.init(random.key(0), cfg, batch_size) + cache.get_sequence = attention_cache_utils.kvcache_get_entry + cache.insert_sequences = attention_cache_utils.kvcache_update_cache + cache.iter = actual_len + return cache + + with jax.sharding.set_mesh(prefill_mesh): + prefill_cache = init_cache(dataclasses.replace(cfg, mesh=prefill_mesh), serve_cfg.prefill_batch_size, 8192) + + forward_fn = l3jax.decode_step # TODO: the model file needs to call it forward explcitly SERVE_LOOP = serving.ServingLoop( - serve_cfg, cfg, l3jax.prefill, prefill_weights, l3jax.decode_step, decode_weights, decode_cache, ARGS.server + #serve_cfg, cfg, init_cache, l3jax.decode_step, prefill_weights, decode_weights, decode_cache, ARGS.server + serve_cfg, cfg, forward_fn, prefill_weights, prefill_cache, decode_weights, decode_cache, ARGS.server ) print("---> Created the serving loop") diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index 3cb1f17..32d4d77 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -13,18 +13,18 @@ # limitations under the License. import dataclasses -from functools import partial, wraps -from typing import Any, Callable +import contextlib +from functools import partial +from typing import Any, Callable, Sequence import math from concurrent.futures import ThreadPoolExecutor, Future import threading import time import json -from typing import Any import jax import jax.numpy as jnp -from jax.sharding import Mesh, PartitionSpec as P, NamedSharding, use_mesh +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding, set_mesh try: from jax.experimental.shard import auto_axes @@ -75,23 +75,23 @@ def jax_device_put(xs: PyTree, sharding: PyTree): def jit_device_put(xs: PyTree, sharding: PyTree): """Most compatabile, uses jit, so requires blocking dispatch.""" - jax.sharding.set_mesh(None) # not compatible with context mesh + # jax.sharding.set_mesh(None) # not compatible with context mesh meshA, meshB = jax.tree.leaves(xs)[0].sharding.mesh, jax.tree.leaves(sharding)[0].mesh return transfer_tree_A2B(xs, meshA, meshB) -device_put = jit_device_put # the most compatible options currently, but NOT async, need +#device_put = jit_device_put # the most compatible options currently, but NOT async, need +device_put = jax.device_put -def _ensure_all_args_on_mesh(*args, mesh: Mesh): - args_len = len(args) +def _ensure_all_args_on_mesh(args, mesh: Mesh): if not all(jax.tree.leaves(arg)[0].sharding.mesh == mesh for arg in args): _correct_mesh = lambda value: jax.tree.leaves(value)[0].sharding.mesh == mesh _args = {i: arg for i, arg in enumerate(args) if not _correct_mesh(arg)} if len(_args) > 0: args = dict(enumerate(args)) | device_put(_args, like_shard(_args, mesh)) args = tuple(args[i] for i in range(len(args))) - return args if args_len > 1 else args[0] + return args ######################################################################################################################## @@ -137,8 +137,9 @@ def offload_buffers(self, how_many: int): candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2**60) for i in candidates[:how_many]: if self.ondevice[i]: - shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i]) - self._store[i] = jax.device_put(self._store[i], shrd) + host_shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i]) + self._store[i] = jax.device_put(self._store[i], host_shrd) + self.ondevice[i] = False self.livecount -= 1 def load(self, id: int): @@ -148,8 +149,8 @@ def load(self, id: int): return self._store[id] self.ondevice[id] = True self.livecount += 1 - shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("device"), self._store[id]) - self._store[id] = jax.device_put(self._store[id], shrd) + device_shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("device"), self._store[id]) + self._store[id] = jax.device_put(self._store[id], device_shrd) return self._store[id] def delete(self, id: int): @@ -178,7 +179,6 @@ def mark_visited(self, id: int): EMPTY, HASH_BITWIDTH = -1, 1 - @dataclasses.dataclass class ChildKeys: keys: np.ndarray @@ -268,6 +268,17 @@ def _append_key(keys: ChildKeys | None, new_key: np.ndarray, key_len: int, pad_i keys.num += 1 return keys + @staticmethod + def _delete_keys(keys: ChildKeys, delete_idxs: np.ndarray): + if keys is None: + return + mask = np.ones(keys.keys.shape[0], dtype=bool) + mask[np.array(list(delete_idxs) if isinstance(delete_idxs, set) else delete_idxs, int)] = False + if np.sum(mask) == 0: + return None + num = max(keys.num - sum(1 for idx in set(delete_idxs) if idx < keys.num), 0) + return ChildKeys(*(z[mask, ...] for z in [keys.keys, keys.keys_hash, keys.keys_hash_mask, keys.key_lens]), num) + @staticmethod def _pad_to_multiple_of(sequence: np.ndarray, chunk_size: int, pad_idx: int = EMPTY): sequence_pad_len = math.ceil(sequence.size / chunk_size) * chunk_size @@ -338,13 +349,28 @@ def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pa match_length = exact_match[best_idx] if match_length > 0: visited_refs.append(int(node.children[best_idx].value)) + if match_length == 0: + break node = node.children[best_idx] total_match += int(match_length) + ref_vals.append(node.value) if match_length != seq_len: break - ref_vals.append(node.value) return (total_match, ref_vals), visited_refs +def remove_prefix_nodes(node: TrieNode, refs_to_delete: Sequence[int]): + refs_to_delete, deleted_refs = set(refs_to_delete), set() + ctx = node.lock if node.lock is not None else contextlib.nullcontext() + with ctx: + for child in node.children: + deleted_refs |= remove_prefix_nodes(child, refs_to_delete) + deleted_refs |= set(child.value for child in node.children if child.value in refs_to_delete) + delete_idxs = set([i for i, child in enumerate(node.children) if child.value in refs_to_delete]) + for idx in delete_idxs: # if we're removing a full child, tell it to remove all its children first + deleted_refs |= remove_prefix_nodes(node.children[idx], [c.value for c in node.children[idx].children]) + node.child_keys = TrieNode._delete_keys(node.child_keys, delete_idxs) + node.children = [child for i, child in enumerate(node.children) if i not in delete_idxs] + return set(deleted_refs) ######################################################################################################################## # serving loop ######################################################################################################### @@ -354,7 +380,7 @@ def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pa like_spec = lambda z: jax.tree.map(lambda x: jax.typeof(x).sharding.spec, z) like_shard = lambda z, mesh: jax.tree.map(lambda x: NamedSharding(mesh, jax.typeof(x).sharding.spec), z) _make_empty = lambda x, mesh: jax.make_array_from_single_device_arrays( - x.shape, NamedSharding(mesh, x.sharding.spec), [], dtype=x.dtype + x.shape, NamedSharding(mesh, jax.typeof(x).sharding.spec), [], dtype=x.dtype ) @@ -367,6 +393,8 @@ class ServingConfig: eos_tokens: tuple[int, ...] | jax.Array = () token_pad_idx: int = 0 max_decode_length: int = 64 + max_ondevice_buffers: int = 100 + max_buffers: int = 256 @dataclasses.dataclass @@ -383,6 +411,13 @@ class DecodeResult: done: bool = False +@dataclasses.dataclass +class PrefillJob: + request: UserRequestPrompt + cache_entry: Any + match_len: int + + @dataclasses.dataclass class PrefillResult: id: int @@ -446,6 +481,13 @@ def broadcast(key: str, current_it: int, value: Any, is_source: bool = False, js value = client.blocking_key_value_get(key + str(current_it), SyncServer.TIMEOUT_SEC * 1000) return json.loads(value) if jsonify else value +def maybe_call(fn: Callable, mesh: Mesh): + """Only call the program if the host worker is participating, get (truly) empty arrys with correct sharding.""" + mesh_devices = set(d.id for d in mesh.devices.flat) + if any(d.id in mesh_devices for d in jax.local_devices()): # host has some participating devices + return fn + return (lambda *args, **kw: jax.tree.map(partial(_make_empty, mesh=mesh), jax.eval_shape(fn, *args, **kw))) + def _make_multistep_decode_fn(decode_fn): @partial(jax.jit, static_argnames=("steps",), donate_argnames=("cache",)) @@ -458,38 +500,7 @@ def body(carry, _): (curr_tokens, cache), output_tokens = jax.lax.scan(body, (curr_tokens, cache), length=steps) return (curr_tokens, cache), output_tokens[..., 0].T - @wraps(multistep_decode_fn) - def wrapped(curr_tokens, decode_weights, cache, cfg, steps: int = 32, *, participate: bool = True): - if participate: - return multistep_decode_fn(curr_tokens, decode_weights, cache, cfg, steps=steps) - else: - _make_empty_, fn = partial(_make_empty, mesh=cfg.mesh), multistep_decode_fn - return jax.tree.map(_make_empty_, jax.eval_shape(fn, curr_tokens, decode_weights, cache, cfg, steps=steps)) - - return wrapped - - -def _make_stacked_prefill(prefill_fn): - def _numpy_pad_tokens(tokens): - opts = dict(mode="constant", constant_values=0) - return np.pad(tokens, [(0, 0), (0, next_power_of_2(tokens.shape[-1]) - tokens.shape[-1])], **opts) - - @jax.jit - def stacked_prefill(inputs, weights, cfg): - next_tokens, logits, kv_list = prefill_fn(inputs, weights, None, cfg) - assert len(kv_list) == cfg.num_layers, "The output kv values have to be in a list kv pairs." - stacked_kv = jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *kv_list) - return next_tokens, logits, stacked_kv - - @wraps(stacked_prefill) - def wrapped(inputs, weights, cfg, *, participate: bool = True): - if participate: - return stacked_prefill(_numpy_pad_tokens(inputs), weights, cfg) - else: - _make_empty_ = partial(_make_empty, mesh=cfg.mesh) - return jax.tree.map(_make_empty_, jax.eval_shape(stacked_prefill, _numpy_pad_tokens(inputs), weights, cfg)) - - return wrapped + return multistep_decode_fn class ServingLoop: @@ -497,44 +508,43 @@ def __init__( self, serve_cfg: ServingConfig, cfg: Config, - prefill_fn: Callable, + forward_fn: Callable, prefill_weights: Weights, - decode_fn: Callable, + prefill_cache: KVCache, decode_weights: Weights, decode_cache: KVCache, is_server: bool = False, ): + #self.init_cache = init_cache + self.prefill_cache = prefill_cache if not SyncServer.broadcast("welcome", 0, is_server, is_server): raise ValueError("Neither this proccess nor any other processe is the main server, at least one must.") self.serve_cfg, self.cfg = serve_cfg, cfg # setup decode - self.decode_fn, self.decode_weights = decode_fn, decode_weights + self.forward_fn, self.decode_weights = forward_fn, decode_weights self.decode_mesh = [x for x in jax.tree.leaves(decode_weights) if hasattr(x, "sharding")][0].sharding.mesh - with use_mesh(self.decode_mesh): + with set_mesh(self.decode_mesh): self.decode_work = DecodeWork(None, decode_cache, [None for _ in range(serve_cfg.decode_batch_size)]) self.decode_work.curr_tokens = jax.device_put( jnp.zeros((serve_cfg.decode_batch_size, 1), dtype=jnp.int32), P() ) - self.multistep_decode_fn = _make_multistep_decode_fn(self.decode_fn) + self.multistep_decode_fn = _make_multistep_decode_fn(self.forward_fn) self._update_index = jax.jit(lambda x, i, new: x.at[i, ...].set(new[:, None], mode="drop")) def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, kvs, batch_idxs, actual_lens): - length_sort = sorted( - range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2] - ) # sort to minimize variants num - new_cache = decode_cache.insert_sequences( - cache, *[[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] - ) - with use_mesh(self.decode_mesh): - new_curr_tokens = self._update_index(curr_tokens, np.array(batch_idxs), new_tokens) + # sort to minimize variants num + length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) + sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] + new_cache = decode_cache.insert_sequences(cache, *sorted_args) + with set_mesh(self.decode_mesh): + new_curr_tokens = self._update_index(curr_tokens, np.array(batch_idxs), np.array(new_tokens)) return new_cache, new_curr_tokens self._update_cache_and_index = _update_cache_and_index self.decode_output = (None, None) # setup prefill - self.prefill_fn = _make_stacked_prefill(prefill_fn) self.prefill_weights = prefill_weights self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh self.prefill_work = PrefillWork([], [], []) @@ -580,9 +590,8 @@ def decode_step(self): break result: PrefillResult = self.prefill_work.to_decode.pop(0) self.decode_work.active_results[i] = DecodeResult(result.id, result.input.tolist()) - with use_mesh(self.decode_mesh): + with set_mesh(self.decode_mesh): result.cache_entry = result.cache_entry() if callable(result.cache_entry) else result.cache_entry - result.cache_entry = _ensure_all_args_on_mesh(result.cache_entry, mesh=self.decode_mesh) self.results[result.id] = self.decode_work.active_results[i] batch_cache_updates.append((result.cache_entry, i, result.len, result.next_token)) if len(self.prefill_work.to_decode) == 0: @@ -590,9 +599,8 @@ def decode_step(self): if "decode" in self.roles and len(batch_cache_updates) > 0: # batch cache update entries, batch_idxs, lens, next_tokens = map(list, zip(*batch_cache_updates)) entries = [entry.result() if hasattr(entry, "result") else entry for entry in entries] # maybe collect - _control_args = (np.array(next_tokens), entries, batch_idxs, lens) self.decode_work.cache, self.decode_work.curr_tokens = self._update_cache_and_index( - self.decode_work.cache, self.decode_work.curr_tokens, *_control_args + self.decode_work.cache, self.decode_work.curr_tokens, next_tokens, entries, batch_idxs, lens ) if all(x is None for x in self.decode_work.active_results): @@ -600,9 +608,12 @@ def decode_step(self): # 2. run N decode steps output_tokens, output_mapping = [], [] - with use_mesh(self.decode_mesh): - config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps, participate="decode" in self.roles) - (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn( + with set_mesh(self.decode_mesh): + # config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps, participate="decode" in self.roles) + config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) + decode_fn = maybe_call(self.multistep_decode_fn, self.decode_mesh) + #decode_fn = self.multistep_decode_fn + (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = decode_fn( self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config ) output_mapping = [ @@ -636,7 +647,7 @@ def decode_step(self): if id > 0: self.results[id].token_list.append(token) self.results[id].tokens_decoded += 1 - with use_mesh(self.decode_mesh): + with set_mesh(self.decode_mesh): for i, result in enumerate(self.decode_work.active_results): if result is None: continue @@ -644,9 +655,7 @@ def decode_step(self): if done[i]: if USE_PREFIX_CACHE: # store the results in the prefix cache buffer store sequence = np.array(result.token_list) - with use_mesh(self.decode_mesh): - cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) - # self._background.submit(self._insert_prefix, sequence, cache_entry, mesh=self.decode_mesh) + cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) ns = math.ceil(sequence.size / self.serve_cfg.prefix_chunk_size) buffer_ids = BUFFER_STORE._get_unique_buffer_ids(ns) visited_ids, store_ids, del_ids = self._insert_prefix(sequence, buffer_ids) @@ -661,80 +670,84 @@ def decode_step(self): result.done, self.decode_work.active_results[i] = True, None def prefill_step(self): - # 1. check on any finished prefill jobs - if self.prefill_work.pending_prefill is not None: - prefill_is_done, is_source = self.prefill_work.pending_prefill.done(), "prefill_coordinator" in self.roles - prefill_is_done = SyncServer.broadcast("prefill_done", self._it, prefill_is_done, is_source=is_source) - if prefill_is_done: - prefill_input, prefill_results = self.prefill_work.pending_prefill.result() - for i, request in enumerate(prefill_input): - with use_mesh(self.prefill_mesh): - kv_list = self._get_index(prefill_results, i) - id, input = request.id, np.array(request.text) - self.prefill_work.to_decode.append(PrefillResult(id, input, input[-1], kv_list, len(input) - 1)) - self.prefill_work.pending_prefill = None + # 1. prefill requests to be prefilled (do this first to overlap with decode) + prefill_input: list[PrefillJob] = self.prefill_work.to_prefill[: self.serve_cfg.prefill_batch_size] + self.prefill_work.to_prefill = self.prefill_work.to_prefill[len(prefill_input) :] + if len(prefill_input) > 0: + prefill_texts = [job.request.text[job.match_len :] for job in prefill_input] + max_len = max([len(text) for text in prefill_texts]) + inputs = [text + [self.pad_id] * (max_len - len(text)) for text in prefill_texts] + inputs = np.stack([np.array(input) for input in inputs], 0) + row_pad = self.serve_cfg.prefill_batch_size - inputs.shape[0] + col_pad = max(next_power_of_2(inputs.shape[-1]), 64) - inputs.shape[-1] + inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) + + with set_mesh(self.prefill_mesh): + actual_cache_len = np.array(max(job.match_len for job in prefill_input), dtype=np.int32) + self.prefill_cache.iter = actual_cache_len # TODO: make this explictly cache public interface + kvs = [job.cache_entry() if job.cache_entry is not None else None for job in prefill_input] + batch_idxs = np.array([i for i, kv in enumerate(kvs) if kv is not None]) + actual_lens = np.array([job.match_len for kv, job in zip(kvs, prefill_input) if kv is not None]) + kvs = [kv for kv in kvs if kv is not None] + + if len(kvs) > 0: + # sort to minimize variants num + length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) + sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] + insert_sequences = maybe_call(self.prefill_cache.insert_sequences, self.prefill_mesh) + self.prefill_cache = insert_sequences(self.prefill_cache, *sorted_args) + cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) + forward_fn = maybe_call(self.forward_fn, self.prefill_mesh) + _, self.prefill_cache = forward_fn(inputs, self.prefill_weights, self.prefill_cache, cfg) + + with set_mesh(self.prefill_mesh): + for i, job in enumerate(prefill_input): + request = job.request + cache_entry, _ = maybe_call(self._get_cache_entry, self.prefill_mesh)(self.prefill_cache, i) + cache_entry = _ensure_all_args_on_mesh(cache_entry, self.decode_mesh) + sequence = np.array(request.text) + new_decode = PrefillResult( + request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1 + ) + self.prefill_work.to_decode.append(new_decode) # 2. triage requests based on whether they need to go to prefill or there's a cache match, so decode directly while len(self.prefill_work.requests) > 0: request = self.prefill_work.requests.pop(0) sequence = np.array(request.text) (total_match, buffer_ids), visited_ids = self._retrieve_prefix(sequence) + assert total_match <= sequence.size BUFFER_STORE.mark_visited(visited_ids) + _axis = self.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) + buffers = BUFFER_STORE.load(buffer_ids) if total_match == sequence.size: - with use_mesh(self.decode_mesh): - time_axis = self.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) - cache_entry = partial(_concat, BUFFER_STORE.load(buffer_ids), time_axis) + cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.decode_mesh), _axis) new_decode = PrefillResult(request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1) self.prefill_work.to_decode.append(new_decode) print(f"Found a full match") else: - print( - f"Need to prefill the request, only found a match for length {total_match / (len(request.text) - 1)}" - ) - self.prefill_work.to_prefill.append(request) - - if self.prefill_work.pending_prefill is not None: # a current prefill is still running, skip scheduling another - return - - # 3. prefill requests to be prefilled - prefill_input = self.prefill_work.to_prefill[: self.serve_cfg.prefill_batch_size] - self.prefill_work.to_prefill = self.prefill_work.to_prefill[len(prefill_input) :] - if len(prefill_input) > 0: - # disaggregated server via async on a subset of devices - def _prefill_job(): - max_len = max([len(request.text) for request in prefill_input]) - inputs = [[self.pad_id] * (max_len - len(request.text)) + request.text for request in prefill_input] - inputs = np.stack([np.array(input) for input in inputs], 0) - row_pad = self.serve_cfg.prefill_batch_size - inputs.shape[0] - col_pad = next_power_of_2(inputs.shape[-1]) - inputs.shape[-1] - inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) - cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) - with use_mesh(self.prefill_mesh): - _, _, prefill_results = self.prefill_fn( - inputs, self.prefill_weights, cfg, participate="prefill" in self.roles - ) - prefill_results = jax.block_until_ready(prefill_results) - return prefill_input, prefill_results - - self.prefill_work.pending_prefill = self._background.submit(_prefill_job) + print(f"Need to prefill, only found a match for length {total_match / (len(request.text) - 1):.2%}") + print(f"That equals {len(buffer_ids)} buffers or {total_match=}") + if total_match > 0: + cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.prefill_mesh), _axis) + else: + cache_entry = None + self.prefill_work.to_prefill.append(PrefillJob(request, cache_entry, total_match)) def serving_step(self): # this event loop relies on determinism for issuing computation to multiple processes (multi-process JAX) # frequent barriers should keep it in sync # potentially profile when received the request to ######################################### + is_server = "server" in self.roles should_start_profile = self.profile_start_time > 0 and not self.profiling - should_start_profile = SyncServer.broadcast( - "profile", self._it, should_start_profile, is_source="server" in self.roles - ) + should_start_profile = SyncServer.broadcast("profile", self._it, should_start_profile, is_source=is_server) if should_start_profile: self.profile_start_time, self.profiling = time.perf_counter(), True jax.profiler.start_trace("/tmp/online") print("STARTING TRACE") should_stop_profile = self.profile_start_time > 0 and time.perf_counter() - self.profile_start_time > 5.0 - should_stop_profile = SyncServer.broadcast( - "stop_profile", self._it, should_stop_profile, is_source="server" in self.roles - ) + should_stop_profile = SyncServer.broadcast("stop_profile", self._it, should_stop_profile, is_source=is_server) if should_stop_profile: self.profile_start_time, self.profiling = -1, False print("STOPPING TRACE") @@ -762,15 +775,15 @@ def serving_step(self): self.prefill_step() # main event loop work ##################################################################### - # manage cache ############################################################################# - # TODO: test and configure host offloading for the cache - if USE_PREFIX_CACHE and len(self.prefix_cache.children) > 100: # clear the cache after 100 root children - self.new_prefix_cache() - # manage cache ############################################################################# - # offload buffers to keep a max of N ####################################################### - max_buffers = 100 - BUFFER_STORE.offload_buffers(max(0, BUFFER_STORE.livecount - max_buffers)) + BUFFER_STORE.offload_buffers(max(0, BUFFER_STORE.livecount - self.serve_cfg.max_ondevice_buffers)) + extra_buffer_count = max(len(BUFFER_STORE.usecount) - self.serve_cfg.max_buffers, 0) + if extra_buffer_count > 0: + refs_to_delete = sorted(BUFFER_STORE.usecount.keys())[:extra_buffer_count] + deleted_buffers = remove_prefix_nodes(self.prefix_cache, refs_to_delete) + BUFFER_STORE.delete(list(deleted_buffers)) + if len(BUFFER_STORE._store) > self.serve_cfg.max_buffers: + raise ValueError() # offload buffers to keep a max of N ####################################################### def add_request(self, request: UserRequestPrompt): diff --git a/serving/serving_jax/cross_host.py b/serving/serving_jax/cross_host.py index 65f3bc4..c9ee95e 100644 --- a/serving/serving_jax/cross_host.py +++ b/serving/serving_jax/cross_host.py @@ -48,13 +48,16 @@ def transfer_tree_A2B(xs: PyTree, meshA, meshB): xs, xs_struct = jax.tree.flatten(xs) combined_sharding = [NamedSharding(meshC, P("cross_mesh", *x.sharding.spec)) for x in xs] dest_sharding = [NamedSharding(meshB, x.sharding.spec) for x in xs] - dest_arrays = _make_zeros(tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in xs), tuple(dest_sharding)) - all_arrays = [x_src._arrays + x_dest._arrays for x_src, x_dest in zip(_prepare_arrays(xs), dest_arrays)] + with jax.sharding.set_mesh(meshB): + dest_arrays = _make_zeros(tuple(jax.ShapeDtypeStruct(x.shape, x.dtype) for x in xs), tuple(dest_sharding)) + with jax.sharding.set_mesh(meshA): + all_arrays = [x_src._arrays + x_dest._arrays for x_src, x_dest in zip(_prepare_arrays(xs), dest_arrays)] xs_combined = [ jax.make_array_from_single_device_arrays((2,) + x.shape, sharding, arrays, dtype=x.dtype) for (x, arrays, sharding) in zip(xs, all_arrays, combined_sharding) ] - xs_repl = _combine(xs_combined) # issue collectives under jit + with jax.sharding.set_mesh(meshC): + xs_repl = _combine(xs_combined) # issue collectives under jit xs_new = [ jax.make_array_from_single_device_arrays( x_src.shape, sharding, x_new._arrays[len(x_src._arrays) :], dtype=x_src.dtype From d4f5482fae008e840a8f1053a6bc88edab13d810 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Fri, 15 Aug 2025 17:56:33 -0700 Subject: [PATCH 08/11] Import gpt_oss from main, clean up serving files --- gpt_oss/.gitignore | 14 + gpt_oss/README.md | 21 + gpt_oss/gpt_oss_jax/chkpt_utils.py | 282 +++++ gpt_oss/gpt_oss_jax/decode_ragged_dot.py | 394 ++++++ gpt_oss/gpt_oss_jax/model.py | 1156 ++++++++++++++++++ gpt_oss/main.py | 97 ++ gpt_oss/pyproject.toml | 36 + gpt_oss/scripts/convert_weights.py | 70 ++ gpt_oss/scripts/download_model.py | 31 + gpt_oss/scripts/quantize_model.py | 31 + gpt_oss/tests/test_model.py | 112 ++ llama3/llama3_jax/model.py | 39 +- llama3/main.py | 18 +- serving/client_demo.py | 21 +- serving/main_serving_ds_r1.py | 58 +- serving/serving_jax/__init__.py | 788 +----------- serving/serving_jax/attention_cache_utils.py | 71 +- serving/serving_jax/cross_host.py | 5 +- serving/serving_jax/http_server.py | 115 ++ serving/serving_jax/serving_loop.py | 814 ++++++++++++ 20 files changed, 3302 insertions(+), 871 deletions(-) create mode 100644 gpt_oss/.gitignore create mode 100644 gpt_oss/README.md create mode 100644 gpt_oss/gpt_oss_jax/chkpt_utils.py create mode 100644 gpt_oss/gpt_oss_jax/decode_ragged_dot.py create mode 100644 gpt_oss/gpt_oss_jax/model.py create mode 100644 gpt_oss/main.py create mode 100644 gpt_oss/pyproject.toml create mode 100644 gpt_oss/scripts/convert_weights.py create mode 100644 gpt_oss/scripts/download_model.py create mode 100644 gpt_oss/scripts/quantize_model.py create mode 100644 gpt_oss/tests/test_model.py create mode 100644 serving/serving_jax/http_server.py create mode 100644 serving/serving_jax/serving_loop.py diff --git a/gpt_oss/.gitignore b/gpt_oss/.gitignore new file mode 100644 index 0000000..3afd13a --- /dev/null +++ b/gpt_oss/.gitignore @@ -0,0 +1,14 @@ +poetry.lock +scratch/** + +projects/charformer/data/ +projects/bio/data/ + +# Python ignores +__pycache__/ +*.pyc +*.egg-info +build/** + +.venv +.vscode \ No newline at end of file diff --git a/gpt_oss/README.md b/gpt_oss/README.md new file mode 100644 index 0000000..9d1c188 --- /dev/null +++ b/gpt_oss/README.md @@ -0,0 +1,21 @@ +# Minimal OpenAI GPT OSS inference + +**tl;dr: open-source OpenAI GPT OSS inference using JAX, minimal yet performant** + +This model is a work in progress, but it should already work well on both TPU and GPU. + +
+ +This is a pure JAX implementation of OpenAI's GPT OSS for inference, including a +checkpoint converter for the K2 Instruct weights. on TPU. +It should work on GPU. + +The entire model is defined in [model.py](gpt_oss_jax/model.py) and invoked +via [main.py](main.py). + +## Quickstart + +Run: +``` +$ python3 main.py +``` diff --git a/gpt_oss/gpt_oss_jax/chkpt_utils.py b/gpt_oss/gpt_oss_jax/chkpt_utils.py new file mode 100644 index 0000000..f683da9 --- /dev/null +++ b/gpt_oss/gpt_oss_jax/chkpt_utils.py @@ -0,0 +1,282 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import re +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +import dataclasses +from typing import Callable + +import jax +from jax import numpy as jnp +from jax.sharding import PartitionSpec as P +import torch +from tqdm import tqdm + +from gpt_oss_jax import model as gpt_jax + + +def quantize_model(ckpt_path: Path, quant_ckpt_path: Path): + ckpt_path, quant_ckpt_path = Path(ckpt_path).expanduser(), Path(quant_ckpt_path).expanduser() + assert ckpt_path.is_dir() + cfg = gpt_jax.load_config(ckpt_path / "config.json") + mesh = jax.make_mesh((1, jax.device_count(), 1), P("x", "y", "z")) + cfg = dataclasses.replace(cfg, mesh=mesh, quant_moe=True, quant_attn=False) # do not quantize attention + + print("Loading weights...") + weights = gpt_jax.load_pytree( + ckpt_path, gpt_jax.Weights.shardings(dataclasses.replace(cfg, quant_moe=False, quant_attn=False)) + ) + + print("Converting weights...") + quant_layers = [gpt_jax.Layer.quantize(layer, cfg) for layer in tqdm(weights.layers, total=len(weights.layers))] + quant_weights = dataclasses.replace(weights, layers=quant_layers) + + print("Saving weights...") + if quant_ckpt_path.exists(): + shutil.rmtree(quant_ckpt_path) + quant_ckpt_path.parent.mkdir(exist_ok=True) + gpt_jax.save_pytree(quant_weights, quant_ckpt_path) + + additional_files = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "chat_template.json", + "chat_template.jinja", + "generation_config.json", + ] + for file in additional_files: + if (ckpt_path / file).exists(): + shutil.copyfile(ckpt_path / file, quant_ckpt_path / file) + + +# mxfp4 utilities, a reimplementation based on: +# - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +# - https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/mxfp4.py +def e2m1_to_fp(x): + x = torch.as_tensor(x, dtype=torch.int8) + sign = 1 - 2 * ((x >> 3) & 0x01) + exp = 2.0 ** (((x >> 1) & 0x3) - 1) + is_subnormal = (x & 0b111) == 0b001 + m = torch.where(is_subnormal, 1.0, (1 / 2 * (x & 0x1) + 1).to(torch.float32)) + is_zero = (x & 0b111) == 0 + return torch.where(is_zero, 0.0, m * exp * sign) + + +def dequantize_mxfp4(blocks_2x_e2m1: torch.Tensor, scales_e8m0: torch.Tensor, dtype=torch.bfloat16): + scales_e8m0 = torch.as_tensor(scales_e8m0, dtype=torch.float32) + scales = (2.0 ** (scales_e8m0 - 127))[..., None, None] + x = torch.stack([e2m1_to_fp(blocks_2x_e2m1), e2m1_to_fp(blocks_2x_e2m1 >> 4)], -1) * scales + return x.reshape((x.shape[:-3] + (-1,))).to(dtype) + + +is_leaf = lambda x: isinstance(x, gpt_jax.ArrayInfo) +j2t = lambda x: torch.from_dlpack(x) + + +def t2j(x): + try: + prev_level, os.environ["TF_CPP_MIN_LOG_LEVEL"] = os.environ.get("TF_CPP_MIN_LOG_LEVEL", None), "9" + return jnp.from_dlpack(x.detach().contiguous()) + finally: + if prev_level is not None: + os.environ["TF_CPP_MIN_LOG_LEVEL"] = prev_level + + +def _index_to_str(x): + """Convert objects from jax.tree.flatten_with_path to dot separated strings.""" + for field in ["name", "idx", "key"]: + if hasattr(x, field): + return str(getattr(x, field)) + raise ValueError + + +def convert_weight(key: str, value: torch.Tensor, cfg: gpt_jax.Config): + value = value.detach() + # HF checkpoint naming convention ------------------------------------------ + # attention ################################################################ + if re.search(r"q_proj\.weight", key) is not None: + assert value.shape == (cfg.q_heads * cfg.head_dim, cfg.embed) + return t2j(value.T.reshape((cfg.embed, cfg.q_heads, cfg.head_dim))) + elif re.search(r"[kv]_proj\.weight", key) is not None: + assert value.shape == (cfg.kv_heads * cfg.head_dim, cfg.embed) + return t2j(value.T.reshape((cfg.embed, cfg.kv_heads, cfg.head_dim))) + elif re.search(r"o_proj\.weight", key) is not None: + assert value.shape == (cfg.embed, cfg.q_heads * cfg.head_dim) + return t2j(value.T.reshape((cfg.q_heads, cfg.head_dim, cfg.embed))) + elif re.search(r"(k|v)_proj\.bias", key) is not None: + assert value.shape == (cfg.kv_heads * cfg.head_dim,) + return t2j(value.reshape((cfg.kv_heads, cfg.head_dim))) + elif re.search(r"q_proj\.bias", key) is not None: + assert value.shape == (cfg.q_heads * cfg.head_dim,) + return t2j(value.reshape((cfg.q_heads, cfg.head_dim))) + elif re.search(r"o_proj\.bias", key) is not None: + assert value.shape == (cfg.embed,) + return t2j(value) + elif re.search(r"sinks", key) is not None: + assert value.shape == (cfg.head_dim,) + return t2j(value) + # MoE ###################################################################### + elif re.search(r"router\.weight", key) is not None: + assert value.shape == (cfg.moe_num_experts, cfg.embed) + return t2j(value.T) + elif re.search(r"router\.bias", key) is not None: + assert value.shape == (cfg.moe_num_experts,) + return t2j(value) + elif re.search(r"experts\.down_proj_bias", key) is not None: + assert value.shape == (cfg.moe_num_experts, cfg.moe_ffw_size) + return t2j(value) + elif re.search(r"experts\.gate_up_proj_bias", key) is not None: + assert value.shape == (cfg.moe_num_experts, 2 * cfg.moe_ffw_size) + return t2j(value) + elif re.search(r"experts\.down_proj$", key) is not None: + assert value.shape == (cfg.moe_num_experts, cfg.moe_ffw_size, cfg.embed) + return t2j(value) + elif re.search(r"experts\.gate_up_proj$", key) is not None: + assert value.shape == (cfg.moe_num_experts, cfg.embed, 2 * cfg.moe_ffw_size) + return t2j(value) + # misc ##################################################################### + elif re.search(r"embed_tokens", key) is not None: + assert value.shape == (cfg.vocab_size, cfg.embed) + return t2j(value) + elif re.search(r"lm_head", key) is not None: + assert value.shape == (cfg.vocab_size, cfg.embed) + return t2j(value.T) + elif re.search(r"layernorm", key) is not None: + assert value.shape == (cfg.embed,) + return t2j(value) + elif re.search(r"norm", key) is not None: + assert value.shape == (cfg.embed,) + return t2j(value) + else: + raise ValueError(f"Unknown weight {key = }") + + +_HF_KEY_MAPPING = { + r"model\.embed_tokens\.weight": "embedding", + r"model\.norm\.weight": "gamma_final", + r"lm_head\.weight": "lm_head", + # attention projection weights + r"model\.layers\.([0-9]+)\.self_attn\.q_proj\.weight": r"layers.\1.attn.q", + r"model\.layers\.([0-9]+)\.self_attn\.k_proj\.weight": r"layers.\1.attn.k", + r"model\.layers\.([0-9]+)\.self_attn\.v_proj\.weight": r"layers.\1.attn.v", + r"model\.layers\.([0-9]+)\.self_attn\.o_proj\.weight": r"layers.\1.attn.o", + r"model\.layers\.([0-9]+)\.self_attn\.q_proj\.bias": r"layers.\1.attn.q_bias", + r"model\.layers\.([0-9]+)\.self_attn\.k_proj\.bias": r"layers.\1.attn.k_bias", + r"model\.layers\.([0-9]+)\.self_attn\.v_proj\.bias": r"layers.\1.attn.v_bias", + r"model\.layers\.([0-9]+)\.self_attn\.o_proj\.bias": r"layers.\1.attn.o_bias", + r"model\.layers\.([0-9]+)\.self_attn\.sinks": r"layers.\1.attn.sinks", + # layer norms (pre/post attention) + r"model\.layers\.([0-9]+)\.input_layernorm\.weight": r"layers.\1.attn_pre_gamma", + r"model\.layers\.([0-9]+)\.post_attention_layernorm\.weight": r"layers.\1.attn_post_gamma", + # moe router + r"model\.layers\.([0-9]+)\.mlp\.router\.weight": r"layers.\1.ffw.w_router", + r"model\.layers\.([0-9]+)\.mlp\.router\.bias": r"layers.\1.ffw.w_router_bias", + # moe experts + r"model\.layers\.([0-9]+)\.mlp\.experts\.gate_up_proj$": r"layers.\1.ffw.we_gate_up", + r"model\.layers\.([0-9]+)\.mlp\.experts\.gate_up_proj_bias": r"layers.\1.ffw.we_gate_up_bias", + r"model\.layers\.([0-9]+)\.mlp\.experts\.down_proj$": r"layers.\1.ffw.we_down", + r"model\.layers\.([0-9]+)\.mlp\.experts\.down_proj_bias": r"layers.\1.ffw.we_down_bias", +} + + +def _torch_key_to_jax_key(source_key, custom_key_map: dict[str, str] | None = None): + key_maps = dict(_HF_KEY_MAPPING, **(dict() if custom_key_map is None else custom_key_map)) + subs = [re.sub(pat, repl, source_key) for pat, repl in key_maps.items() if re.match(pat, source_key)] + if len(subs) > 1: + raise ValueError(f"More than 1 key matched: {subs}") + else: + return None if len(subs) == 0 else subs[0] + + +def _map_weight(source_key, value: torch.Tensor, custom_transform_map: dict[str, Callable] | None = None): + key_maps = dict(dict(), **(dict() if custom_transform_map is None else custom_transform_map)) + fns = {pat: fn for pat, fn in key_maps.items() if re.match(pat, source_key)} + if len(fns) > 1: + raise ValueError(f"More than 1 key matched: {fns}") + else: + return value if len(fns) == 0 else list(fns.values())[0](value) + + +def convert_model_or_layer( + layer: gpt_jax.Weights | gpt_jax.Layer, + ref_layer: torch.nn.Module, + cfg: gpt_jax.Config, + device: jax.Device | None = None, + sequential: bool = True, + custom_key_map: dict[str, str] | None = None, + custom_transform_map: dict[str, Callable] | None = None, + allow_unconverted_parameters: bool = False, + prefix: str | None = None, +): + device = device if device is not None else jax.devices("cpu")[0] + torch_params = dict(ref_layer.named_parameters() if hasattr(ref_layer, "named_parameters") else ref_layer) + torch_params = {k: v for (k, v) in torch_params.items() if prefix is None or k.startswith(prefix)} + mxfp4_keys = [key for key in torch_params if re.match(r".*(gate_up|down)_proj_(scales|blocks)$", key)] + if len(mxfp4_keys) > 0: + print("Converting mxfp4 weights to bfloat16 for conversion.") + for key in tqdm(mxfp4_keys): + if re.match(r"^.*_scales$", key): + continue + root = re.match(r"^(.*)_blocks$", key)[1] + weight = dequantize_mxfp4(torch_params[root + "_blocks"], torch_params[root + "_scales"]).contiguous() + del torch_params[root + "_blocks"], torch_params[root + "_scales"] + torch_params[root] = weight.transpose(1, 2) + + layer_params = { + ".".join(map(_index_to_str, k)): v for (k, v) in jax.tree.flatten_with_path(layer, is_leaf=is_leaf)[0] + } + new_params = {k: None for k in layer_params.keys()} + + def convert_weight_thread(tkey, tweight): + with jax.default_device(device): + jweight = convert_weight(tkey, _map_weight(tkey, tweight, custom_transform_map=custom_transform_map), cfg) + jkey = _torch_key_to_jax_key(tkey, custom_key_map=custom_key_map) + if jkey is None: + raise ValueError(f"Could not find parameter mapping for torch paramter: `{tkey}`.") + if jkey not in new_params: + raise ValueError(f"The JAX model is not expecting `{jkey}`! Expected keys are {list(new_params.keys())}") + if new_params[jkey] is not None: + raise ValueError(f"Parameter `{jkey}` already set!") + new_params[jkey] = jweight + + if sequential: + for tkey, tweight in torch_params.items(): + convert_weight_thread(tkey, tweight) + else: + futures, executor = [], ThreadPoolExecutor(max_workers=16) + for tkey, tweight in torch_params.items(): + futures.append(executor.submit(convert_weight_thread, tkey, tweight)) + for fut in tqdm(futures, desc="Converting weights"): + fut.result() + + if not allow_unconverted_parameters: + assert all(v is not None for v in new_params.values()), str({k: v for k, v in new_params.items() if v is None}) + for (key, param), new_param in zip(layer_params.items(), new_params.values()): + if param.shape != new_param.shape: + raise ValueError(f"Shape of {key=} does not match, expected = {param.shape}, got {new_param.shape}") + + if isinstance(layer, gpt_jax.Weights): + return jax.tree.unflatten(jax.tree.structure(layer, is_leaf=is_leaf), new_params.values()) + else: + return jax.tree.unflatten( + jax.tree.structure(layer, is_leaf=is_leaf), + [ + new_param if new_param is not None else param + for (new_param, param) in zip(new_params.values(), layer_params.values()) + ], + ) diff --git a/gpt_oss/gpt_oss_jax/decode_ragged_dot.py b/gpt_oss/gpt_oss_jax/decode_ragged_dot.py new file mode 100644 index 0000000..df5ff6e --- /dev/null +++ b/gpt_oss/gpt_oss_jax/decode_ragged_dot.py @@ -0,0 +1,394 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import random as pyrandom +from functools import partial + +import jax +from jax import numpy as jnp +from jax import random +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from tqdm import tqdm + + +def decode_ragged_dot_kernel( + # scalar prefetch + lhs_idx_map_ref, # [g // block_g, n // block_n] + rhs_idx_map_ref, # [g // block_g, n // block_n] + # inputs + x_ref, # [block_n, k] + A_ref, # [block_g, k, m] + group_sizes_ref, # [g] + # outputs + y_ref, # [block_n, m] # hbm scratch output, to-be-reduced over 0-axis + # (scratch) persistent lhs idx + lhs_idx_ref, # [1] + group_id_ref, # [1] + group_size_ref, # [1] + # hyperparameters + block_n: int, + block_g: int, + block_compute: int, + n: int, + g: int, +): + del rhs_idx_map_ref + pid_g, pid_i = pl.program_id(1), pl.program_id(2) # (out column tiles, matrix groups, lhs row tiles) + (_, k), _, m = x_ref.shape, A_ref.shape[0], A_ref.shape[-1] + block_n_id = lhs_idx_map_ref[pid_g, pid_i] + + lhs_idx = jnp.where((pid_g == 0) & (pid_i == 0), 0, lhs_idx_ref[0]) + group_id = jnp.where((pid_g == 0) & (pid_i == 0), 0, group_id_ref[0]) + group_size = jnp.where(pid_i == 0, group_sizes_ref[pid_g * block_g], group_size_ref[0]) + + idx = jnp.maximum(pid_g * lhs_idx_map_ref.shape[-1] + pid_i - 1, 0) + prev_block_n_id = lhs_idx_map_ref[idx // lhs_idx_map_ref.shape[-1], idx % lhs_idx_map_ref.shape[-1]] + is_block_n_new = ((pid_g == 0) & (pid_i == 0)) | (prev_block_n_id != block_n_id) + + @pl.when(is_block_n_new) + def _(): + y_ref[...] = jnp.zeros_like(y_ref) + + # for i in range(lhs_idx // block_compute, n // block_compute): # blockwise over rows in lhs + def outer_body_fn(i, carry): + lhs_idx, group_id, group_size = carry + local_i = i - block_n_id * (block_n // block_compute) + y = y_ref[pl.ds(local_i * block_compute, block_compute), :].astype(jnp.float32) + x = x_ref[pl.ds(local_i * block_compute, block_compute), :] + + # iterate until lhs rows are exhausted or we use up all rhs groups + def cond_fn(val): + y, lhs_idx, group_id, group_size = val + del y, group_size + local_group_id = group_id - pid_g * block_g + return (lhs_idx < (i + 1) * block_compute) & (local_group_id < A_ref.shape[0]) + + def body_fn(val): + y, lhs_idx, group_id, group_size = val + + # check how many valid elements we computed and + els2compute = jnp.maximum(jnp.minimum(group_size, ((i + 1) * block_compute - lhs_idx).astype(jnp.int32)), 0) + group_exhausted = els2compute >= group_size + local_group_id = group_id - pid_g * block_g + + def _compute(): + # compute the actual product with the group_id group + A = A_ref[local_group_id, :, :] + xA = jax.lax.dot_general(x, A, (((1,), (0,)), ((), ())), preferred_element_type=jnp.float32) + xA = xA.astype(y.dtype) + # write to y accumulator masking already computed values + iota = jax.lax.broadcasted_iota(jnp.int32, (block_compute, m), dimension=0) + i * block_compute + mask = (iota >= lhs_idx) & (iota < (lhs_idx + els2compute)) + return jnp.where(mask, xA, y) + + new_y = jax.lax.cond(els2compute > 0, _compute, lambda: y) + + new_group_id = jnp.where(group_exhausted, group_id + 1, group_id) + next_group_size = group_sizes_ref[jnp.clip(pid_g * block_g + local_group_id + 1, max=g - 1)] + new_group_size = jnp.where(group_exhausted, next_group_size, group_size - els2compute) + new_lhs_idx = lhs_idx + els2compute + + return new_y, new_lhs_idx, new_group_id, new_group_size + + y, new_lhs_idx, new_group_id, new_group_size = jax.lax.while_loop( + cond_fn, body_fn, (y, lhs_idx, group_id, group_size) + ) + y_ref[pl.ds(local_i * block_compute, block_compute), :] = y.astype(y_ref.dtype) + return new_lhs_idx, new_group_id, new_group_size + + start_idx = jnp.maximum(lhs_idx, block_n_id * block_n) // block_compute + end_idx = jnp.minimum(n, (block_n_id + 1) * block_n) // block_compute + new_lhs_idx, new_group_id, new_group_size = jax.lax.fori_loop( + start_idx, end_idx, outer_body_fn, (lhs_idx, group_id, group_size) + ) + lhs_idx_ref[0], group_id_ref[0], group_size_ref[0] = new_lhs_idx, new_group_id, new_group_size + + +################################################################################ + + +@partial(jax.jit, static_argnames=("block_n", "block_g", "block_compute", "block_out", "interpret")) +def decode_ragged_dot( + lhs: jax.Array, # [n, k] + rhs: jax.Array, # [g, k, m] + group_sizes: jax.Array, # g[] + block_n: int = int(1e20), # by default replicate activations fully + block_g: int = 2, + block_compute: int = 8, + block_out: int = int(1e20), # by default write full output columns at the same time + interpret: bool = False, +) -> jax.Array: + """Computes y = x @ A, x.shape=[n, k], A.shape=[g, k, m]; rows in x are assigned to groups via `group_sizes`. + + To use a quantized version pass quantized arguments and either pre-scale lhs before or post-scale the result. + + The implementation attempts to maximize HBM BW of rhs by loading batches along g axis. It works most efficiently + when lhs fits into VMEM entirely. Alternatively provide `block_n` splits to split lhs. + THIS REQUIRES [g, n, m] EXTRA HBM SCRATCH. + + Args: + lhs: The input array of shape (n, k) where groups of rows are raggedly assigned to g axis via `group_sizes`. + rhs: The stack of matrices (k, m) [g, k, m]. For example g axis can represent experts. + group_sizes: An array of shape (g,) containing the sizes of each group in the ragged array `A`. + block_n: Splitting rows in x if activations do not fit in VMEM memory - do not use unless necessary. + block_g: The batch of group entries in A to preload at the same time. + block_compute: The compute window moving over dimension n. + block_out: The tiling of the output columns (to manage vmem usage). + interpret: Enable the pallas interpret mode. + + Returns: + The result of the ragged dot product, an array of shape (g, n). + """ + + block_n = min(block_n, lhs.shape[0]) + block_compute = min(block_compute, block_n) + block_out = 128 * pl.cdiv(max(128, min(block_out, rhs.shape[-1])), 128) # min 128, multiple of 128 + assert rhs.ndim == 3 and lhs.ndim == 2, "lhs must have 2 dims, rhs 3 dims" + assert rhs.shape[0] % block_g == 0, f"{block_g=} must divide {rhs.shape[0]=} (# of groups) must divide" + assert block_n % block_compute == 0, f"{block_n = } {block_compute = } {lhs.shape = }" + assert rhs.shape[:1] == group_sizes.shape + (n, k), (g, _, m) = lhs.shape, rhs.shape + + grid = (pl.cdiv(rhs.shape[-1], block_out), g // block_g, n // block_n) + + # compute lhs prefetch map, only increment lhs idx if work is exhausted to avoid revisiting rows in lhs/output + # [[ 0 0 1 1 ] + # [ 1 1 1 1 ] + # [ 1 2 2 2 ] + # [ 3 4 4 4 ]] + work_total = jnp.pad(jnp.cumsum(jnp.sum(group_sizes.reshape((-1, block_g)), -1), axis=-1), (1, 0)) + min_lhs_j = work_total[:-1] // block_n + lhs_idx_map = min_lhs_j[:, None] + jnp.arange(grid[-1])[None, :] + max_lhs_j = jnp.concatenate([min_lhs_j[1:], jnp.array([grid[-1] - 1])]) + lhs_idx_map = jnp.clip(lhs_idx_map, max=jnp.minimum(max_lhs_j[:, None], grid[-1] - 1)) + + # compute rhs prefetch map + # [ 1 1 1 3 4 5 5 5] if 8 groups, but only {1, 3, 4, 5} active + rhs_work_mask = jnp.sum(group_sizes.reshape((-1, block_g)), -1) > 0 + unique_rhs_groups = jnp.sort(jnp.arange(rhs_work_mask.shape[-1]) * rhs_work_mask, descending=True) + flipped_rhs_groups_mapping = jnp.maximum(jnp.cumsum(jnp.flip(rhs_work_mask, axis=-1)) - 1, 0) + rhs_idx_map = jnp.flip(unique_rhs_groups[flipped_rhs_groups_mapping], axis=-1) + + def lhs_prefetch(out_i, i, j, lhs_idx_map_ref, rhs_idx_map_ref): + # as opposed to: `return j, 0` + del rhs_idx_map_ref + return lhs_idx_map_ref[i, j], 0 + + def rhs_prefetch(out_i, i, j, lhs_idx_map_ref, rhs_idx_map_ref): + # as opposed to: `return i, 0, 0` + del j, lhs_idx_map_ref + return rhs_idx_map_ref[i], 0, out_i + + def out_prefetch(out_i, i, j, lhs_idx_map_ref, rhs_idx_map_ref): + # as opposed to: `return j, out_i` + del rhs_idx_map_ref + return lhs_idx_map_ref[i, j], out_i + + in_specs = [ + pl.BlockSpec((block_n, k), lhs_prefetch), + pl.BlockSpec((block_g, rhs.shape[-2], block_out), rhs_prefetch), + pl.BlockSpec((group_sizes.size,), lambda i, j, *_: (0,), memory_space=pltpu.SMEM), + ] + out_specs = pl.BlockSpec((block_n, block_out), out_prefetch) + + scratch_shapes = [ + pltpu.SMEM((1,), dtype=jnp.int32), + pltpu.SMEM((1,), dtype=jnp.int32), + pltpu.SMEM((1,), dtype=jnp.int32), + ] + + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + grid=grid, + in_specs=in_specs, + out_specs=out_specs, + scratch_shapes=scratch_shapes, + ) + out_shape = jax.ShapeDtypeStruct((n, m), dtype=lhs.dtype) + return pl.pallas_call( + partial(decode_ragged_dot_kernel, block_n=block_n, block_g=block_g, block_compute=block_compute, n=n, g=g), + out_shape=out_shape, + grid_spec=grid_spec, + interpret=interpret, + )(lhs_idx_map, rhs_idx_map, lhs, rhs, group_sizes.astype(jnp.int32)) + + +@partial(jax.jit, static_argnames=("block_n", "block_g", "block_compute", "block_out")) +def decode_ragged_dot_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + block_n: int = 64, + block_g: int = 16, + block_out: int = 2**30, + block_compute: int = 8, +) -> jax.Array: + return jax.lax.ragged_dot(lhs, rhs, group_sizes) + + +def test_tune(): + from jax.experimental.layout import Format, Layout + import tune_jax + + tune_jax.CONFIG.allow_fallback_timing = False + tune_jax.logger.setLevel("DEBUG") + seed = 25 + g, n, k, m = 32, 1024, 2880, 1440 + + keys = iter(random.split(random.key(seed), 1024)) + x = random.normal(next(keys), (n, k), dtype=jnp.bfloat16) + A = random.normal(next(keys), (g, k, m), dtype=jnp.bfloat16) + A = A / jnp.linalg.norm(A, axis=-1)[..., None] + A = jnp.round(A * 127).astype(jnp.int8) + A = jax.device_put(A, Format(Layout((0, 1, 2), tiling=((8, 128), (4, 1))), A.sharding)) + + group_sizes = jnp.exp(1e-1 * random.uniform(next(keys), g)) + group_sizes = jnp.round(n * (group_sizes / jnp.sum(group_sizes))).astype(jnp.int32) + while jnp.sum(group_sizes) > n: + idx = jnp.argmax(group_sizes) + group_sizes = group_sizes.at[idx].set(group_sizes[idx] - 1) + while jnp.sum(group_sizes) < n: + idx = jnp.argmax(group_sizes) + group_sizes = group_sizes.at[idx].set(group_sizes[idx] + 1) + + print(jnp.sum(group_sizes)) + print(group_sizes) + assert jnp.sum(group_sizes) <= n + + # place the inputs with optimal shardings so that no copies for data-reformatting are included in tuning + auto_layouts = jax.tree.map(lambda x: Format(Layout.AUTO), (x, A, group_sizes)) + shapes = jax.tree.map(jax.typeof, (x, A, group_sizes)) + opt_shrd = jax.jit(decode_ragged_dot, in_shardings=auto_layouts).lower(*shapes).compile().input_formats[0] + x, A, group_sizes = jax.device_put((x, A, group_sizes), opt_shrd) + + hyperparams = dict( + block_n=[8, 16, 1e20], + block_compute=[4, 8, 16, 32], + block_g=[1, 2, 4, 8], + block_out=[128, 256, 512, 1024, 2048, 4096], + ) + + fn = tune_jax.tune(decode_ragged_dot, hyperparams=hyperparams) + fn(x, A, group_sizes) + print(tune_jax.tabulate(fn)) + + +def test_profile_speed(interpret): + seed = 25 + # n, k, g, m = 32, 128, 64, 256 + # n, k, g, m = 64, 128, 64, 7168 + n, k, g, m = 64, 7168, 64, 128 + + # k, m = m, k + + keys = iter(random.split(random.key(seed), 1024)) + x = random.normal(next(keys), (n, k), dtype=jnp.bfloat16) + A = random.normal(next(keys), (g, k, m), dtype=jnp.bfloat16) + A = A / jnp.linalg.norm(A, axis=-1)[..., None] + A = jnp.round(A * 127).astype(jnp.int8) + + block_g, block_compute, block_n = g // 8, 8, n // 4 + + group_sizes = jnp.exp(1e1 * random.uniform(next(keys), g)) + group_sizes = jnp.round(n * (group_sizes / jnp.sum(group_sizes))).astype(jnp.int32) + + # group_sizes = jnp.zeros(g, dtype=jnp.int32) + # group_sizes = group_sizes.at[7].set(n) + print(group_sizes) + print(group_sizes.reshape((-1, block_g))) + print(jnp.sum(group_sizes)) + assert jnp.sum(group_sizes) <= n + + opts = dict(block_n=block_n, block_g=block_g, block_compute=block_compute, interpret=interpret) + for _ in range(1): + ret = decode_ragged_dot(x, A, group_sizes, **opts).block_until_ready() + ret_ref = decode_ragged_dot_ref(x, A, group_sizes).block_until_ready() + print(f"error = {float(jnp.linalg.norm(ret - ret_ref) / (jnp.linalg.norm(ret_ref) + 1e-5)):.4e}") + rowwise_error = jnp.linalg.norm((ret - ret_ref).astype(jnp.float32), axis=-1) / ( + jnp.linalg.norm(ret_ref.astype(jnp.float32), axis=-1) + 1e-7 + ) + print(f"mean row error = {jnp.mean(rowwise_error):.4e}") + print(f"row-wise error = {rowwise_error}") + print(1 * (jnp.arange(group_sizes.size) < jnp.sum(group_sizes))) + + opts = dict(block_n=block_n, block_g=block_g, block_compute=block_compute) + with jax.profiler.trace(str(Path("~/profiles/decode_ragged2").expanduser())): + for _ in range(3): + ret = decode_ragged_dot(x, A, group_sizes, **opts).block_until_ready() + s = jnp.linalg.norm(ret).block_until_ready() # dummy computation as a profile barrier + + for _ in range(3): + ret = decode_ragged_dot_ref(x, A, group_sizes, **opts).block_until_ready() + s = jnp.linalg.norm(ret).block_until_ready() # dummy computation as a profile barrier + + +######################################################################################################################## + + +def _numeric_test_case(seed, interpret, n, k, g, m, block_g, block_n, block_compute): + keys = iter(random.split(random.key(seed), 1024)) + x = random.normal(next(keys), (n, k), dtype=jnp.bfloat16) + A = random.normal(next(keys), (g, k, m), dtype=jnp.bfloat16) + # A = A / jnp.linalg.norm(A, axis=-1)[..., None] + # A = jnp.round(A * 127).astype(jnp.int8) + + group_sizes = jnp.exp(1e-2 * random.uniform(next(keys), g)) + group_sizes = jnp.round(n * (group_sizes / jnp.sum(group_sizes))).astype(jnp.int32) + while jnp.sum(group_sizes) > n: + idx = jnp.argmax(group_sizes) + group_sizes = group_sizes.at[idx].set(group_sizes[idx] - 1) + assert jnp.sum(group_sizes) <= n + + opts = dict(block_n=block_n, block_g=block_g, block_compute=block_compute) + try: + ret = decode_ragged_dot(x, A, group_sizes, **opts, interpret=interpret).block_until_ready() + except jax.errors.JaxRuntimeError: + return float("nan") + ret_ref = decode_ragged_dot_ref(x, A, group_sizes).block_until_ready() + error = float(jnp.linalg.norm(ret - ret_ref) / (jnp.linalg.norm(ret_ref) + 1e-5)) + return error + + +def test_numerics(): + tests = [ + (seed, n, k, g, m, g // g_splits, n // n_splits, block_compute) + for seed in [0, 1, 2] + for n in [128, 64, 32] + for k in [128, 7168] + for g in [32, 64, 256] + for m in [7168, 128] + for block_compute in [8, 16] + for g_splits in [1, 2, 4, 8] + for n_splits in [1, 2, 4, 8] + if n // n_splits >= block_compute and g // g_splits >= 8 and not (m == 7168 and k == 7168) + ] + pyrandom.shuffle(tests) + + max_error = 0 + it = 0 + for seed, n, k, g, m, block_g, block_n, block_compute in tqdm(tests): + error = _numeric_test_case(seed, False, n, k, g, m, block_g, block_n, block_compute) + error = max(error, max_error) + if max_error > 1e-4: + raise ValueError(f"failing {(seed, n, k, g, m, block_g, block_n, block_compute)=} with {error=:.4}") + if (it + 1) % 100 == 0: + tqdm.write(f"{max_error = :.4e}") + it += 1 + + +if __name__ == "__main__": + # test_numerics() + test_tune() diff --git a/gpt_oss/gpt_oss_jax/model.py b/gpt_oss/gpt_oss_jax/model.py new file mode 100644 index 0000000..650274e --- /dev/null +++ b/gpt_oss/gpt_oss_jax/model.py @@ -0,0 +1,1156 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal model definition.""" + +import dataclasses +import os +import json +from pathlib import Path +import math +from functools import partial, lru_cache +from typing import Callable, Any +from inspect import signature +from collections import OrderedDict as odict + +import jax +import jax.numpy as jnp +from jax import random +from jax import tree_util +from jax.experimental.layout import Format, Layout +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib + +# from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P +from jax.experimental.array_serialization import pytree_serialization as ser + +try: + from jax.experimental.shard import auto_axes as _auto_axes, reshard +except ModuleNotFoundError: + from jax.sharding import auto_axes as _auto_axes, reshard + +from .decode_ragged_dot import decode_ragged_dot + +PAD_ID = 199999 + +AxisName = str | tuple[str, ...] | None +Axes = tuple[AxisName, ...] +AutoTokenizer = Any + +# Expected physical mesh axis names: +# x - batch +# y - 1st of 2D tensor sharding +# z - 2nd of 2D tensor sharding +BATCH_AXIS_NAME = "x" +EXPERT_AXIS_NAME = "z" +TENSOR_ONLY_AXIS_NAME = "y" +ATTN_HEADS_AXIS_NAME = "y" +TENSOR_AXIS_NAME = ("y", "z") + + +@dataclasses.dataclass +class ShardingRules: + """Mapping from logical data axes to physical mesh axes. + + To manage the different shardings in the model, we define the "logical" + dimensions of various arrays (each dimension for each layer's weights, + etc.). Each of these logical axes may then be sharded over a physical mesh + axis, i.e. over multiple devices. For example, any values with a batch + dimension should always be sharded over the batch axis of the mesh. + + Defining the shardings this way allows us to easily try out new sharding + strategies by just changing this mapping. The rest of the code handles + taking this mapping and eventually turning it into the correct JAX shardings + and sharding contraints. + """ + + batch: AxisName = BATCH_AXIS_NAME + sequence: AxisName = None + act_embed: AxisName = None + act_heads: AxisName = None + head_dim: AxisName = None + # attention + qkv_embed: AxisName = None + q_heads: AxisName = ATTN_HEADS_AXIS_NAME + kv_heads: AxisName = ATTN_HEADS_AXIS_NAME + o_heads: AxisName = ATTN_HEADS_AXIS_NAME + o_embed: AxisName = None + # MoE layer + moe_e_experts: AxisName = EXPERT_AXIS_NAME + moe_e_up_embed: AxisName = None + moe_e_up_ffw: AxisName = TENSOR_ONLY_AXIS_NAME + moe_e_down_ffw: AxisName = TENSOR_ONLY_AXIS_NAME + moe_e_down_embed: AxisName = None + moe_e_tp: AxisName = TENSOR_ONLY_AXIS_NAME # moe forward function tensor parallelism + moe_e_ep: AxisName = EXPERT_AXIS_NAME # moe forward function expert parallelism + # vocab + vocab_in: AxisName = None + vocab_out: AxisName = TENSOR_AXIS_NAME + + +def auto_axes(x, out_sharding): # TOOD(rdyro): remove once in JAX >= 0.7.0 + argname = "out_sharding" if "out_sharding" in signature(_auto_axes).parameters else "out_shardings" + return _auto_axes(x, **{argname: out_sharding}) + + +def logical_to_physical(logical: Axes, rules: ShardingRules) -> jax.sharding.PartitionSpec: + """Returns how to physically shard a given sequence of logical array dimensions (i.e. the logical shape of an array).""" + spec = [getattr(rules, axis) if axis is not None else None for axis in logical] + # `spec` may contain tuples, flatten to check that `spec` maps each physical mesh axis to at most one logical array + # axis. + flat_axes = jax.tree.leaves(spec) + if len(set(flat_axes)) != len(flat_axes): + raise ValueError(f"Colliding physical axes from translating logical spec {logical} -> {spec}") + return P(*spec) + + +def logical_to_sharding(logical: Axes, mesh: jax.sharding.Mesh, rules: ShardingRules) -> jax.sharding.Sharding: + """Returns the sharding for a given sequence of logical array dimensions (i.e. the logical shape of an array).""" + assert mesh is not None + return jax.sharding.NamedSharding(mesh, logical_to_physical(logical, rules)) + + +def jax_pytree_struct(cls, meta_fields: tuple = ()): + """jax.tree_util.register_dataclass wrapper that automatically infers data_fields.""" + if not dataclasses.is_dataclass(cls): + cls = dataclasses.dataclass(cls) + all_fields = tuple(f.name for f in dataclasses.fields(cls) if f.init) + data_fields = tuple(f for f in all_fields if f not in meta_fields) + return tree_util.register_dataclass(cls, data_fields=data_fields, meta_fields=meta_fields) + + +jax_static = lambda cls: tree_util.register_static(dataclasses.dataclass(cls)) + + +@jax_static +class Config: + embed: int + q_heads: int + kv_heads: int + num_layers: int + head_dim: int + vocab_size: int + max_seq_len: int + # Attention + causal: bool + sliding_attention_map: list[str] + sliding_window_size: int + # MoE + moe_ffw_size: int + moe_experts_per_tok: int + moe_num_experts: int + moe_gate_up_alpha: float = 1.702 + moe_gate_up_limit: float = 7.0 + moe_gate_dtype: "jnp.dtype" = jnp.float32 + ep_strategy: str = "decode" + # kernel config + use_prefill_attn_kernel: bool = False + use_decode_attn_kernel: bool = False + use_ragged_dot_kernel: bool = True + decode_ragged_dot_tiling: dict[str, int] = dataclasses.field( + default_factory=lambda: {"block_g": 1, "block_n": 2**30, "block_compute": 32, "block_out": 2048} + ) + dtype: "jnp.dtype" = jnp.bfloat16 + norm_eps: float = 1e-5 + # sharding + rules: ShardingRules = dataclasses.field(default_factory=ShardingRules) + mesh: jax.sharding.Mesh | None = None + max_position_embeddings: int = 131072 + rope_theta: float = 500000.0 + rope_factor: float = 32.0 + rope_original_max_position_embeddings: int = 4096 + rope_beta_slow: float = 1.0 + rope_beta_fast: float = 32.0 + quant_moe: bool = False + quant_attn: bool = False # OpenAI doesn't seem to use this, i.e., always False + quant_cache: bool = True + quant_scale_dtype: "jnp.dtype" = jnp.bfloat16 + + +def hf_to_jax_config(hf_config: Any | dict[str, Any]) -> "Config": + _get = lambda x, k, default=None: ( + getattr(x, k, default) if not isinstance(hf_config, dict) else hf_config.get(k, default) + ) + return Config( + embed=_get(hf_config, "hidden_size"), + moe_ffw_size=_get(hf_config, "intermediate_size"), + q_heads=_get(hf_config, "num_attention_heads"), + kv_heads=_get(hf_config, "num_key_value_heads"), + num_layers=_get(hf_config, "num_hidden_layers"), + head_dim=_get(hf_config, "head_dim"), + vocab_size=_get(hf_config, "vocab_size"), + norm_eps=_get(hf_config, "rms_norm_eps"), + moe_experts_per_tok=_get(hf_config, "num_experts_per_tok"), + moe_num_experts=_get(hf_config, "num_local_experts"), + moe_gate_up_alpha=_get(hf_config, "alpha", 1.702), + moe_gate_up_limit=_get(hf_config, "swiglu_limit", 7.0), + max_seq_len=1024, + dtype=jnp.bfloat16, + causal=True, + sliding_attention_map=_get(hf_config, "layer_types"), + sliding_window_size=_get(hf_config, "sliding_window", 128), + use_prefill_attn_kernel=False, + use_decode_attn_kernel=False, + rope_theta=_get(hf_config, "rope_theta"), + ) + + +def load_config(config_path: str | os.PathLike[str] | Path) -> "Config": + return hf_to_jax_config(json.loads(Path(config_path).read_text())) + + +def load_tokenizer(chkpt_path: str | os.PathLike[str] | Path) -> AutoTokenizer: # noqa: F821 + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(chkpt_path) + + +@partial(jax_pytree_struct, meta_fields=("shape", "logical_axes", "initializer")) +@dataclasses.dataclass(frozen=True) +class ArrayInfo: + shape: tuple[int, ...] + dtype: "jnp.dtype" + logical_axes: tuple + initializer: Callable | None = None + + +# module reload friendly isinstance check +is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) +is_param = lambda x: is_type(x, ArrayInfo) +_count_left_padding = lambda ids, pad_id=PAD_ID: auto_axes( + lambda ids: jnp.sum(jnp.cumsum(ids != pad_id, axis=-1) == 0, axis=-1), out_sharding=P(None) +)(ids) +_length_minus_right_padding = lambda segment_ids: auto_axes( + lambda segment_ids: jnp.sum(jnp.cumsum(jnp.flip(segment_ids != 0, -1), axis=-1) > 0, -1), out_sharding=P(None) +)(segment_ids) +which_platform = lambda cfg: cfg.mesh.devices.reshape(-1)[0].platform +# he_normal generates a new lambda every time it's called which defeats caching for `Layer.init`, so cache the lambda +_he_normal = lru_cache(lambda *args, **kw: jax.nn.initializers.he_normal(*args, **kw)) + + +@partial(jax.jit, static_argnames=("abstract", "shardings")) +def _init_leaves(key, abstract, shardings): + @partial(jax.jit, out_shardings=shardings) + def _init_fn(key): + num_leaves = len(jax.tree.leaves(abstract, is_leaf=is_param)) # one new RNG key per tensor + key_iter = iter(random.split(key, num_leaves)) + return jax.tree.map( + lambda info: info.initializer(next(key_iter), info.shape, info.dtype), abstract, is_leaf=is_param + ) + + return _init_fn(key) + + +class _Init: + @classmethod + def abstract(cls, cfg: Config, *args, **kw): + raise NotImplementedError + + @classmethod + def shardings(cls, cfg: Config, *args, **kw): + abstract = cls.abstract(cfg, *args, **kw) + + return jax.tree.map( + lambda info: logical_to_sharding(info.logical_axes, cfg.mesh, cfg.rules), + abstract, + is_leaf=is_param, + ) + + @classmethod + def init(cls, key: random.PRNGKey, cfg: Config, *args, **kw): + """Returns a pytree of randomly-initialized jax.Arrays corresponding to abstract().""" + abstract = cls.abstract(cfg, *args, **kw) + shardings = jax.tree.map( + lambda info: logical_to_sharding(info.logical_axes, cfg.mesh, cfg.rules), abstract, is_leaf=is_param + ) + abstract_leaves, abstract_struct = jax.tree.flatten(abstract, is_leaf=is_param) + shardings_leaves = jax.tree.leaves(shardings, is_leaf=is_param) + return jax.tree.unflatten(abstract_struct, _init_leaves(key, tuple(abstract_leaves), tuple(shardings_leaves))) + + +@partial(jax_pytree_struct, meta_fields=("out_scaling", "scale_expand_dims")) +class QuantArray: + quant: jax.Array | ArrayInfo + scale: jax.Array | ArrayInfo + out_scaling: bool = False + scale_expand_dims: int | tuple[int, ...] = () + shape = property(lambda self: self.quant.shape) + ndim = property(lambda self: self.quant.ndim) + + +def einsum(subscripts: str, lhs: jax.Array, rhs: jax.Array | QuantArray, out_sharding: P | None = None): + """jnp.einsum wrapper that handles regular arrays and QuantArrays""" + if is_type(rhs, QuantArray): + scale = jnp.expand_dims(rhs.scale, rhs.scale_expand_dims) + if rhs.out_scaling: + return jnp.einsum(subscripts, lhs, rhs.quant, out_sharding=out_sharding) * scale + else: + return jnp.einsum(subscripts, lhs * scale, rhs.quant, out_sharding=out_sharding) + else: + return jnp.einsum(subscripts, lhs, rhs, out_sharding=out_sharding) + + +_int8_quant_init = lambda key, shape, dtype=jnp.int8: random.randint(key, shape, -128, 128, dtype=dtype) +_int8_scale_init = lambda key, shape, dtype: random.normal(key, shape, dtype=dtype) / math.sqrt(math.prod(shape)) / 127 + + +def quantize(x: jax.Array | ArrayInfo, axis: int | tuple[int, ...], scale_dtype=jnp.bfloat16, zero_init: bool = False): + if is_type(x, QuantArray): + raise ValueError("Attempting to quantize an already quantized QuantArray.") + if not isinstance(axis, (list, tuple)): + axis = (axis,) + axis = tuple(z % len(x.shape) for z in axis) + + if isinstance(x, jax.Array): + axis = tuple(z % x.ndim for z in axis) + amax = jnp.max(jnp.abs(x), axis=axis, keepdims=True) + scale = (amax / 127.0 + jnp.finfo(scale_dtype).tiny).astype(scale_dtype) + quant = jnp.round(x / scale).astype(jnp.int8) + scale = scale.reshape([z for i, z in enumerate(scale.shape) if i not in axis]) + return quant, scale + + if is_type(x, ArrayInfo): + new_shape = tuple(ax for i, ax in enumerate(x.shape) if i not in axis) + new_logical_axes = tuple(ax for i, ax in enumerate(x.logical_axes) if i not in axis) + if zero_init: + quant_init, scale_init = jax.nn.initializers.zeros, jax.nn.initializers.ones + else: + quant_init, scale_init = _int8_quant_init, _int8_scale_init + quant = dataclasses.replace(x, shape=x.shape, dtype=jnp.int8, initializer=quant_init) + scale = ArrayInfo(new_shape, scale_dtype, new_logical_axes, scale_init) + return quant, scale + raise ValueError(f"quantize got unexpected type: {type(x)}") + + +def update_slice(x: jax.Array | QuantArray, y: jax.Array, pos: int, update_axis: int, quant_axis: int = -1): + """dynamic_update_slice wrapper that handles regular arrays and QuantArrays""" + if is_type(x, QuantArray): + assert x.quant.ndim == y.ndim + quant_axis, update_axis = quant_axis % x.quant.ndim, update_axis % x.quant.ndim # normalize axis numbers + y_quant, y_scale = quantize(y, axis=quant_axis, scale_dtype=x.scale.dtype) # quantize rhs + y_quant = reshard(y_quant.astype(x.quant.dtype), jax.typeof(x.quant).sharding.spec) + y_scale = reshard(y_scale.astype(x.scale.dtype), jax.typeof(x.scale).sharding.spec) + new_quant = jax.lax.dynamic_update_slice_in_dim(x.quant, y_quant, pos, axis=update_axis) + scale_update_axis = [ax for ax in range(x.quant.ndim) if ax != quant_axis][update_axis] + new_scale = jax.lax.dynamic_update_slice_in_dim( + x.scale, y_scale, pos, axis=scale_update_axis + ) # update axis in `scale` + return dataclasses.replace(x, quant=new_quant, scale=new_scale) + else: + assert x.ndim == y.ndim + y = reshard(y.astype(x.dtype), jax.typeof(x).sharding.spec) + return jax.lax.dynamic_update_slice_in_dim(x, y, pos, axis=update_axis) + + +@jax_pytree_struct +class AttentionLayer(_Init): + q: jax.Array | ArrayInfo | QuantArray + q_bias: jax.Array | ArrayInfo | QuantArray + k: jax.Array | ArrayInfo | QuantArray + k_bias: jax.Array | ArrayInfo | QuantArray + v: jax.Array | ArrayInfo | QuantArray + v_bias: jax.Array | ArrayInfo | QuantArray + o: jax.Array | ArrayInfo | QuantArray + o_bias: jax.Array | ArrayInfo | QuantArray + sinks: jax.Array | ArrayInfo | QuantArray + + @classmethod + def abstract(cls, cfg: Config) -> "AttentionLayer": + _init = _he_normal(in_axis=0, out_axis=(1, 2)) + _zero_init = jax.nn.initializers.zeros + layer = AttentionLayer( + q=ArrayInfo((cfg.embed, cfg.q_heads, cfg.head_dim), cfg.dtype, ("qkv_embed", "q_heads", "head_dim"), _init), + q_bias=ArrayInfo((cfg.q_heads, cfg.head_dim), cfg.dtype, ("q_heads", "head_dim"), _zero_init), + k=ArrayInfo( + (cfg.embed, cfg.kv_heads, cfg.head_dim), cfg.dtype, ("qkv_embed", "kv_heads", "head_dim"), _init + ), + k_bias=ArrayInfo((cfg.kv_heads, cfg.head_dim), cfg.dtype, ("kv_heads", "head_dim"), _zero_init), + v=ArrayInfo( + (cfg.embed, cfg.kv_heads, cfg.head_dim), cfg.dtype, ("qkv_embed", "kv_heads", "head_dim"), _init + ), + v_bias=ArrayInfo((cfg.kv_heads, cfg.head_dim), cfg.dtype, ("kv_heads", "head_dim"), _zero_init), + o=ArrayInfo((cfg.q_heads, cfg.head_dim, cfg.embed), cfg.dtype, ("o_heads", "head_dim", "o_embed"), _init), + o_bias=ArrayInfo((cfg.embed,), cfg.dtype, ("o_embed",), _zero_init), + sinks=ArrayInfo((cfg.head_dim,), cfg.dtype, (None,), _zero_init), + ) + layer = cls.quantize(layer, cfg) + return layer + + @staticmethod + def quantize(layer: "AttentionLayer", cfg: Config): + if not cfg.quant_attn: + return layer + scale_dtype = cfg.quant_scale_dtype + return dataclasses.replace( + layer, + q=QuantArray(*quantize(layer.q, 0, scale_dtype), out_scaling=True, scale_expand_dims=-2), + k=QuantArray(*quantize(layer.k, 0, scale_dtype), out_scaling=True, scale_expand_dims=-2), + v=QuantArray(*quantize(layer.v, 0, scale_dtype), out_scaling=True, scale_expand_dims=-2), + o=QuantArray(*quantize(layer.o, (0, 1), scale_dtype), out_scaling=True), + ) + + +@jax_pytree_struct +class MoELayer(_Init): + # router + w_router: jax.Array | ArrayInfo | QuantArray + w_router_bias: jax.Array | ArrayInfo | QuantArray + # experts + we_gate_up: jax.Array | ArrayInfo | QuantArray + we_gate_up_bias: jax.Array | ArrayInfo | QuantArray + we_down: jax.Array | ArrayInfo | QuantArray + we_down_bias: jax.Array | ArrayInfo | QuantArray + + @classmethod + def abstract(cls, cfg: Config): + _einit, _sinit = _he_normal(in_axis=0, out_axis=(1, 2)), _he_normal(in_axis=0, out_axis=1) + _zero_init = jax.nn.initializers.zeros + dtype = cfg.dtype + layer = MoELayer( + w_router=ArrayInfo((cfg.embed, cfg.moe_num_experts), cfg.moe_gate_dtype, ("moe_e_up_embed", None), _sinit), + w_router_bias=ArrayInfo((cfg.moe_num_experts,), cfg.moe_gate_dtype, (None,), _zero_init), + we_gate_up=ArrayInfo( + (cfg.moe_num_experts, cfg.embed, 2 * cfg.moe_ffw_size), + dtype, + ("moe_e_experts", "moe_e_up_embed", "moe_e_up_ffw"), + _einit, + ), + we_gate_up_bias=ArrayInfo( + (cfg.moe_num_experts, 2 * cfg.moe_ffw_size), dtype, ("moe_e_experts", "moe_e_up_ffw"), _zero_init + ), + we_down=ArrayInfo( + (cfg.moe_num_experts, cfg.moe_ffw_size, cfg.embed), + dtype, + ("moe_e_experts", "moe_e_down_ffw", "moe_e_down_embed"), + _einit, + ), + we_down_bias=ArrayInfo( + (cfg.moe_num_experts, cfg.embed), dtype, ("moe_e_experts", "moe_e_down_embed"), _zero_init + ), + ) + layer = cls.quantize(layer, cfg) + return layer + + @staticmethod + def quantize(layer: "MoELayer", cfg: Config): + if not cfg.quant_moe: + return layer + scale_dtype = cfg.quant_scale_dtype + return dataclasses.replace( + layer, + we_gate_up=QuantArray(*quantize(layer.we_gate_up, 1, scale_dtype), out_scaling=True), + we_down=QuantArray(*quantize(layer.we_down, 1, scale_dtype), out_scaling=True), + ) + + +@jax_pytree_struct +class Layer(_Init): + ffw: MoELayer + attn: AttentionLayer + attn_pre_gamma: jax.Array | ArrayInfo + attn_post_gamma: jax.Array | ArrayInfo + + ######################################################################################################################## + @classmethod + def abstract(cls, cfg: Config, layer_idx: int) -> "Layer": + layer = Layer( + ffw=MoELayer.abstract(cfg), + attn=AttentionLayer.abstract(cfg), + attn_pre_gamma=ArrayInfo((cfg.embed,), cfg.dtype, ("act_embed",), jax.nn.initializers.constant(1.0)), + attn_post_gamma=ArrayInfo((cfg.embed,), cfg.dtype, ("act_embed",), jax.nn.initializers.constant(1.0)), + ) + # layer = cls.quantize(layer, cfg) # abstract already quantized + return layer + + @staticmethod + def quantize(layer: "Layer", cfg: Config): + return dataclasses.replace( + layer, ffw=layer.ffw.quantize(layer.ffw, cfg), attn=layer.attn.quantize(layer.attn, cfg) + ) + + +@jax_pytree_struct +class Weights(_Init): + layers: list[Layer] + embedding: jax.Array | ArrayInfo + gamma_final: jax.Array | ArrayInfo + lm_head: jax.Array | ArrayInfo + + @classmethod + def abstract(cls, cfg: Config): + layers = [Layer.abstract(cfg, layer_idx) for layer_idx in range(cfg.num_layers)] + init01, init10 = _he_normal(in_axis=0, out_axis=1), _he_normal(in_axis=1, out_axis=0) + return Weights( + layers=layers, + embedding=ArrayInfo((cfg.vocab_size, cfg.embed), cfg.dtype, ("vocab_in", "vocab_in"), init01), + gamma_final=ArrayInfo((cfg.embed,), cfg.dtype, ("act_embed",), jax.nn.initializers.constant(1.0)), + lm_head=ArrayInfo((cfg.embed, cfg.vocab_size), cfg.dtype, ("vocab_in", "vocab_out"), init10), + ) + + +@partial(jax_pytree_struct, meta_fields=["time_axis", "size", "get_sequence", "insert_sequences"]) +class KVCache(_Init): + k: list[jax.Array] # (batch_size, key_heads, max_seq_len, head_dim) + v: list[jax.Array] # (batch_size, key_heads, max_seq_len, head_dim) + iter: jax.Array # [] # sequences are right-aligned for slice udpate performance + starts: jax.Array # [batch_size] # sequences are right-aligned, we need start indices + time_axis: int = 2 + size: int = -1 + get_sequence: Callable | None = None + insert_sequences: Callable | None = None + + @classmethod + def abstract(cls, cfg: Config, batch_size: int, max_seq_len: int): + val_info = ArrayInfo( + (batch_size, cfg.kv_heads, max_seq_len, cfg.head_dim), + cfg.dtype, + ("batch", "kv_heads", "sequence", "head_dim"), + jax.nn.initializers.zeros, + ) + cache = KVCache( + k=[val_info for _ in range(cfg.num_layers)], + v=[val_info for _ in range(cfg.num_layers)], + iter=ArrayInfo((), jnp.int32, (), jax.nn.initializers.constant(-1)), + starts=ArrayInfo((batch_size,), jnp.int32, ("batch",), jax.nn.initializers.zeros), + size=max_seq_len, + ) + if cfg.quant_cache: + _quantize = partial(quantize, axis=-1, scale_dtype=cfg.quant_scale_dtype, zero_init=True) + cache = dataclasses.replace( + cache, + k=[ + QuantArray(*_quantize(cache.k[idx]), out_scaling=True, scale_expand_dims=(-2, -3)) + for idx in range(len(cache.k)) + ], + v=[ + QuantArray(*_quantize(cache.v[idx]), out_scaling=False, scale_expand_dims=(-2, -3)) + for idx in range(len(cache.v)) + ], + ) + return cache + + def fill_len(self) -> jax.Array: + return jnp.where(self.iter >= 0, (self.iter - self.starts) % self.size, 0) + + @property + def buffers(self) -> tuple[jax.Array | QuantArray, ...]: + return (self.k, self.v) + + +def segment_ids_to_positions(segment_ids): + """Counts positions for segment ids.""" + + def scan_fun(a, b): + return ((a[0] + 1) * (a[1] == b[1]) + b[0], b[1]) + + vals = (jnp.zeros_like(segment_ids), segment_ids) + return jnp.array(jax.lax.associative_scan(scan_fun, vals, axis=-1)[0], dtype="int32") + + +def _generate_pos_embeddings(positions: jax.Array, features: int, cfg: Config) -> tuple[jax.Array, jax.Array]: + """Yarn rope""" + base, factor = cfg.rope_theta, cfg.rope_factor + original_max_pos = cfg.rope_original_max_position_embeddings + low = (features * math.log(original_max_pos / (cfg.rope_beta_fast * 2 * math.pi))) / (2 * math.log(base)) + high = (features * math.log(original_max_pos / (cfg.rope_beta_slow * 2 * math.pi))) / (2 * math.log(base)) + low, high = max(low, 0), min(high, features - 1) + + timescale = base ** (jnp.arange(0, features, 2, dtype=jnp.float32) / features) + rot_freq_extra, rot_freq_inter = 1.0 / timescale, 1.0 / (factor * timescale) + + high = high if low != high else (high + 0.001) + interp_factor = 1 - jnp.clip((jnp.arange(features // 2, dtype=jnp.float32) - low) / (high - low), min=0, max=1) + + rotational_frequency = rot_freq_inter * (1 - interp_factor) + rot_freq_extra * interp_factor + # Must use high precision einsum here, since rounding off to a bfloat16 is catastrophic. bfloat16 rounds 257 to 256, + # but sin(257) is very different from sin(256). + sinusoid_inp = jnp.einsum( + "BT,k->BTk", + positions, + rotational_frequency, + precision=jax.lax.Precision.HIGHEST, + out_sharding=P(None, None, None), + ) + + m_scale = 1.0 + attention_scaling = 1.0 if factor <= 1 else (0.1 * m_scale * math.log(factor) + 1.0) + return jnp.sin(sinusoid_inp) * attention_scaling, jnp.cos(sinusoid_inp) * attention_scaling + + +def apply_rotary_embedding(x: jax.Array, sin: jax.Array, cos: jax.Array) -> jax.Array: + assert x.ndim == 4 and sin.ndim == 3 and cos.ndim == 3 + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + sin, cos = sin[:, None, :, :], cos[:, None, :, :] # [B, T, head_dim] -> [B, h, T, head_dim] + return jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + + +def make_attention_mask( + q_len, k_len, q_segment_ids, kv_segment_ids, q_offset, kv_offset, causal: bool, sliding_window: int | None = None +): + segment_mask = (q_segment_ids[:, :, None] == kv_segment_ids[:, None, :])[:, None, :, :] # [B, 1, t, T] + segment_mask &= (q_segment_ids != 0)[:, None, :, None] & (kv_segment_ids != 0)[:, None, None, :] + if causal: + qk = (1, 1, q_len, k_len) # [b, h, t, T] + q_positions = jax.lax.broadcasted_iota(jnp.int32, qk, 2) + q_offset[:, None, None, None] + kv_positions = (jax.lax.broadcasted_iota(jnp.int32, qk, 3) + kv_offset[:, None, None, None]) % k_len + causal_mask = q_positions >= kv_positions + if sliding_window is not None: + causal_mask &= q_positions < (kv_positions + sliding_window) + return segment_mask & causal_mask + return segment_mask + + +@partial(auto_axes, out_sharding=P(BATCH_AXIS_NAME, ATTN_HEADS_AXIS_NAME, None, None)) +def attention( + q: jax.Array, + k: jax.Array | tuple[jax.Array, jax.Array], + v: jax.Array | tuple[jax.Array, jax.Array], + sinks: jax.Array, + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, + q_offset: jax.Array, + kv_offset: jax.Array, + starts: jax.Array, + lengths: jax.Array, + *, + sliding_window: int | None = None, + cfg: Config, +) -> jax.Array: + """ + Compute attention. + + Args: + q: Query tensor of shape (batch_size, num_heads, q_len, head_dim) + k: Key tensor of shape (batch_size, num_heads, k_len, head_dim) + v: Value tensor of shape (batch_size, num_heads, k_len, head_dim) + q_segment_ids: Query segment IDs of shape (batch_size, q_len) + k_segment_ids: Key segment IDs of shape (batch_size, k_len) + q_offset: Query offset of shape (batch_size,) + cfg: Configuration object + + Returns: + Attention output of shape (batch_size, num_heads, q_len, head_dim) + """ + del starts, lengths + + scale = cfg.head_dim**-0.5 + + # grouped-query attention + b, qh, t, d = q.shape + _, kh, T, _ = k.shape + + q_ = q.reshape((b, kh, qh // kh, t, d)) + qk = einsum("bhgtd,bhTd->bhgtT", q_, k) * scale + qk = qk.reshape((b, qh, t, T)) + + mask = make_attention_mask(t, T, q_segment_ids, kv_segment_ids, q_offset, kv_offset, cfg.causal, sliding_window) + + # Apply the combined mask + qk = jnp.where(mask, qk, -1e30).astype(jnp.float32) + # attn = jax.nn.softmax(qk.astype(jnp.float32), axis=-1) + qk_max = jnp.maximum(jnp.max(qk, axis=-1, keepdims=True), sinks[..., None, None]) + exp = jnp.exp(qk - qk_max) + attn = exp / (jnp.sum(exp, axis=-1, keepdims=True) + jnp.exp(sinks[..., None, None] - qk_max)) + + # grouped-query attention + attn_ = attn.reshape((b, kh, qh // kh, t, T)) + qkv = einsum("bhgtT,bhTd->bhgtd", attn_, v).astype(cfg.dtype) + return qkv.reshape((b, qh, t, d)) + + +def attention_kernel(q, k, v, q_segment_ids, kv_segment_ids, q_offset, kv_offset, starts, lengths, cfg: Config): + """Flash attention kernel!""" + + # On TPUv3, pallas seems to only work with float32. + # q, k, v = jnp.float32(q), jnp.float32(k), jnp.float32(v) + + k, k_scale = (k.quant, k.scale) if is_type(k, QuantArray) else (k, None) + v, v_scale = (v.quant, v.scale) if is_type(v, QuantArray) else (v, None) + + # handle grouped query attention + assert q.shape[-3] % k.shape[-3] == 0 + scale = q.shape[-1] ** -0.5 + + l2p = lambda *logical: logical_to_physical(logical, cfg.rules) + + kv_repeats = q.shape[-3] // k.shape[-3] + q_spec = P( + *(l2p("batch", "kv_heads") + tuple(set(*l2p("q_heads")) - set(*l2p("kv_heads"))) + l2p("sequence", "head_dim")) + ) + q_shape__ = q.shape + q = jax.lax.reshape(q, (q.shape[:-3] + (k.shape[-3], kv_repeats, q.shape[-2], q.shape[-1])), out_sharding=q_spec) + + # shard_map + in_specs = ( + q_spec, + l2p("batch", "kv_heads", "sequence", "head_dim"), + l2p("batch", "kv_heads", "sequence", "head_dim"), + l2p("batch", "sequence"), + l2p("batch", "sequence"), + l2p("batch"), + l2p("batch"), + ) + in_specs += (None if k_scale is None else l2p("batch", "kv_heads", "sequence"),) + in_specs += (None if v_scale is None else l2p("batch", "kv_heads", "sequence"),) + out_specs = q_spec + + @partial(jax.shard_map, mesh=cfg.mesh, in_specs=in_specs, out_specs=out_specs, check_vma=False) + def _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale): + q_org_shape = q.shape + + if q.shape[-2] != 1: + mask = mask_lib.MultiHeadMask([mask_lib.CausalMask((q.shape[-2], k.shape[-2])) for _ in range(q.shape[-3])]) + block_q, block_kv = min(q.shape[-2], 512), min(k.shape[-2], 1024) + block_sizes = splash.BlockSizes(block_q=block_q, block_kv=block_kv, block_kv_compute=block_kv) + attn_fn = splash.make_splash_mqa_single_device(mask=mask, block_sizes=block_sizes) + attn_fn = jax.vmap(jax.vmap(attn_fn, in_axes=(0, 0, 0, None)), in_axes=(0, 0, 0, 0)) + + segment_ids = splash.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + if k_scale is not None: + k = (k * k_scale[..., None]).astype(jnp.bfloat16) + if v_scale is not None: + v = (v * v_scale[..., None]).astype(jnp.bfloat16) + ret = attn_fn(q * scale, k, v, segment_ids) + else: + assert q.shape[-2] == 1, "This is a decode kernel, q.shape[-2] must be 1" + q = q[..., 0, :] + in_axes = (1, 1, 1, None, None) + in_axes += ((None if k_scale is None else 1),) + in_axes += ((None if v_scale is None else 1),) + hyperparams = dict(scale=scale, block_kv=512, block_bs=32) + raise NotImplementedError + # ret = jax.vmap(partial(ragged_attention.ragged_decode_fwd, **hyperparams), in_axes=in_axes, out_axes=1)( + # q, k, v, starts, lengths, k_scale, v_scale + # ) + return ret.reshape(q_org_shape) + + lengths = jnp.broadcast_to(lengths, starts.shape) + ret = _f(q, k, v, q_segment_ids, kv_segment_ids, starts, lengths, k_scale, v_scale).astype(jnp.bfloat16) + return jax.lax.reshape(ret, q_shape__, out_sharding=l2p("batch", "q_heads", "sequence", "head_dim")) + + +def rms_norm(x: jax.Array, gamma: jax.Array | None, eps: jax.Array | float) -> jax.Array: + """Apply RMS normalization.""" + rms = jnp.sqrt(jnp.mean(jnp.astype(x, jnp.float32) ** 2, axis=-1, keepdims=True) + eps) + return jnp.astype((gamma if gamma is not None else 1) * x / rms, jnp.bfloat16) + + +def attention_block( + x: jax.Array, + segment_ids: jax.Array, + layer: AttentionLayer, + sin: jax.Array, + cos: jax.Array, + cfg: Config, + cache: KVCache | None = None, + idx: int | None = None, +): + assert idx is not None + l2p = lambda *specs: logical_to_physical(specs, cfg.rules) + x = x.astype(cfg.dtype) + + # Multi-head attention + with jax.named_scope("qkv_matmul"): + q = (einsum("btd,dhq->bhtq", x, layer.q) + layer.q_bias[:, None, :]).astype(cfg.dtype) + k = (einsum("btd,dhq->bhtq", x, layer.k) + layer.k_bias[:, None, :]).astype(cfg.dtype) + v = (einsum("btd,dhq->bhtq", x, layer.v) + layer.v_bias[:, None, :]).astype(cfg.dtype) + + # Apply rotary embeddings + with jax.named_scope("rope"): + q, k = apply_rotary_embedding(q, sin, cos), apply_rotary_embedding(k, sin, cos) + + with jax.named_scope("cache_update"): + if is_type(cache, KVCache): + it = jnp.maximum(cache.iter, 0) + k = update_slice(cache.k[idx], k, it, update_axis=cache.time_axis, quant_axis=-1) + v = update_slice(cache.v[idx], v, it, update_axis=cache.time_axis, quant_axis=-1) + cache_updates = (k, v) + + # create position embeddings + additional_tokens = jnp.max(_length_minus_right_padding(segment_ids)) + time_indices = (jnp.arange(0, v.shape[-2])[None, :] - cache.starts[:, None]) % cache.size + q_segment_ids = jnp.where(segment_ids != 0, 1, 0) + kv_segment_ids = (time_indices >= 0) & (time_indices < cache.fill_len()[:, None] + additional_tokens) + q_offset = cache.fill_len() - _count_left_padding(q_segment_ids, pad_id=0) # pad_id=0 for segment_ids + kv_offset = -cache.starts + starts, lengths = cache.starts, cache.fill_len() + additional_tokens + else: + q_segment_ids, kv_segment_ids = segment_ids, segment_ids + starts = _count_left_padding(kv_segment_ids, 0) # pad_id=0 for segment_ids + lengths = _length_minus_right_padding(kv_segment_ids) + q_offset, kv_offset = -starts, -starts + cache_updates = (k, v) + sliding_window = None if ("sliding" not in cfg.sliding_attention_map[idx]) else cfg.sliding_window_size + + # Compute attention + with jax.named_scope("attention"): + attn_args = (q, k, v, layer.sinks, q_segment_ids, kv_segment_ids, q_offset, kv_offset, starts, lengths) + if (cfg.use_prefill_attn_kernel and q.shape[-2] != 1) or (cfg.use_decode_attn_kernel and q.shape[-2] == 1): + raise NotImplementedError + attn_out = attention_kernel(*attn_args, cfg=cfg, sliding_window=sliding_window) + else: + attn_out = attention(*attn_args, cfg=cfg, sliding_window=sliding_window) + + # Project attention output + with jax.named_scope("projection"): + attn_out = ( + einsum("bhtq,hqd->btd", attn_out, layer.o, out_sharding=l2p("batch", "sequence", "act_embed")) + + layer.o_bias + ).astype(cfg.dtype) + return attn_out, cache_updates + + +@partial(jax.jit, static_argnames=("replicated_routing",)) +def _route_tokens_to_experts(x: jax.Array, weight: jax.Array, bias: jax.Array, replicated_routing: bool, cfg: Config): + lsc = lambda x, spec: reshard(x, logical_to_physical(spec, cfg.rules)) + x_shape = x.shape + x = x.reshape((-1, x.shape[-1])) + # not distributing the routing work avoids communication for small batches + x = lsc(x, (None, None)) if replicated_routing else reshard(x, P(TENSOR_AXIS_NAME, None)) + weight, bias = lsc(weight, (None, None)), lsc(bias, (None,)) + scores = (jnp.einsum("Sk,kj->Sj", x, weight) + bias).astype(jnp.float32) + topk_weights, topk_idx = jax.lax.top_k(scores, cfg.moe_experts_per_tok) + topk_weights = jax.nn.softmax(topk_weights, axis=-1) + topk_weights = lsc(topk_weights, (None, None)).reshape(x_shape[:-1] + (cfg.moe_experts_per_tok,)) + topk_idx = lsc(topk_idx, (None, None)).reshape(x_shape[:-1] + (cfg.moe_experts_per_tok,)) + return topk_weights, topk_idx + + +def _moe_gmm(lhs, rhs, group_sizes, topk_idx, cfg: Config): + assert lhs.ndim == 2 and rhs.ndim == 3, f"{lhs.ndim=} != 2 and {rhs.ndim=} != 3" + group_sizes = group_sizes.astype(jnp.int32) + if cfg.use_ragged_dot_kernel and which_platform(cfg) == "tpu": + with jax.named_scope("decode_ragged_dot"): + if is_type(rhs, QuantArray): + assert rhs.scale.ndim == 2 and rhs.scale.shape == (rhs.quant.shape[0], rhs.quant.shape[2]) + scale = jnp.take_along_axis(rhs.scale, topk_idx[:, None], axis=-2) + ret = decode_ragged_dot(lhs, rhs.quant, group_sizes, **cfg.decode_ragged_dot_tiling) + ret = ret * scale + else: + ret = decode_ragged_dot(lhs, rhs, group_sizes, **cfg.decode_ragged_dot_tiling) + else: + with jax.named_scope("jax.lax.ragged_dot"): + if is_type(rhs, QuantArray): + assert rhs.scale.ndim == 2 and rhs.scale.shape == (rhs.quant.shape[0], rhs.quant.shape[2]) + scale = jnp.take_along_axis(rhs.scale, topk_idx[:, None], axis=-2) + ret = jax.lax.ragged_dot(lhs, rhs.quant, group_sizes) * scale + else: + ret = jax.lax.ragged_dot(lhs, rhs, group_sizes) + return ret.astype(cfg.dtype) + + +def moe_block(x: jax.Array, layer: MoELayer, cfg: Config): + assert x.ndim == 3 + l2p = lambda *axes: logical_to_physical(axes, cfg.rules) + _psc = lambda z, spec: reshard(z, P(*spec)) + _qpsc = lambda z, spec: dataclasses.replace(z, quant=_psc(z.quant, spec.quant), scale=_psc(z.scale, spec.scale)) + psc = lambda z, spec: _qpsc(z, spec) if is_type(z, QuantArray) else _psc(z, spec) + + # we're decoding or device count does not divide total token count + replicated_routing = x.shape[-2] == 1 or (x.shape[-2] * x.shape[-3]) % jax.device_count() != 0 + topk_weights, topk_idx = _route_tokens_to_experts(x, layer.w_router, layer.w_router_bias, replicated_routing, cfg) + tensor_axname, expert_axname = l2p("moe_e_tp")[0], l2p("moe_e_ep")[0] + + x_spec = l2p("batch", "sequence", None) + topk_weights_spec, topk_idx_spec = l2p("batch", "sequence", None), l2p("batch", "sequence", None) + out_spec = l2p("batch", "sequence", None) + + we_gate_up_spec, we_gate_up_bias_spec = l2p("moe_e_ep", None, "moe_e_tp"), l2p("moe_e_ep", "moe_e_tp") + we_down_spec, we_down_bias_spec = l2p("moe_e_ep", "moe_e_tp", None), l2p("moe_e_ep", None) + if all(is_type(z, QuantArray) for z in [layer.we_gate_up, layer.we_down]): + we_gate_up_spec = dataclasses.replace( + layer.we_gate_up, quant=we_gate_up_spec, scale=P(we_gate_up_spec[0], we_gate_up_spec[2]) + ) + we_down_spec = dataclasses.replace(layer.we_down, quant=we_down_spec, scale=P(we_down_spec[0], we_down_spec[2])) + we_gate_up = psc(layer.we_gate_up, we_gate_up_spec) + we_gate_up_bias = psc(layer.we_gate_up_bias, we_gate_up_bias_spec) + we_down = psc(layer.we_down, we_down_spec) + we_down_bias = psc(layer.we_down_bias, we_down_bias_spec) + + in_specs = ( + x_spec, + we_gate_up_spec, + we_gate_up_bias_spec, + we_down_spec, + we_down_bias_spec, + topk_weights_spec, + topk_idx_spec, + ) + + is_embedding_sharded = l2p("act_embed")[0] is not None + if is_embedding_sharded: # activations are sharded + out_spec = P(*(out_spec[:-1] + (tensor_axname,))) # override last axis name + if cfg.ep_strategy == "prefill": + out_spec = P(*(out_spec[:-1] + (tensor_axname,))) # override last axis name + + expert_count = cfg.mesh.axis_sizes[cfg.mesh.axis_names.index(expert_axname)] if expert_axname is not None else 1 + tensor_count = cfg.mesh.axis_sizes[cfg.mesh.axis_names.index(tensor_axname)] if tensor_axname is not None else 1 + assert cfg.moe_num_experts % expert_count == 0 + expert_size = cfg.moe_num_experts // expert_count + + @partial(jax.shard_map, mesh=cfg.mesh, in_specs=in_specs, out_specs=out_spec, check_vma=False) + def _expert_fn(x, we_gate_up, we_gate_up_bias, we_down, we_down_bias, topk_weights, topk_idx): + (b, s, d), e = x.shape, cfg.moe_experts_per_tok + expert_idx = jax.lax.axis_index(expert_axname) if expert_axname is not None else 0 + tensor_idx = jax.lax.axis_index(tensor_axname) if tensor_axname is not None else 0 + topk_idx_ = topk_idx.reshape(-1) + valid_group_mask_ = (topk_idx_ >= expert_size * expert_idx) & (topk_idx_ < expert_size * (expert_idx + 1)) + expert_mapped_topk_idx_ = jnp.where(valid_group_mask_, topk_idx_ - expert_idx * expert_size, 2**30) + + sort_idx_ = jnp.argsort(expert_mapped_topk_idx_, axis=-1) # [b * s * e] + isort_idx_ = jnp.argsort(sort_idx_) + + if cfg.ep_strategy == "prefill": + truncate_size = round(2 * sort_idx_.size / expert_count) + sort_idx_, isort_idx_ = sort_idx_[:truncate_size], isort_idx_[:truncate_size] + + topk_idx_sort_ = topk_idx_[sort_idx_] # [b * s * e] + expert_mapped_topk_idx_sort_ = expert_mapped_topk_idx_[sort_idx_] + valid_group_mask_sort_ = expert_mapped_topk_idx_sort_ < 2**30 + expert_mapped_topk_idx_sort_ = jnp.where(expert_mapped_topk_idx_sort_ < 2**30, expert_mapped_topk_idx_sort_, 0) + + # equivalent to: + # ``` + # x_repeat_ = jnp.repeat(x.reshape((-1, x.shape[-1])), e, axis=0) + # x_repeat_sort_ = jnp.take_along_axis(x_repeat_, sort_idx_[:, None], axis=-2) # [b * s, d] + # ``` + x_repeat_sort_ = jnp.take_along_axis(x.reshape((-1, x.shape[-1])), sort_idx_[:, None] // e, axis=-2) + # [b * s * e, d] # "// e" is an index trick to avoid jnp.repeat + + group_sizes = jnp.bincount(topk_idx_sort_, length=cfg.moe_num_experts) + group_sizes_shard = jax.lax.dynamic_slice_in_dim(group_sizes, expert_idx * expert_size, expert_size, 0) + + with jax.named_scope("we_gate"): + ff_gate_up = _moe_gmm(x_repeat_sort_, we_gate_up, group_sizes_shard, expert_mapped_topk_idx_sort_, cfg) + ff_gate_up = ff_gate_up + we_gate_up_bias[expert_mapped_topk_idx_sort_, :] + ff_gate = jnp.clip(ff_gate_up[..., ::2], max=cfg.moe_gate_up_limit) + ff_up = jnp.clip(ff_gate_up[..., 1::2], min=-cfg.moe_gate_up_limit, max=cfg.moe_gate_up_limit) + ff_gate_up = (ff_up + 1) * (ff_gate * jax.nn.sigmoid(ff_gate * cfg.moe_gate_up_alpha)) + ff_gate_up = jnp.where(valid_group_mask_sort_[..., None], ff_gate_up, 0) + with jax.named_scope("we_down"): + ff_out = _moe_gmm(ff_gate_up, we_down, group_sizes_shard, expert_mapped_topk_idx_sort_, cfg) + ff_out = ff_out + (tensor_idx == 0) * we_down_bias[expert_mapped_topk_idx_sort_, :] + ff_out = jnp.where(valid_group_mask_sort_[..., None], ff_out, 0) # expensive + + if cfg.ep_strategy == "prefill": + rs_shape = math.ceil((ff_out.shape[-1] // tensor_count) / 256) * 256 * tensor_count + pad_size = rs_shape - ff_out.shape[-1] + ff_out = jnp.pad(ff_out, ((0, 0), (0, pad_size))) + ff_out = jax.lax.psum_scatter(ff_out, axis_name=tensor_axname, scatter_dimension=1, tiled=True) + ff_out = ff_out * topk_weights.reshape(-1)[sort_idx_][:, None] + + if cfg.ep_strategy == "prefill": + with jax.named_scope("unpermute"): + # unpermute tokens + dtype = jnp.bfloat16 + dim_nums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,) + ) + ff_out_expert = jax.lax.scatter_add( + jnp.zeros((b * s, ff_out.shape[-1]), dtype=dtype), + sort_idx_[..., None] // e, + ff_out.astype(dtype), + dim_nums, + ).astype(dtype) + ff_out_expert = ff_out_expert.astype(cfg.dtype) + else: + with jax.named_scope("unpermute"): + ff_out = jnp.take_along_axis(ff_out, isort_idx_[..., None], axis=-2) + with jax.named_scope("expert_summing"): + ff_out_expert = jnp.sum(ff_out.reshape((b * s, e, d)), -2) + ff_out_expert = ff_out_expert.astype(cfg.dtype) + + with jax.named_scope("experts_collective"): + if cfg.ep_strategy == "prefill": + if expert_axname is not None: + ff_out_expert = jax.lax.psum(ff_out_expert, expert_axname) + else: + # collectives + if is_embedding_sharded: # activations are supposed to be sharded on out + with jax.named_scope("tp_e_psum_scatter"): + ff_out_expert = jax.lax.psum_scatter( + ff_out_expert, tensor_axname, scatter_dimension=1, tiled=True + ) + with jax.named_scope("ep_e_psum"): + if expert_axname is not None: + ff_out_expert = jax.lax.psum(ff_out_expert, expert_axname) + else: + psum_axes = tensor_axname if expert_axname is None else (expert_axname, tensor_axname) + ff_out_expert = jax.lax.psum(ff_out_expert, psum_axes) + ff_out_expert = ff_out_expert.reshape((b, s, ff_out_expert.shape[-1])) + return ff_out_expert + + with jax.named_scope("moe_routed_expert"): + x_ = psc(x, x_spec) + ff_out_expert = _expert_fn(x_, we_gate_up, we_gate_up_bias, we_down, we_down_bias, topk_weights, topk_idx) + return psc(ff_out_expert, l2p("batch", "sequence", "act_embed")) + + +def forward_layer( + x: jax.Array, + segment_ids: jax.Array, + layer: Layer, + sin: jax.Array, + cos: jax.Array, + idx: int, + cfg: Config, + cache: KVCache | None = None, +) -> tuple[jax.Array, jax.Array, jax.Array]: + x = x.astype(cfg.dtype) + + # Attention block + with jax.named_scope("attn_pre_norm"): + attn_in = rms_norm(x, layer.attn_pre_gamma, cfg.norm_eps) + attn_out, cache_updates = attention_block(attn_in, segment_ids, layer.attn, sin, cos, cfg, cache, idx) + with jax.named_scope("residual"): + x = x + attn_out.astype(cfg.dtype) + + # FFN block + with jax.named_scope("attn_post_norm"): + ff_in = rms_norm(x, layer.attn_post_gamma, cfg.norm_eps) + with jax.named_scope("ffn"): + ff_out = moe_block(ff_in, layer.ffw, cfg) + with jax.named_scope("residual"): + x = x + ff_out.astype(cfg.dtype) + + return x, cache_updates + + +def forward(x: jax.Array, segment_ids: jax.Array, weights: Weights, cfg: Config, cache: KVCache | None = None): + l2p = lambda *args: logical_to_physical(args, cfg.rules) + # Embed input tokens [B, T] -> [B, T, D] + x = weights.embedding.at[x, :].get(out_sharding=l2p("batch", "sequence", "act_embed"))[..., : cfg.embed] + + positions = segment_ids_to_positions(segment_ids) + if is_type(cache, KVCache): + positions = positions + cache.fill_len()[:, None] + sin, cos = _generate_pos_embeddings(positions, cfg.head_dim, cfg) # [B, T, head_dim] + sin, cos = sin.astype(cfg.dtype), cos.astype(cfg.dtype) + + all_cache_updates = [] + for idx, layer in enumerate(weights.layers): + x, cache_updates = forward_layer(x, segment_ids, layer, sin, cos, idx, cfg, cache) + all_cache_updates.append(cache_updates) + + x = rms_norm(x, weights.gamma_final, cfg.norm_eps) # Final layer norm. + logits = einsum("btd,dv->btv", x, weights.lm_head) # Project to vocabulary size + if is_type(cache, KVCache): + cache.k, cache.v = [[z[i] for z in all_cache_updates] for i in range(2)] + additional_tokens = jnp.max(_length_minus_right_padding(segment_ids)) + return logits, dataclasses.replace(cache, iter=(jnp.maximum(0, cache.iter) + additional_tokens) % cache.size) + else: + return logits, all_cache_updates + + +def optimal_formats(cfg: Config): + SDS, tree_map, bs = jax.ShapeDtypeStruct, partial(jax.tree.map, is_leaf=is_param), 16 + weights_abstract, cache_abstract = Weights.abstract(cfg), KVCache.abstract(cfg, bs, cfg.max_seq_len) + weights_shardings, cache_shardings = Weights.shardings(cfg), KVCache.shardings(cfg, bs, cfg.max_seq_len) + weights_shapes = tree_map(lambda x, s: SDS(x.shape, x.dtype, sharding=s), weights_abstract, weights_shardings) + cache_shapes = tree_map(lambda x, s: SDS(x.shape, x.dtype, sharding=s), cache_abstract, cache_shardings) + _forward = lambda weights, cache: forward(*([jnp.ones((bs, 1), jnp.int32)] * 2), weights, cfg, cache=cache) + with jax.sharding.set_mesh(cfg.mesh): + fn = jax.jit( + _forward, in_shardings=Format(Layout.AUTO), out_shardings=Format(Layout.AUTO), donate_argnames=("cache",) + ) + weights_formats, cache_formats = fn.trace(weights_shapes, cache_shapes).lower().compile().input_formats[0] + weights = tree_map(lambda x, f: SDS(x.shape, x.dtype, sharding=f), weights_abstract, weights_formats) + cache = tree_map(lambda x, f: SDS(x.shape, x.dtype, sharding=f), cache_abstract, cache_formats) + return weights, cache + + +# serialization +def save_pytree(weights, path): + flat_data = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(weights)[0]) + ser.save(flat_data, path) # save a flatten with path to avoid custom + + +def load_pytree(path, sharding=None): + flat_sharding = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(sharding)[0]) + data = jax.tree.unflatten(jax.tree.structure(sharding), jax.tree.leaves(ser.load(path, flat_sharding))) + return data + + +# Inference. +@partial(jax.jit, static_argnums=(1, 2)) +def prepare_chunk(chunk, pad_to: int, pad_id: int): + # [bs, length] -> [bs, padded] + if chunk.ndim == 1: + chunk = chunk[None, :] + chunk = jnp.pad(chunk, [(0, 0), (0, pad_to - chunk.shape[-1])], mode="constant", constant_values=pad_id) + segment_ids = jnp.where(chunk != pad_id, 1, 0).astype(jnp.int32) + return chunk, segment_ids + + +def prefill( + tokens: jax.Array, weights: Weights, cache: KVCache, cfg: Config, pad_id: int = PAD_ID +) -> tuple[jax.Array, jax.Array, KVCache]: + """Samples from a prompt.""" + # Calculate the next power of 2 for padding, up to cfg.max_seq. + assert tokens.shape[-1] <= cfg.max_seq_len + pad_to = 2 ** math.ceil(math.log2((tokens.shape[-1]))) + prompt, prompt_segment_ids = prepare_chunk(tokens, pad_to=pad_to, pad_id=pad_id) + assert prompt.ndim == 2 + + cache_shardings = KVCache.shardings(cfg, prompt.shape[0], cfg.max_seq_len) + if is_type(cache, KVCache): + uninitialized_iter = -jnp.ones_like(cache.iter) + cache = dataclasses.replace(cache, starts=_count_left_padding(prompt, pad_id=pad_id), iter=uninitialized_iter) + else: + cache_shardings = tuple([z[idx] for idx in range(cfg.num_layers)] for z in cache_shardings) + logits_shardings = jax.sharding.NamedSharding(cfg.mesh, P(BATCH_AXIS_NAME, None, TENSOR_AXIS_NAME)) + logits, cache = jax.jit(forward, donate_argnums=(4,), out_shardings=(logits_shardings, cache_shardings))( + prompt, prompt_segment_ids, weights, cfg, cache + ) + next_tokens = jax.jit(partial(jnp.argmax, axis=-1))(logits) + return next_tokens, logits, cache + + +def sample_top(key: jax.Array, logits: jax.Array, k: int = 16, temp: float = 1.0): + def sample_multinomial(logits): + probs = jax.nn.softmax(logits / temp, axis=-1) + + @partial( + jax.shard_map, + in_specs=(P(), P(BATCH_AXIS_NAME, None, TENSOR_AXIS_NAME)), + out_specs=P(BATCH_AXIS_NAME, None), + ) + def _(key, probs): + idx = jax.lax.axis_index(TENSOR_AXIS_NAME) + top_probs, top_tokens = jax.lax.approx_max_k(probs, k=k) + top_tokens = top_tokens + probs.shape[-1] * idx + top_probs = jax.lax.all_gather(top_probs, TENSOR_AXIS_NAME, axis=-1, tiled=True) + top_tokens = jax.lax.all_gather(top_tokens, TENSOR_AXIS_NAME, axis=-1, tiled=True) + top_probs, idx = jax.lax.top_k(top_probs, k=k) + top_tokens = jnp.take_along_axis(top_tokens, idx, -1) + + # by-hand binomial sampling + norm_probs = jnp.cumsum(top_probs, axis=-1) / jnp.sum(top_probs, axis=-1)[..., None] + idx = jnp.argmax(random.uniform(key, top_probs.shape[:-1])[..., None] <= norm_probs, axis=-1) + return jax.lax.pmax(jnp.take_along_axis(top_tokens, idx[..., None], -1)[..., 0], TENSOR_AXIS_NAME) + + return _(key, probs) + + return jax.lax.cond(temp > 1e-3, sample_multinomial, partial(jnp.argmax, axis=-1), logits) + + +@partial(jax.jit, donate_argnames=("cache",)) +def decode_step(last_tokens: jax.Array, weights: Weights, cache: KVCache, cfg: Config, pad_id: int = PAD_ID, key=None): + assert last_tokens.ndim == 2 + segment_ids = (last_tokens != pad_id).astype(jnp.int32) + next_logits, cache = forward(last_tokens, segment_ids, weights, cfg, cache) + key = key if key is not None else random.key(cache.iter) # poor man's random key + next_tokens = sample_top(key, next_logits, temp=0.7) + return reshard(next_tokens, P()), cache diff --git a/gpt_oss/main.py b/gpt_oss/main.py new file mode 100644 index 0000000..e0624be --- /dev/null +++ b/gpt_oss/main.py @@ -0,0 +1,97 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from etils import epath +import json + +import jax +from jax import numpy as jnp +from jax import random +from jax.sharding import set_mesh, AxisType, PartitionSpec as P + +try: + from jax.sharding import use_mesh + + set_mesh = use_mesh +except ImportError: + pass +import numpy as np + +from transformers import AutoTokenizer +from gpt_oss_jax import model as gpt_jax + + +jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax_cache").expanduser())) + + +def encode_input(tokenizer, texts, pad_id: int = gpt_jax.PAD_ID): + assert isinstance(texts, list) + inputs = [ + tokenizer.apply_chat_template([{"role": "user", "content": text}], add_bos=True, add_generation_prompt=True) + for text in texts + ] + max_len = max([len(x) for x in inputs]) + inputs = [(max_len - len(x)) * [pad_id] + x for x in inputs] + return np.array(inputs) + + +if __name__ == "__main__": + # jax.distributed.initialize() # if you want to run multi-host + quant = True + + ckpt_path = epath.Path("~/bucket/gpt_oss_jax/gpt_oss_20b").expanduser() + if quant: + ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) + + tp = 2 + mesh = jax.make_mesh( + (1, tp, jax.device_count() // tp), ("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3 + ) + cfg = gpt_jax.hf_to_jax_config(json.loads((ckpt_path / "config.json").read_text())) + cfg = dataclasses.replace(cfg, mesh=mesh, quant_moe=quant, quant_cache=quant, max_seq_len=2048) + + input = encode_input( + tokenizer, + [ + "Tell me your name", + "What is the weather like expressed in long prose in Old English", + "Do you like ice cream, be extremely precise", + ] + + ["Do you like ice cream, be extremely precise"] * (4 - 3), + ) + weights = gpt_jax.load_pytree(ckpt_path, gpt_jax.Weights.shardings(cfg)) + weights = jax.device_put(weights, gpt_jax.compute_optimal_weights_layouts(weights, cfg)) + + profile = True + with set_mesh(cfg.mesh): + zero_cache = gpt_jax.KVCache.init(random.key(1), cfg, input.shape[0], cfg.max_seq_len) + next_tokens, logits, cache = gpt_jax.prefill(input, weights, zero_cache, cfg) + curr_tokens = next_tokens.at[:, cache.iter - 1 : cache.iter].get(out_sharding=P(None, None)) + tokens_list = [] + for i in range(1024): + if profile and i == 2: + jax.profiler.start_trace("/tmp/gpt_profile") + tokens_list.append(curr_tokens) + curr_tokens, cache = gpt_jax.decode_step(curr_tokens, weights, cache, cfg) + if profile and i == 6: + jax.block_until_ready(tokens_list) + jax.profiler.stop_trace() + tokens = np.array(jnp.concatenate(tokens_list, axis=-1)) + responses = [tokenizer.decode(row) for row in tokens] + print("Responses:") + for response in responses: + print(response) + print("\n".join(3 * ["-" * 80])) diff --git a/gpt_oss/pyproject.toml b/gpt_oss/pyproject.toml new file mode 100644 index 0000000..806f353 --- /dev/null +++ b/gpt_oss/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "gpt_oss_jax" +version = "0.1.0" +description = "" +authors = [ + { name = "Robert Dyro" }, +] +readme = "README.md" +requires-python = ">=3.10" +license = { text = "Apache-2.0" } + +dependencies = [ + "jax", + "torch", + "transformers", # for the model config and the tokenizer + "flatbuffers", + "tensorstore", + "tqdm", + "numpy", + "datasets", + "gcsfs", + "etils", + "absl-py", +] + +# we don't need CUDA torch +[[tool.uv.index]] +name = "pytorch" +url = "https://download.pytorch.org/whl/cpu" + +[build-system] +requires = ["setuptools>=61.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.dynamic] +dependencies = { file = ["pyproject.toml"] } diff --git a/gpt_oss/scripts/convert_weights.py b/gpt_oss/scripts/convert_weights.py new file mode 100644 index 0000000..2cddccd --- /dev/null +++ b/gpt_oss/scripts/convert_weights.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path +from argparse import ArgumentParser +import dataclasses +import shutil + + +def main(model_path: str | Path, ckpt_path: str | Path): + try: + from gpt_oss_jax import model as gpt_jax + from gpt_oss_jax import chkpt_utils as utils + except ImportError: + sys.path.append(str(Path(__file__).parents[1].absolute())) + + from gpt_oss_jax import model as gpt_jax + from gpt_oss_jax import chkpt_utils as utils + + from transformers import AutoConfig + from safetensors import safe_open + from tqdm import tqdm + + model_path, ckpt_path = Path(model_path).expanduser(), Path(ckpt_path).expanduser() + files = list(model_path.glob("*safetensors")) + assert len(files) > 1 + config_file = model_path / "config.json" + assert config_file.exists(), "Must have only one `config.json` file in the model path" + config = AutoConfig.from_pretrained(config_file) + cfg = gpt_jax.hf_to_jax_config(config) + + # we convert the model unquantized when reading a GPT OSS model checkpoint + weights = gpt_jax.Weights.abstract(dataclasses.replace(cfg, quant_moe=False, quant_attn=False)) + + if not ckpt_path.exists(): + model = {} + for file in tqdm(files): + with safe_open(file, framework="torch") as f: + for key in tqdm(f.keys(), leave=False): + model[key] = f.get_tensor(key) + converted_weights = utils.convert_model_or_layer(weights, model, cfg, sequential=False) + gpt_jax.save_pytree(converted_weights, ckpt_path) + + additional_files = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "chat_template.json", + "chat_template.jinja", + "generation_config.json", + ] + for additional_file in additional_files: + full_path = model_path / f"{additional_file}" + if not full_path.exists(): + print(f"Could not find {additional_file}, skipping...") + full_path = full_path + shutil.copyfile(full_path, ckpt_path / full_path.name) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--source-path", default="~/gpt-oss-20b", required=True, help="HF model directory path") + parser.add_argument( + "--dest-path", + default="~/gpt_oss_jax/gpt-oss-20b", + required=True, + help="JAX model model directory (to be created).", + ) + args = parser.parse_args() + main(args.source_path, args.dest_path) diff --git a/gpt_oss/scripts/download_model.py b/gpt_oss/scripts/download_model.py new file mode 100644 index 0000000..cf5369c --- /dev/null +++ b/gpt_oss/scripts/download_model.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from argparse import ArgumentParser +from pathlib import Path + +example_models = [ + "openai/gpt-oss-20b", + "openai/gpt-oss-120b", +] + + +def main(model_id: str, dest_root_path: str | Path): + from huggingface_hub import snapshot_download + + local_dir = Path(dest_root_path).expanduser().absolute() / str(model_id).replace("/", "--") + snapshot_download(repo_id=model_id, local_dir=local_dir) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument( + "--model-id", required=True, help=f"HuggingFace model / repo id. Examples include: {example_models}" + ) + parser.add_argument( + "--dest-root-path", + required=True, + default="~/", + help="Destination root directory, the model will be saved into its own directory.", + ) + args = parser.parse_args() + main(args.model_id, args.dest_root_path) diff --git a/gpt_oss/scripts/quantize_model.py b/gpt_oss/scripts/quantize_model.py new file mode 100644 index 0000000..0c7fd08 --- /dev/null +++ b/gpt_oss/scripts/quantize_model.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path +from argparse import ArgumentParser + + +def main(path: str | Path, suffix: str): + try: + from gpt_oss_jax import chkpt_utils as utils + except ImportError: + sys.path.append(str(Path(__file__).parents[1].absolute())) + + from gpt_oss_jax import chkpt_utils as utils + + path = Path(path).expanduser().absolute() + dest_path = path.parent / f"{path.name}{suffix}" + utils.quantize_model(path, dest_path) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--path", default="~/gpt_oss_20b", required=True, help="Existing JAX model checkpoint path") + parser.add_argument( + "--suffix", + default="quant", + help="Suffix for a new checkpoint directory, e.g., path=~/model, suffix=-quant -> ~/model-quant", + ) + + args = parser.parse_args() + main(args.path, args.suffix if args.suffix.startswith("-") else f"-{args.suffix}") diff --git a/gpt_oss/tests/test_model.py b/gpt_oss/tests/test_model.py new file mode 100644 index 0000000..ee12c58 --- /dev/null +++ b/gpt_oss/tests/test_model.py @@ -0,0 +1,112 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from absl.testing import absltest, parameterized + +import jax +from jax import numpy as jnp +from jax import random +from jax.sharding import PartitionSpec as P, AxisType, set_mesh +try: + from jax.sharding import use_mesh + set_mesh = use_mesh +except ImportError: + pass + + +from gpt_oss_jax import model as gpt_jax + +jax.config.update("jax_platforms", "cpu") +jax.config.update("jax_num_cpu_devices", 4) + +MOE_CFG = gpt_jax.Config( + embed=2880, + q_heads=64, + kv_heads=8, + num_layers=24, + head_dim=64, + vocab_size=201088, + max_seq_len=128, + causal=True, + moe_ffw_size=2880, + moe_experts_per_tok=4, + moe_num_experts=32, + sliding_window_size=128, + sliding_attention_map=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention" + ], + ep_strategy="decode", +) + +class TestModel(parameterized.TestCase): + def setUp(self): + self.mesh = jax.make_mesh((1, len(jax.devices()), 1), P("x", "y", "z"), axis_types=(AxisType.Explicit,) * 3) + self.small_moe_cfg = dataclasses.replace(MOE_CFG, mesh=self.mesh, num_layers=2, embed=32, vocab_size=128) + + @parameterized.product(quant=[False, True]) + def test_model_init(self, quant): + cfg = self.small_moe_cfg + cfg = dataclasses.replace(cfg, quant_attn=quant, quant_moe=quant) + weights = gpt_jax.Weights.init(random.key(0), cfg) + del weights + + @parameterized.product(quant=[False, True]) + def test_cache_init(self, quant): + cfg = self.small_moe_cfg + cache = gpt_jax.KVCache.init(random.key(0), cfg, 2, cfg.max_seq_len) + del cache + + @parameterized.product(moe=[True, False], quant_weights=[False, True], quant_cache=[True, False]) + def test_prefill_decode(self, moe, quant_weights, quant_cache): + cfg = self.small_moe_cfg + cfg = dataclasses.replace( + cfg, quant_attn=quant_weights, quant_moe=quant_weights, quant_cache=quant_cache + ) + tokens = jnp.ones((1, 32), dtype=jnp.int32) + weights = gpt_jax.Weights.init(random.key(0), cfg) + cache = gpt_jax.KVCache.init(random.key(0), cfg, tokens.shape[0], cfg.max_seq_len) + with set_mesh(cfg.mesh): + max_tokens, _, cache = gpt_jax.prefill(tokens, weights, cache, cfg) + next_tokens = max_tokens[:, :-1] + with set_mesh(cfg.mesh): + for _ in range(2): + next_tokens, cache = gpt_jax.decode_step(next_tokens, weights, cache, cfg) + + +if __name__ == "__main__": + absltest.main() diff --git a/llama3/llama3_jax/model.py b/llama3/llama3_jax/model.py index a54539a..e69022b 100644 --- a/llama3/llama3_jax/model.py +++ b/llama3/llama3_jax/model.py @@ -119,9 +119,7 @@ def jax_pytree_struct(cls, meta_fields: tuple = ()): cls = dataclasses.dataclass(cls) all_fields = tuple(f.name for f in dataclasses.fields(cls) if f.init) data_fields = tuple(f for f in all_fields if f not in meta_fields) - # return register_dataclass_serialization( - return tree_util.register_dataclass(cls, data_fields=data_fields, meta_fields=meta_fields) # , - # serialize_auxdata=lambda *args: b"", deserialize_auxdata=lambda *args: ()) + return tree_util.register_dataclass(cls, data_fields=data_fields, meta_fields=meta_fields) jax_static = lambda cls: tree_util.register_static(dataclasses.dataclass(cls)) @@ -1008,22 +1006,33 @@ def forward( # serialization -def save_pytree(data, path): - import orbax.checkpoint as ocp +#def save_pytree(data, path): +# import orbax.checkpoint as ocp +# +# with ocp.PyTreeCheckpointer() as ckptr: +# ckptr.save(epath.Path(path), data, ocp.args.PyTreeSave(data, ocdbt_target_data_file_size=1024 * 1024 * 100)) +# +# +#def load_pytree(path, sharding=None): +# import orbax.checkpoint as ocp +# +# item, transforms = sharding, None +# restore_args = jax.tree.map(lambda s: ocp.ArrayRestoreArgs(sharding=s), sharding) +# with ocp.PyTreeCheckpointer() as ckptr: +# return ckptr.restore( +# epath.Path(path), ocp.args.PyTreeRestore(item=item, transforms=transforms, restore_args=restore_args) +# ) - with ocp.PyTreeCheckpointer() as ckptr: - ckptr.save(epath.Path(path), data, ocp.args.PyTreeSave(data, ocdbt_target_data_file_size=1024 * 1024 * 100)) +# serialization +def save_pytree(weights, path): + flat_data = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(weights)[0]) + ser.save(flat_data, path) # save a flatten with path to avoid custom def load_pytree(path, sharding=None): - import orbax.checkpoint as ocp - - item, transforms = sharding, None - restore_args = jax.tree.map(lambda s: ocp.ArrayRestoreArgs(sharding=s), sharding) - with ocp.PyTreeCheckpointer() as ckptr: - return ckptr.restore( - epath.Path(path), ocp.args.PyTreeRestore(item=item, transforms=transforms, restore_args=restore_args) - ) + flat_sharding = odict(("weights" + "".join(map(str, k)), v) for k, v in jax.tree.flatten_with_path(sharding)[0]) + data = jax.tree.unflatten(jax.tree.structure(sharding), jax.tree.leaves(ser.load(path, flat_sharding))) + return data # Inference. diff --git a/llama3/main.py b/llama3/main.py index ad21bfc..f1d1365 100644 --- a/llama3/main.py +++ b/llama3/main.py @@ -20,7 +20,12 @@ import jax from jax import numpy as jnp from jax import random -from jax.sharding import use_mesh, AxisType, PartitionSpec as P +from jax.sharding import set_mesh, AxisType, PartitionSpec as P +try: + from jax.sharding import use_mesh + set_mesh = use_mesh +except ImportError: + pass import numpy as np from llama3_jax import model as l3jax @@ -43,13 +48,14 @@ def encode_input(tokenizer, texts: list[str], model_name: str, pad_id: int = 0): #jax.distributed.initialize() # if you want to run multi-host quant = True - ckpt_path = epath.Path("~/bucket/llama3_jax/DeepSeek-R1-Distill-Llama-3.1-8B-Instruct").expanduser() + ckpt_path = epath.Path("~/bucket/llama3_jax_old/DeepSeek-R1-Distill-Llama-3.1-8B-Instruct").expanduser() if quant: ckpt_path = ckpt_path.parent / f"{ckpt_path.name}-quant" tokenizer = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") mesh = jax.make_mesh( - (1, 8, jax.device_count() // 8), ("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3 + #(1, 8, jax.device_count() // 8), ("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3 + (1, 4, jax.device_count() // 4), ("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3 ) cfg = l3jax.llama_to_jax_config(json.loads((ckpt_path / "config.json").read_text())) cfg = dataclasses.replace(cfg, mesh=mesh, quant_layer=quant, quant_cache=quant) @@ -65,10 +71,10 @@ def encode_input(tokenizer, texts: list[str], model_name: str, pad_id: int = 0): model_name=ckpt_path.name, ) - with use_mesh(cfg.mesh): - zero_cache = l3jax.KVCache.init(random.key(1), cfg, input.shape[0], cfg.max_seq_len) + with set_mesh(cfg.mesh): + zero_cache = l3jax.KVCache.init(random.key(1), cfg, input.shape[0]) next_tokens, logits, cache = l3jax.prefill(input, weights, zero_cache, cfg) - curr_tokens = next_tokens.at[:, cache.length - 1 : cache.length].get(out_sharding=P(None, None)) + curr_tokens = next_tokens.at[:, cache.iter - 1 : cache.iter].get(out_sharding=P(None, None)) tokens_list = [] for _ in range(16): tokens_list.append(curr_tokens) diff --git a/serving/client_demo.py b/serving/client_demo.py index c83ad6c..7c28cd0 100644 --- a/serving/client_demo.py +++ b/serving/client_demo.py @@ -1,23 +1,22 @@ #!/usr/bin/env python3 -import threading -import time -import requests -from typing import List -import textwrap -import sys +import base64 import gzip import json -import base64 +import sys +import textwrap +import threading +import time from pathlib import Path +from typing import List +import numpy as np +import requests +from rich.console import Console +from rich.layout import Layout from rich.live import Live from rich.panel import Panel -from rich.layout import Layout -from rich.console import Console from rich.text import Text -import numpy as np - # --- Configuration --- SERVER_URL = "http://localhost:8081" diff --git a/serving/main_serving_ds_r1.py b/serving/main_serving_ds_r1.py index 663fece..ab4cb80 100644 --- a/serving/main_serving_ds_r1.py +++ b/serving/main_serving_ds_r1.py @@ -1,40 +1,41 @@ -import sys -import dataclasses -import time -from pathlib import Path -import threading import asyncio -import socket +import dataclasses import signal +import socket +import sys +import threading import time -from typing import AsyncGenerator -from contextlib import asynccontextmanager from argparse import ArgumentParser +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncGenerator import jax -from jax import random -from jax.sharding import PartitionSpec as P, AxisType -from deepseek_r1_jax import model as dsjax -from deepseek_r1_jax import chkpt_utils as dsjax_utils -import serving_jax as serving -from serving_jax import attention_cache_utils import numpy as np - +import serving_jax as serving +import uvicorn from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse, Response +from fastapi.responses import Response, StreamingResponse +from jax import random +from jax.sharding import AxisType +from jax.sharding import PartitionSpec as P from pydantic import BaseModel -import uvicorn +from serving_jax import attention_cache_utils + +from deepseek_r1_jax import chkpt_utils as dsjax_utils +from deepseek_r1_jax import model as dsjax TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None jax.config.update("jax_explain_cache_misses", True) jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) -#jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) -#jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) +# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) +# jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) jax.config.update("jax_enable_empty_arrays", True) shutdown_signal = threading.Event() + def encode_input(tokenizer, texts, pad_id: int = 0): assert isinstance(texts, list) inputs = [ @@ -43,6 +44,7 @@ def encode_input(tokenizer, texts, pad_id: int = 0): max_len = max([len(x) for x in inputs]) return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) + def load_model(): global SERVE_LOOP, SERVING_THREAD, TOKENIZER, ARGS @@ -60,17 +62,20 @@ def load_model(): assert ckpt_path.is_dir() print("---> Model config loaded") - mesh = jax.make_mesh((1, 8, jax.device_count() // 8), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Auto,) * 3) - cfg = dataclasses.replace(dsjax.Config(), mesh=mesh)#, num_layers=4) + mesh = jax.make_mesh( + (1, 8, jax.device_count() // 8), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Auto,) * 3 + ) + cfg = dataclasses.replace(dsjax.Config(), mesh=mesh) # , num_layers=4) weights = dsjax_utils.load_model(ckpt_path, cfg) decode_weights, prefill_weights = weights, weights print("---> Weights loaded") - serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64, - decode_batch_size=8, prefill_batch_size=1, prefix_chunk_size=64) + serve_cfg = serving.ServingConfig( + decode_steps=32, max_decode_length=64, decode_batch_size=8, prefill_batch_size=1, prefix_chunk_size=64 + ) decode_cache = dsjax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, cfg.max_seq_len) - decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry - decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache + decode_cache.get_sequence = attention_cache_utils.kvcache_get_sequence + decode_cache.insert_sequences = attention_cache_utils.kvcache_insert_sequences SERVE_LOOP = serving.ServingLoop( serve_cfg, cfg, dsjax.prefill, prefill_weights, dsjax.decode_step, decode_weights, decode_cache, ARGS.server ) @@ -82,6 +87,7 @@ def serve_forever(): SERVE_LOOP.serving_step() except Exception as e: import traceback + print(traceback.format_exc(), flush=True) print(f"Exception {e}", flush=True) finally: @@ -112,7 +118,7 @@ class GenerateRequest(BaseModel): text: str -#async def generate_generator(params: GenerateRequest, request: Request) -> AsyncGenerator[str, None]: +# async def generate_generator(params: GenerateRequest, request: Request) -> AsyncGenerator[str, None]: async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: if id in SERVE_LOOP.results: del SERVE_LOOP.results[id] diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index 32d4d77..f77e9b9 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -12,789 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import dataclasses -import contextlib -from functools import partial -from typing import Any, Callable, Sequence -import math -from concurrent.futures import ThreadPoolExecutor, Future -import threading -import time -import json - -import jax -import jax.numpy as jnp -from jax.sharding import Mesh, PartitionSpec as P, NamedSharding, set_mesh - -try: - from jax.experimental.shard import auto_axes -except ModuleNotFoundError: - from jax.sharding import auto_axes -from jax._src import distributed - -from jax._src.lib import xla_client as xc -import numpy as np - -from .cross_host import transfer_tree_A2B - - -KVCache, Weights, Config = Any, Any, Any -PyTree, PyTreeStruct = Any, Any - -TIME_AXIS = 2 -USE_PREFIX_CACHE = True # the eviction mechanism is extremely simple right now -# USE_PREFIX_CACHE = False -is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) - -######################################################################################################################## -# device put for cross-process/hosts transfers ######################################################################### -######################################################################################################################## - - -def unsafe_device_put(xs: PyTree, spec: PyTree, dest_mesh: Mesh): - """Fastest, but local single-process JAX only for now.""" - xs_flat, xs_struct = jax.tree.flatten(xs) - shardings_list = [NamedSharding(dest_mesh, s) for s in jax.tree.leaves(spec)] - devices_list = [s._internal_device_list for s in shardings_list] - copy_semantics = [xc.ArrayCopySemantics.ALWAYS_COPY] * len(devices_list) - out = xc.batched_copy_array_to_devices_with_sharding(xs_flat, devices_list, shardings_list, copy_semantics) - return jax.tree.unflatten(xs_struct, out) - - -def jax_device_put(xs: PyTree, sharding: PyTree): - """Async, available in future JAX.""" - is_source = len(getattr(jax.tree.leaves(xs)[0], "addressable_shards", [])) > 0 - if is_source: - return jax.device_put(xs, sharding) - else: - empty_arrays = jax.tree.map( - lambda x: jax.make_array_from_single_device_arrays(x.shape, x.sharding, [], dtype=x.dtype), xs - ) - return jax.device_put(empty_arrays, sharding) - - -def jit_device_put(xs: PyTree, sharding: PyTree): - """Most compatabile, uses jit, so requires blocking dispatch.""" - # jax.sharding.set_mesh(None) # not compatible with context mesh - meshA, meshB = jax.tree.leaves(xs)[0].sharding.mesh, jax.tree.leaves(sharding)[0].mesh - return transfer_tree_A2B(xs, meshA, meshB) - - -#device_put = jit_device_put # the most compatible options currently, but NOT async, need -device_put = jax.device_put - - -def _ensure_all_args_on_mesh(args, mesh: Mesh): - if not all(jax.tree.leaves(arg)[0].sharding.mesh == mesh for arg in args): - _correct_mesh = lambda value: jax.tree.leaves(value)[0].sharding.mesh == mesh - _args = {i: arg for i, arg in enumerate(args) if not _correct_mesh(arg)} - if len(_args) > 0: - args = dict(enumerate(args)) | device_put(_args, like_shard(_args, mesh)) - args = tuple(args[i] for i in range(len(args))) - return args - - -######################################################################################################################## -# kv cache buffer management ########################################################################################### -######################################################################################################################## - - -@partial(jax.jit, static_argnames=("axis", "chunk_size", "ns")) -def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]: - def _fn(val): - axis_ = axis % val.ndim - size = val.shape[axis_] - if size < chunk_size * ns: - min_len = chunk_size * ns - val = jnp.pad(val, [(0, 0) if i != axis_ else (0, min_len - val.shape[axis_]) for i in range(val.ndim)]) - index = [slice(None) if i != axis_ else slice(0, ns * chunk_size) for i in range(val.ndim)] - return jnp.split(val[*index], ns, axis=axis_)[:ns] - - val_leaves, val_structure = jax.tree.flatten(val) - spec = [[x] * ns for x in like_spec(val_leaves)] - split_leaves = auto_axes(lambda vals: [_fn(val) for val in vals], out_sharding=spec)(val_leaves) - return [jax.tree.unflatten(val_structure, [x[i] for x in split_leaves]) for i in range(ns)] - - -@partial(jax.jit, static_argnames=("split_axis",)) -def _concat(values, split_axis: int): - _fn = lambda vals: jax.tree.map(lambda *args: jnp.concatenate(args, axis=split_axis), *vals) - return auto_axes(_fn, out_sharding=like_spec(values[0]))(values) - - -class KVBufferStore: - def __init__(self): - self.usecount, self.ondevice, self._store, self.unique_id, self.livecount = {}, {}, {}, 18, 0 - - def _get_unique_buffer_ids(self, n: int): - ids = list(range(self.unique_id, self.unique_id + n)) - self.unique_id += n - return ids - - def offload_buffers(self, how_many: int): - if how_many == 0: - return - candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2**60) - for i in candidates[:how_many]: - if self.ondevice[i]: - host_shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i]) - self._store[i] = jax.device_put(self._store[i], host_shrd) - self.ondevice[i] = False - self.livecount -= 1 - - def load(self, id: int): - if isinstance(id, (tuple, list)): - return [self.load(i) for i in id] - if self.ondevice[id]: - return self._store[id] - self.ondevice[id] = True - self.livecount += 1 - device_shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("device"), self._store[id]) - self._store[id] = jax.device_put(self._store[id], device_shrd) - return self._store[id] - - def delete(self, id: int): - if isinstance(id, (list, tuple)): - return [self.delete(i) for i in id] - self.livecount -= self.ondevice[id] - del self.usecount[id], self.ondevice[id], self._store[id] - - def store(self, id: int, val: Any): - if isinstance(id, (tuple, list)): - return [self.store(i, v) for i, v in zip(id, val)] - self.livecount += 1 - self.usecount[id], self.ondevice[id], self._store[id] = 1, True, val - - def mark_visited(self, id: int): - if isinstance(id, (list, tuple)): - return [self.mark_visited(i) for i in id] - self.usecount[id] += 1 - - -BUFFER_STORE = KVBufferStore() - -######################################################################################################################## -# trie utils ########################################################################################################### -######################################################################################################################## - -EMPTY, HASH_BITWIDTH = -1, 1 - -@dataclasses.dataclass -class ChildKeys: - keys: np.ndarray - keys_hash: np.ndarray - keys_hash_mask: np.ndarray - key_lens: np.ndarray - num: int = 0 - - -def _hash_encode(v: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int = EMPTY): - v, last_dim = v.astype(np.int64), min(64 // hash_bitwidth, v.shape[-1]) - v_, el_mask = v.reshape(v.shape[:-1] + (-1, last_dim)), (1 << hash_bitwidth) - 1 - mask = np.bitwise_or.reduce(((v_ != pad_idx) * el_mask) << (hash_bitwidth * np.arange(v_.shape[-1])), axis=-1) - h = np.bitwise_or.reduce((v_ & el_mask) << (hash_bitwidth * np.arange(v_.shape[-1])), axis=-1) - return h, mask - - -def _prefilter_on_hash( - w: np.ndarray, - keys: np.ndarray, - vh: np.ndarray, - vm: np.ndarray, - hash_bitwidth: int = HASH_BITWIDTH, - pad_idx: int = EMPTY, -): - wh, wm = _hash_encode(w, hash_bitwidth=hash_bitwidth, pad_idx=pad_idx) - inv_match = (wh ^ vh) & vm & wm - # count full hash chunk matches, but don't miss sequences not matching at least one full hash - match_len = np.sum(np.cumsum(inv_match, axis=-1) == 0, axis=-1) + (w[0] == keys[:, 0]) - max_match_len = max(np.max(match_len), 1) - return np.where(match_len == max_match_len)[0] - - -def _fast_pad(x, size, axis, pad_val=0): - new_buf = pad_val * np.ones([size - s if i == axis else s for i, s in enumerate(x.shape)], dtype=x.dtype) - return np.concat([x, new_buf], axis) - - -@dataclasses.dataclass -class TrieNode: - value: int - children: list["TrieNode"] = dataclasses.field(default_factory=list) - child_keys: ChildKeys | None = None - lock: "threading.Lock | None" = None - usage: int = 1 - - def __repr__(self, indent: int = 0): - lines = [f"TrieNode(value={self.value}, usage={self.usage}, children={{"] - if len(self.children) == 0: - lines[-1] = lines[-1][:-1] + "})" - else: - for i, child in enumerate(self.children): - child_key = self.child_keys.keys[i, : self.child_keys.key_lens[i]].tolist() - lines.append(f"{' ' * indent} {child_key}: {child.__repr__(indent + 2).strip()},") - lines.append(")") - return "\n".join([(" " * indent) + line for line in lines]) - - @staticmethod - def _overlap(child_keys: ChildKeys, key, key_len, pad_idx: int = EMPTY): - keys = child_keys.keys[: child_keys.num, :] - keys_hash = child_keys.keys_hash[: child_keys.num, :] - keys_hash_mask = child_keys.keys_hash_mask[: child_keys.num, :] - - # pre-filter sequences - relevant_idx = _prefilter_on_hash(key, keys, keys_hash, keys_hash_mask, pad_idx=pad_idx) - if len(relevant_idx) == 0: - return np.zeros((child_keys.num,), dtype=np.int32), np.zeros((child_keys.num,), dtype=np.int32) - keys = keys[relevant_idx, :] - - mask = np.cumsum((key == keys) | (key == pad_idx) | (keys == pad_idx), -1) == np.arange(1, key.shape[-1] + 1) - overlap = np.zeros((child_keys.num,), dtype=np.int32) - overlap[relevant_idx] = np.sum(mask, axis=-1) - return np.minimum(overlap, key_len), np.minimum(overlap, child_keys.key_lens[: child_keys.num]) - - @staticmethod - def _append_key(keys: ChildKeys | None, new_key: np.ndarray, key_len: int, pad_idx: int = EMPTY): - if keys is None: - key_hash, key_hash_mask = _hash_encode(new_key[None, :], pad_idx=pad_idx) - return ChildKeys(new_key[None, :], key_hash, key_hash_mask, np.array([key_len], dtype=np.int32), 1) - if keys.num == keys.keys.shape[0]: # need to double the keys buffer - keys.keys = _fast_pad(keys.keys, 2 * keys.num, 0, 0) - keys.key_lens = _fast_pad(keys.key_lens, 2 * keys.num, 0) - keys.keys_hash = _fast_pad(keys.keys_hash, 2 * keys.num, 0, 0) - keys.keys_hash_mask = _fast_pad(keys.keys_hash_mask, 2 * keys.num, 0, 0) - keys.keys[keys.num, :], keys.key_lens[keys.num] = new_key, key_len - keys.keys_hash[keys.num, :], keys.keys_hash_mask[keys.num, :] = _hash_encode(new_key, pad_idx=pad_idx) - keys.num += 1 - return keys - - @staticmethod - def _delete_keys(keys: ChildKeys, delete_idxs: np.ndarray): - if keys is None: - return - mask = np.ones(keys.keys.shape[0], dtype=bool) - mask[np.array(list(delete_idxs) if isinstance(delete_idxs, set) else delete_idxs, int)] = False - if np.sum(mask) == 0: - return None - num = max(keys.num - sum(1 for idx in set(delete_idxs) if idx < keys.num), 0) - return ChildKeys(*(z[mask, ...] for z in [keys.keys, keys.keys_hash, keys.keys_hash_mask, keys.key_lens]), num) - - @staticmethod - def _pad_to_multiple_of(sequence: np.ndarray, chunk_size: int, pad_idx: int = EMPTY): - sequence_pad_len = math.ceil(sequence.size / chunk_size) * chunk_size - return _fast_pad(sequence, sequence_pad_len, 0, pad_idx) - - -def insert_prefix(root: TrieNode, sequence: np.ndarray, ref_vals: list[int], *, chunk_size: int, pad_idx: int = 2**30): - if len(sequence) == 0: - return [], [], [] - sequence = np.array(sequence) - assert sequence.ndim == 1 - sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) - ns = sequence.shape[-1] // chunk_size - seq_actual_lens = [(chunk_size if i != ns - 1 else (sequence_len - (ns - 1) * chunk_size)) for i in range(ns)] - sequence_chunks = np.split(sequence, ns) - if len(ref_vals) < ns: - msg = f"Pass at least as many references as there are chunks (size={chunk_size}) in the sequence " - msg += f" (size={sequence_len}), so expected at least {ns} references, got {len(ref_vals)=} instead." - raise ValueError(msg) - visited_refs, store_refs, delete_refs = [], [], [] # which refs to retain and which to delete - - # walk the prefix cache tree - with root.lock: - node = root - for seq_idx, (seq, seq_len) in enumerate(zip(sequence_chunks, seq_actual_lens)): - if len(node.children) > 0: - left_match, right_match = TrieNode._overlap(node.child_keys, seq, seq_len, pad_idx=pad_idx) - best_idx = np.argmax(left_match) - left_match, right_match = left_match[best_idx], right_match[best_idx] - else: - left_match, right_match, best_idx = 0, 0, 2**30 # case 0: no children, add new child - if left_match != seq_len: # append new node - node.child_keys = TrieNode._append_key(node.child_keys, seq, seq_len, pad_idx=pad_idx) - node.children.append(TrieNode(int(ref_vals[seq_idx]))) - store_refs.append(int(ref_vals[seq_idx])) - node = node.children[-1] - elif right_match < left_match: # replace the node - delete_refs.append(node.children[best_idx].value) - node.children[best_idx] = TrieNode(int(ref_vals[seq_idx])) - node.child_keys.keys[best_idx, :], node.child_keys.key_lens[best_idx] = seq, seq_len - store_refs.append(int(ref_vals[seq_idx])) - node = node.children[best_idx] - else: # full match, do nothing - if best_idx > len(node.children): - break - visited_refs.append(int(node.children[best_idx].value)) - node = node.children[best_idx] - visited_refs = list(set(visited_refs) | set(store_refs)) - return visited_refs, store_refs, delete_refs - - -def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pad_idx: int = 2**30): - sequence, total_match, ref_vals = np.array(sequence), 0, [] - assert sequence.ndim == 1 - sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) - ns = sequence.shape[-1] // chunk_size - seq_actual_lens = [(chunk_size if i != ns - 1 else (sequence_len - (ns - 1) * chunk_size)) for i in range(ns)] - visited_refs = [] - - with root.lock: - node = root - for seq, seq_len in zip(np.split(sequence, ns), seq_actual_lens): - if len(node.children) == 0: # cache ran out of node - return (total_match, ref_vals), visited_refs - left_match, right_match = TrieNode._overlap(node.child_keys, seq, seq_len, pad_idx=pad_idx) - exact_match = np.minimum(left_match, right_match) - best_idx = np.argmax(exact_match) - match_length = exact_match[best_idx] - if match_length > 0: - visited_refs.append(int(node.children[best_idx].value)) - if match_length == 0: - break - node = node.children[best_idx] - total_match += int(match_length) - ref_vals.append(node.value) - if match_length != seq_len: - break - return (total_match, ref_vals), visited_refs - -def remove_prefix_nodes(node: TrieNode, refs_to_delete: Sequence[int]): - refs_to_delete, deleted_refs = set(refs_to_delete), set() - ctx = node.lock if node.lock is not None else contextlib.nullcontext() - with ctx: - for child in node.children: - deleted_refs |= remove_prefix_nodes(child, refs_to_delete) - deleted_refs |= set(child.value for child in node.children if child.value in refs_to_delete) - delete_idxs = set([i for i, child in enumerate(node.children) if child.value in refs_to_delete]) - for idx in delete_idxs: # if we're removing a full child, tell it to remove all its children first - deleted_refs |= remove_prefix_nodes(node.children[idx], [c.value for c in node.children[idx].children]) - node.child_keys = TrieNode._delete_keys(node.child_keys, delete_idxs) - node.children = [child for i, child in enumerate(node.children) if i not in delete_idxs] - return set(deleted_refs) - -######################################################################################################################## -# serving loop ######################################################################################################### -######################################################################################################################## - -next_power_of_2 = lambda x: 2 ** round(math.ceil(math.log2(x))) -like_spec = lambda z: jax.tree.map(lambda x: jax.typeof(x).sharding.spec, z) -like_shard = lambda z, mesh: jax.tree.map(lambda x: NamedSharding(mesh, jax.typeof(x).sharding.spec), z) -_make_empty = lambda x, mesh: jax.make_array_from_single_device_arrays( - x.shape, NamedSharding(mesh, jax.typeof(x).sharding.spec), [], dtype=x.dtype -) - - -@dataclasses.dataclass -class ServingConfig: - decode_steps: int = 10 - decode_batch_size: int = 16 - prefill_batch_size: int = 4 - prefix_chunk_size: int = 512 - eos_tokens: tuple[int, ...] | jax.Array = () - token_pad_idx: int = 0 - max_decode_length: int = 64 - max_ondevice_buffers: int = 100 - max_buffers: int = 256 - - -@dataclasses.dataclass -class UserRequestPrompt: - id: int - text: str - - -@dataclasses.dataclass -class DecodeResult: - id: int - token_list: list[int] - tokens_decoded: int = 0 - done: bool = False - - -@dataclasses.dataclass -class PrefillJob: - request: UserRequestPrompt - cache_entry: Any - match_len: int - - -@dataclasses.dataclass -class PrefillResult: - id: int - input: np.ndarray - next_token: jax.Array - cache_entry: Any - len: int - - -@dataclasses.dataclass -class DecodeWork: - curr_tokens: jax.Array # [B, 1] to conform with the general forward fn expecting a sequence dimension - cache: KVCache - active_results: list[DecodeResult | None] - - -@dataclasses.dataclass -class PrefillWork: - requests: list[UserRequestPrompt] - to_prefill: list[UserRequestPrompt] - to_decode: list[PrefillResult] - pending_prefill: Future | None = None - pending_cache_retrievals: list[tuple[UserRequestPrompt, Future]] = dataclasses.field(default_factory=list) - - -def return_request(resp: DecodeResult): - # an optional callback called with results available on decode nodes only - # something happens here to output the response to the global queue - # print(f"Finished request: {resp.id}") - pass - - -class SyncServer: - """A regular local network server for syncing between JAX processes in the multi-process JAX setup.""" - - CLIENT = None - TIMEOUT_SEC = 600 - - @staticmethod - def _get_client(): - if SyncServer.CLIENT is None: - SyncServer.CLIENT = distributed.global_state.client - return SyncServer.CLIENT - - @staticmethod - def barrier(key: str, current_it: int) -> None: - client = SyncServer._get_client() - if client is None: - return - client.wait_at_barrier(key + str(current_it), timeout_in_ms=SyncServer.TIMEOUT_SEC * 1000) - - @staticmethod - def broadcast(key: str, current_it: int, value: Any, is_source: bool = False, jsonify: bool = True) -> None: - client = SyncServer._get_client() - if client is None: - return value - if is_source: - client.key_value_set(key + str(current_it), json.dumps(value) if jsonify else value) - return value - else: - value = client.blocking_key_value_get(key + str(current_it), SyncServer.TIMEOUT_SEC * 1000) - return json.loads(value) if jsonify else value - -def maybe_call(fn: Callable, mesh: Mesh): - """Only call the program if the host worker is participating, get (truly) empty arrys with correct sharding.""" - mesh_devices = set(d.id for d in mesh.devices.flat) - if any(d.id in mesh_devices for d in jax.local_devices()): # host has some participating devices - return fn - return (lambda *args, **kw: jax.tree.map(partial(_make_empty, mesh=mesh), jax.eval_shape(fn, *args, **kw))) - - -def _make_multistep_decode_fn(decode_fn): - @partial(jax.jit, static_argnames=("steps",), donate_argnames=("cache",)) - def multistep_decode_fn(curr_tokens, decode_weights, cache, cfg, steps: int = 32): - def body(carry, _): - curr_tokens, cache = carry - next_tokens, cache = decode_fn(curr_tokens, decode_weights, cache, cfg) - return (next_tokens, cache), next_tokens - - (curr_tokens, cache), output_tokens = jax.lax.scan(body, (curr_tokens, cache), length=steps) - return (curr_tokens, cache), output_tokens[..., 0].T - - return multistep_decode_fn - - -class ServingLoop: - def __init__( - self, - serve_cfg: ServingConfig, - cfg: Config, - forward_fn: Callable, - prefill_weights: Weights, - prefill_cache: KVCache, - decode_weights: Weights, - decode_cache: KVCache, - is_server: bool = False, - ): - #self.init_cache = init_cache - self.prefill_cache = prefill_cache - if not SyncServer.broadcast("welcome", 0, is_server, is_server): - raise ValueError("Neither this proccess nor any other processe is the main server, at least one must.") - self.serve_cfg, self.cfg = serve_cfg, cfg - - # setup decode - self.forward_fn, self.decode_weights = forward_fn, decode_weights - self.decode_mesh = [x for x in jax.tree.leaves(decode_weights) if hasattr(x, "sharding")][0].sharding.mesh - with set_mesh(self.decode_mesh): - self.decode_work = DecodeWork(None, decode_cache, [None for _ in range(serve_cfg.decode_batch_size)]) - self.decode_work.curr_tokens = jax.device_put( - jnp.zeros((serve_cfg.decode_batch_size, 1), dtype=jnp.int32), P() - ) - self.multistep_decode_fn = _make_multistep_decode_fn(self.forward_fn) - self._update_index = jax.jit(lambda x, i, new: x.at[i, ...].set(new[:, None], mode="drop")) - - def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, kvs, batch_idxs, actual_lens): - # sort to minimize variants num - length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) - sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] - new_cache = decode_cache.insert_sequences(cache, *sorted_args) - with set_mesh(self.decode_mesh): - new_curr_tokens = self._update_index(curr_tokens, np.array(batch_idxs), np.array(new_tokens)) - return new_cache, new_curr_tokens - - self._update_cache_and_index = _update_cache_and_index - self.decode_output = (None, None) - - # setup prefill - self.prefill_weights = prefill_weights - self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh - self.prefill_work = PrefillWork([], [], []) - self._get_index = jax.jit(lambda z, idx: jax.tree.map(lambda x: x[:, idx, ...], z)) - self._get_cache_entry = jax.jit(self.decode_work.cache.get_sequence) - - # setup misc - self.pending_requests, self.state_lock, self.results = [], threading.Lock(), {} - self.pad_id, self.eos_tokens, self.time_axis = 0, np.array(serve_cfg.eos_tokens), TIME_AXIS - self._background = ThreadPoolExecutor(max_workers=1024) - - # setup profiling - self.profile_start_time, self.profiling = -1, False - - # setup cache management - self.prefix_cache, self._retrieve_prefix, self._insert_prefix = None, None, None - self.new_prefix_cache() - - # setup the sync server for multi-host - self._it, self.roles = 0, (("server",) if is_server else ()) # main server - if any(d.id in [d_.id for d_ in self.decode_mesh.devices.reshape(-1)] for d in jax.local_devices()): - self.roles += ("decode",) # any node which has decode mesh devices - if any(d.id in [d_.id for d_ in self.prefill_mesh.devices.reshape(-1)] for d in jax.local_devices()): - self.roles += ("prefill",) # any node which has prefill devices - if any(d.id == min([d_.id for d_ in self.decode_mesh.devices.reshape(-1)]) for d in jax.local_devices()): - self.roles += ("decode_coordinator",) # the decode node which holds the smallest decode mesh device - if any(d.id == min([d_.id for d_ in self.prefill_mesh.devices.reshape(-1)]) for d in jax.local_devices()): - self.roles += ("prefill_coordinator",) # the prefill node which holds the smallest prefill mesh device - self.total_requests = 0 - - def decode_step(self): - # TODO: a more intelligent decision between decode and prefill (adaptive strategies, prefill queue size) - - # 1. add outstanding ready to decode prefill result to the active decode - # - some cache entries require some computation, so they're a callable - # - some cache entries are not on the correct decode_mesh - if len(self.prefill_work.to_decode) > 0: - batch_cache_updates = [] - for i, active_result in enumerate(self.decode_work.active_results): - if active_result is not None: - continue - if len(self.prefill_work.to_decode) == 0: - break - result: PrefillResult = self.prefill_work.to_decode.pop(0) - self.decode_work.active_results[i] = DecodeResult(result.id, result.input.tolist()) - with set_mesh(self.decode_mesh): - result.cache_entry = result.cache_entry() if callable(result.cache_entry) else result.cache_entry - self.results[result.id] = self.decode_work.active_results[i] - batch_cache_updates.append((result.cache_entry, i, result.len, result.next_token)) - if len(self.prefill_work.to_decode) == 0: - break - if "decode" in self.roles and len(batch_cache_updates) > 0: # batch cache update - entries, batch_idxs, lens, next_tokens = map(list, zip(*batch_cache_updates)) - entries = [entry.result() if hasattr(entry, "result") else entry for entry in entries] # maybe collect - self.decode_work.cache, self.decode_work.curr_tokens = self._update_cache_and_index( - self.decode_work.cache, self.decode_work.curr_tokens, next_tokens, entries, batch_idxs, lens - ) - - if all(x is None for x in self.decode_work.active_results): - return # skip decoding if no decoding tasks are present - - # 2. run N decode steps - output_tokens, output_mapping = [], [] - with set_mesh(self.decode_mesh): - # config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps, participate="decode" in self.roles) - config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) - decode_fn = maybe_call(self.multistep_decode_fn, self.decode_mesh) - #decode_fn = self.multistep_decode_fn - (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = decode_fn( - self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config - ) - output_mapping = [ - [getattr(result, "id", -1) for result in self.decode_work.active_results] - ] * self.serve_cfg.decode_steps - output_mapping = np.array(output_mapping).T - print(f"Decoding with fill rate: {np.mean([result is not None for result in self.decode_work.active_results])}") - - # 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection) - self.decode_output, (output_tokens, output_mapping) = (output_tokens, output_mapping), self.decode_output - if output_tokens is not None: - SyncServer.barrier("output_tokens", self._it) - if "decode" in self.roles: - output_tokens = np.array(output_tokens) - done = np.any(output_tokens[..., None] == self.eos_tokens, (-1, -2)).tolist() # check for done - done = [ - d or getattr(result, "tokens_decoded", 0) >= self.serve_cfg.max_decode_length - for d, result in zip(done, self.decode_work.active_results) - ] - output_tokens_flat = output_tokens.reshape(-1).tolist() - output_mapping_flat = output_mapping.reshape(-1).tolist() - else: - output_tokens, done, output_tokens_flat, output_mapping_flat = None, None, None, None - output_tokens_flat, output_mapping_flat, done = SyncServer.broadcast( - "decode_output", - self._it, - (output_tokens_flat, output_mapping_flat, done), - is_source="decode_coordinator" in self.roles, - ) - for token, id in zip(output_tokens_flat, output_mapping_flat): - if id > 0: - self.results[id].token_list.append(token) - self.results[id].tokens_decoded += 1 - with set_mesh(self.decode_mesh): - for i, result in enumerate(self.decode_work.active_results): - if result is None: - continue - # 2. check for done sequences; evict them if done and return them - if done[i]: - if USE_PREFIX_CACHE: # store the results in the prefix cache buffer store - sequence = np.array(result.token_list) - cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) - ns = math.ceil(sequence.size / self.serve_cfg.prefix_chunk_size) - buffer_ids = BUFFER_STORE._get_unique_buffer_ids(ns) - visited_ids, store_ids, del_ids = self._insert_prefix(sequence, buffer_ids) - if len(store_ids) > 0: - axis = self.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) - chunked_cache_entry = _split(cache_entry, axis, self.serve_cfg.prefix_chunk_size, ns) - vals = [chunked_cache_entry[buffer_ids.index(id)] for id in store_ids] - BUFFER_STORE.store(store_ids, vals) - BUFFER_STORE.delete(del_ids) - BUFFER_STORE.mark_visited(visited_ids) - return_request(result) - result.done, self.decode_work.active_results[i] = True, None - - def prefill_step(self): - # 1. prefill requests to be prefilled (do this first to overlap with decode) - prefill_input: list[PrefillJob] = self.prefill_work.to_prefill[: self.serve_cfg.prefill_batch_size] - self.prefill_work.to_prefill = self.prefill_work.to_prefill[len(prefill_input) :] - if len(prefill_input) > 0: - prefill_texts = [job.request.text[job.match_len :] for job in prefill_input] - max_len = max([len(text) for text in prefill_texts]) - inputs = [text + [self.pad_id] * (max_len - len(text)) for text in prefill_texts] - inputs = np.stack([np.array(input) for input in inputs], 0) - row_pad = self.serve_cfg.prefill_batch_size - inputs.shape[0] - col_pad = max(next_power_of_2(inputs.shape[-1]), 64) - inputs.shape[-1] - inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) - - with set_mesh(self.prefill_mesh): - actual_cache_len = np.array(max(job.match_len for job in prefill_input), dtype=np.int32) - self.prefill_cache.iter = actual_cache_len # TODO: make this explictly cache public interface - kvs = [job.cache_entry() if job.cache_entry is not None else None for job in prefill_input] - batch_idxs = np.array([i for i, kv in enumerate(kvs) if kv is not None]) - actual_lens = np.array([job.match_len for kv, job in zip(kvs, prefill_input) if kv is not None]) - kvs = [kv for kv in kvs if kv is not None] - - if len(kvs) > 0: - # sort to minimize variants num - length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) - sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] - insert_sequences = maybe_call(self.prefill_cache.insert_sequences, self.prefill_mesh) - self.prefill_cache = insert_sequences(self.prefill_cache, *sorted_args) - cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) - forward_fn = maybe_call(self.forward_fn, self.prefill_mesh) - _, self.prefill_cache = forward_fn(inputs, self.prefill_weights, self.prefill_cache, cfg) - - with set_mesh(self.prefill_mesh): - for i, job in enumerate(prefill_input): - request = job.request - cache_entry, _ = maybe_call(self._get_cache_entry, self.prefill_mesh)(self.prefill_cache, i) - cache_entry = _ensure_all_args_on_mesh(cache_entry, self.decode_mesh) - sequence = np.array(request.text) - new_decode = PrefillResult( - request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1 - ) - self.prefill_work.to_decode.append(new_decode) - - # 2. triage requests based on whether they need to go to prefill or there's a cache match, so decode directly - while len(self.prefill_work.requests) > 0: - request = self.prefill_work.requests.pop(0) - sequence = np.array(request.text) - (total_match, buffer_ids), visited_ids = self._retrieve_prefix(sequence) - assert total_match <= sequence.size - BUFFER_STORE.mark_visited(visited_ids) - _axis = self.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) - buffers = BUFFER_STORE.load(buffer_ids) - if total_match == sequence.size: - cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.decode_mesh), _axis) - new_decode = PrefillResult(request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1) - self.prefill_work.to_decode.append(new_decode) - print(f"Found a full match") - else: - print(f"Need to prefill, only found a match for length {total_match / (len(request.text) - 1):.2%}") - print(f"That equals {len(buffer_ids)} buffers or {total_match=}") - if total_match > 0: - cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.prefill_mesh), _axis) - else: - cache_entry = None - self.prefill_work.to_prefill.append(PrefillJob(request, cache_entry, total_match)) - - def serving_step(self): - # this event loop relies on determinism for issuing computation to multiple processes (multi-process JAX) - # frequent barriers should keep it in sync - - # potentially profile when received the request to ######################################### - is_server = "server" in self.roles - should_start_profile = self.profile_start_time > 0 and not self.profiling - should_start_profile = SyncServer.broadcast("profile", self._it, should_start_profile, is_source=is_server) - if should_start_profile: - self.profile_start_time, self.profiling = time.perf_counter(), True - jax.profiler.start_trace("/tmp/online") - print("STARTING TRACE") - should_stop_profile = self.profile_start_time > 0 and time.perf_counter() - self.profile_start_time > 5.0 - should_stop_profile = SyncServer.broadcast("stop_profile", self._it, should_stop_profile, is_source=is_server) - if should_stop_profile: - self.profile_start_time, self.profiling = -1, False - print("STOPPING TRACE") - jax.profiler.stop_trace() - # potentially profile when received the request to ######################################### - - # sync on the server requests received ##################################################### - SyncServer.barrier("serving_step", self._it) - self._it, requests = self._it + 1, None - if "server" in self.roles: - with self.state_lock: - self.pending_requests, requests = [], list(self.pending_requests) - serve_cfg, requests = SyncServer.broadcast( - "requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles - ) - with self.state_lock: - self.serve_cfg = dataclasses.replace(self.serve_cfg, **serve_cfg) - for request in requests: - self.total_requests += 1 - self.prefill_work.requests.append(UserRequestPrompt(**request)) - # sync on the server requests received ##################################################### - - # main event loop work ##################################################################### - self.decode_step() - self.prefill_step() - # main event loop work ##################################################################### - - # offload buffers to keep a max of N ####################################################### - BUFFER_STORE.offload_buffers(max(0, BUFFER_STORE.livecount - self.serve_cfg.max_ondevice_buffers)) - extra_buffer_count = max(len(BUFFER_STORE.usecount) - self.serve_cfg.max_buffers, 0) - if extra_buffer_count > 0: - refs_to_delete = sorted(BUFFER_STORE.usecount.keys())[:extra_buffer_count] - deleted_buffers = remove_prefix_nodes(self.prefix_cache, refs_to_delete) - BUFFER_STORE.delete(list(deleted_buffers)) - if len(BUFFER_STORE._store) > self.serve_cfg.max_buffers: - raise ValueError() - # offload buffers to keep a max of N ####################################################### - - def add_request(self, request: UserRequestPrompt): - with self.state_lock: - self.pending_requests.append(dataclasses.asdict(request)) - - def update_params(self, params: dict[str, Any]): - with self.state_lock: - self.serve_cfg = dataclasses.replace(self.serve_cfg, **params) - - def new_prefix_cache(self): - self.prefix_cache = TrieNode(None, lock=threading.Lock()) - self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) - self._insert_prefix = partial(insert_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) +from .http_server import run_http_server +from .serving_loop import DecodeResult, ServingConfig, ServingLoop, UserRequestPrompt diff --git a/serving/serving_jax/attention_cache_utils.py b/serving/serving_jax/attention_cache_utils.py index 854be86..ba40d48 100644 --- a/serving/serving_jax/attention_cache_utils.py +++ b/serving/serving_jax/attention_cache_utils.py @@ -1,25 +1,31 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import dataclasses -from functools import partial import math +from functools import partial from typing import Any import jax import jax.numpy as jnp -QuantArray, PyTree = Any, Any +QuantArray, PyTree, KVCache, PagedKVCache = Any, Any, Any, Any -KVCache = Any next_power_of_2 = lambda x: 2 ** math.ceil(math.log2(max(x, 1))) _pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)]) -def safe_zip(*args): - if len(args) == 0: - return [] - assert all(len(arg) == len(args[0]) for arg in args) - return zip(*args) - - def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): "From a list of cache entries stacked along layer idx (in transit) to stacked along batch, layers split into list." @@ -30,7 +36,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): for i, c in enumerate(kv_list[0]): els = [[_split(z) for z in jax.tree.leaves(kv[i])] for kv in kv_list] # [B, R_flat, L] els = jax.tree.map(lambda *xs: jnp.concatenate(xs, axis=0), *els) # [R_flat, L] - leaves_list = list(safe_zip(*els)) # [L, R_flat] + leaves_list = list(zip(*els, strict=True)) # [L, R_flat] out[i] = [jax.tree.unflatten(jax.tree.structure(c), leaves) for leaves in leaves_list] # [L, R] return tuple(out), max_seq_len @@ -41,16 +47,17 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): @partial(jax.jit, donate_argnames=("cache",)) -def _kvcache_update_cache( +def _kvcache_insert_sequences( cache: KVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], update_mask: list[bool] | None = None, + erase: bool = False, ): assert len(kvs) == len(batch_idxs) == len(actual_lens) batch_idxs, actual_lens, update_mask = jnp.array(batch_idxs), jnp.array(actual_lens), jnp.array(update_mask) - uninitialized_cache = cache.iter < 0 + uninitialized_cache = jnp.logical_or(cache.iter < 0, erase) start_time = jnp.where( uninitialized_cache, jnp.max(actual_lens) - actual_lens, (cache.iter - actual_lens) % cache.size ) @@ -62,31 +69,43 @@ def _update_element(x, u): update_permute = [0, cache.time_axis] + [i for i in range(u.ndim) if i not in (0, cache.time_axis)] # time_dim, batch_dim = update_permute.pop(cache.time_axis), update_permute.pop(0) # first pop time_axis # update_permute = [batch_dim, time_dim] + update_permute - return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop") + return x.at[batch_idxs[:, None], :, time_indices, ...].set( + u.transpose(update_permute), mode="drop", out_sharding=jax.typeof(x).sharding + ) cache_kvs = jax.tree.map(_update_element, cache.buffers, kvs) - cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop") + cache_starts = cache.starts.at[batch_idxs].set( + start_time, mode="drop", out_sharding=jax.typeof(cache.starts).sharding + ) cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter) - buffer_names = [field.name for field in dataclasses.fields(cache)][:len(cache_kvs)] - return dataclasses.replace(cache, **dict(safe_zip(buffer_names, cache_kvs)), iter=cache_iter, starts=cache_starts) + buffer_names = [field.name for field in dataclasses.fields(cache)][: len(cache_kvs)] + return dataclasses.replace(cache, **dict(zip(buffer_names, cache_kvs, strict=True)), iter=cache_iter, starts=cache_starts) + + +@partial(jax.jit, donate_argnames=("cache",)) +def _maybe_erase_only(cache: KVCache, erase: bool): + return dataclasses.replace(cache, iter=jnp.where(erase, -1, cache.iter)) -def kvcache_update_cache( +def kvcache_insert_sequences( cache: KVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], + erase: bool = False, ): + if len(kvs) == 0: + return _maybe_erase_only(cache, erase) pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] kvs = kvs + [kvs[-1]] * pad_len batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len - return _kvcache_update_cache(cache, kvs, batch_idxs, actual_lens, update_mask) + return _kvcache_insert_sequences(cache, kvs, batch_idxs, actual_lens, update_mask, erase=erase) @jax.jit -def kvcache_get_entry(cache: KVCache, batch_idx: jax.Array): +def kvcache_get_sequence(cache: KVCache, batch_idx: jax.Array): shift = -cache.starts[batch_idx] assert cache.time_axis > 0 kvs = jax.tree.map(lambda x: jnp.roll(x[batch_idx, ...], shift=shift, axis=cache.time_axis - 1), cache.buffers) @@ -114,7 +133,7 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | @partial(jax.jit, donate_argnames=("cache",)) -def _batch_paged_update_sequences( +def _paged_kvcache_insert_sequences( cache: PagedKVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], @@ -161,27 +180,29 @@ def _update_element(x, u): new_free_pages = new_free_pages.at[pages_idx.reshape(-1)].set(False, mode="drop") new_lengths = cache.lengths.at[batch_idxs].set(actual_lens, mode="drop") - named_buffers = dict(zip([field.name for field in dataclasses.fields(cache)][:len(new_buffers)], new_buffers)) + named_buffers = dict(zip([field.name for field in dataclasses.fields(cache)][: len(new_buffers)], new_buffers)) return dataclasses.replace( cache, **named_buffers, lengths=new_lengths, block_tables=new_block_tables, free_pages=new_free_pages ) -def batch_paged_update_sequences( +def paged_kvcache_insert_sequences( cache: KVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], actual_lens: list[jax.Array], + erase: bool = False, ): + del erase # inapplicable pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] kvs = kvs + [kvs[-1]] * pad_len batch_idxs, actual_lens = batch_idxs + [batch_idxs[-1]] * pad_len, actual_lens + [actual_lens[-1]] * pad_len - return _batch_paged_update_sequences(cache, kvs, batch_idxs, actual_lens, update_mask) + return _paged_kvcache_insert_sequences(cache, kvs, batch_idxs, actual_lens, update_mask) @partial(jax.jit, static_argnames=("max_seq_len",)) -def batch_paged_get_entry(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len: int = -1): +def paged_kvcache_get_sequence(cache: PagedKVCache, batch_idx: jax.Array, max_seq_len: int = -1): true_len = cache.fill_len()[batch_idx] max_seq_len = max_seq_len if max_seq_len > 0 else cache.page_size * cache.block_tables.shape[-1] max_seq_len = min(max_seq_len, cache.page_size * cache.block_tables.shape[-1]) # cache capacity diff --git a/serving/serving_jax/cross_host.py b/serving/serving_jax/cross_host.py index c9ee95e..498ad7c 100644 --- a/serving/serving_jax/cross_host.py +++ b/serving/serving_jax/cross_host.py @@ -5,10 +5,11 @@ import jax import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import numpy as np +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P -#jax.config.update("jax_enable_empty_arrays", True) +# jax.config.update("jax_enable_empty_arrays", True) PyTree = Any diff --git a/serving/serving_jax/http_server.py b/serving/serving_jax/http_server.py new file mode 100644 index 0000000..d4622ae --- /dev/null +++ b/serving/serving_jax/http_server.py @@ -0,0 +1,115 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Callable + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import Response, StreamingResponse +from pydantic import BaseModel + +from serving_jax import serving_loop + + +class GenerateRequest(BaseModel): + id: int + text: str + + +def run_http_server( + serve_loop: serving_loop.ServingLoop, + tokenizer_encode: Callable[[str], list[int]], + tokenizer_decode: Callable[[list[int]], str], + is_server: bool = False, + shutdown_signal: threading.Event | None = None, +) -> None: + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + if shutdown_signal is not None: + shutdown_signal.set() + + APP = FastAPI(lifespan=lifespan) + + async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: + if id in serve_loop.results: # delete previous request if it exists + del serve_loop.results[id] + + input = tokenizer_encode(text) + iter = len(input) # iterator for finding our current place in a append-only output text + serve_loop.add_request(serving_loop.UserRequestPrompt(id, input)) + while id not in serve_loop.results: # wait for the request to be prefilled + await asyncio.sleep(0.1) + try: + result_ref: serving_loop.DecodeResult = serve_loop.results[id] + while not result_ref.done: # return text to the client as it becomes available + if await request.is_disconnected(): # Check if client disconnected + print("Client disconnected.") + break + if len(result_ref.token_list) > iter: + new_segment, iter = tokenizer_decode(result_ref.token_list[iter:]), len(result_ref.token_list) + yield f"{new_segment}" + await asyncio.sleep(0.1) + + # return the final piece of generate text to the client + if len(result_ref.token_list) > iter: + new_segment, iter = tokenizer_decode(result_ref.token_list[iter:]), len(result_ref.token_list) + yield f"{new_segment}" + except asyncio.CancelledError: + pass + + @APP.get("/stream") + async def stream_response(params: GenerateRequest, request: Request): + return StreamingResponse(generate_generator(params.id, params.text, request), media_type="text/event-stream") + + @APP.get("/generate") + async def generate(id: int, text: str): # generate without output + print(f"Input text: {text}") + serve_loop.add_request(serving_loop.UserRequestPrompt(id, tokenizer_encode(text))) + return Response("OK") + + @APP.get("/retrieve") + async def retrieve(id: int): + if id in serve_loop.results: + return Response(tokenizer_decode(serve_loop.results[id].token_list)) + return Response("NO TEXT") + + @APP.get("/set_generation_length") + async def set_generation_length(length: int): + serve_loop.serve_cfg.max_decode_length = max(length, 32) + return Response("OK") + + @APP.get("/profile") + async def profile(request: Request): + del request + serve_loop.profile_start_time = time.perf_counter() + return Response("OK") + + @APP.get("/") + async def root(): + return {"message": "Welcome! Try the /stream-text endpoint."} + + if is_server: + uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) + else: + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + if shutdown_signal is not None: + shutdown_signal.set() diff --git a/serving/serving_jax/serving_loop.py b/serving/serving_jax/serving_loop.py new file mode 100644 index 0000000..46d4765 --- /dev/null +++ b/serving/serving_jax/serving_loop.py @@ -0,0 +1,814 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import json +import math +import threading +import time +from concurrent.futures import Future, ThreadPoolExecutor +from functools import partial +from typing import Any, Callable, Sequence + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, set_mesh +from jax.sharding import PartitionSpec as P + +try: + from jax.experimental.shard import auto_axes +except ModuleNotFoundError: + from jax.sharding import auto_axes +import numpy as np +from jax._src import distributed +from jax._src.lib import xla_client as xc + +from .cross_host import transfer_tree_A2B + +KVCache, Weights, Config = Any, Any, Any +PyTree, PyTreeStruct = Any, Any + +is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) + +######################################################################################################################## +# device put for cross-process/hosts transfers ######################################################################### +######################################################################################################################## + + +def unsafe_device_put(xs: PyTree, spec: PyTree, dest_mesh: Mesh): + """Fastest, but local single-process JAX only for now.""" + xs_flat, xs_struct = jax.tree.flatten(xs) + shardings_list = [NamedSharding(dest_mesh, s) for s in jax.tree.leaves(spec)] + devices_list = [s._internal_device_list for s in shardings_list] + copy_semantics = [xc.ArrayCopySemantics.ALWAYS_COPY] * len(devices_list) + out = xc.batched_copy_array_to_devices_with_sharding(xs_flat, devices_list, shardings_list, copy_semantics) + return jax.tree.unflatten(xs_struct, out) + + +def jax_device_put(xs: PyTree, sharding: PyTree): + """Async, available in future JAX.""" + is_source = len(getattr(jax.tree.leaves(xs)[0], "addressable_shards", [])) > 0 + if is_source: + return jax.device_put(xs, sharding) + else: + empty_arrays = jax.tree.map( + lambda x: jax.make_array_from_single_device_arrays(x.shape, x.sharding, [], dtype=x.dtype), xs + ) + return jax.device_put(empty_arrays, sharding) + + +def jit_device_put(xs: PyTree, sharding: PyTree): + """Most compatabile, uses jit, so requires blocking dispatch.""" + # jax.sharding.set_mesh(None) # not compatible with context mesh + meshA, meshB = jax.tree.leaves(xs)[0].sharding.mesh, jax.tree.leaves(sharding)[0].mesh + return transfer_tree_A2B(xs, meshA, meshB) + + +# device_put = jit_device_put # the most compatible options currently, but NOT async, need +device_put = jax.device_put + + +def _ensure_all_args_on_mesh(args, mesh: Mesh): + if not all(jax.tree.leaves(arg)[0].sharding.mesh == mesh for arg in args): + _correct_mesh = lambda value: jax.tree.leaves(value)[0].sharding.mesh == mesh + _args = {i: arg for i, arg in enumerate(args) if not _correct_mesh(arg)} + if len(_args) > 0: + args = dict(enumerate(args)) | device_put(_args, like_shard(_args, mesh)) + args = tuple(args[i] for i in range(len(args))) + return args + + +######################################################################################################################## +# kv cache buffer management ########################################################################################### +######################################################################################################################## + + +@partial(jax.jit, static_argnames=("axis", "chunk_size", "ns")) +def _split(val: jax.Array | list[jax.Array], axis: int, chunk_size: int, ns: int) -> list[jax.Array]: + def _fn(val): + axis_ = axis % val.ndim + size = val.shape[axis_] + if size < chunk_size * ns: + min_len = chunk_size * ns + val = jnp.pad(val, [(0, 0) if i != axis_ else (0, min_len - val.shape[axis_]) for i in range(val.ndim)]) + index = [slice(None) if i != axis_ else slice(0, ns * chunk_size) for i in range(val.ndim)] + return jnp.split(val[*index], ns, axis=axis_)[:ns] + + val_leaves, val_structure = jax.tree.flatten(val) + spec = [[x] * ns for x in like_spec(val_leaves)] + split_leaves = auto_axes(lambda vals: [_fn(val) for val in vals], out_sharding=spec)(val_leaves) + return [jax.tree.unflatten(val_structure, [x[i] for x in split_leaves]) for i in range(ns)] + + +@partial(jax.jit, static_argnames=("split_axis",)) +def _concat(values, split_axis: int): + _fn = lambda vals: jax.tree.map(lambda *args: jnp.concatenate(args, axis=split_axis), *vals) + return auto_axes(_fn, out_sharding=like_spec(values[0]))(values) + + +class KVBufferStore: + def __init__(self): + self.usecount, self.ondevice, self._store, self.unique_id, self.livecount = {}, {}, {}, 18, 0 + + def _get_unique_buffer_ids(self, n: int): + ids = list(range(self.unique_id, self.unique_id + n)) + self.unique_id += n + return ids + + def offload_buffers(self, how_many: int): + if how_many == 0: + return + candidates = sorted(self._store.keys(), key=lambda i: self.usecount[i] if self.ondevice[i] else 2**60) + for i in candidates[:how_many]: + if self.ondevice[i]: + host_shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("pinned_host"), self._store[i]) + self._store[i] = jax.device_put(self._store[i], host_shrd) + self.ondevice[i] = False + self.livecount -= 1 + + def load(self, id: int): + if isinstance(id, (tuple, list)): + return [self.load(i) for i in id] + if self.ondevice[id]: + return self._store[id] + self.ondevice[id] = True + self.livecount += 1 + device_shrd = jax.tree.map(lambda x: x.sharding.with_memory_kind("device"), self._store[id]) + self._store[id] = jax.device_put(self._store[id], device_shrd) + return self._store[id] + + def delete(self, id: int): + if isinstance(id, (list, tuple)): + return [self.delete(i) for i in id] + self.livecount -= self.ondevice[id] + del self.usecount[id], self.ondevice[id], self._store[id] + + def store(self, id: int, val: Any): + if isinstance(id, (tuple, list)): + return [self.store(i, v) for i, v in zip(id, val)] + self.livecount += 1 + self.usecount[id], self.ondevice[id], self._store[id] = 1, True, val + + def mark_visited(self, id: int): + if isinstance(id, (list, tuple)): + return [self.mark_visited(i) for i in id] + self.usecount[id] += 1 + + +BUFFER_STORE = KVBufferStore() + +######################################################################################################################## +# trie utils ########################################################################################################### +######################################################################################################################## + +EMPTY, HASH_BITWIDTH = -1, 1 + + +@dataclasses.dataclass +class ChildKeys: + keys: np.ndarray + keys_hash: np.ndarray + keys_hash_mask: np.ndarray + key_lens: np.ndarray + num: int = 0 + + +def _hash_encode(v: np.ndarray, hash_bitwidth: int = HASH_BITWIDTH, pad_idx: int = EMPTY): + v, last_dim = v.astype(np.int64), min(64 // hash_bitwidth, v.shape[-1]) + v_, el_mask = v.reshape(v.shape[:-1] + (-1, last_dim)), (1 << hash_bitwidth) - 1 + mask = np.bitwise_or.reduce(((v_ != pad_idx) * el_mask) << (hash_bitwidth * np.arange(v_.shape[-1])), axis=-1) + h = np.bitwise_or.reduce((v_ & el_mask) << (hash_bitwidth * np.arange(v_.shape[-1])), axis=-1) + return h, mask + + +def _prefilter_on_hash( + w: np.ndarray, + keys: np.ndarray, + vh: np.ndarray, + vm: np.ndarray, + hash_bitwidth: int = HASH_BITWIDTH, + pad_idx: int = EMPTY, +): + wh, wm = _hash_encode(w, hash_bitwidth=hash_bitwidth, pad_idx=pad_idx) + inv_match = (wh ^ vh) & vm & wm + # count full hash chunk matches, but don't miss sequences not matching at least one full hash + match_len = np.sum(np.cumsum(inv_match, axis=-1) == 0, axis=-1) + (w[0] == keys[:, 0]) + max_match_len = max(np.max(match_len), 1) + return np.where(match_len == max_match_len)[0] + + +def _fast_pad(x, size, axis, pad_val=0): + new_buf = pad_val * np.ones([size - s if i == axis else s for i, s in enumerate(x.shape)], dtype=x.dtype) + return np.concat([x, new_buf], axis) + + +@dataclasses.dataclass +class TrieNode: + value: int + children: list["TrieNode"] = dataclasses.field(default_factory=list) + child_keys: ChildKeys | None = None + lock: "threading.Lock | None" = None + usage: int = 1 + + def __repr__(self, indent: int = 0): + lines = [f"TrieNode(value={self.value}, usage={self.usage}, children={{"] + if len(self.children) == 0: + lines[-1] = lines[-1][:-1] + "})" + else: + for i, child in enumerate(self.children): + child_key = self.child_keys.keys[i, : self.child_keys.key_lens[i]].tolist() + lines.append(f"{' ' * indent} {child_key}: {child.__repr__(indent + 2).strip()},") + lines.append(")") + return "\n".join([(" " * indent) + line for line in lines]) + + @staticmethod + def _overlap(child_keys: ChildKeys, key, key_len, pad_idx: int = EMPTY): + keys = child_keys.keys[: child_keys.num, :] + keys_hash = child_keys.keys_hash[: child_keys.num, :] + keys_hash_mask = child_keys.keys_hash_mask[: child_keys.num, :] + + # pre-filter sequences + relevant_idx = _prefilter_on_hash(key, keys, keys_hash, keys_hash_mask, pad_idx=pad_idx) + if len(relevant_idx) == 0: + return np.zeros((child_keys.num,), dtype=np.int32), np.zeros((child_keys.num,), dtype=np.int32) + keys = keys[relevant_idx, :] + + mask = np.cumsum((key == keys) | (key == pad_idx) | (keys == pad_idx), -1) == np.arange(1, key.shape[-1] + 1) + overlap = np.zeros((child_keys.num,), dtype=np.int32) + overlap[relevant_idx] = np.sum(mask, axis=-1) + return np.minimum(overlap, key_len), np.minimum(overlap, child_keys.key_lens[: child_keys.num]) + + @staticmethod + def _append_key(keys: ChildKeys | None, new_key: np.ndarray, key_len: int, pad_idx: int = EMPTY): + if keys is None: + key_hash, key_hash_mask = _hash_encode(new_key[None, :], pad_idx=pad_idx) + return ChildKeys(new_key[None, :], key_hash, key_hash_mask, np.array([key_len], dtype=np.int32), 1) + if keys.num == keys.keys.shape[0]: # need to double the keys buffer + keys.keys = _fast_pad(keys.keys, 2 * keys.num, 0, 0) + keys.key_lens = _fast_pad(keys.key_lens, 2 * keys.num, 0) + keys.keys_hash = _fast_pad(keys.keys_hash, 2 * keys.num, 0, 0) + keys.keys_hash_mask = _fast_pad(keys.keys_hash_mask, 2 * keys.num, 0, 0) + keys.keys[keys.num, :], keys.key_lens[keys.num] = new_key, key_len + keys.keys_hash[keys.num, :], keys.keys_hash_mask[keys.num, :] = _hash_encode(new_key, pad_idx=pad_idx) + keys.num += 1 + return keys + + @staticmethod + def _delete_keys(keys: ChildKeys, delete_idxs: np.ndarray): + if keys is None: + return + mask = np.ones(keys.keys.shape[0], dtype=bool) + mask[np.array(list(delete_idxs) if isinstance(delete_idxs, set) else delete_idxs, int)] = False + if np.sum(mask) == 0: + return None + num = max(keys.num - sum(1 for idx in set(delete_idxs) if idx < keys.num), 0) + return ChildKeys(*(z[mask, ...] for z in [keys.keys, keys.keys_hash, keys.keys_hash_mask, keys.key_lens]), num) + + @staticmethod + def _pad_to_multiple_of(sequence: np.ndarray, chunk_size: int, pad_idx: int = EMPTY): + sequence_pad_len = math.ceil(sequence.size / chunk_size) * chunk_size + return _fast_pad(sequence, sequence_pad_len, 0, pad_idx) + + +def insert_prefix(root: TrieNode, sequence: np.ndarray, ref_vals: list[int], *, chunk_size: int, pad_idx: int = 2**30): + if len(sequence) == 0: + return [], [], [] + sequence = np.array(sequence) + assert sequence.ndim == 1 + sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) + ns = sequence.shape[-1] // chunk_size + seq_actual_lens = [(chunk_size if i != ns - 1 else (sequence_len - (ns - 1) * chunk_size)) for i in range(ns)] + sequence_chunks = np.split(sequence, ns) + if len(ref_vals) < ns: + msg = f"Pass at least as many references as there are chunks (size={chunk_size}) in the sequence " + msg += f" (size={sequence_len}), so expected at least {ns} references, got {len(ref_vals)=} instead." + raise ValueError(msg) + visited_refs, store_refs, delete_refs = [], [], [] # which refs to retain and which to delete + + # walk the prefix cache tree + with root.lock: + node = root + for seq_idx, (seq, seq_len) in enumerate(zip(sequence_chunks, seq_actual_lens)): + if len(node.children) > 0: + left_match, right_match = TrieNode._overlap(node.child_keys, seq, seq_len, pad_idx=pad_idx) + best_idx = np.argmax(left_match) + left_match, right_match = left_match[best_idx], right_match[best_idx] + else: + left_match, right_match, best_idx = 0, 0, 2**30 # case 0: no children, add new child + if left_match != seq_len: # append new node + node.child_keys = TrieNode._append_key(node.child_keys, seq, seq_len, pad_idx=pad_idx) + node.children.append(TrieNode(int(ref_vals[seq_idx]))) + store_refs.append(int(ref_vals[seq_idx])) + node = node.children[-1] + elif right_match < left_match: # replace the node + delete_refs.append(node.children[best_idx].value) + node.children[best_idx] = TrieNode(int(ref_vals[seq_idx])) + node.child_keys.keys[best_idx, :], node.child_keys.key_lens[best_idx] = seq, seq_len + store_refs.append(int(ref_vals[seq_idx])) + node = node.children[best_idx] + else: # full match, do nothing + if best_idx > len(node.children): + break + visited_refs.append(int(node.children[best_idx].value)) + node = node.children[best_idx] + visited_refs = list(set(visited_refs) | set(store_refs)) + return visited_refs, store_refs, delete_refs + + +def retrieve_prefix(root: TrieNode, sequence: np.ndarray, *, chunk_size: int, pad_idx: int = 2**30): + sequence, total_match, ref_vals = np.array(sequence), 0, [] + assert sequence.ndim == 1 + sequence_len, sequence = sequence.size, TrieNode._pad_to_multiple_of(sequence, chunk_size, pad_idx=pad_idx) + ns = sequence.shape[-1] // chunk_size + seq_actual_lens = [(chunk_size if i != ns - 1 else (sequence_len - (ns - 1) * chunk_size)) for i in range(ns)] + visited_refs = [] + + with root.lock: + node = root + for seq, seq_len in zip(np.split(sequence, ns), seq_actual_lens): + if len(node.children) == 0: # cache ran out of node + return (total_match, ref_vals), visited_refs + left_match, right_match = TrieNode._overlap(node.child_keys, seq, seq_len, pad_idx=pad_idx) + exact_match = np.minimum(left_match, right_match) + best_idx = np.argmax(exact_match) + match_length = exact_match[best_idx] + if match_length > 0: + visited_refs.append(int(node.children[best_idx].value)) + if match_length == 0: + break + node = node.children[best_idx] + total_match += int(match_length) + ref_vals.append(node.value) + if match_length != seq_len: + break + return (total_match, ref_vals), visited_refs + + +def remove_prefix_nodes(node: TrieNode, refs_to_delete: Sequence[int]): + refs_to_delete, deleted_refs = set(refs_to_delete), set() + ctx = node.lock if node.lock is not None else contextlib.nullcontext() + with ctx: + for child in node.children: + deleted_refs |= remove_prefix_nodes(child, refs_to_delete) + deleted_refs |= set(child.value for child in node.children if child.value in refs_to_delete) + delete_idxs = set([i for i, child in enumerate(node.children) if child.value in refs_to_delete]) + for idx in delete_idxs: # if we're removing a full child, tell it to remove all its children first + deleted_refs |= remove_prefix_nodes(node.children[idx], [c.value for c in node.children[idx].children]) + node.child_keys = TrieNode._delete_keys(node.child_keys, delete_idxs) + node.children = [child for i, child in enumerate(node.children) if i not in delete_idxs] + return set(deleted_refs) + + +######################################################################################################################## +# worker sync server ################################################################################################### +######################################################################################################################## + + +class SyncServer: + """A regular local network server for syncing between JAX processes in the multi-process JAX setup.""" + + CLIENT = None + TIMEOUT_SEC = 600 + + @staticmethod + def _get_client(): + if SyncServer.CLIENT is None: + SyncServer.CLIENT = distributed.global_state.client + return SyncServer.CLIENT + + @staticmethod + def barrier(key: str, current_it: int) -> None: + client = SyncServer._get_client() + if client is None: + return + client.wait_at_barrier(key + str(current_it), timeout_in_ms=SyncServer.TIMEOUT_SEC * 1000) + + @staticmethod + def broadcast(key: str, current_it: int, value: Any, is_source: bool = False, jsonify: bool = True) -> None: + client = SyncServer._get_client() + if client is None: + return value + if is_source: + client.key_value_set(key + str(current_it), json.dumps(value) if jsonify else value) + return value + else: + value = client.blocking_key_value_get(key + str(current_it), SyncServer.TIMEOUT_SEC * 1000) + return json.loads(value) if jsonify else value + + +######################################################################################################################## +# serving data structures ############################################################################################## +######################################################################################################################## + + +@dataclasses.dataclass +class ServingConfig: + decode_steps: int = 10 + decode_batch_size: int = 16 + prefill_batch_size: int = 4 + prefix_chunk_size: int = 512 + eos_tokens: tuple[int, ...] | jax.Array = () + token_pad_idx: int = 0 + max_decode_length: int = 64 + max_ondevice_buffers: int = 100 + max_buffers: int = 256 + use_prefix_cache: bool = True + time_axis: int = 2 + + +@dataclasses.dataclass +class UserRequestPrompt: + id: int + text: str + + +@dataclasses.dataclass +class DecodeResult: + id: int + token_list: list[int] + tokens_decoded: int = 0 + done: bool = False + + +@dataclasses.dataclass +class PrefillJob: + request: UserRequestPrompt + cache_entry: Any + match_len: int + + +@dataclasses.dataclass +class PrefillResult: + id: int + input: np.ndarray + next_token: jax.Array + cache_entry: Any + len: int + + +@dataclasses.dataclass +class DecodeWork: + curr_tokens: jax.Array # [B, 1] to conform with the general forward fn expecting a sequence dimension + cache: KVCache + active_results: list[DecodeResult | None] + + +@dataclasses.dataclass +class PrefillWork: + requests: list[UserRequestPrompt] + to_prefill: list[UserRequestPrompt] + to_decode: list[PrefillResult] + pending_prefill: Future | None = None + pending_cache_retrievals: list[tuple[UserRequestPrompt, Future]] = dataclasses.field(default_factory=list) + + +def return_request(resp: DecodeResult): + # an optional callback called with results available on decode nodes only + # something happens here to output the response to the global queue + # print(f"Finished request: {resp.id}") + pass + + +######################################################################################################################## +# serving utilities #################################################################################################### +######################################################################################################################## + +next_power_of_2 = lambda x: 2 ** round(math.ceil(math.log2(x))) +like_spec = lambda z: jax.tree.map(lambda x: jax.typeof(x).sharding.spec, z) +like_shard = lambda z, mesh: jax.tree.map(lambda x: NamedSharding(mesh, jax.typeof(x).sharding.spec), z) +_make_empty = lambda x, mesh: jax.make_array_from_single_device_arrays( + x.shape, NamedSharding(mesh, jax.typeof(x).sharding.spec), [], dtype=x.dtype +) + + +def maybe_call(fn: Callable, mesh: Mesh): + """Only call the program if the host worker is participating, get (truly) empty arrys with correct sharding.""" + mesh_devices = set(d.id for d in mesh.devices.flat) + if any(d.id in mesh_devices for d in jax.local_devices()): # host has some participating devices + return fn + return lambda *args, **kw: jax.tree.map(partial(_make_empty, mesh=mesh), jax.eval_shape(fn, *args, **kw)) + + +def _make_multistep_decode_fn(decode_fn): + @partial(jax.jit, static_argnames=("steps",), donate_argnames=("cache",)) + def multistep_decode_fn(curr_tokens, decode_weights, cache, cfg, steps: int = 32): + def body(carry, _): + curr_tokens, cache = carry + next_tokens, cache = decode_fn(curr_tokens, decode_weights, cache, cfg) + return (next_tokens, cache), next_tokens + + (curr_tokens, cache), output_tokens = jax.lax.scan(body, (curr_tokens, cache), length=steps) + return (curr_tokens, cache), output_tokens[..., 0].T + + return multistep_decode_fn + + +######################################################################################################################## +# serving ############################################################################################################## +######################################################################################################################## + + +class ServingLoop: + def __init__( + self, + serve_cfg: ServingConfig, + cfg: Config, + forward_fn: Callable, + prefill_weights: Weights, + prefill_cache: KVCache, + decode_weights: Weights, + decode_cache: KVCache, + is_server: bool = False, + ): + # self.init_cache = init_cache + self.prefill_cache = prefill_cache + if not SyncServer.broadcast("welcome", 0, is_server, is_server): + raise ValueError("Neither this proccess nor any other processe is the main server, at least one must.") + self.serve_cfg, self.cfg = serve_cfg, cfg + + # setup decode + self.forward_fn, self.decode_weights = forward_fn, decode_weights + self.decode_mesh = [x for x in jax.tree.leaves(decode_weights) if hasattr(x, "sharding")][0].sharding.mesh + with set_mesh(self.decode_mesh): + self.decode_work = DecodeWork(None, decode_cache, [None for _ in range(serve_cfg.decode_batch_size)]) + self.decode_work.curr_tokens = jax.device_put( + jnp.zeros((serve_cfg.decode_batch_size, 1), dtype=jnp.int32), P() + ) + self.multistep_decode_fn = _make_multistep_decode_fn(self.forward_fn) + self._update_index = jax.jit( + lambda x, i, new: x.at[i, ...].set(new[:, None], mode="drop", out_sharding=jax.typeof(x).sharding) + ) + + def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, kvs, batch_idxs, actual_lens): + # sort to minimize variants num + length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) + sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] + new_cache = decode_cache.insert_sequences(cache, *sorted_args) + with set_mesh(self.decode_mesh): + new_curr_tokens = self._update_index(curr_tokens, np.array(batch_idxs), np.array(new_tokens)) + return new_cache, new_curr_tokens + + self._update_cache_and_index = _update_cache_and_index + self.decode_output = (None, None) + + # setup prefill + self.prefill_weights = prefill_weights + self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh + self.prefill_work = PrefillWork([], [], []) + self._get_index = jax.jit(lambda z, idx: jax.tree.map(lambda x: x[:, idx, ...], z)) + self._get_cache_entry = jax.jit(self.decode_work.cache.get_sequence) + + # setup misc + self.pending_requests, self.state_lock, self.results = [], threading.Lock(), {} + self.pad_id, self.eos_tokens = 0, np.array(serve_cfg.eos_tokens) + self._background = ThreadPoolExecutor(max_workers=1024) + + # setup profiling + self.profile_start_time, self.profiling = -1, False + + # setup cache management + self.prefix_cache, self._retrieve_prefix, self._insert_prefix = None, None, None + self.new_prefix_cache() + + # setup the sync server for multi-host + self._it, self.roles = 0, (("server",) if is_server else ()) # main server + if any(d.id in [d_.id for d_ in self.decode_mesh.devices.reshape(-1)] for d in jax.local_devices()): + self.roles += ("decode",) # any node which has decode mesh devices + if any(d.id in [d_.id for d_ in self.prefill_mesh.devices.reshape(-1)] for d in jax.local_devices()): + self.roles += ("prefill",) # any node which has prefill devices + if any(d.id == min([d_.id for d_ in self.decode_mesh.devices.reshape(-1)]) for d in jax.local_devices()): + self.roles += ("decode_coordinator",) # the decode node which holds the smallest decode mesh device + if any(d.id == min([d_.id for d_ in self.prefill_mesh.devices.reshape(-1)]) for d in jax.local_devices()): + self.roles += ("prefill_coordinator",) # the prefill node which holds the smallest prefill mesh device + self.total_requests = 0 + + def decode_step(self): + # TODO: a more intelligent decision between decode and prefill (adaptive strategies, prefill queue size) + + # 1. add outstanding ready to decode prefill result to the active decode + # - some cache entries require some computation, so they're a callable + # - some cache entries are not on the correct decode_mesh + if len(self.prefill_work.to_decode) > 0: + batch_cache_updates = [] + for i, active_result in enumerate(self.decode_work.active_results): + if active_result is not None: + continue + if len(self.prefill_work.to_decode) == 0: + break + result: PrefillResult = self.prefill_work.to_decode.pop(0) + self.decode_work.active_results[i] = DecodeResult(result.id, result.input.tolist()) + with set_mesh(self.decode_mesh): + result.cache_entry = result.cache_entry() if callable(result.cache_entry) else result.cache_entry + self.results[result.id] = self.decode_work.active_results[i] + batch_cache_updates.append((result.cache_entry, i, result.len, result.next_token)) + if len(self.prefill_work.to_decode) == 0: + break + if "decode" in self.roles and len(batch_cache_updates) > 0: # batch cache update + entries, batch_idxs, lens, next_tokens = map(list, zip(*batch_cache_updates)) + entries = [entry.result() if hasattr(entry, "result") else entry for entry in entries] # maybe collect + self.decode_work.cache, self.decode_work.curr_tokens = self._update_cache_and_index( + self.decode_work.cache, self.decode_work.curr_tokens, next_tokens, entries, batch_idxs, lens + ) + + if all(x is None for x in self.decode_work.active_results): + return # skip decoding if no decoding tasks are present + + # 2. run N decode steps + output_tokens, output_mapping = [], [] + with set_mesh(self.decode_mesh): + config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) + decode_fn = maybe_call(self.multistep_decode_fn, self.decode_mesh) + (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = decode_fn( + self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config + ) + output_mapping = [ + [getattr(result, "id", -1) for result in self.decode_work.active_results] + ] * self.serve_cfg.decode_steps + output_mapping = np.array(output_mapping).T + print(f"Decoding with fill rate: {np.mean([result is not None for result in self.decode_work.active_results])}") + + # 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection) + self.decode_output, (output_tokens, output_mapping) = (output_tokens, output_mapping), self.decode_output + if output_tokens is not None: + SyncServer.barrier("output_tokens", self._it) + if "decode" in self.roles: + output_tokens = np.array(output_tokens) + done = np.any(output_tokens[..., None] == self.eos_tokens, (-1, -2)).tolist() # check for done + done = [ + d or getattr(result, "tokens_decoded", 0) >= self.serve_cfg.max_decode_length + for d, result in zip(done, self.decode_work.active_results) + ] + output_tokens_flat = output_tokens.reshape(-1).tolist() + output_mapping_flat = output_mapping.reshape(-1).tolist() + else: + output_tokens, done, output_tokens_flat, output_mapping_flat = None, None, None, None + output_tokens_flat, output_mapping_flat, done = SyncServer.broadcast( + "decode_output", + self._it, + (output_tokens_flat, output_mapping_flat, done), + is_source="decode_coordinator" in self.roles, + ) + for token, id in zip(output_tokens_flat, output_mapping_flat): + if id > 0: + self.results[id].token_list.append(token) + self.results[id].tokens_decoded += 1 + with set_mesh(self.decode_mesh): + for i, result in enumerate(self.decode_work.active_results): + if result is None: + continue + # 2. check for done sequences; evict them if done and return them + if done[i]: + return_request(result) + result.done, self.decode_work.active_results[i] = True, None + if self.serve_cfg.use_prefix_cache: # store the results in the prefix cache buffer store + sequence = np.array(result.token_list) + cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) + ns = math.ceil(sequence.size / self.serve_cfg.prefix_chunk_size) + buffer_ids = BUFFER_STORE._get_unique_buffer_ids(ns) + visited_ids, store_ids, del_ids = self._insert_prefix(sequence, buffer_ids) + if len(store_ids) > 0: + axis = self.serve_cfg.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) + chunked_cache_entry = _split(cache_entry, axis, self.serve_cfg.prefix_chunk_size, ns) + vals = [chunked_cache_entry[buffer_ids.index(id)] for id in store_ids] + BUFFER_STORE.store(store_ids, vals) + BUFFER_STORE.delete(del_ids) + BUFFER_STORE.mark_visited(visited_ids) + + def prefill_step(self): + # 1. prefill requests to be prefilled (do this before triage to overlap host work) + prefill_input: list[PrefillJob] = self.prefill_work.to_prefill[: self.serve_cfg.prefill_batch_size] + self.prefill_work.to_prefill = self.prefill_work.to_prefill[len(prefill_input) :] + if len(prefill_input) > 0: + prefill_texts = [job.request.text[job.match_len :] for job in prefill_input] + max_len = max([len(text) for text in prefill_texts]) + inputs = [text + [self.pad_id] * (max_len - len(text)) for text in prefill_texts] + inputs = np.stack([np.array(input) for input in inputs], 0) + row_pad = self.serve_cfg.prefill_batch_size - inputs.shape[0] + col_pad = max(next_power_of_2(inputs.shape[-1]), 64) - inputs.shape[-1] + inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) + + with set_mesh(self.prefill_mesh): + #actual_cache_len = np.array(max(job.match_len for job in prefill_input), dtype=np.int32) + #self.prefill_cache.iter = actual_cache_len # TODO: make this explictly cache public interface + kvs = [job.cache_entry() if job.cache_entry is not None else None for job in prefill_input] + batch_idxs = np.array([i for i, kv in enumerate(kvs) if kv is not None]) + actual_lens = np.array([job.match_len for kv, job in zip(kvs, prefill_input) if kv is not None]) + kvs = [kv for kv in kvs if kv is not None] + + # sort to minimize variants num + length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) + sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] + insert_sequences = maybe_call(self.prefill_cache.insert_sequences, self.prefill_mesh) + self.prefill_cache = insert_sequences(self.prefill_cache, *sorted_args, erase=True) + + cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) + forward_fn = maybe_call(self.forward_fn, self.prefill_mesh) + _, self.prefill_cache = forward_fn(inputs, self.prefill_weights, self.prefill_cache, cfg) + + with set_mesh(self.prefill_mesh): + for i, job in enumerate(prefill_input): + request = job.request + cache_entry, _ = maybe_call(self._get_cache_entry, self.prefill_mesh)(self.prefill_cache, i) + cache_entry = _ensure_all_args_on_mesh(cache_entry, self.decode_mesh) + sequence = np.array(request.text) + new_decode = PrefillResult( + request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1 + ) + self.prefill_work.to_decode.append(new_decode) + + # 2. triage requests based on whether they need to go to prefill or there's a cache match, so decode directly + while len(self.prefill_work.requests) > 0: + request = self.prefill_work.requests.pop(0) + sequence = np.array(request.text) + (total_match, buffer_ids), visited_ids = self._retrieve_prefix(sequence) + assert total_match <= sequence.size + BUFFER_STORE.mark_visited(visited_ids) + _axis = self.serve_cfg.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) + buffers = BUFFER_STORE.load(buffer_ids) + if total_match == sequence.size: + cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.decode_mesh), _axis) + new_decode = PrefillResult(request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1) + self.prefill_work.to_decode.append(new_decode) + print(f"Found a full match") + else: + print(f"Need to prefill, only found a match for length {total_match / (len(request.text) - 1):.2%}") + print(f"That equals {len(buffer_ids)} buffers or {total_match=}") + if total_match > 0: + cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.prefill_mesh), _axis) + else: + cache_entry = None + self.prefill_work.to_prefill.append(PrefillJob(request, cache_entry, total_match)) + + def serving_step(self): # event loop relies on determinism for multi-host/process computations (multi-process JAX) + # potentially profile when received the request to ######################################### + is_server = "server" in self.roles + should_start_profile = self.profile_start_time > 0 and not self.profiling + should_start_profile = SyncServer.broadcast("profile", self._it, should_start_profile, is_source=is_server) + if should_start_profile: + self.profile_start_time, self.profiling = time.perf_counter(), True + jax.profiler.start_trace("/tmp/online") + print("STARTING TRACE") + should_stop_profile = self.profile_start_time > 0 and time.perf_counter() - self.profile_start_time > 5.0 + should_stop_profile = SyncServer.broadcast("stop_profile", self._it, should_stop_profile, is_source=is_server) + if should_stop_profile: + self.profile_start_time, self.profiling = -1, False + print("STOPPING TRACE") + jax.profiler.stop_trace() + # potentially profile when received the request to ######################################### + + # sync on the server requests received ##################################################### + SyncServer.barrier("serving_step", self._it) + self._it, requests = self._it + 1, None + if "server" in self.roles: + with self.state_lock: + self.pending_requests, requests = [], list(self.pending_requests) + serve_cfg, requests = SyncServer.broadcast( + "requests", self._it, (dataclasses.asdict(self.serve_cfg), requests), is_source="server" in self.roles + ) + with self.state_lock: + self.serve_cfg = dataclasses.replace(self.serve_cfg, **serve_cfg) + for request in requests: + self.total_requests += 1 + self.prefill_work.requests.append(UserRequestPrompt(**request)) + # sync on the server requests received ##################################################### + + # main event loop work ##################################################################### + self.decode_step() + self.prefill_step() + # main event loop work ##################################################################### + + # offload buffers to keep a max of N ####################################################### + BUFFER_STORE.offload_buffers(max(0, BUFFER_STORE.livecount - self.serve_cfg.max_ondevice_buffers)) + extra_buffer_count = max(len(BUFFER_STORE.usecount) - self.serve_cfg.max_buffers, 0) + if extra_buffer_count > 0: + refs_to_delete = sorted(BUFFER_STORE.usecount.keys())[:extra_buffer_count] + deleted_buffers = remove_prefix_nodes(self.prefix_cache, refs_to_delete) + BUFFER_STORE.delete(list(deleted_buffers)) + if len(BUFFER_STORE._store) > self.serve_cfg.max_buffers: + raise ValueError() + # offload buffers to keep a max of N ####################################################### + + def add_request(self, request: UserRequestPrompt): + with self.state_lock: + self.pending_requests.append(dataclasses.asdict(request)) + + def update_params(self, params: dict[str, Any]): + with self.state_lock: + self.serve_cfg = dataclasses.replace(self.serve_cfg, **params) + + def new_prefix_cache(self): + self.prefix_cache = TrieNode(None, lock=threading.Lock()) + self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) + self._insert_prefix = partial(insert_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) From 23cbafdf07c476c81c42da563aa477d141a57655 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 19 Aug 2025 18:15:39 -0700 Subject: [PATCH 09/11] further cleanup and common model serving --- serving/client_demo.py | 37 +-- serving/main_serving.py | 228 ------------------- serving/main_serving_ds_r1.py | 207 ++++++----------- serving/main_serving_gpt_oss.py | 147 ++++++++++++ serving/main_serving_llama3.py | 140 ++++++++++++ serving/pyproject.toml | 20 +- serving/serving_jax/__init__.py | 2 +- serving/serving_jax/attention_cache_utils.py | 61 ++++- serving/serving_jax/http_server.py | 26 ++- serving/serving_jax/serving_loop.py | 135 ++++++----- 10 files changed, 545 insertions(+), 458 deletions(-) delete mode 100644 serving/main_serving.py create mode 100644 serving/main_serving_gpt_oss.py create mode 100644 serving/main_serving_llama3.py diff --git a/serving/client_demo.py b/serving/client_demo.py index 7c28cd0..06bc956 100644 --- a/serving/client_demo.py +++ b/serving/client_demo.py @@ -8,6 +8,7 @@ import time from pathlib import Path from typing import List +from argparse import ArgumentParser import numpy as np import requests @@ -26,7 +27,7 @@ def fetch_stream(request_id: int, prompt_text: str): """ payload = {"id": request_id, "text": prompt_text} headers = {"accept": "application/json", "Content-Type": "application/json"} - global responses, responses_lock, responses_done + global responses, responses_lock, responses_done, SERVER_URL try: t_first, t_start = None, time.perf_counter() @@ -94,40 +95,56 @@ def generate_layout() -> Layout: def profile_issue(): headers = {"accept": "application/json", "Content-Type": "application/json"} + global SERVER_URL server_url = SERVER_URL + "/profile" requests.get(server_url, headers=headers) def set_generation_length(length: int): + global SERVER_URL headers = {"accept": "application/json", "Content-Type": "application/json"} server_url = SERVER_URL + "/set_generation_length" requests.get(server_url, headers=headers, params={"length": length}) def retrieve(id: int): + global SERVER_URL headers = {"accept": "application/json", "Content-Type": "application/json"} server_url = SERVER_URL + "/retrieve" print(requests.get(server_url, headers=headers, params={"id": id})) def investigate(id: int): + global SERVER_URL headers = {"accept": "application/json", "Content-Type": "application/json"} server_url = SERVER_URL + "/investigate" requests.get(server_url, headers=headers, params={"id": id}) def main(): - global responses, MAX_PANEL_LINES, responses_lock, responses_done + global responses, MAX_PANEL_LINES, responses_lock, responses_done, SERVER_URL all_prompts = get_prompts() - prompts_num = 18 - idxs = np.random.randint(0, len(all_prompts), prompts_num) - PROMPTS = [all_prompts[idx] for idx in idxs] # This controls the "scrolling" effect. It's the max number of lines # displayed in a panel. When text exceeds this, only the latest lines are shown. MAX_PANEL_LINES = 15 # --------------------- + parser = ArgumentParser() + parser.add_argument("--profile", action="store_true", default=False) + parser.add_argument("--decode-length", "-d", default=32, type=int) + parser.add_argument("--query-num", "-q", default=18, type=int) + parser.add_argument("--port", "-p", default=8081, type=int) + + args = parser.parse_args() + SERVER_URL = f"http://localhost:{args.port}" + idxs = np.random.randint(0, len(all_prompts), args.query_num) + PROMPTS = [all_prompts[idx] for idx in idxs] + + if args.profile: + profile_issue() + return + # A thread-safe dictionary to store the state of each streaming response. # The structure will be: { request_id: {"prompt": str, "response": str, "status": str} } GLOBAL_ID = time.time_ns() % 2**30 @@ -138,15 +155,7 @@ def main(): responses_lock = threading.Lock() console = Console() - if len(sys.argv) > 1 and sys.argv[1] == "profile": - profile_issue() - return - - if len(sys.argv) > 2 and sys.argv[1] == "investigate": - investigate(int(sys.argv[2])) - return - - set_generation_length(64 if len(sys.argv) <= 1 else int(sys.argv[1])) + set_generation_length(args.decode_length) threads = [] console.print("[bold cyan]Starting streaming requests... Press Ctrl+C to exit.[/]") time.sleep(1) # Give user time to read the message diff --git a/serving/main_serving.py b/serving/main_serving.py deleted file mode 100644 index d8d80cc..0000000 --- a/serving/main_serving.py +++ /dev/null @@ -1,228 +0,0 @@ -import dataclasses -import time -from pathlib import Path -import threading -import asyncio -import socket -import signal -import time -from typing import AsyncGenerator -from contextlib import asynccontextmanager -from argparse import ArgumentParser -from typing import Any - -import jax -from jax import random -from jax.sharding import PartitionSpec as P, AxisType, NamedSharding -import numpy as np -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse, Response -from pydantic import BaseModel -import uvicorn - -from llama3_jax import model as l3jax -import serving_jax as serving -from serving_jax import attention_cache_utils - -Config = Any - -TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None - -jax.config.update("jax_explain_cache_misses", True) -#jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) - -try: # newer JAX only - my_id = int(socket.gethostname().split("-")[-1]) - my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] - jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") - jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) -except: # noqa: E722 - pass - -shutdown_signal = threading.Event() - -def encode_input(tokenizer, texts, pad_id: int = 0): - assert isinstance(texts, list) - inputs = [ - tokenizer.apply_chat_template([{"role": "user", "content": text}], add_generation_prompt=True) for text in texts - ] - max_len = max([len(x) for x in inputs]) - return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) - - -def load_model(): - global SERVE_LOOP, SERVING_THREAD, TOKENIZER, ARGS - - parser = ArgumentParser() - parser.add_argument("--server", action="store_true", help="Make this node the main server.", default=False) - ARGS = parser.parse_args() - - #process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) - #jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) - jax.distributed.initialize() - print(jax.devices()) - print("-" * 80) - print(jax.local_devices()) - - #model_name = "Llama-3.1-8B-Instruct" - #ckpt_path = Path(f"~/{model_name}").expanduser() - #model_name = "Llama-3.1-8B-Instruct-quant" - model_name = "Llama-3.1-70B-Instruct-quant" - ckpt_path = Path(f"~/bucket/llama3_jax_old/{model_name}").expanduser() - cfg = l3jax.load_config(ckpt_path / "config.json") - TOKENIZER = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") - assert ckpt_path.is_dir() - print("---> Model config loaded") - - # two hosts, different device and host meshes - #local_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) - #local_mesh = jax.make_mesh((1, 1, 1), P("x", "y", "z"), devices=jax.local_devices(), axis_types=(AxisType.Explicit,) * 3) - #decode_mesh, prefill_mesh = local_mesh, local_mesh - decode_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[:8], axis_types=(AxisType.Explicit,) * 3) - prefill_mesh = jax.make_mesh((1, 8, 1), P("x", "y", "z"), devices=jax.devices()[8:], axis_types=(AxisType.Explicit,) * 3) - #decode_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3) - #prefill_mesh = jax.make_mesh((1, 8, 2), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Explicit,) * 3) - cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True) - cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=2048) - cfg.quant_cache = False - - decode_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=decode_mesh))) - prefill_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(dataclasses.replace(cfg, mesh=prefill_mesh))) - - print("---> Weights loaded") - - serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64, prefix_chunk_size=64) - decode_cache = l3jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size) - decode_cache.get_sequence = attention_cache_utils.kvcache_get_entry - decode_cache.insert_sequences = attention_cache_utils.kvcache_update_cache - #decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) - #decode_cache.get_sequence = attention_cache_utils.batch_paged_get_entry - #decode_cache.insert_sequences = attention_cache_utils.batch_paged_update_sequences - - def init_cache(cfg: Config, batch_size: int, actual_len: int): - cache = l3jax.KVCache.init(random.key(0), cfg, batch_size) - cache.get_sequence = attention_cache_utils.kvcache_get_entry - cache.insert_sequences = attention_cache_utils.kvcache_update_cache - cache.iter = actual_len - return cache - - with jax.sharding.set_mesh(prefill_mesh): - prefill_cache = init_cache(dataclasses.replace(cfg, mesh=prefill_mesh), serve_cfg.prefill_batch_size, 8192) - - forward_fn = l3jax.decode_step # TODO: the model file needs to call it forward explcitly - SERVE_LOOP = serving.ServingLoop( - #serve_cfg, cfg, init_cache, l3jax.decode_step, prefill_weights, decode_weights, decode_cache, ARGS.server - serve_cfg, cfg, forward_fn, prefill_weights, prefill_cache, decode_weights, decode_cache, ARGS.server - ) - print("---> Created the serving loop") - - def serve_forever(): - try: - while not shutdown_signal.is_set(): - SERVE_LOOP.serving_step() - except: # noqa: E722 - import traceback - print(traceback.format_exc(), flush=True) - finally: - print("Received a shutdown signal") - time.sleep(0.1) - signal.raise_signal(signal.SIGKILL) # shut down the web server - print("Exiting the serving loop") - - SERVING_THREAD = threading.Thread(target=serve_forever) - SERVING_THREAD.start() - - -######################################################################################################################## - - -@asynccontextmanager -async def lifespan(app: FastAPI): - yield - shutdown_signal.set() - - -_ = load_model() -APP = FastAPI(lifespan=lifespan) - - -class GenerateRequest(BaseModel): - id: int - text: str - - -#async def generate_generator(params: GenerateRequest, request: Request) -> AsyncGenerator[str, None]: -async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: - if id in SERVE_LOOP.results: - del SERVE_LOOP.results[id] - - input = encode_input(TOKENIZER, [text])[0].tolist() - iter = len(input) - SERVE_LOOP.add_request(serving.UserRequestPrompt(id, input)) - while id not in SERVE_LOOP.results: - await asyncio.sleep(0.1) - try: - result: serving.DecodeResult = SERVE_LOOP.results[id] - while not result.done: - if await request.is_disconnected(): # Check if client disconnected - print("Client disconnected.") - break - if len(result.token_list) > iter: - new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) - yield f"{new_segment}" - await asyncio.sleep(0.1) # Stream a new message every 1 second - if len(result.token_list) > iter: - new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) - yield f"{new_segment}" - except asyncio.CancelledError: - pass - finally: - pass - - -@APP.get("/stream") -async def stream_response(params: GenerateRequest, request: Request): - return StreamingResponse(generate_generator(params.id, params.text, request), media_type="text/event-stream") - - -@APP.get("/generate") -async def generate(id: int, text: str): # generate without output - print(f"Input text: {text}") - SERVE_LOOP.add_request(serving.UserRequestPrompt(id, encode_input(TOKENIZER, [text])[0].tolist())) - return Response("OK") - - -@APP.get("/retrieve") -async def retrieve(id: int): - if id in SERVE_LOOP.results: - return Response(TOKENIZER.decode(SERVE_LOOP.results[id].token_list)) - return Response("NO TEXT") - - -@APP.get("/set_generation_length") -async def set_generation_length(length: int): - SERVE_LOOP.serve_cfg.max_decode_length = max(length, 32) - return Response("OK") - - -@APP.get("/profile") -async def profile(request: Request): - del request - SERVE_LOOP.profile_start_time = time.perf_counter() - return Response("OK") - - -@APP.get("/") -async def root(): - return {"message": "Welcome! Try the /stream-text endpoint."} - - -if __name__ == "__main__": - if ARGS.server: - uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) - else: - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - shutdown_signal.set() diff --git a/serving/main_serving_ds_r1.py b/serving/main_serving_ds_r1.py index ab4cb80..06b65dc 100644 --- a/serving/main_serving_ds_r1.py +++ b/serving/main_serving_ds_r1.py @@ -1,39 +1,40 @@ -import asyncio +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import dataclasses -import signal -import socket -import sys import threading -import time from argparse import ArgumentParser -from contextlib import asynccontextmanager +from functools import partial from pathlib import Path -from typing import AsyncGenerator import jax +import jax.numpy as jnp import numpy as np import serving_jax as serving -import uvicorn -from fastapi import FastAPI, Request -from fastapi.responses import Response, StreamingResponse from jax import random from jax.sharding import AxisType from jax.sharding import PartitionSpec as P -from pydantic import BaseModel from serving_jax import attention_cache_utils from deepseek_r1_jax import chkpt_utils as dsjax_utils from deepseek_r1_jax import model as dsjax -TOKENIZER, SERVE_LOOP, SERVING_THREAD, ARGS = None, None, None, None jax.config.update("jax_explain_cache_misses", True) -jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) +# jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) # jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) # jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) -jax.config.update("jax_enable_empty_arrays", True) - -shutdown_signal = threading.Event() def encode_input(tokenizer, texts, pad_id: int = 0): @@ -45,152 +46,78 @@ def encode_input(tokenizer, texts, pad_id: int = 0): return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) -def load_model(): - global SERVE_LOOP, SERVING_THREAD, TOKENIZER, ARGS +tokenizer_encode = lambda tokenizer, text: encode_input(tokenizer, [text])[0].tolist() +tokenizer_decode = lambda tokenizer, tokens: tokenizer.decode(tokens) + + +def distributed_init(): + # for TPU + jax.distributed.initialize() + + # for GPU/CPU + # process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) + # jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) + # jax.distributed.initialize() + +def load_model(): parser = ArgumentParser() parser.add_argument("--server", action="store_true", help="Make this node the main server.", default=False) ARGS = parser.parse_args() - jax.distributed.initialize() - print(jax.devices()) - print("-" * 80) - print(jax.local_devices()) + distributed_init() + devices = jax.devices() # this helps catch distributed errors quickly ckpt_path = Path(f"~/bucket/deepseek-r1-jax-chkpt").expanduser() - TOKENIZER = dsjax.load_tokenizer() + tokenizer = dsjax.load_tokenizer() assert ckpt_path.is_dir() print("---> Model config loaded") - mesh = jax.make_mesh( - (1, 8, jax.device_count() // 8), P("x", "y", "z"), devices=jax.devices(), axis_types=(AxisType.Auto,) * 3 - ) - cfg = dataclasses.replace(dsjax.Config(), mesh=mesh) # , num_layers=4) + mesh = jax.make_mesh((1, 8, len(devices) // 8), P("x", "y", "z"), devices=devices, axis_types=(AxisType.Auto,) * 3) + cfg = dataclasses.replace(dsjax.Config(), max_seq_len=1024, mesh=mesh)#, num_layers=4) weights = dsjax_utils.load_model(ckpt_path, cfg) decode_weights, prefill_weights = weights, weights print("---> Weights loaded") serve_cfg = serving.ServingConfig( - decode_steps=32, max_decode_length=64, decode_batch_size=8, prefill_batch_size=1, prefix_chunk_size=64 + decode_steps=32, max_decode_length=64, decode_batch_size=8, prefill_batch_size=1, prefix_chunk_size=64, max_ondevice_buffers=16 ) - decode_cache = dsjax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, cfg.max_seq_len) - decode_cache.get_sequence = attention_cache_utils.kvcache_get_sequence - decode_cache.insert_sequences = attention_cache_utils.kvcache_insert_sequences - SERVE_LOOP = serving.ServingLoop( - serve_cfg, cfg, dsjax.prefill, prefill_weights, dsjax.decode_step, decode_weights, decode_cache, ARGS.server + decode_cache = serving.AttentionWrapper( + dsjax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, cfg.max_seq_len), + attention_cache_utils.kvcache_get_sequence, + attention_cache_utils.kvcache_insert_sequences + ) + prefill_cache = serving.AttentionWrapper( + dsjax.KVCache.init(random.key(0), cfg, serve_cfg.prefill_batch_size, cfg.max_seq_len), + attention_cache_utils.kvcache_get_sequence, + attention_cache_utils.kvcache_insert_sequences ) - print("---> Created the serving loop") - - def serve_forever(): - try: - while not shutdown_signal.is_set(): - SERVE_LOOP.serving_step() - except Exception as e: - import traceback - - print(traceback.format_exc(), flush=True) - print(f"Exception {e}", flush=True) - finally: - print("Received a shutdown signal") - time.sleep(0.1) - signal.raise_signal(signal.SIGKILL) # shut down the web server - print("Exiting the serving loop") - - SERVING_THREAD = threading.Thread(target=serve_forever) - SERVING_THREAD.start() - - -######################################################################################################################## - - -@asynccontextmanager -async def lifespan(app: FastAPI): - yield - shutdown_signal.set() - - -_ = load_model() -APP = FastAPI(lifespan=lifespan) - - -class GenerateRequest(BaseModel): - id: int - text: str - - -# async def generate_generator(params: GenerateRequest, request: Request) -> AsyncGenerator[str, None]: -async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: - if id in SERVE_LOOP.results: - del SERVE_LOOP.results[id] - - input = encode_input(TOKENIZER, [text])[0].tolist() - iter = len(input) - SERVE_LOOP.add_request(serving.UserRequestPrompt(id, input)) - while id not in SERVE_LOOP.results: - await asyncio.sleep(0.1) - try: - result: serving.DecodeResult = SERVE_LOOP.results[id] - while not result.done: - if await request.is_disconnected(): # Check if client disconnected - print("Client disconnected.") - break - if len(result.token_list) > iter: - new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) - yield f"{new_segment}" - await asyncio.sleep(0.1) # Stream a new message every 1 second - if len(result.token_list) > iter: - new_segment, iter = TOKENIZER.decode(result.token_list[iter:]), len(result.token_list) - yield f"{new_segment}" - except asyncio.CancelledError: - pass - finally: - pass - - -@APP.get("/stream") -async def stream_response(params: GenerateRequest, request: Request): - return StreamingResponse(generate_generator(params.id, params.text, request), media_type="text/event-stream") - - -@APP.get("/generate") -async def generate(id: int, text: str): # generate without output - print(f"Input text: {text}") - SERVE_LOOP.add_request(serving.UserRequestPrompt(id, encode_input(TOKENIZER, [text])[0].tolist())) - return Response("OK") - - -@APP.get("/retrieve") -async def retrieve(id: int): - if id in SERVE_LOOP.results: - return Response(TOKENIZER.decode(SERVE_LOOP.results[id].token_list)) - return Response("NO TEXT") - -@APP.get("/set_generation_length") -async def set_generation_length(length: int): - SERVE_LOOP.update_params({"max_decode_length": max(length, 32)}) - return Response("OK") + sampler = partial(jnp.argmax, axis=-1) + @partial(jax.jit, donate_argnames=("cache",)) + def forward_fn(inputs, weights, cache, cfg): + logits, cache = dsjax.forward(inputs, (inputs != dsjax.PAD_ID).astype(np.int32), weights, cfg, cache) + return sampler(logits), cache -@APP.get("/profile") -async def profile(request: Request): - del request - SERVE_LOOP.profile_start_time = time.perf_counter() - return Response("OK") + serve_loop = serving.ServingLoop( + serve_cfg, cfg, forward_fn, prefill_weights, prefill_cache, decode_weights, decode_cache, ARGS.server + ) + print("---> Created the serving loop") + shutdown_signal = threading.Event() + serve_loop.serve_forever(shutdown_signal) -@APP.get("/") -async def root(): - return {"message": "Welcome! Try the /stream-text endpoint."} + serving.run_http_server( + serve_loop, + partial(tokenizer_encode, tokenizer), + partial(tokenizer_decode, tokenizer), + ARGS.server, + shutdown_signal=shutdown_signal, + ) if __name__ == "__main__": - if ARGS.server: - print(f"jax.process_idx() == {jax.process_index()}") - uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) - else: - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - shutdown_signal.set() + load_model() + +######################################################################################################################## diff --git a/serving/main_serving_gpt_oss.py b/serving/main_serving_gpt_oss.py new file mode 100644 index 0000000..d03b739 --- /dev/null +++ b/serving/main_serving_gpt_oss.py @@ -0,0 +1,147 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import socket +import threading +from argparse import ArgumentParser +from functools import partial +from pathlib import Path +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +import serving_jax as serving +from gpt_oss_jax import model as gpt_jax +from jax import random +from jax.sharding import AxisType +from serving_jax import attention_cache_utils as attn_utils + +Config = Any + +jax.config.update("jax_explain_cache_misses", True) +# jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) +# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0.0) +# jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) + +try: # newer JAX only + my_id = int(socket.gethostname().split("-")[-1]) + my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] + jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") + jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) +except: # noqa: E722 + pass + + +def encode_input(tokenizer, texts, pad_id: int = gpt_jax.PAD_ID): + assert isinstance(texts, list) + inputs = [ + tokenizer.apply_chat_template([{"role": "user", "content": text}], add_generation_prompt=True) for text in texts + ] + max_len = max([len(x) for x in inputs]) + return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) + + +tokenizer_encode = lambda tokenizer, text: encode_input(tokenizer, [text])[0].tolist() +tokenizer_decode = lambda tokenizer, tokens: tokenizer.decode(tokens) + + +def distributed_init(): + # for TPU + jax.distributed.initialize() + + # for GPU/CPU + # process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) + # jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) + # jax.distributed.initialize() + + +def main(): + parser = ArgumentParser() + parser.add_argument("--server", action="store_true", help="Make this node the main server.", default=False) + ARGS = parser.parse_args() + + distributed_init() + devices = jax.devices() # this helps catch distributed errors quickly + + model_name = "gpt_oss_20b-quant" + ckpt_path = Path(f"~/bucket/gpt_oss_jax/{model_name}").expanduser() + cfg = gpt_jax.load_config(ckpt_path / "config.json") + tokenizer = gpt_jax.load_tokenizer(ckpt_path) + assert ckpt_path.is_dir() + print("---> Model config loaded") + + # two hosts, different device and host meshes + decode_mesh = jax.make_mesh((1, 2, 2), ("x", "y", "z"), devices=devices, axis_types=(AxisType.Explicit,) * 3) + prefill_mesh = jax.make_mesh((1, 2, 2), ("x", "y", "z"), devices=devices, axis_types=(AxisType.Explicit,) * 3) + cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_moe=True, quant_cache=True) + cfg = dataclasses.replace(cfg, use_prefill_attn_kernel=False, use_decode_attn_kernel=False, max_seq_len=2048) + cfg.quant_cache = True + + weights_formats, decode_cache_formats = gpt_jax.optimal_formats(dataclasses.replace(cfg, mesh=decode_mesh)) + decode_weights = gpt_jax.load_pytree(ckpt_path, weights_formats) + weights_formats, prefill_cache_formats = gpt_jax.optimal_formats(dataclasses.replace(cfg, mesh=prefill_mesh)) + # prefill_weights = gpt_jax.load_pytree(ckpt_path, weights_formats) + prefill_weights = decode_weights + + print("---> Weights loaded") + + serve_cfg = serving.ServingConfig(decode_steps=32, max_decode_length=64, prefix_chunk_size=64) + decode_cache = gpt_jax.KVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, cfg.max_seq_len) + decode_cache = jax.tree.map(lambda x, sds: jax.device_put(x, sds.sharding), decode_cache, decode_cache_formats) + decode_cache = serving.AttentionWrapper( + decode_cache, attn_utils.kvcache_get_sequence, attn_utils.kvcache_insert_sequences + ) + # decode_cache = gpt_jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) + # decode_cache = serving.AttentionWrapper( + # decode_cache, attn_utils.paged_kvcache_get_sequence, attn_utils.paged_kvcache_insert_sequences + # ) + + prefill_cache = gpt_jax.KVCache.init(random.key(0), dataclasses.replace(cfg, mesh=prefill_mesh), serve_cfg.prefill_batch_size, 2048) + prefill_cache = jax.tree.map(lambda x, sds: jax.device_put(x, sds.sharding), prefill_cache, prefill_cache_formats) + prefill_cache = serving.AttentionWrapper( + prefill_cache, attn_utils.kvcache_get_sequence, attn_utils.kvcache_insert_sequences + ) + + sampler = partial(jnp.argmax, axis=-1) + + @partial(jax.jit, donate_argnames=("cache",)) + def forward_fn(inputs, weights, cache, cfg): + logits, cache = gpt_jax.forward(inputs, (inputs != gpt_jax.PAD_ID).astype(jnp.int32), weights, cfg, cache) + return sampler(logits), cache + + serve_loop = serving.ServingLoop( + serve_cfg, cfg, forward_fn, prefill_weights, prefill_cache, decode_weights, decode_cache, ARGS.server + ) + print("---> Created the serving loop") + + shutdown_signal = threading.Event() + serve_loop.serve_forever(shutdown_signal) + + serving.run_http_server( + serve_loop, + partial(tokenizer_encode, tokenizer), + partial(tokenizer_decode, tokenizer), + ARGS.server, + shutdown_signal=shutdown_signal, + ) + + return + + +if __name__ == "__main__": + main() + +######################################################################################################################## diff --git a/serving/main_serving_llama3.py b/serving/main_serving_llama3.py new file mode 100644 index 0000000..769c20b --- /dev/null +++ b/serving/main_serving_llama3.py @@ -0,0 +1,140 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import socket +import threading +from argparse import ArgumentParser +from functools import partial +from pathlib import Path +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +import serving_jax as serving +from jax import random +from jax.sharding import AxisType +from llama3_jax import model as l3jax +from serving_jax import attention_cache_utils as attn_utils + +Config = Any + +jax.config.update("jax_explain_cache_misses", True) +# jax.config.update("jax_compilation_cache_dir", str(Path("~/.cache/jax").expanduser())) + +try: # newer JAX only + my_id = int(socket.gethostname().split("-")[-1]) # a scheme where hosts end with -HOST_NUM (host-0, host-1, ...) + my_ip = socket.getaddrinfo(socket.gethostname(), 80)[0][-1][0] + jax.config.update("jax_cross_host_transfer_socket_address", f"{my_ip}:{17007 + my_id}") + jax.config.update("jax_cross_host_transport_addresses", ",".join([f"{my_ip}:0"] * 8)) +except: # noqa: E722 + pass + + +def encode_input(tokenizer, texts, pad_id: int = 0): + assert isinstance(texts, list) + inputs = [ + tokenizer.apply_chat_template([{"role": "user", "content": text}], add_generation_prompt=True) for text in texts + ] + max_len = max([len(x) for x in inputs]) + return np.array([(max_len - len(x)) * [pad_id] + x for x in inputs]) + + +tokenizer_encode = lambda tokenizer, text: encode_input(tokenizer, [text])[0].tolist() +tokenizer_decode = lambda tokenizer, tokens: tokenizer.decode(tokens) + + +def distributed_init(): + # for TPU + jax.distributed.initialize() + + # for GPU/CPU + # process_idx = int(socket.gethostname().split("-")[-1]) - 1 # a scheme where hosts are (host-1, host-2, ...) + # jax.distributed.initialize(os.environ["COORDINATOR_ADDRESS"], 2, process_idx) + # jax.distributed.initialize() + + +def main(): + parser = ArgumentParser() + parser.add_argument("--server", action="store_true", help="Make this node the main server.", default=False) + ARGS = parser.parse_args() + + distributed_init() + devices = jax.devices() # this helps catch distributed errors quickly + + model_name = "Llama-3.1-8B-Instruct-quant" + ckpt_path = Path(f"~/bucket/llama3_jax/{model_name}").expanduser() + cfg = l3jax.load_config(ckpt_path / "config.json") + tokenizer = l3jax.load_tokenizer(ckpt_path / "tokenizer.json", ckpt_path / "tokenizer_config.json") + assert ckpt_path.is_dir() + print("---> Model config loaded") + + # two hosts, different device and host meshes + decode_mesh = jax.make_mesh((1, 8, 1), ("x", "y", "z"), devices=devices[:8], axis_types=(AxisType.Explicit,) * 3) + prefill_mesh = jax.make_mesh((1, 8, 1), ("x", "y", "z"), devices=devices[8:], axis_types=(AxisType.Explicit,) * 3) + cfg = dataclasses.replace(cfg, mesh=decode_mesh, quant_layer=True, quant_cache=True, max_seq_len=2048) + cfg_decode, cfg_prefill = dataclasses.replace(cfg, mesh=decode_mesh), dataclasses.replace(cfg, mesh=prefill_mesh) + + decode_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(cfg_decode)) + prefill_weights = l3jax.load_pytree(ckpt_path, l3jax.Weights.shardings(cfg_prefill)) + # prefill_weights = decode_weights + + print("---> Weights loaded") + + serve_cfg = serving.ServingConfig( + decode_steps=32, max_decode_length=64, prefix_chunk_size=64, max_ondevice_buffers=2048, max_buffers=2048 + ) + decode_cache = l3jax.KVCache.init(random.key(0), cfg_decode, serve_cfg.decode_batch_size) + decode_cache = attn_utils.AttentionInterface( + decode_cache, attn_utils.kvcache_get_sequence, attn_utils.kvcache_insert_sequences + ) + # decode_cache = l3jax.PagedKVCache.init(random.key(0), cfg, serve_cfg.decode_batch_size, 2048, 32) + # decode_cache = attn_utils.AttentionInterface( + # decode_cache, attn_utils.paged_kvcache_get_sequence, attn_utils.paged_kvcache_insert_sequences + # ) + + prefill_cache = l3jax.KVCache.init(random.key(0), cfg_prefill, serve_cfg.prefill_batch_size) + prefill_cache = attn_utils.AttentionInterface( + prefill_cache, attn_utils.kvcache_get_sequence, attn_utils.kvcache_insert_sequences + ) + + sampler = partial(jnp.argmax, axis=-1) + + @partial(jax.jit, donate_argnames=("cache",)) + def forward_fn(inputs, weights, cache, cfg): + logits, cache = l3jax.forward(inputs, (inputs != 0).astype(jnp.int32), weights, cfg, cache) + return sampler(logits), cache + + serve_loop = serving.ServingLoop( + serve_cfg, cfg, forward_fn, prefill_weights, prefill_cache, decode_weights, decode_cache, ARGS.server + ) + print("---> Created the serving loop") + + shutdown_signal = threading.Event() + serve_loop.serve_forever(shutdown_signal) + + serving.run_http_server( + serve_loop, + partial(tokenizer_encode, tokenizer), + partial(tokenizer_decode, tokenizer), + ARGS.server, + shutdown_signal=shutdown_signal, + ) + + +if __name__ == "__main__": + main() + +######################################################################################################################## diff --git a/serving/pyproject.toml b/serving/pyproject.toml index 0df6079..c2d422f 100644 --- a/serving/pyproject.toml +++ b/serving/pyproject.toml @@ -13,12 +13,28 @@ dependencies = [ "jax", "tqdm", "numpy", - #"orbax-checkpoint", - #"datasets", "gcsfs", "etils", ] +[tool.ruff] +line-length = 120 +indent-width = 4 + +[tool.ruff.lint] +select = [ + "E", + "W291", # trailing whitespace + "F821", # undefined variables +] +ignore = [ + "E731", # lambdas expression instead of def + "E741", # allow 'l' as a variable name + "E402", # allow module imports not at top of file + "F841", # allow unused variables + "E501", # ignore line-too-long +] + [build-system] requires = ["setuptools>=61.0.0", "wheel"] build-backend = "setuptools.build_meta" diff --git a/serving/serving_jax/__init__.py b/serving/serving_jax/__init__.py index f77e9b9..160ffa9 100644 --- a/serving/serving_jax/__init__.py +++ b/serving/serving_jax/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. from .http_server import run_http_server -from .serving_loop import DecodeResult, ServingConfig, ServingLoop, UserRequestPrompt +from .serving_loop import DecodeResult, ServingConfig, ServingLoop, UserRequestPrompt, AttentionWrapper diff --git a/serving/serving_jax/attention_cache_utils.py b/serving/serving_jax/attention_cache_utils.py index ba40d48..1c1d89c 100644 --- a/serving/serving_jax/attention_cache_utils.py +++ b/serving/serving_jax/attention_cache_utils.py @@ -15,13 +15,22 @@ import dataclasses import math from functools import partial -from typing import Any +from typing import Any, Callable import jax import jax.numpy as jnp +from jax.sharding import auto_axes QuantArray, PyTree, KVCache, PagedKVCache = Any, Any, Any, Any + +@dataclasses.dataclass +class AttentionInterface: + cache: KVCache + get_sequence: Callable + insert_sequences: Callable + + next_power_of_2 = lambda x: 2 ** math.ceil(math.log2(max(x, 1))) _pad_after = lambda x, l, axis: jnp.pad(x, [(0, 0) if i != axis else (0, l - x.shape[i]) for i in range(x.ndim)]) @@ -47,7 +56,7 @@ def _transpose_attention_tree(kv_list: list[PyTree], time_axis: int): @partial(jax.jit, donate_argnames=("cache",)) -def _kvcache_insert_sequences( +def __kvcache_insert_sequences( cache: KVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], @@ -67,20 +76,32 @@ def _kvcache_insert_sequences( def _update_element(x, u): update_permute = [0, cache.time_axis] + [i for i in range(u.ndim) if i not in (0, cache.time_axis)] - # time_dim, batch_dim = update_permute.pop(cache.time_axis), update_permute.pop(0) # first pop time_axis - # update_permute = [batch_dim, time_dim] + update_permute - return x.at[batch_idxs[:, None], :, time_indices, ...].set( - u.transpose(update_permute), mode="drop", out_sharding=jax.typeof(x).sharding - ) + return x.at[batch_idxs[:, None], :, time_indices, ...].set(u.transpose(update_permute), mode="drop") + buffers_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, cache.buffers) cache_kvs = jax.tree.map(_update_element, cache.buffers, kvs) - cache_starts = cache.starts.at[batch_idxs].set( - start_time, mode="drop", out_sharding=jax.typeof(cache.starts).sharding - ) + cache_starts = cache.starts.at[batch_idxs].set(start_time, mode="drop") cache_iter = jnp.where(uninitialized_cache, jnp.max(actual_lens), cache.iter) buffer_names = [field.name for field in dataclasses.fields(cache)][: len(cache_kvs)] - return dataclasses.replace(cache, **dict(zip(buffer_names, cache_kvs, strict=True)), iter=cache_iter, starts=cache_starts) + return dataclasses.replace( + cache, **dict(zip(buffer_names, cache_kvs, strict=True)), iter=cache_iter, starts=cache_starts + ) + + +@partial(jax.jit, donate_argnames=("cache",)) +def _kvcache_insert_sequences( + cache: KVCache, + kvs: list[tuple[jax.Array | QuantArray, ...]], + batch_idxs: list[jax.Array], + actual_lens: list[jax.Array], + update_mask: list[bool] | None = None, + erase: bool = False, +): + cache_shardings = jax.tree.map(lambda x: jax.typeof(x).sharding, cache) + return auto_axes(__kvcache_insert_sequences, out_sharding=cache_shardings)( + cache, kvs, batch_idxs, actual_lens, update_mask, erase + ) @partial(jax.jit, donate_argnames=("cache",)) @@ -133,7 +154,7 @@ def _find_empty_pages(free_pages: jax.Array, k: int, proposal_pages: jax.Array | @partial(jax.jit, donate_argnames=("cache",)) -def _paged_kvcache_insert_sequences( +def __paged_kvcache_insert_sequences( cache: PagedKVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], batch_idxs: list[jax.Array], @@ -186,6 +207,20 @@ def _update_element(x, u): ) +@partial(jax.jit, donate_argnames=("cache",)) +def _paged_kvcache_insert_sequences( + cache: PagedKVCache, + kvs: list[tuple[jax.Array | QuantArray, ...]], + batch_idxs: list[jax.Array], + actual_lens: list[jax.Array], + update_mask: list[bool] | None = None, +) -> PagedKVCache: + cache_sharding = jax.tree.map(lambda x: jax.typeof(x).sharding, cache) + return auto_axes(__paged_kvcache_insert_sequences, out_sharding=cache_sharding)( + cache, kvs, batch_idxs, actual_lens, update_mask + ) + + def paged_kvcache_insert_sequences( cache: KVCache, kvs: list[tuple[jax.Array | QuantArray, ...]], @@ -194,6 +229,8 @@ def paged_kvcache_insert_sequences( erase: bool = False, ): del erase # inapplicable + if len(kvs) == 0: + return cache pad_len = max(next_power_of_2(len(kvs)), 4) - len(kvs) # an update of power of 2 and at least 4 update_mask = [i < len(kvs) for i in range(len(kvs) + pad_len)] kvs = kvs + [kvs[-1]] * pad_len diff --git a/serving/serving_jax/http_server.py b/serving/serving_jax/http_server.py index d4622ae..376fd31 100644 --- a/serving/serving_jax/http_server.py +++ b/serving/serving_jax/http_server.py @@ -13,10 +13,11 @@ # limitations under the License. import asyncio +import signal import threading import time from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Callable +from typing import AsyncGenerator, Callable import uvicorn from fastapi import FastAPI, Request @@ -38,12 +39,27 @@ def run_http_server( is_server: bool = False, shutdown_signal: threading.Event | None = None, ) -> None: + # lifetime management + def signal_listener(): + while not shutdown_signal.is_set(): + time.sleep(1) + signal.raise_signal(signal.SIGKILL) + + threading.Thread(target=signal_listener).start() + + def interrupt_handler(signum, frame): + if shutdown_signal is not None: + shutdown_signal.set() + + signal.signal(signal.SIGINT, interrupt_handler) + @asynccontextmanager async def lifespan(app: FastAPI): yield if shutdown_signal is not None: shutdown_signal.set() + # the HTTP server APP = FastAPI(lifespan=lifespan) async def generate_generator(id: int, text: str, request: Request) -> AsyncGenerator[str, None]: @@ -107,9 +123,5 @@ async def root(): if is_server: uvicorn.run(APP, host="0.0.0.0", port=8081, reload=False, server_header=False) else: - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - if shutdown_signal is not None: - shutdown_signal.set() + while not shutdown_signal.is_set(): + time.sleep(0.1) diff --git a/serving/serving_jax/serving_loop.py b/serving/serving_jax/serving_loop.py index 46d4765..35b7924 100644 --- a/serving/serving_jax/serving_loop.py +++ b/serving/serving_jax/serving_loop.py @@ -16,14 +16,19 @@ import dataclasses import json import math +from pprint import pformat import threading import time +import traceback from concurrent.futures import Future, ThreadPoolExecutor from functools import partial -from typing import Any, Callable, Sequence +from typing import Any, Callable, Sequence, NamedTuple +import logging import jax import jax.numpy as jnp +import numpy as np +from jax._src import distributed from jax.sharding import Mesh, NamedSharding, set_mesh from jax.sharding import PartitionSpec as P @@ -31,16 +36,27 @@ from jax.experimental.shard import auto_axes except ModuleNotFoundError: from jax.sharding import auto_axes -import numpy as np -from jax._src import distributed -from jax._src.lib import xla_client as xc +try: + from jax.sharding import use_mesh + + set_mesh = use_mesh +except ImportError: + pass from .cross_host import transfer_tree_A2B KVCache, Weights, Config = Any, Any, Any PyTree, PyTreeStruct = Any, Any +AttentionWrapper = NamedTuple( + "AttentionWrapper", [("cache", KVCache), ("get_sequence", Callable), ("insert_sequences", Callable)] +) -is_type = lambda x, cls: (type(x).__name__ == cls.__name__) and (type(x).__module__ == cls.__module__) +logger = logging.getLogger("serving_jax") +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter("%(levelname)s:%(filename)s:%(lineno)d: %(message)s")) +logger.handlers = [handler] +logger.setLevel("INFO") +DEBUG, INFO, WARN = logger.debug, logger.info, logger.warning ######################################################################################################################## # device put for cross-process/hosts transfers ######################################################################### @@ -49,6 +65,8 @@ def unsafe_device_put(xs: PyTree, spec: PyTree, dest_mesh: Mesh): """Fastest, but local single-process JAX only for now.""" + from jax._src.lib import xla_client as xc + xs_flat, xs_struct = jax.tree.flatten(xs) shardings_list = [NamedSharding(dest_mesh, s) for s in jax.tree.leaves(spec)] devices_list = [s._internal_device_list for s in shardings_list] @@ -391,14 +409,14 @@ def _get_client(): @staticmethod def barrier(key: str, current_it: int) -> None: client = SyncServer._get_client() - if client is None: + if client is None or jax.process_count() == 1: return client.wait_at_barrier(key + str(current_it), timeout_in_ms=SyncServer.TIMEOUT_SEC * 1000) @staticmethod def broadcast(key: str, current_it: int, value: Any, is_source: bool = False, jsonify: bool = True) -> None: client = SyncServer._get_client() - if client is None: + if client is None or jax.process_count() == 1: return value if is_source: client.key_value_set(key + str(current_it), json.dumps(value) if jsonify else value) @@ -477,7 +495,7 @@ class PrefillWork: def return_request(resp: DecodeResult): # an optional callback called with results available on decode nodes only # something happens here to output the response to the global queue - # print(f"Finished request: {resp.id}") + # INFO(f"Finished request: {resp.id}") pass @@ -527,62 +545,65 @@ def __init__( cfg: Config, forward_fn: Callable, prefill_weights: Weights, - prefill_cache: KVCache, + prefill_cache_wrapper: AttentionWrapper, decode_weights: Weights, - decode_cache: KVCache, + decode_cache_wrapper: AttentionWrapper, is_server: bool = False, ): - # self.init_cache = init_cache - self.prefill_cache = prefill_cache if not SyncServer.broadcast("welcome", 0, is_server, is_server): raise ValueError("Neither this proccess nor any other processe is the main server, at least one must.") self.serve_cfg, self.cfg = serve_cfg, cfg - # setup decode + # setup decode # self.forward_fn, self.decode_weights = forward_fn, decode_weights self.decode_mesh = [x for x in jax.tree.leaves(decode_weights) if hasattr(x, "sharding")][0].sharding.mesh + self.decode_work = DecodeWork( + None, decode_cache_wrapper.cache, [None for _ in range(serve_cfg.decode_batch_size)] + ) with set_mesh(self.decode_mesh): - self.decode_work = DecodeWork(None, decode_cache, [None for _ in range(serve_cfg.decode_batch_size)]) - self.decode_work.curr_tokens = jax.device_put( - jnp.zeros((serve_cfg.decode_batch_size, 1), dtype=jnp.int32), P() - ) - self.multistep_decode_fn = _make_multistep_decode_fn(self.forward_fn) - self._update_index = jax.jit( - lambda x, i, new: x.at[i, ...].set(new[:, None], mode="drop", out_sharding=jax.typeof(x).sharding) + self.decode_work.curr_tokens = jax.device_put(jnp.zeros((serve_cfg.decode_batch_size, 1), dtype=int), P()) + self.multistep_decode_fn = maybe_call(_make_multistep_decode_fn(self.forward_fn), self.decode_mesh) + _update_tokens = lambda x, i, new: x.at[i, ...].set(new[:, None], mode="drop") + self._update_tokens = jax.jit( + lambda x, i, new: auto_axes(_update_tokens, out_sharding=jax.typeof(x).sharding)(x, i, new) ) + self._get_decode_cache_entry = jax.jit(decode_cache_wrapper.get_sequence) + self.decode_output = (None, None) def _update_cache_and_index(cache: KVCache, curr_tokens: jax.Array, new_tokens, kvs, batch_idxs, actual_lens): # sort to minimize variants num length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] - new_cache = decode_cache.insert_sequences(cache, *sorted_args) + new_cache = decode_cache_wrapper.insert_sequences(cache, *sorted_args) with set_mesh(self.decode_mesh): - new_curr_tokens = self._update_index(curr_tokens, np.array(batch_idxs), np.array(new_tokens)) + new_curr_tokens = self._update_tokens(curr_tokens, np.array(batch_idxs), np.array(new_tokens)) return new_cache, new_curr_tokens self._update_cache_and_index = _update_cache_and_index - self.decode_output = (None, None) - # setup prefill + # setup prefill ################################################################################################ self.prefill_weights = prefill_weights + self.prefill_cache = prefill_cache_wrapper.cache self.prefill_mesh = [x for x in jax.tree.leaves(prefill_weights) if hasattr(x, "sharding")][0].sharding.mesh self.prefill_work = PrefillWork([], [], []) self._get_index = jax.jit(lambda z, idx: jax.tree.map(lambda x: x[:, idx, ...], z)) - self._get_cache_entry = jax.jit(self.decode_work.cache.get_sequence) + self.prefill_fn = maybe_call(self.forward_fn, self.prefill_mesh) + self._get_prefill_cache_entry = maybe_call(jax.jit(prefill_cache_wrapper.get_sequence), self.prefill_mesh) + self._prefill_insert_sequences = maybe_call(prefill_cache_wrapper.insert_sequences, self.prefill_mesh) - # setup misc + # setup misc ################################################################################################### self.pending_requests, self.state_lock, self.results = [], threading.Lock(), {} self.pad_id, self.eos_tokens = 0, np.array(serve_cfg.eos_tokens) self._background = ThreadPoolExecutor(max_workers=1024) - - # setup profiling self.profile_start_time, self.profiling = -1, False - # setup cache management + # setup prefix cache management ################################################################################ self.prefix_cache, self._retrieve_prefix, self._insert_prefix = None, None, None - self.new_prefix_cache() + self.prefix_cache = TrieNode(None, lock=threading.Lock()) + self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) + self._insert_prefix = partial(insert_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) - # setup the sync server for multi-host + # setup the sync server for multi-host ######################################################################### self._it, self.roles = 0, (("server",) if is_server else ()) # main server if any(d.id in [d_.id for d_ in self.decode_mesh.devices.reshape(-1)] for d in jax.local_devices()): self.roles += ("decode",) # any node which has decode mesh devices @@ -629,15 +650,14 @@ def decode_step(self): output_tokens, output_mapping = [], [] with set_mesh(self.decode_mesh): config = dict(cfg=self.cfg, steps=self.serve_cfg.decode_steps) - decode_fn = maybe_call(self.multistep_decode_fn, self.decode_mesh) - (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = decode_fn( + (self.decode_work.curr_tokens, self.decode_work.cache), output_tokens = self.multistep_decode_fn( self.decode_work.curr_tokens, self.decode_weights, self.decode_work.cache, **config ) output_mapping = [ [getattr(result, "id", -1) for result in self.decode_work.active_results] ] * self.serve_cfg.decode_steps output_mapping = np.array(output_mapping).T - print(f"Decoding with fill rate: {np.mean([result is not None for result in self.decode_work.active_results])}") + INFO(f"Decoding with fill rate: {np.mean([result is not None for result in self.decode_work.active_results])}") # 3. parse output tokens from previous decoding loop to allow for the tokens arrive (delayed EOS detection) self.decode_output, (output_tokens, output_mapping) = (output_tokens, output_mapping), self.decode_output @@ -674,7 +694,7 @@ def decode_step(self): result.done, self.decode_work.active_results[i] = True, None if self.serve_cfg.use_prefix_cache: # store the results in the prefix cache buffer store sequence = np.array(result.token_list) - cache_entry, _ = self._get_cache_entry(self.decode_work.cache, i) + cache_entry, _ = self._get_decode_cache_entry(self.decode_work.cache, i) ns = math.ceil(sequence.size / self.serve_cfg.prefix_chunk_size) buffer_ids = BUFFER_STORE._get_unique_buffer_ids(ns) visited_ids, store_ids, del_ids = self._insert_prefix(sequence, buffer_ids) @@ -700,8 +720,6 @@ def prefill_step(self): inputs = np.pad(inputs, ((0, row_pad), (0, col_pad)), mode="constant", constant_values=self.pad_id) with set_mesh(self.prefill_mesh): - #actual_cache_len = np.array(max(job.match_len for job in prefill_input), dtype=np.int32) - #self.prefill_cache.iter = actual_cache_len # TODO: make this explictly cache public interface kvs = [job.cache_entry() if job.cache_entry is not None else None for job in prefill_input] batch_idxs = np.array([i for i, kv in enumerate(kvs) if kv is not None]) actual_lens = np.array([job.match_len for kv, job in zip(kvs, prefill_input) if kv is not None]) @@ -710,19 +728,16 @@ def prefill_step(self): # sort to minimize variants num length_sort = sorted(range(len(kvs)), key=lambda i: jax.tree.leaves(kvs[i])[0].shape[-2]) sorted_args = [[x[i] for i in length_sort] for x in (kvs, batch_idxs, actual_lens)] - insert_sequences = maybe_call(self.prefill_cache.insert_sequences, self.prefill_mesh) - self.prefill_cache = insert_sequences(self.prefill_cache, *sorted_args, erase=True) + self.prefill_cache = self._prefill_insert_sequences(self.prefill_cache, *sorted_args, erase=True) cfg = dataclasses.replace(self.cfg, mesh=self.prefill_mesh) - forward_fn = maybe_call(self.forward_fn, self.prefill_mesh) - _, self.prefill_cache = forward_fn(inputs, self.prefill_weights, self.prefill_cache, cfg) + _, self.prefill_cache = self.prefill_fn(inputs, self.prefill_weights, self.prefill_cache, cfg) with set_mesh(self.prefill_mesh): for i, job in enumerate(prefill_input): - request = job.request - cache_entry, _ = maybe_call(self._get_cache_entry, self.prefill_mesh)(self.prefill_cache, i) + request, sequence = job.request, np.array(job.request.text) + cache_entry, _ = self._get_prefill_cache_entry(self.prefill_cache, i) cache_entry = _ensure_all_args_on_mesh(cache_entry, self.decode_mesh) - sequence = np.array(request.text) new_decode = PrefillResult( request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1 ) @@ -741,10 +756,10 @@ def prefill_step(self): cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.decode_mesh), _axis) new_decode = PrefillResult(request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1) self.prefill_work.to_decode.append(new_decode) - print(f"Found a full match") + INFO(f"Found a full match") else: - print(f"Need to prefill, only found a match for length {total_match / (len(request.text) - 1):.2%}") - print(f"That equals {len(buffer_ids)} buffers or {total_match=}") + INFO(f"Need to prefill, only found a match for length {total_match / (len(request.text) - 1):.2%}") + INFO(f"That equals {len(buffer_ids)} buffers or {total_match=}") if total_match > 0: cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.prefill_mesh), _axis) else: @@ -759,12 +774,12 @@ def serving_step(self): # event loop relies on determinism for multi-host/proce if should_start_profile: self.profile_start_time, self.profiling = time.perf_counter(), True jax.profiler.start_trace("/tmp/online") - print("STARTING TRACE") + DEBUG("STARTING TRACE") should_stop_profile = self.profile_start_time > 0 and time.perf_counter() - self.profile_start_time > 5.0 should_stop_profile = SyncServer.broadcast("stop_profile", self._it, should_stop_profile, is_source=is_server) if should_stop_profile: self.profile_start_time, self.profiling = -1, False - print("STOPPING TRACE") + DEBUG("STOPPING TRACE") jax.profiler.stop_trace() # potentially profile when received the request to ######################################### @@ -787,6 +802,7 @@ def serving_step(self): # event loop relies on determinism for multi-host/proce # main event loop work ##################################################################### self.decode_step() self.prefill_step() + [handler.flush() for handler in logger.handlers] # main event loop work ##################################################################### # offload buffers to keep a max of N ####################################################### @@ -796,7 +812,7 @@ def serving_step(self): # event loop relies on determinism for multi-host/proce refs_to_delete = sorted(BUFFER_STORE.usecount.keys())[:extra_buffer_count] deleted_buffers = remove_prefix_nodes(self.prefix_cache, refs_to_delete) BUFFER_STORE.delete(list(deleted_buffers)) - if len(BUFFER_STORE._store) > self.serve_cfg.max_buffers: + if len(BUFFER_STORE._store) > self.serve_cfg.max_buffers: # DEBUG raise ValueError() # offload buffers to keep a max of N ####################################################### @@ -808,7 +824,18 @@ def update_params(self, params: dict[str, Any]): with self.state_lock: self.serve_cfg = dataclasses.replace(self.serve_cfg, **params) - def new_prefix_cache(self): - self.prefix_cache = TrieNode(None, lock=threading.Lock()) - self._retrieve_prefix = partial(retrieve_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) - self._insert_prefix = partial(insert_prefix, self.prefix_cache, chunk_size=self.serve_cfg.prefix_chunk_size) + def serve_forever(self, shutdown_signal: threading.Event): + def serve_thread(): + try: + while not shutdown_signal.is_set(): + self.serving_step() + except Exception as e: + WARN(traceback.format_exc()) + WARN(f"Exception {e}") + finally: + shutdown_signal.set() + INFO("Received a shutdown signal") + INFO("Exiting the serving loop") + + serving_thread = threading.Thread(target=serve_thread) + serving_thread.start() From ce328b98132de130bba6e5837395488ea9e44d38 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 19 Aug 2025 18:20:08 -0700 Subject: [PATCH 10/11] updating readme and ci from main --- .github/workflows/tests.yaml | 2 +- README.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d97960b..5b118cf 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -32,7 +32,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - model: ["deepseek_r1_jax", "kimi_k2", "llama3", "llama4", "qwen3"] + model: ["deepseek_r1_jax", "kimi_k2", "llama3", "llama4", "qwen3", "gpt_oss"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1 diff --git a/README.md b/README.md index f853a77..f843e64 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Current contents include: * [Llama 3](llama3/) * [Qwen 3](qwen3/) * [Kimi K2](kimi_k2/) +* [OpenAI GPT OSS](gpt_oss/) --- From 9bad388d7220f00d3782f621dfc90f5d5f6d2736 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 2 Sep 2025 14:15:15 -0700 Subject: [PATCH 11/11] TPU workaround --- serving/serving_jax/serving_loop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/serving/serving_jax/serving_loop.py b/serving/serving_jax/serving_loop.py index 35b7924..ed50047 100644 --- a/serving/serving_jax/serving_loop.py +++ b/serving/serving_jax/serving_loop.py @@ -509,6 +509,7 @@ def return_request(resp: DecodeResult): _make_empty = lambda x, mesh: jax.make_array_from_single_device_arrays( x.shape, NamedSharding(mesh, jax.typeof(x).sharding.spec), [], dtype=x.dtype ) +which_platform = lambda cfg: cfg.mesh.devices.reshape(-1)[0].platform def maybe_call(fn: Callable, mesh: Mesh): @@ -752,6 +753,9 @@ def prefill_step(self): BUFFER_STORE.mark_visited(visited_ids) _axis = self.serve_cfg.time_axis - 1 + 1 # batch missing (-1) layers concatenated (+1) buffers = BUFFER_STORE.load(buffer_ids) + if which_platform(self.prefill_mesh) == "tpu" and total_match == sequence.size: + # skip full match on TPU, temporary workaround to ensure buffer consistency + total_match = max(sequence.size - 1, 0) if total_match == sequence.size: cache_entry = partial(_concat, _ensure_all_args_on_mesh(buffers, mesh=self.decode_mesh), _axis) new_decode = PrefillResult(request.id, sequence, request.text[-1], cache_entry, len(request.text) - 1)