diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 409941f7d8..cee0288ea0 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -143,12 +143,18 @@ def __init__( self.gpu_id = gpu_id self.cache_info = dict() - self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks + self.rank_id = ( + self.rank + local_data_parallel_id * self.nranks + ) # align with engine worker rank (paddle.distributed.launch) layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) layerwise_send_cache_thread.daemon = True layerwise_send_cache_thread.start() + connect_rdma_thread = threading.Thread(target=self._handle_connect_task) + connect_rdma_thread.daemon = True + connect_rdma_thread.start() + logger.info(f"cache messager init finished, use {transfer_protocol}") def _prefill_layerwise_send_cache_thread(self): @@ -161,14 +167,14 @@ def _prefill_layerwise_send_cache_thread(self): prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) try: step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_step_{self.rank_id}", array=prefilled_step_idx_data, dtype=np.int32, suffix=self.gpu_id, create=True, ) layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_layer_{self.rank_id}", array=prefilled_layer_idx_data, dtype=np.int32, suffix=self.gpu_id, @@ -176,14 +182,14 @@ def _prefill_layerwise_send_cache_thread(self): ) except: step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_step_{self.rank_id}", array=prefilled_step_idx_data, dtype=np.int32, suffix=self.gpu_id, create=False, ) layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_layer_{self.rank_id}", array=prefilled_layer_idx_data, dtype=np.int32, suffix=self.gpu_id, @@ -196,6 +202,9 @@ def _prefill_layerwise_send_cache_thread(self): self.last_step_idx = -1 self.last_layer_idx = -1 # int32 + max_step_idx = 100003 + engine_recycled_count = 0 + while True: cache_info = self.engine_worker_queue.get_cache_info() @@ -215,7 +224,6 @@ def _prefill_layerwise_send_cache_thread(self): current_info["status"] = "init" logger.info(f"start cache_infos: {current_info}") self.cache_info[info["request_id"]] = current_info - self.last_step_idx = min(self.last_step_idx, current_info["current_id"]) else: self.cache_info[info["request_id"]] = info prefilled_layer_idx = layer_shm_value.value[0] @@ -231,7 +239,17 @@ def _prefill_layerwise_send_cache_thread(self): if not self.cache_info: time.sleep(0.001) continue - logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") + if self.last_step_idx > prefilled_step_idx: + engine_recycled_count += 1 + self.last_step_idx = prefilled_step_idx # only copy value read from shm memory + prefilled_step_idx = ( + prefilled_step_idx + max_step_idx * engine_recycled_count + ) # remap prefilled_step_idx for comparison + + logger.debug( + f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx}," + f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}" + ) for req_id, item in list(self.cache_info.items()): if "status" not in item: continue @@ -305,9 +323,26 @@ def _prefill_layerwise_send_cache_thread(self): self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) logger.info(f"put write cache {item['request_id']}") del self.cache_info[req_id] - - self.last_step_idx = prefilled_step_idx - self.last_layer_idx = prefilled_layer_idx + self.last_layer_idx = prefilled_layer_idx except Exception as e: logger.error(f"prefill layerwise send cache thread has exception: {e}, {str(traceback.format_exc())}") + + def _handle_connect_task(self): + while True: + try: + task = self.engine_worker_queue.get_connect_rdma_task() + if task is None: + time.sleep(0.001) + continue + logger.info(f"_handle_connect_task recv task: {task}") + task_id = task["task_id"] + ip, rdma_port = task["ip"], task["rdma_port"] + status = self.messager["rdma"].connect(ip, rdma_port) + if not status: + response = {"task_id": task_id, "success": False} + else: + response = {"task_id": task_id, "success": True} + self.engine_worker_queue.put_connect_rdma_task_response(response) + except Exception as e: + logger.error(f"handle_connect_task has exception: {e}") diff --git a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py index 94abbb3b8e..6a0c0ac36e 100644 --- a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py @@ -61,18 +61,12 @@ def connect(self, ip, port): Connect to remote gpu and write cache. """ assert self.splitwise_role == "prefill", "only prefill can call this method" - addr = f"{ip}:{port!s}" - if addr in self.connected_rdma: - return True ret = self.messager.is_connected(ip, str(port)) if ret: - self.connected_rdma.add(addr) return True ret = self.messager.connect(ip, str(port)) logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}") - if ret == 0: - self.connected_rdma.add(addr) return ret == 0 def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx): diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 359ad6ba63..b83b7a3e22 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -884,6 +884,7 @@ def create_scheduler_config(self) -> SchedulerConfig: "max_num_partial_prefills", "max_long_partial_prefills", "long_prefill_token_threshold", + "splitwise_role", ] all = asdict(self) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 49f776577e..a308899298 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -47,12 +47,14 @@ EngineCacheQueue, EngineWorkerQueue, IPCSignal, - ZmqClient, + ZmqIpcServer, + ZmqTcpServer, ) from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -180,9 +182,19 @@ def start(self, api_server_pid=None): self.data_processor = self.input_processor.create_processor() if api_server_pid is not None: - self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) - self.zmq_server.start_server() - self.zmq_server.create_router() + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) + self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) + self.external_adapter = InternalAdapter( + cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node + ) + else: + self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) + self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) + self.recv_result_handle_thread = threading.Thread( + target=self.send_response_server.recv_result_handle, daemon=True + ) + self.recv_result_handle_thread.start() time.sleep(3) if self.do_profile == 0 and ( @@ -259,7 +271,7 @@ def _zmq_send_generated_tokens(self): time.sleep(0.005) continue for request_id, contents in results.items(): - self.zmq_server.send_multipart(request_id, contents) + self.send_response_server.send_response(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") @@ -276,7 +288,7 @@ def _insert_task_to_worker(self): Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ - current_id = -1 + current_id = 0 while self.running: try: if self.resource_manager.available_batch() == 0: @@ -314,12 +326,15 @@ def _insert_task_to_worker(self): time.sleep(0.001) continue - current_id = (current_id + 1) % 100003 if self.cfg.splitwise_role != "mixed": llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) - self.insert_tasks(tasks, current_id) + insert_successful = self.insert_tasks(tasks, current_id) + if insert_successful: + current_id = current_id + 1 + else: + continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) @@ -383,14 +398,18 @@ def _insert_zmq_task_to_scheduler(self): if self.api_server_pid is None: return + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if self.cfg.splitwise_role == "decode": + return + added_requests: Dict[str, int] = dict() while self.running: try: block = True if len(added_requests) == 0 else False if not self.cfg.model_config.enable_mm: - err, data = self.zmq_server.receive_json_once(block) + err, data = self.recv_request_server.receive_json_once(block) else: - err, data = self.zmq_server.receive_pyobj_once(block) + err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") break @@ -438,7 +457,7 @@ def _insert_zmq_task_to_scheduler(self): ) # Since the request is not in scheduler # Send result by zmq directly - self.zmq_server.send_multipart(request_id, error_result) + self.send_response_server.send_response(request_id, [error_result]) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " @@ -1003,8 +1022,12 @@ def _exit_sub_services(self): console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") self.engine_worker_queue.cleanup() - if hasattr(self, "zmq_server") and self.zmq_server is not None: - self.zmq_server.close() + if hasattr(self, "send_response_server") and self.send_response_server is not None: + self.send_response_server.close() + if hasattr(self, "recv_request_server") and self.recv_request_server is not None: + self.recv_request_server.close() + if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: + self.recv_control_cmd_server.close() if hasattr(self, "dp_processed"): for p in self.dp_processed: p.join() diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 3b1e28c5df..9fcfb22c23 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -16,6 +16,7 @@ from __future__ import annotations +import copy import os import signal import threading @@ -29,8 +30,9 @@ from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector -from fastdeploy.utils import EngineError, console_logger, llm_logger +from fastdeploy.utils import EngineError, console_logger, envs, llm_logger class ExpertService: @@ -59,7 +61,7 @@ def __init__(self, cfg, local_data_parallel_id): self.cfg.disaggregate_info = None self.scheduler = cfg.scheduler_config.scheduler() - if cfg.splitwise_role != "mixed": + if self.cfg.scheduler_config.name == "splitwise": self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id @@ -111,8 +113,12 @@ def __init__(self, cfg, local_data_parallel_id): ) self._finalizer = weakref.finalize(self, self._exit_sub_services) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) - def start(self, ipc_signal_suffix, local_data_parallel_id): + def start( + self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None + ): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread @@ -145,7 +151,11 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info - self.scheduler.start(role, host_ip, disaggregate) + if self.cfg.scheduler_config.name == "dp": + assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None) + self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) + elif self.cfg.scheduler_config.name == "splitwise": + self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() launched_expert_service_signal_data = np.zeros( @@ -171,7 +181,7 @@ def _insert_task_to_worker(self): Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ - current_id = -1 + current_id = 0 while True: try: if self.resource_manager.available_batch() == 0: @@ -206,9 +216,11 @@ def _insert_task_to_worker(self): llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) - current_id = (current_id + 1) % 100003 - - self.insert_tasks(tasks, current_id) + insert_successful = self.insert_tasks(tasks, current_id) + if insert_successful: + current_id = current_id + 1 + else: + continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) @@ -283,6 +295,9 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False): cur_task_idx = self.resource_manager.req_dict[task.request_id] del self.resource_manager.req_dict[task.request_id] cur_task = self.resource_manager.tasks_list[cur_task_idx] + cur_task.prompt_token_ids[0] = task.outputs.token_ids[0] + if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode": + cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids) if task.error_code != 200: self.resource_manager.stop_flags[cur_task_idx] = True self.resource_manager.tasks_list[cur_task_idx] = None @@ -369,13 +384,17 @@ def _exit_sub_services(self): self.zmq_server.close() -def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix): +def start_expert_service( + cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None +): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) try: - expert_service.start(ipc_signal_suffix, local_data_parallel_id) + expert_service.start( + ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc + ) expert_service.split_connector.start_receiver() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}") diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 67c0caa08f..aa17471ce1 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -72,6 +72,7 @@ def __init__( guided_json_object: Optional[bool] = None, enable_thinking: Optional[bool] = True, trace_carrier: dict = dict(), + dp_rank: Optional[int] = None, chat_template: Optional[str] = None, ) -> None: self.request_id = request_id @@ -123,6 +124,7 @@ def __init__( self.task_type = RequestType.PREFILL self.idx = None self.need_prefill_tokens = self.prompt_token_ids_len + self.dp_rank = dp_rank @classmethod def from_dict(cls, d: dict): @@ -155,6 +157,7 @@ def from_dict(cls, d: dict): guided_json_object=d.get("guided_json_object", None), enable_thinking=d.get("enable_thinking", True), trace_carrier=d.get("trace_carrier", {}), + dp_rank=d.get("dp_rank", None), chat_template=d.get("chat_template", None), ) diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index aad0a624d9..b7e011a294 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -317,8 +317,8 @@ def _delete_cached_data(self, task, cached_len): Delete cached data from the task's prompt token ids based on the cached length. """ if cached_len == len(task.prompt_token_ids): - task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :] - task.seq_lens_decoder = cached_len - 1 + task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :] + task.seq_lens_decoder = cached_len - self.cfg.block_size else: task.prompt_token_ids = task.prompt_token_ids[cached_len:] task.seq_lens_decoder = cached_len diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index c407a76633..a1c1d76040 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -26,7 +26,7 @@ from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import IPCSignal, ZmqClient +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform @@ -102,7 +102,7 @@ def create_zmq_client(self, model, mode): """ Create a ZMQ client. """ - self.zmq_client = ZmqClient(model, mode) + self.zmq_client = ZmqIpcClient(model, mode) self.zmq_client.connect() def format_and_add_data(self, prompts: dict): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0155e260f0..c57e970e0f 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -80,6 +80,14 @@ "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), # enable kv cache block scheduler v1 (no need for kv_cache_ratio) "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), + # enable internal module to access LLMEngine. + "FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")), + # LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), + # LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), + # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), # Whether to use PLUGINS. "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), # set trace attribute job_id. diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 0c1cc0d9fc..ea08af31a4 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -17,6 +17,7 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue from .ipc_signal import IPCSignal -from .zmq_client import ZmqClient +from .zmq_client import ZmqIpcClient +from .zmq_server import ZmqIpcServer, ZmqTcpServer -__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"] +__all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"] diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index da88265a26..e216f430d2 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -85,12 +85,15 @@ class QueueManager(BaseManager): ] self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] + self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] + self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] self.client_read_info_flag_init: List[List[int]] = [ [1] * self.num_client for _ in range(self.local_data_parallel_size) ] self.lock_info_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] + self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)] self.finish_request_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) @@ -112,11 +115,26 @@ class QueueManager(BaseManager): callable=lambda idx: self.lock_init[idx], proxytype=AcquirerProxy, ) + QueueManager.register( + "get_connect_task_lock", + callable=lambda idx: self.connect_task_lock_init[idx], + proxytype=AcquirerProxy, + ) QueueManager.register( "get_read_finish_flag", callable=lambda idx: self.read_finish_flag_init[idx], proxytype=ValueProxy, ) + QueueManager.register( + "get_connect_rdma_tasks", + callable=lambda idx: self.connect_rdma_tasks_list[idx], + proxytype=ListProxy + ) + QueueManager.register( + "get_connect_rdma_tasks_responses", + callable=lambda idx: self.connect_rdma_tasks_response_list[idx], + proxytype=ListProxy + ) QueueManager.register( "get_connected_client_counter", callable=lambda idx: self.connected_client_counter_init[idx], @@ -180,6 +198,9 @@ class QueueManager(BaseManager): QueueManager.register("get_disaggregate_requests") QueueManager.register("get_available_prefill_instances") QueueManager.register("get_finish_request_barrier") + QueueManager.register("get_connect_rdma_tasks") + QueueManager.register("get_connect_rdma_tasks_responses") + QueueManager.register("get_connect_task_lock") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -200,6 +221,13 @@ class QueueManager(BaseManager): self.available_prefill_instances = self.manager.get_available_prefill_instances() self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) + # p/d互联 + self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) + self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses( + self.local_data_parallel_id + ) + self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) + assert self.num_client == len(self.client_read_flag) if is_server: @@ -280,6 +308,45 @@ def num_tasks(self) -> int: total_num: int = len(self.tasks) self.lock.release() return total_num + + def put_connect_rdma_task(self, connect_rdma_task): + self.connect_task_lock.acquire() + self.connect_rdma_task_queue.append(connect_rdma_task) + self.connect_task_lock.release() + + def get_connect_rdma_task(self): + result = None + self.connect_task_lock.acquire() + if len(self.connect_rdma_task_queue) == 0: + self.connect_task_lock.release() + return result + try: + result = self.connect_rdma_task_queue.pop(0) + except Exception as e: + llm_logger.info(f"get_connect_rdma_task got exception: {e}") + finally: + self.connect_task_lock.release() + return result + + def put_connect_rdma_task_response(self, connect_rdma_task_response): + self.connect_task_lock.acquire() + self.connect_rdma_task_response_queue.append(connect_rdma_task_response) + self.connect_task_lock.release() + + def get_connect_rdma_task_response(self): + result = None + self.connect_task_lock.acquire() + if len(self.connect_rdma_task_response_queue) == 0: + self.connect_task_lock.release() + return result + try: + result = self.connect_rdma_task_response_queue.pop(0) + except Exception as e: + llm_logger.info(f"get_connect_rdma_task_response got exception: {e}") + finally: + self.connect_task_lock.release() + return result + def get_prefill_instances(self): """ diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 6affcd8e7a..13242f2a20 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -14,209 +14,78 @@ # limitations under the License. """ -import os -import threading -import time -import traceback +from abc import ABC, abstractmethod -import msgpack import zmq -from fastdeploy import envs -from fastdeploy.utils import llm_logger - -class ZmqClient: +class ZmqClientBase(ABC): """ - ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. + ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ. """ - def __init__(self, name, mode): - self.context = zmq.Context(4) - self.socket = self.context.socket(mode) - self.file_name = f"/dev/shm/{name}.socket" - self.router_path = f"/dev/shm/router_{name}.ipc" + def __init__(self): + pass - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass - self.mutex = threading.Lock() - self.req_dict = dict() - self.router = None - self.poller = None - self.running = True + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + @abstractmethod def connect(self): """ Connect to the server using the file name specified in the constructor. """ - self.socket.connect(f"ipc://{self.file_name}") - - def start_server(self): - """ - Start the server using the file name specified in the constructor. - """ - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.socket.setsockopt(zmq.SNDTIMEO, -1) - self.socket.bind(f"ipc://{self.file_name}") - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) - - def create_router(self): - """ - Create a ROUTER socket and bind it to the specified router path. - """ - self.router = self.context.socket(zmq.ROUTER) - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) - self.router.setsockopt(zmq.SNDTIMEO, -1) - self.router.bind(f"ipc://{self.router_path}") + pass def send_json(self, data): """ Send a JSON-serializable object over the socket. """ + self._ensure_socket() self.socket.send_json(data) def recv_json(self): """ Receive a JSON-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_json() def send_pyobj(self, data): """ Send a Pickle-serializable object over the socket. """ + self._ensure_socket() self.socket.send_pyobj(data) def recv_pyobj(self): """ Receive a Pickle-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_pyobj() - def pack_aggregated_data(self, data): - """ - Aggregate multiple responses into one and send them to the client. - """ - result = data[0] - if len(data) > 1: - for response in data[1:]: - result.add(response) - result = msgpack.packb([result.to_dict()]) - return result - - def send_multipart(self, req_id, data): - """ - Send a multipart message to the router socket. - """ - if self.router is None: - raise RuntimeError("Router socket not created. Call create_router() first.") - - while self.running: - with self.mutex: - if req_id not in self.req_dict: - try: - client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK) - req_id_str = request_id.decode("utf-8") - self.req_dict[req_id_str] = client - except zmq.Again: - time.sleep(0.001) - continue - else: - break - - if self.req_dict[req_id] == -1: - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - return - try: - start_send = time.time() - if self.aggregate_send: - result = self.pack_aggregated_data(data) - else: - result = msgpack.packb([response.to_dict() for response in data]) - self.router.send_multipart([self.req_dict[req_id], b"", result]) - llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") - except zmq.ZMQError as e: - llm_logger.error(f"[{req_id}] zmq error: {e}") - self.req_dict[req_id] = -1 - except Exception as e: - llm_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}") - - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - llm_logger.info(f"send_multipart finished, req_id: {req_id}") - - def receive_json_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_json(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - llm_logger.warning(f"{e}, {str(traceback.format_exc())}") - return str(e), None - - def receive_pyobj_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_pyobj(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - llm_logger.warning(f"{e}, {str(traceback.format_exc())}") - return str(e), None - - def _clear_ipc(self, name): - """ - Remove the IPC file with the given name. - """ - if os.path.exists(name): - try: - os.remove(name) - except OSError as e: - llm_logger.warning(f"Failed to remove IPC file {name} - {e}") - - def close(self): - """ - Close the socket and context, and remove the IPC files. - """ - if not self.running: - return - - self.running = False - llm_logger.info("Closing ZMQ connection...") - try: - if hasattr(self, "socket") and not self.socket.closed: - self.socket.close() - if self.router is not None and not self.router.closed: - self.router.close() - - if not self.context.closed: - self.context.term() +class ZmqIpcClient(ZmqClientBase): + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.file_name = f"/dev/shm/{name}.socket" + self.context = zmq.Context() + self.socket = self.context.socket(self.mode) - self._clear_ipc(self.file_name) - self._clear_ipc(self.router_path) - except Exception as e: - llm_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}") - return + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.context = zmq.Context() + return self.context.socket(self.mode) - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + def connect(self): + self._ensure_socket() + self.socket.connect(f"ipc://{self.file_name}") diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py new file mode 100644 index 0000000000..ab97e3bbd4 --- /dev/null +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -0,0 +1,303 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import threading +import time +from abc import ABC, abstractmethod +from collections import defaultdict + +import msgpack +import zmq + +from fastdeploy import envs +from fastdeploy.utils import llm_logger + + +class ZmqServerBase(ABC): + """ + ZmqServerBase + """ + + def __init__(self): + self.cached_results = defaultdict(list) + self.response_token_lock = threading.Lock() + + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass + + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + + def pack_aggregated_data(self, data): + """ + Aggregate multiple responses into one and send them to the client. + """ + result = data[0] + if len(data) > 1: + for response in data[1:]: + result.add(response) + result = msgpack.packb([result.to_dict()]) + return result + + def receive_json_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_json(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def receive_pyobj_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_pyobj(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def recv_result_handle(self): + while True: + try: + with self.response_token_lock: + client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) + req_id_str = request_id.decode("utf-8") + with self.mutex: + self.req_dict[req_id_str] = client + except zmq.Again: + time.sleep(0.001) + continue + except Exception as e: + llm_logger.error(f"recv_result_handle get unknown exception: {e}") + continue + + def send_response(self, req_id, data): + """ + Send generated token result to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created. Call create_router() first.") + new_data = [] + has_result_handle = False + with self.mutex: + if req_id not in self.req_dict: + self.cached_results[req_id].append(data) + else: + has_result_handle = True + if req_id in self.cached_results: + for history_data in self.cached_results[req_id]: + new_data.extend(history_data) + llm_logger.info( + f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}" + ) + del self.cached_results[req_id] + if has_result_handle: + try: + new_data.extend(data) + start_send = time.time() + if self.aggregate_send: + result = self.pack_aggregated_data(new_data) + else: + result = msgpack.packb([response.to_dict() for response in new_data]) + with self.response_token_lock: + self.socket.send_multipart([self.req_dict[req_id], b"", result]) + llm_logger.debug( + f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}" + ) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + if data[-1].finished: + with self.mutex: + if req_id not in self.req_dict: + llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it") + if req_id in self.cached_results: + del self.cached_results[req_id] + else: + llm_logger.info(f"send_multipart finished, req_id: {req_id}") + self.req_dict.pop(req_id, None) + + @abstractmethod + def close(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class ZmqIpcServer(ZmqServerBase): + """ + ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0 + """ + + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.cached_results = defaultdict(list) + if mode == zmq.PULL: + self.file_name = f"/dev/shm/{name}.socket" + elif mode == zmq.ROUTER: + self.file_name = f"/dev/shm/router_{name}.ipc" + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + self.mutex = threading.Lock() + self.response_token_lock = threading.Lock() + self.req_dict = dict() + self.running = True + self.context = zmq.Context() + self._create_socket() + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"ipc://{self.file_name}") + return self.socket + + def _clear_ipc(self, name): + """ + Remove the IPC file with the given name. + """ + if os.path.exists(name): + try: + os.remove(name) + except OSError as e: + llm_logger.warning(f"Failed to remove IPC file {name} - {e}") + + def close(self): + """ + Close the socket and context, and remove the IPC files. + """ + if not self.running: + return + + self.running = False + llm_logger.info("Closing ZMQ connection...") + try: + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + self._clear_ipc(self.file_name) + except Exception as e: + llm_logger.warning(f"Failed to close ZMQ connection - {e}") + return + + +class ZmqTcpServer(ZmqServerBase): + """ + ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1 + """ + + def __init__(self, port, mode): + self.mode = mode + self.port = port + self.cached_results = defaultdict(list) + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + + self.mutex = threading.Lock() + self.req_dict = dict() + self.running = True + self.context = zmq.Context() + self._create_socket() + self.mutex = threading.Lock() + self.response_token_lock = threading.Lock() + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"tcp://*:{self.port}") + return self.socket + + def recv_control_cmd(self): + """ + Recieve control command from client + """ + self._ensure_socket() + try: + client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) + task = msgpack.unpackb(task_data) + task_id_str = task["task_id"] + except zmq.Again: + return None + with self.mutex: + self.req_dict[task_id_str] = client + return task + + def response_for_control_cmd(self, task_id, result): + """ + Send command result back to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created.") + try: + result = msgpack.packb(result) + self.socket.send_multipart([self.req_dict[task_id], b"", result]) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + with self.mutex: + self.req_dict.pop(task_id, None) + llm_logger.debug(f"response control cmd finished, task_id: {task_id}") + + def close(self): + """ + Close the socket and context. + """ + if not self.running: + return + + self.running = False + llm_logger.info("Closing ZMQ connection...") + try: + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + + except Exception as e: + llm_logger.warning(f"Failed to close ZMQ connection - {e}") + return diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 30b87d65b1..adb148cc6e 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -73,6 +73,7 @@ update_inputs, step_reschedule, update_inputs_v1, + speculate_step_reschedule, ) from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput @@ -374,12 +375,11 @@ def step_cuda( """ if speculative_config.method is not None: - if enable_prefix_caching: - speculate_step_system_cache( + if DISABLE_RECOVER: + speculate_step_reschedule( share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], share_inputs["step_seq_lens_encoder"], - share_inputs["step_seq_lens_decoder"], share_inputs["seq_lens_encoder"], share_inputs["seq_lens_decoder"], share_inputs["block_tables"], @@ -405,64 +405,67 @@ def step_cuda( speculative_config.num_speculative_tokens, ) else: - speculate_step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - share_inputs["accept_num"], - block_size, - enc_dec_block_num, - speculative_config.num_speculative_tokens, - ) + if enable_prefix_caching: + speculate_step_system_cache( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["step_seq_lens_decoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + speculative_config.num_speculative_tokens, + ) + else: + speculate_step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + speculative_config.num_speculative_tokens, + ) else: - if enable_prefix_caching: - step_system_cache( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["step_seq_lens_decoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) - elif DISABLE_RECOVER: + if DISABLE_RECOVER: step_reschedule( share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], @@ -490,32 +493,61 @@ def step_cuda( enc_dec_block_num, ) else: - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) + if enable_prefix_caching: + step_system_cache( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["step_seq_lens_decoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) + else: + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) def rebuild_padding( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index c72150a284..4260b5c196 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -156,7 +156,14 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + else: + + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) if self.output_tokens[0] == -2: continue @@ -229,13 +236,13 @@ def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False llm_logger.info(f"finished_task_id: {finished_task_id}") self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] if task_id in self.prefill_result_status: - self.split_connector.send_first_token(task.disaggregate_info, [result]) self.resource_manager.stop_flags[index] = True self.resource_manager.tasks_list[index] = None self.resource_manager._recycle_block_tables(task) if self.prefill_result_status[task_id] != "finished": result.error_code = 400 result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" + self.split_connector.send_first_token(task.disaggregate_info, [result]) del self.resource_manager.req_dict[task_id] break else: @@ -314,16 +321,22 @@ def _process_batch_output(self): task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -418,7 +431,11 @@ def _process_batch_output(self): self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) break - if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + if ( + not is_prefill + or self.cfg.scheduler_config.name == "splitwise" + or self.cfg.scheduler_config.name == "dp" + ): batch_result.append(result) self.postprocess(batch_result) @@ -458,7 +475,7 @@ def _record_speculative_decoding_mertics(self, accept_num): self.cfg.speculative_config.num_speculative_tokens, ) - real_accept_num = [x for x in accept_num if x != 0] + real_accept_num = [x for x in accept_num if x > 0] num_accepted_tokens = sum([x - 1 for x in real_accept_num]) self.num_accepted_tokens += num_accepted_tokens num_emitted_tokens = sum(real_accept_num) diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index cd0a72af1a..c831b0f44a 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -18,6 +18,7 @@ from fastdeploy.utils import llm_logger +from .dp_scheduler import DPScheduler from .global_scheduler import GlobalScheduler from .local_scheduler import LocalScheduler from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig @@ -89,6 +90,57 @@ def print(self): llm_logger.info("=============================================================") +class DPLocalSchedulerConfig(LocalSchedulerConfig): + """ + Configuration class for DPLocalScheduler. + + Attributes: + max_size: Maximum number of concurrent requests (-1 for unlimited) + ttl: Time-to-live in seconds for request expiration + """ + + def __init__( + self, + max_size: int = -1, + ttl: int = 900, + max_model_len: int = 8192, + enable_chunked_prefill: bool = False, + max_num_partial_prefills: int = 1, + max_long_partial_prefills: int = 1, + long_prefill_token_threshold: int = 0, + splitwise_role: str = "prefill", + **kwargs, + ): + """ + Initialize LocalScheduler configuration. + + Args: + max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled) + ttl: Time-to-live in seconds for request expiration (default 900s) + max_model_len: Maximum model context length in tokens + enable_chunked_prefill: Whether to enable chunked prefill processing + max_num_partial_prefills: Max partial prefill operations allowed + max_long_partial_prefills: Max long-running partial prefill ops + long_prefill_token_threshold: Token count threshold for long prefill + **kwargs: Additional unused arguments (for forward compatibility) + + Note: + - If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len + - See LocalScheduler class for implementation details + """ + self.max_size = max_size + self.ttl = ttl + + self.max_model_len = max_model_len + self.enable_chunked_prefill = enable_chunked_prefill + self.max_num_partial_prefills = max_num_partial_prefills + self.max_long_partial_prefills = max_long_partial_prefills + self.long_prefill_token_threshold = long_prefill_token_threshold + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + self.splitwise_role = splitwise_role + + class GlobalSchedulerConfig: """ Configuration class for GlobalScheduler (Redis-based). @@ -229,6 +281,9 @@ def __init__(self, name="local", **kwargs): if name == "splitwise": self.config = SplitWiseSchedulerConfig(**kwargs) + if name == "dp": + self.config = DPLocalSchedulerConfig(**kwargs) + def check(self): """ Validate the configuration. @@ -236,7 +291,7 @@ def check(self): Raises: Exception: If invalid scheduler type is specified """ - if self.name not in ["local", "global", "splitwise"]: + if self.name not in ["local", "global", "splitwise", "dp"]: raise Exception(f"Unknown scheduler type {self.name}") self.config.check() @@ -274,6 +329,17 @@ def scheduler(self): if self.name == "splitwise": return SplitWiseScheduler(self.config) + if self.name == "dp": + return DPScheduler( + max_size=self.config.max_size, + ttl=self.config.ttl, + enable_chunked_prefill=self.config.enable_chunked_prefill, + max_num_partial_prefills=self.config.max_num_partial_prefills, + max_long_partial_prefills=self.config.max_long_partial_prefills, + long_prefill_token_threshold=self.config.long_prefill_token_threshold, + splitwise_role=self.config.splitwise_role, + ) + return LocalScheduler( max_size=self.config.max_size, ttl=self.config.ttl, diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py new file mode 100644 index 0000000000..d5d1d39674 --- /dev/null +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -0,0 +1,258 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import logging +import threading +import time +from multiprocessing import Queue +from typing import Dict, List, Optional + +from fastdeploy.engine.request import Request, RequestOutput +from fastdeploy.scheduler.data import ScheduledResponse +from fastdeploy.scheduler.local_scheduler import LocalScheduler +from fastdeploy.utils import envs, get_logger + + +class DPLocalScheduler(LocalScheduler): + def __init__( + self, + max_size: int, + ttl: int, + enable_chunked_prefill: bool, + max_num_partial_prefills: int, + max_long_partial_prefills: int, + long_prefill_token_threshold: int, + splitwise_role: str = "prefill", + ): + super().__init__( + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + ) + self.splitwise_role = splitwise_role + self.scheduler_logger = logging + + def put_results(self, results: List[RequestOutput]): + """ + Add processing results back to the scheduler. + Args: + results: List of RequestOutput objects containing results + """ + responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results] + + finished_responses = [response.request_id for response in responses if response.finished] + if len(finished_responses) > 0: + self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") + + with self.mutex: + for response in responses: + if response.request_id not in self.responses: + self.responses[response.request_id] = [response] + continue + self.responses[response.request_id].append(response) + self.responses_not_empty.notify_all() + + def _recycle(self, request_id: Optional[str] = None): + """ + Clean up expired or completed requests to free memory. + Args: + request_id: Optional specific request ID to remove. + If None, removes all expired requests. + """ + if request_id is not None: + self.requests.pop(request_id, None) + self.responses.pop(request_id, None) + if self.splitwise_role == "decode": + return + self.ids.pop(self.ids.index(request_id)) + self.ids_read_cursor -= 1 + return + + if self.max_size <= 0: + return + + if len(self.requests) <= self.max_size: + return + + now = time.time() + expired_ids = [] + for request_id in self.ids: + request = self.requests[request_id] + if now - request.schedule_time < self.ttl: + break + expired_ids.append(request.request_id) + + for i, expired_id in enumerate(expired_ids): + self.requests.pop(expired_id, None) + self.responses.pop(expired_id, None) + self.ids.pop(i) + + if len(expired_ids) > 0: + if len(expired_ids) - 1 >= self.ids_read_cursor: + self.ids_read_cursor = 0 + else: + self.ids_read_cursor -= len(expired_ids) + + def get_requests( + self, + available_blocks, + block_size, + reserved_output_blocks, + max_num_batched_tokens, + batch=1, + ) -> List[Request]: + """ + Retrieve requests from the scheduler based on available resources. + + Args: + available_blocks: Number of available processing blocks + block_size: Size of each processing block + reserved_output_blocks: Blocks reserved for output + max_num_batched_tokens: Maximum tokens that can be batched + batch: Preferred batch size + + Returns: + List of Request objects ready for processing + """ + if available_blocks <= reserved_output_blocks or batch < 1: + self.scheduler_logger.debug( + f"Scheduler's resource are insufficient: available_blocks={available_blocks} " + f"reserved_output_blocks={reserved_output_blocks} batch={batch} " + f"max_num_batched_tokens={max_num_batched_tokens}" + ) + return [] + required_total_blocks = 0 + current_prefill_tokens = 0 + start_batch_time = time.time() + requests: List[Request] = [] + + with self.requests_not_empty: + while True: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + if current_prefill_tokens > max_num_batched_tokens: + break + + requests.append(request.raw) + self.ids_read_cursor += 1 + start_batch_time = time.time() + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break + if batch_ids: + if len(batch_ids) > 0 and len(requests) == 0: + self.scheduler_logger.debug( + f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" + ) + + if len(requests) > 0: + self.scheduler_logger.info( + f"Scheduler has pulled some request: {[request.request_id for request in requests]}" + ) + + return requests + + +class DPScheduler: + def __init__( + self, + max_size: int, + ttl: int, + enable_chunked_prefill: bool, + max_num_partial_prefills: int, + max_long_partial_prefills: int, + long_prefill_token_threshold: int, + splitwise_role: str = "prefill", + ): + self._scheduler = DPLocalScheduler( + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + splitwise_role, + ) + + def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue): + self.dp_rank = dp_rank + self.request_queues = request_queues + self.result_queue = result_queue + self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log") + self._scheduler.scheduler_logger = self.scheduler_logger + threading.Thread(target=self._put_requests_to_local).start() + threading.Thread(target=self._get_response_from_local).start() + + def put_requests(self, requests: List[Dict]): + results = [] + for request in requests: + if not hasattr(request, "dp_rank"): + raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}") + self.request_queues[request.dp_rank].put(request) + results.append((request.request_id, None)) + return results + + def _put_requests_to_local(self): + while True: + request = self.request_queues[self.dp_rank].get() + self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}") + self._scheduler.put_requests([request]) + + def _get_response_from_local(self): + while True: + results = self._scheduler.get_results() + if len(results) == 0: + continue + self.result_queue.put(results) + + def get_requests( + self, + available_blocks, + block_size, + reserved_output_blocks, + max_num_batched_tokens, + batch=1, + ) -> List[Request]: + return self._scheduler.get_requests( + available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch + ) + + def get_unhandled_request_num(self): + return len(self._scheduler.requests) + + def put_results(self, results: List[RequestOutput]): + self._scheduler.put_results(results) + + def get_results(self) -> Dict[str, List[RequestOutput]]: + return self.result_queue.get() diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index 5d79e50090..20e53317b7 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -208,6 +208,9 @@ def calc_required_blocks(self, token_num, block_size): """ return (token_num + block_size - 1) // block_size + def get_unhandled_request_num(self): + return len(self.requests) + def get_requests( self, available_blocks, diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py new file mode 100644 index 0000000000..d52edf897a --- /dev/null +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -0,0 +1,117 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import time +import traceback + +# **Note**: Just for internal use +import zmq + +from fastdeploy.inter_communicator import ZmqTcpServer +from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics +from fastdeploy.utils import envs, get_logger + +logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log") + + +class InternalAdapter: + def __init__(self, cfg, engine, dp_rank): + self.cfg = cfg + self.engine = engine + self.dp_rank = dp_rank + recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently + self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER) + self.recv_external_instruct_thread = threading.Thread( + target=self._recv_external_module_control_instruct, daemon=True + ) + self.recv_external_instruct_thread.start() + self.response_external_instruct_thread = threading.Thread( + target=self._response_external_module_control_instruct, daemon=True + ) + self.response_external_instruct_thread.start() + + def _get_current_server_info(self): + """ + Get resources information + """ + available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) + + available_block_num = self.engine.resource_manager.available_block_num() + server_info = { + "splitwise_role": self.cfg.splitwise_role, + "block_size": int(self.cfg.cache_config.block_size), + "block_num": int(available_block_num), + "max_block_num": int(self.cfg.cache_config.total_block_num), + "dec_token_num": int(self.cfg.cache_config.dec_token_num), + "available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num), + "max_batch_size": int(available_batch_size), + "max_input_token_num": self.cfg.max_num_batched_tokens, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), + "available_batch": int(self.engine.resource_manager.available_batch()), + } + return server_info + + def _recv_external_module_control_instruct(self): + """ + Receive a multipart message from the control cmd socket. + """ + while True: + try: + with self.response_lock: + task = self.recv_control_cmd_server.recv_control_cmd() + if task is None: + time.sleep(0.001) + continue + logger.info(f"Recieve control task: {task}") + task_id_str = task["task_id"] + if task["cmd"] == "get_payload": + payload_info = self._get_current_server_info() + result = {"task_id": task_id_str, "result": payload_info} + logger.debug(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + + elif task["cmd"] == "get_metrics": + metrics_text = get_filtered_metrics( + [], + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), + ) + result = {"task_id": task_id_str, "result": metrics_text} + logger.debug(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + elif task["cmd"] == "connect_rdma": + self.engine.engine_worker_queue.put_connect_rdma_task(task) + + except Exception as e: + logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + + def _response_external_module_control_instruct(self): + while True: + try: + result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response() + if result_data: + task_id_str = result_data["task_id"] + result = {"task_id": task_id_str, "result": result_data} + logger.info(f"Response for task: {task_id_str}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + else: + time.sleep(0.001) + except Exception as e: + logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index d60ab8ad87..ae4ae834c1 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -457,8 +457,11 @@ def _handle_decode(self, payload): index=task["outputs"]["index"], send_idx=0, token_ids=task["outputs"]["token_ids"], + draft_token_ids=task["outputs"]["draft_token_ids"], ), finished=True, + error_code=task["error_code"], + error_msg=task["error_msg"], ) ) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))