Skip to content

[Feature] optimize prefix cache #3107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 186 additions & 26 deletions fastdeploy/cache_manager/cache_messager.py

Large diffs are not rendered by default.

107 changes: 23 additions & 84 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger
Expand All @@ -39,26 +39,12 @@ def parse_args():
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
Expand All @@ -68,7 +54,6 @@ def parse_args():
help="engine worker queue port",
)
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")

parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
Expand Down Expand Up @@ -109,7 +94,6 @@ def __init__(self, args):

device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
Expand Down Expand Up @@ -138,40 +122,27 @@ def __init__(self, args):
self.num_cpu_blocks = args.num_cpu_blocks

cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks

self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
cache_shape = [
args.num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
]

for i in range(args.num_hidden_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
cache_shape[0] = num_gpu_blocks
key_name = f"key_caches_{i}_rank{rank}.device{device}"
value_name = f"value_caches_{i}_rank{rank}.device{device}"
key_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_name, cache_shape)
value_cache = share_external_data(value_cache, value_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[value_name] = value_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])

set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
Expand All @@ -180,7 +151,7 @@ def __init__(self, args):
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
for i in range(args.num_hidden_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
Expand All @@ -190,38 +161,6 @@ def __init__(self, args):
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])

cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
self.cache_ready_signal.value[self.rank] = 1

paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager

self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")

cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal",
Expand Down
104 changes: 84 additions & 20 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def launch_cache_manager(
filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename)

cache_messager_processes = []
if self.splitwise_role != "mixed":
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []

if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
Expand All @@ -151,50 +165,32 @@ def launch_cache_manager(
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size

cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
Expand All @@ -204,8 +200,76 @@ def launch_cache_manager(
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
cache_manager_processes.extend(cache_messager_processes)
return cache_manager_processes

def launch_cache_messager(
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
):
"""
launch_cache_messager function used to initialize the cache messager.
"""
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_messager.py"
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size

cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=pid_suffix,
create=True,
)

py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --block_size {cache_config.block_size}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_messager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache messager successful")
else:
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
cache_messager_processes = None
return cache_messager_processes

def update_cache_config(self, cache_config):
"""
update cache config
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def create_scheduler_config(self) -> SchedulerConfig:
"max_num_partial_prefills",
"max_long_partial_prefills",
"long_prefill_token_threshold",
"splitwise_role"
]

all = asdict(self)
Expand Down
Loading
Loading