From 60115fcb724c772bb23d4649a734f33cf02e5350 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 12:13:35 +0800 Subject: [PATCH 01/19] Support external module --- fastdeploy/inter_communicator/zmq_server.py | 209 ++++++++++++++++++++ fastdeploy/scheduler/dp_scheduler.py | 175 ++++++++++++++++ 2 files changed, 384 insertions(+) create mode 100644 fastdeploy/inter_communicator/zmq_server.py create mode 100644 fastdeploy/scheduler/dp_scheduler.py diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py new file mode 100644 index 0000000000..0f90148abc --- /dev/null +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -0,0 +1,209 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import time + +import msgpack +import zmq + +from fastdeploy import envs +from fastdeploy.utils import llm_logger + + +class ZmqTcpServer: + """ + ZmqTcpServer, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + """ + + def __init__(self, port, mode): + self.context = zmq.Context() + self.socket = self.context.socket(mode) + self.mode = mode + self.port = port + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"tcp://*:{self.port}") + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + + self.mutex = threading.Lock() + self.req_dict = dict() + self.poller = None + self.running = True + if self.mode == zmq.PULL: + self.poller = zmq.Poller() + self.poller.register(self.socket, zmq.POLLIN) + + def send_json(self, data): + """ + Send a JSON-serializable object over the socket. + """ + self.socket.send_json(data) + + def recv_json(self): + """ + Receive a JSON-serializable object from the socket. + """ + return self.socket.recv_json() + + def send_pyobj(self, data): + """ + Send a Pickle-serializable object over the socket. + """ + self.socket.send_pyobj(data) + + def recv_pyobj(self): + """ + Receive a Pickle-serializable object from the socket. + """ + return self.socket.recv_pyobj() + + def pack_aggregated_data(self, data): + """ + Aggregate multiple responses into one and send them to the client. + """ + result = data[0] + if len(data) > 1: + for response in data[1:]: + result.add(response) + result = msgpack.packb([result.to_dict()]) + return result + + def recv_control_cmd(self): + while self.running: + try: + client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) + task = msgpack.unpackb(task_data) + task_id_str = task["task_id"] + except zmq.Again: + time.sleep(0.001) + continue + with self.mutex: + self.req_dict[task_id_str] = client + return task + + def response_for_control_cmd(self, task_id, result): + """ + Send a multipart message to the control cmd socket. + """ + if self.socket is None: + raise RuntimeError("Router socket not created.") + try: + result = msgpack.packb(result) + self.socket.send_multipart([self.req_dict[task_id], b"", result]) + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + with self.mutex: + self.req_dict.pop(task_id, None) + llm_logger.info(f"response contrl cmd finished, task_id: {task_id}") + + def send_multipart(self, req_id, data): + """ + Send a multipart message to the router socket. + """ + if self.socket is None: + raise RuntimeError("Router socket not created. Call create_router() first.") + + while self.running: + with self.mutex: + if req_id not in self.req_dict: + try: + client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) + req_id_str = request_id.decode("utf-8") + self.req_dict[req_id_str] = client + except zmq.Again: + time.sleep(0.001) + continue + else: + break + + try: + start_send = time.time() + if self.aggregate_send: + result = self.pack_aggregated_data(data) + else: + result = msgpack.packb([response.to_dict() for response in data]) + self.socket.send_multipart([self.req_dict[req_id], b"", result]) + llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + if data[-1].finished: + with self.mutex: + self.req_dict.pop(req_id, None) + llm_logger.info(f"send_multipart finished, req_id: {req_id}") + + def receive_json_once(self, block=False): + """ + Receive a single message from the socket. + """ + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_json(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def receive_pyobj_once(self, block=False): + """ + Receive a single message from the socket. + """ + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_pyobj(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def close(self): + """ + Close the socket and context, and remove the IPC files. + """ + if not self.running: + return + + self.running = False + llm_logger.info("Closing ZMQ connection...") + try: + if hasattr(self, "socket") and not self.socket.closed: + self.socket.close() + + if self.socket is not None and not self.socket.closed: + self.socket.close() + + if not self.context.closed: + self.context.term() + + except Exception as e: + llm_logger.warning(f"Failed to close ZMQ connection - {e}") + return + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() \ No newline at end of file diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py new file mode 100644 index 0000000000..93c0549d80 --- /dev/null +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -0,0 +1,175 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import threading +from multiprocessing import Queue +from typing import Dict, List +from typing import Dict, List, Optional, Tuple + +from fastdeploy.engine.request import Request, RequestOutput +from fastdeploy.scheduler.data import ScheduledResponse +from fastdeploy.scheduler.local_scheduler import LocalScheduler +from fastdeploy.utils import scheduler_logger + + +class DPLocalScheduler(LocalScheduler): + def __init__( + self, + max_size: int, + ttl: int, + enable_chunked_prefill: bool, + max_num_partial_prefills: int, + max_long_partial_prefills: int, + long_prefill_token_threshold: int, + splitwise_role: str = 'prefill' + ): + super().__init__( + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + ) + self.splitwise_role = splitwise_role + + def put_results(self, results: List[RequestOutput]): + """ + Add processing results back to the scheduler. + Args: + results: List of RequestOutput objects containing results + """ + responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results] + + finished_responses = [response.request_id for response in responses if response.finished] + if len(finished_responses) > 0: + scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") + + with self.mutex: + for response in responses: + if response.request_id not in self.responses: + self.responses[response.request_id] = [response] + continue + self.responses[response.request_id].append(response) + self.responses_not_empty.notify_all() + + def _recycle(self, request_id: Optional[str] = None): + """ + Clean up expired or completed requests to free memory. + Args: + request_id: Optional specific request ID to remove. + If None, removes all expired requests. + """ + if request_id is not None: + self.requests.pop(request_id, None) + self.responses.pop(request_id, None) + if self.splitwise_role == 'decode': + return + self.ids.pop(self.ids.index(request_id)) + self.ids_read_cursor -= 1 + return + + if self.max_size <= 0: + return + + if len(self.requests) <= self.max_size: + return + + now = time.time() + expired_ids = [] + for request_id in self.ids: + request = self.requests[request_id] + if now - request.schedule_time < self.ttl: + break + expired_ids.append(request.request_id) + + for i, expired_id in enumerate(expired_ids): + self.requests.pop(expired_id, None) + self.responses.pop(expired_id, None) + self.ids.pop(i) + + if len(expired_ids) > 0: + if len(expired_ids) - 1 >= self.ids_read_cursor: + self.ids_read_cursor = 0 + else: + self.ids_read_cursor -= len(expired_ids) + + + + +class DPScheduler: + def __init__( + self, + max_size: int, + ttl: int, + enable_chunked_prefill: bool, + max_num_partial_prefills: int, + max_long_partial_prefills: int, + long_prefill_token_threshold: int, + splitwise_role: str = 'prefill' + ): + self._scheduler = DPLocalScheduler( + max_size, + ttl, + enable_chunked_prefill, + max_num_partial_prefills, + max_long_partial_prefills, + long_prefill_token_threshold, + splitwise_role + ) + + def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue): + self.dp_rank = dp_rank + self.request_queues = request_queues + self.result_queue = result_queue + threading.Thread(target=self._put_requests_to_local).start() + threading.Thread(target=self._get_response_from_local).start() + + def put_requests(self, requests: List[Dict]): + results = [] + for request in requests: + self.request_queues[request.dp_rank].put(request) + results.append((request.request_id, None)) + return results + + def _put_requests_to_local(self): + while True: + request = self.request_queues[self.dp_rank].get() + self._scheduler.put_requests([request]) + + def _get_response_from_local(self): + while True: + results = self._scheduler.get_results() + if len(results) == 0: + continue + self.result_queue.put(results) + + def get_requests( + self, + available_blocks, + block_size, + reserved_output_blocks, + max_num_batched_tokens, + batch=1, + ) -> List[Request]: + return self._scheduler.get_requests( + available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch + ) + + def put_results(self, results: List[RequestOutput]): + self._scheduler.put_results(results) + + def get_results(self) -> Dict[str, List[RequestOutput]]: + return self.result_queue.get() \ No newline at end of file From 793a7d109ecdfc0849fd7403529911a3ddb9044f Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 12:13:44 +0800 Subject: [PATCH 02/19] Support external module --- fastdeploy/cache_manager/cache_messager.py | 37 +++++- .../cache_manager/cache_transfer_manager.py | 2 + fastdeploy/engine/args_utils.py | 1 + fastdeploy/engine/engine.py | 119 +++++++++++++++++- fastdeploy/engine/expert_service.py | 92 ++++++++++++-- fastdeploy/envs.py | 8 ++ fastdeploy/inter_communicator/__init__.py | 2 +- .../inter_communicator/engine_worker_queue.py | 67 ++++++++++ fastdeploy/scheduler/config.py | 60 +++++++++ 9 files changed, 370 insertions(+), 18 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index e06d05a67f..316f8cdd1b 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -40,6 +40,7 @@ def __init__( pod_ip, engine_worker_queue_port, local_data_parallel_id, + data_parallel_size, gpu_cache_kvs, rank, nranks, @@ -143,11 +144,16 @@ def __init__( self.gpu_id = gpu_id self.cache_info = dict() self.dp_rank_id = local_data_parallel_id + self.data_parallel_size = data_parallel_size layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread) layerwise_send_cache_thread.daemon = True layerwise_send_cache_thread.start() + connect_rdma_thread = threading.Thread(target=self._handle_connect_task) + connect_rdma_thread.daemon = True + connect_rdma_thread.start() + logger.info(f"cache messager init finished, use {transfer_protocol}") def _prefill_layerwise_send_cache_thread(self): @@ -158,16 +164,20 @@ def _prefill_layerwise_send_cache_thread(self): try: prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32) prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32) + if self.data_parallel_size > 1: + shm_rank_id = self.dp_rank_id + else: + shm_rank_id = self.rank try: step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_step_{shm_rank_id}", array=prefilled_step_idx_data, dtype=np.int32, suffix=self.gpu_id, create=True, ) layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_layer_{shm_rank_id}", array=prefilled_layer_idx_data, dtype=np.int32, suffix=self.gpu_id, @@ -175,14 +185,14 @@ def _prefill_layerwise_send_cache_thread(self): ) except: step_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_step_{shm_rank_id}", array=prefilled_step_idx_data, dtype=np.int32, suffix=self.gpu_id, create=False, ) layer_shm_value = IPCSignal( - name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}", + name=f"splitwise_complete_prefilled_layer_{shm_rank_id}", array=prefilled_layer_idx_data, dtype=np.int32, suffix=self.gpu_id, @@ -310,3 +320,22 @@ def _prefill_layerwise_send_cache_thread(self): except Exception as e: logger.error(f"prefill layerwise send cache thread has exception: {e}") + + def _handle_connect_task(self): + while True: + try: + task = self.engine_worker_queue.get_connect_rdma_task() + if task is None: + time.sleep(0.001) + continue + logger.info(f"_handle_connect_task recv task: {task}") + task_id = task["task_id"] + ip, rdma_port = task["ip"], task["rdma_port"] + status = self.messager["rdma"].connect(ip, rdma_port) + if not status: + response = {"task_id": task_id, "success": False} + else: + response = {"task_id": task_id, "success": True} + self.engine_worker_queue.put_connect_rdma_task_response(response) + except Exception as e: + logger.error(f"handle_connect_task has exception: {e}") diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 34ccf144ca..a9c613bd7d 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -92,6 +92,7 @@ def parse_args(): help="speculative config", ) parser.add_argument("--local_data_parallel_id", type=int, default=0) + parser.add_argument("--data_parallel_size", type=int, default=1) args = parser.parse_args() return args @@ -212,6 +213,7 @@ def __init__(self, args): pod_ip=args.pod_ip, engine_worker_queue_port=args.engine_worker_queue_port, local_data_parallel_id=args.local_data_parallel_id, + data_parallel_size=args.data_parallel_size, gpu_cache_kvs=self.gpu_cache_kvs, rank=self.rank, nranks=args.mp_num, diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4a2414304d..d620a7edaf 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -820,6 +820,7 @@ def create_scheduler_config(self) -> SchedulerConfig: "max_num_partial_prefills", "max_long_partial_prefills", "long_prefill_token_threshold", + "splitwise_role" ] all = asdict(self) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e7443bc1db..692bb0e43e 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -49,7 +49,7 @@ IPCSignal, ZmqClient, ) -from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor @@ -161,6 +161,7 @@ def __init__(self, cfg): self.cfg.guided_decoding_backend, disable_any_whitespace=self.cfg.disable_any_whitespace, ) + def start(self, api_server_pid=None): """ @@ -183,6 +184,18 @@ def start(self, api_server_pid=None): self.zmq_server.start_server() self.zmq_server.create_router() time.sleep(3) + + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + recv_request_port = envs.ZMQ_RECV_REQUEST_SERVER_PORT + self.recv_request_server = ZmqTcpServer(port=recv_request_port, mode=zmq.PULL) + send_response_port = envs.ZMQ_SEND_RESPONSE_SERVER_PORT + self.send_response_server = ZmqTcpServer(port=send_response_port, mode=zmq.ROUTER) + recv_control_cmd_ports = envs.ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[0], mode=zmq.ROUTER) + self.handle_control_cmd_thread = threading.Thread(target=self._handle_control_cmd, daemon=True) + self.handle_control_cmd_thread.start() + self.handle_control_cmd_result_thread = threading.Thread(target=self._handle_connect_rdma_results, daemon=True) + self.handle_control_cmd_result_thread.start() if self.do_profile == 0 and ( self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" @@ -248,8 +261,16 @@ def start(self, api_server_pid=None): role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info + request_queues = None + result_queue = None if self.cfg.scheduler_config.name == "splitwise": self.scheduler.start(role, host_ip, disaggregate) + elif self.cfg.scheduler_config.name == 'dp': + request_queues = [] + result_queue = multiprocessing.Queue() + for i in range(self.cfg.parallel_config.data_parallel_size): + request_queues.append(multiprocessing.Queue()) + self.scheduler.start(self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues, result_queue) time.sleep(1) @@ -267,6 +288,8 @@ def start(self, api_server_pid=None): self.cfg, i + self.cfg.node_rank * self.cfg.worker_num_per_node, self.ipc_signal_suffix, + request_queues, + result_queue, ), ) ) @@ -278,6 +301,67 @@ def start(self, api_server_pid=None): console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True + + def _get_current_server_info(self): + """ + 获取服务当前资源信息 + """ + available_batch_size = min(self.cfg.max_prefill_batch, self.resource_manager.available_batch()) + + available_block_num = self.resource_manager.available_block_num() + server_info = { + "splitwise_role": self.cfg.splitwise_role, + "block_size": int(self.cfg.cache_config.block_size), + "block_num": int(available_block_num), + "dec_token_num": int(self.cfg.cache_config.dec_token_num), + "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, + "max_batch_size": int(available_batch_size), + "max_input_token_num": self.cfg.max_num_batched_tokens, + } + return server_info + + def _handle_control_cmd(self): + """ + Receive a multipart message from the control cmd socket. + """ + while self.running: + try: + task = self.recv_control_cmd_server.recv_control_cmd() + llm_logger.info(f"Recieve control task: {task}") + task_id_str = task["task_id"] + if task["cmd"] == "get_payload": + payload_info = self._get_current_server_info() + result = {"task_id": task_id_str, "result": payload_info} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + + elif task["cmd"] == "get_metrics": + metrics_text = get_filtered_metrics( + [], + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), + ) + result = {"task_id": task_id_str, "result": metrics_text} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + elif task["cmd"] == "connect_rdma": + self.engine_worker_queue.put_connect_rdma_task(task) + + except Exception as e: + llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + + def _handle_connect_rdma_results(self): + while True: + try: + result_data = self.engine_worker_queue.get_connect_rdma_task_response() + if result_data: + task_id_str = result_data["task_id"] + result = {"task_id": task_id_str, "result": result_data} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + else: + time.sleep(0.001) + except Exception as e: + llm_logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") def _zmq_send_generated_tokens(self): """ @@ -291,7 +375,10 @@ def _zmq_send_generated_tokens(self): time.sleep(0.005) continue for request_id, contents in results.items(): - self.zmq_server.send_multipart(request_id, contents) + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + self.send_response_server.send_multipart(request_id, contents) + else: + self.zmq_server.send_multipart(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") @@ -414,15 +501,25 @@ def _fetch_request(): def _insert_zmq_task_to_scheduler(self): if self.api_server_pid is None: return + + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + if self.cfg.splitwise_role == "decode": + return added_requests: Dict[str, int] = dict() while self.running: try: block = True if len(added_requests) == 0 else False - if not self.cfg.enable_mm: - err, data = self.zmq_server.receive_json_once(block) + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + if not self.cfg.enable_mm: + err, data = self.recv_request_server.receive_json_once(block) + else: + err, data = self.recv_request_server.receive_pyobj_once(block) else: - err, data = self.zmq_server.receive_pyobj_once(block) + if not self.cfg.enable_mm: + err, data = self.zmq_server.receive_json_once(block) + else: + err, data = self.zmq_server.receive_pyobj_once(block) if err is not None: llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") break @@ -470,7 +567,10 @@ def _insert_zmq_task_to_scheduler(self): ) # Since the request is not in scheduler # Send result by zmq directly - self.zmq_server.send_multipart(request_id, error_result) + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + self.send_response_server.send_multipart(request_id, error_result) + else: + self.zmq_server.send_multipart(request_id, error_result) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " @@ -991,6 +1091,13 @@ def _exit_sub_services(self): self.engine_worker_queue.cleanup() if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + if hasattr(self, "send_response_server") and self.send_response_server is not None: + self.send_response_server.close() + if hasattr(self, "recv_request_server") and self.recv_request_server is not None: + self.recv_request_server.close() + if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: + self.recv_control_cmd_server.close() if hasattr(self, "dp_processed"): for p in self.dp_processed: p.join() diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 63b1b15beb..be296a5acd 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -22,15 +22,17 @@ import time import traceback import weakref +import zmq import numpy as np from fastdeploy.engine.resource_manager import ResourceManager -from fastdeploy.inter_communicator import EngineWorkerQueue +from fastdeploy.inter_communicator import EngineWorkerQueue, ZmqTcpServer from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector -from fastdeploy.utils import EngineError, console_logger, llm_logger +from fastdeploy.utils import EngineError, console_logger, envs, llm_logger +from fastdeploy.metrics.metrics import EXCLUDE_LABELS, get_filtered_metrics, main_process_metrics class ExpertService: @@ -60,7 +62,8 @@ def __init__(self, cfg, local_data_parallel_id): self.scheduler = cfg.scheduler_config.scheduler() - self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") + if self.cfg.scheduler_config.name == 'splitwise': + self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id @@ -111,8 +114,17 @@ def __init__(self, cfg, local_data_parallel_id): ) self._finalizer = weakref.finalize(self, self._exit_sub_services) + if envs.ENABLE_ENGINE_ZMQ_REMOTE_ACCESS: + recv_control_cmd_ports = envs.ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + self.recv_control_cmd_server = ZmqTcpServer( + port=recv_control_cmd_ports[start_pos:end_pos], mode=zmq.ROUTER + ) + self.handle_control_cmd_thread = threading.Thread(target=self._handle_control_cmd, daemon=True) + self.handle_control_cmd_thread.start() + self.handle_control_cmd_result_thread = threading.Thread(target=self._handle_connect_rdma_results, daemon=True) + self.handle_control_cmd_result_thread.start() - def start(self, ipc_signal_suffix, local_data_parallel_id): + def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues=None, result_queue=None): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread @@ -147,11 +159,77 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info - self.scheduler.start(role, host_ip, disaggregate) + if self.cfg.scheduler_config.name == 'dp': + assert (request_queues is not None) and (result_queue is not None) + self.scheduler.start(local_data_parallel_id, request_queues, result_queue) + elif self.cfg.scheduler_config.name == 'splitwise': + self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True + + def _get_current_server_info(self): + """ + 获取服务当前资源信息 + """ + available_batch_size = min(self.cfg.max_prefill_batch, self.resource_manager.available_batch()) + + available_block_num = self.resource_manager.available_block_num() + server_info = { + "splitwise_role": self.cfg.splitwise_role, + "block_size": int(self.cfg.cache_config.block_size), + "block_num": int(available_block_num), + "dec_token_num": int(self.cfg.cache_config.dec_token_num), + "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, + "max_batch_size": int(available_batch_size), + "max_input_token_num": self.cfg.max_num_batched_tokens, + } + return server_info + + def _handle_control_cmd(self): + """ + Receive a multipart message from the control cmd socket. + """ + while True: + try: + task = self.recv_control_cmd_server.recv_control_cmd() + llm_logger.info(f"Recieve control task: {task}") + task_id_str = task["task_id"] + if task["cmd"] == "get_payload": + payload_info = self._get_current_server_info() + result = {"task_id": task_id_str, "result": payload_info} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + + elif task["cmd"] == "get_metrics": + metrics_text = get_filtered_metrics( + EXCLUDE_LABELS, + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), + ) + result = {"task_id": task_id_str, "result": metrics_text} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + elif task["cmd"] == "connect_rdma": + self.engine_worker_queue.put_connect_rdma_task(task) + + except Exception as e: + llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + + def _handle_connect_rdma_results(self): + while True: + try: + result_data = self.engine_worker_queue.get_connect_rdma_task_response() + if result_data: + task_id_str = result_data["task_id"] + result = {"task_id": task_id_str, "result": result_data} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + else: + time.sleep(0.001) + except Exception as e: + llm_logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") + def _insert_task_to_worker(self): """ @@ -356,13 +434,13 @@ def _exit_sub_services(self): self.zmq_server.close() -def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix): +def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix, request_queues=None, result_queue=None): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) try: - expert_service.start(ipc_signal_suffix, local_data_parallel_id) + expert_service.start(ipc_signal_suffix, local_data_parallel_id, request_queues, result_queue) expert_service.split_connector.start_receiver() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 901ef3f5a0..881fb86d0f 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -80,6 +80,14 @@ "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), # enable kv cache block scheduler v1 (no need for kv_cache_ratio) "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), + # enable external module to access LLMEngine. + "ENABLE_EXTERNAL_MODULE_ACCESS": lambda: int(os.getenv("ENABLE_EXTERNAL_MODULE_ACCESS", "0")), + # LLMEngine recieve requests port, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + "ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), + # LLMEngine send response port, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + "ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), + # LLMEngine recieve control command port, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + "ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), } diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 0c1cc0d9fc..ddffbbbd65 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -19,4 +19,4 @@ from .ipc_signal import IPCSignal from .zmq_client import ZmqClient -__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"] +__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer"] diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index da88265a26..e216f430d2 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -85,12 +85,15 @@ class QueueManager(BaseManager): ] self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] + self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] + self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] self.client_read_info_flag_init: List[List[int]] = [ [1] * self.num_client for _ in range(self.local_data_parallel_size) ] self.lock_info_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] + self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)] self.finish_request_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) @@ -112,11 +115,26 @@ class QueueManager(BaseManager): callable=lambda idx: self.lock_init[idx], proxytype=AcquirerProxy, ) + QueueManager.register( + "get_connect_task_lock", + callable=lambda idx: self.connect_task_lock_init[idx], + proxytype=AcquirerProxy, + ) QueueManager.register( "get_read_finish_flag", callable=lambda idx: self.read_finish_flag_init[idx], proxytype=ValueProxy, ) + QueueManager.register( + "get_connect_rdma_tasks", + callable=lambda idx: self.connect_rdma_tasks_list[idx], + proxytype=ListProxy + ) + QueueManager.register( + "get_connect_rdma_tasks_responses", + callable=lambda idx: self.connect_rdma_tasks_response_list[idx], + proxytype=ListProxy + ) QueueManager.register( "get_connected_client_counter", callable=lambda idx: self.connected_client_counter_init[idx], @@ -180,6 +198,9 @@ class QueueManager(BaseManager): QueueManager.register("get_disaggregate_requests") QueueManager.register("get_available_prefill_instances") QueueManager.register("get_finish_request_barrier") + QueueManager.register("get_connect_rdma_tasks") + QueueManager.register("get_connect_rdma_tasks_responses") + QueueManager.register("get_connect_task_lock") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -200,6 +221,13 @@ class QueueManager(BaseManager): self.available_prefill_instances = self.manager.get_available_prefill_instances() self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id) + # p/d互联 + self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) + self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses( + self.local_data_parallel_id + ) + self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) + assert self.num_client == len(self.client_read_flag) if is_server: @@ -280,6 +308,45 @@ def num_tasks(self) -> int: total_num: int = len(self.tasks) self.lock.release() return total_num + + def put_connect_rdma_task(self, connect_rdma_task): + self.connect_task_lock.acquire() + self.connect_rdma_task_queue.append(connect_rdma_task) + self.connect_task_lock.release() + + def get_connect_rdma_task(self): + result = None + self.connect_task_lock.acquire() + if len(self.connect_rdma_task_queue) == 0: + self.connect_task_lock.release() + return result + try: + result = self.connect_rdma_task_queue.pop(0) + except Exception as e: + llm_logger.info(f"get_connect_rdma_task got exception: {e}") + finally: + self.connect_task_lock.release() + return result + + def put_connect_rdma_task_response(self, connect_rdma_task_response): + self.connect_task_lock.acquire() + self.connect_rdma_task_response_queue.append(connect_rdma_task_response) + self.connect_task_lock.release() + + def get_connect_rdma_task_response(self): + result = None + self.connect_task_lock.acquire() + if len(self.connect_rdma_task_response_queue) == 0: + self.connect_task_lock.release() + return result + try: + result = self.connect_rdma_task_response_queue.pop(0) + except Exception as e: + llm_logger.info(f"get_connect_rdma_task_response got exception: {e}") + finally: + self.connect_task_lock.release() + return result + def get_prefill_instances(self): """ diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index cd0a72af1a..a12e2349b5 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -18,6 +18,7 @@ from fastdeploy.utils import llm_logger +from .dp_scheduler import DPLocalScheduler from .global_scheduler import GlobalScheduler from .local_scheduler import LocalScheduler from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig @@ -89,6 +90,59 @@ def print(self): llm_logger.info("=============================================================") + + +class DPLocalSchedulerConfig(LocalSchedulerConfig): + """ + Configuration class for DPLocalScheduler. + + Attributes: + max_size: Maximum number of concurrent requests (-1 for unlimited) + ttl: Time-to-live in seconds for request expiration + """ + + def __init__( + self, + max_size: int = -1, + ttl: int = 900, + max_model_len: int = 8192, + enable_chunked_prefill: bool = False, + max_num_partial_prefills: int = 1, + max_long_partial_prefills: int = 1, + long_prefill_token_threshold: int = 0, + splitwise_role: str = 'prefill', + **kwargs, + ): + """ + Initialize LocalScheduler configuration. + + Args: + max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled) + ttl: Time-to-live in seconds for request expiration (default 900s) + max_model_len: Maximum model context length in tokens + enable_chunked_prefill: Whether to enable chunked prefill processing + max_num_partial_prefills: Max partial prefill operations allowed + max_long_partial_prefills: Max long-running partial prefill ops + long_prefill_token_threshold: Token count threshold for long prefill + **kwargs: Additional unused arguments (for forward compatibility) + + Note: + - If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len + - See LocalScheduler class for implementation details + """ + self.max_size = max_size + self.ttl = ttl + + self.max_model_len = max_model_len + self.enable_chunked_prefill = enable_chunked_prefill + self.max_num_partial_prefills = max_num_partial_prefills + self.max_long_partial_prefills = max_long_partial_prefills + self.long_prefill_token_threshold = long_prefill_token_threshold + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + self.splitwise_role = splitwise_role + + class GlobalSchedulerConfig: """ Configuration class for GlobalScheduler (Redis-based). @@ -228,6 +282,9 @@ def __init__(self, name="local", **kwargs): if name == "splitwise": self.config = SplitWiseSchedulerConfig(**kwargs) + + if name == "dp": + self.config = DPLocalSchedulerConfig(**kwargs) def check(self): """ @@ -273,6 +330,9 @@ def scheduler(self): if self.name == "splitwise": return SplitWiseScheduler(self.config) + + if self.name == "dp": + return DPLocalScheduler(self.config) return LocalScheduler( max_size=self.config.max_size, From 1f2c43b593491cd434f7b2303a5bed4ec2a6f336 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 12:29:49 +0800 Subject: [PATCH 03/19] Support external module --- fastdeploy/engine/engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 692bb0e43e..5782f40ade 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -48,6 +48,7 @@ EngineWorkerQueue, IPCSignal, ZmqClient, + ZmqTcpServer ) from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request From e8083e0a4ee60eb362a7042abf5f49a25d8a99c1 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 12:41:23 +0800 Subject: [PATCH 04/19] Support external module --- fastdeploy/scheduler/dp_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 93c0549d80..62ed25b742 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -17,6 +17,7 @@ from multiprocessing import Queue from typing import Dict, List from typing import Dict, List, Optional, Tuple +import time from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse From 64b7726e6f23d47352c6e6bf5d58aace095faf13 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 16:22:46 +0800 Subject: [PATCH 05/19] refactor code to make it more clear --- fastdeploy/engine/engine.py | 123 ++------ fastdeploy/engine/expert_service.py | 70 +---- fastdeploy/entrypoints/engine_client.py | 4 +- fastdeploy/inter_communicator/__init__.py | 5 +- fastdeploy/inter_communicator/zmq_client.py | 187 +++---------- fastdeploy/inter_communicator/zmq_server.py | 262 +++++++++++------- .../splitwise/internal_adapter_utils.py | 97 +++++++ 7 files changed, 321 insertions(+), 427 deletions(-) create mode 100644 fastdeploy/splitwise/internal_adapter_utils.py diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5782f40ade..ee4544b9ef 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -47,7 +47,7 @@ EngineCacheQueue, EngineWorkerQueue, IPCSignal, - ZmqClient, + ZmqIpcServer, ZmqTcpServer ) from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics @@ -55,6 +55,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector +from fastdeploy.splitwise.internal_adapter_utils import ExternalModuleAdapter from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -181,23 +182,15 @@ def start(self, api_server_pid=None): self.data_processor = self.input_processor.create_processor() if api_server_pid is not None: - self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL) - self.zmq_server.start_server() - self.zmq_server.create_router() + if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + self.recv_request_server = ZmqTcpServer(port=envs.ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) + self.send_response_server = ZmqTcpServer(port=envs.ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) + self.external_adapter = ExternalModuleAdapter(cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node) + else: + self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) + self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) time.sleep(3) - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: - recv_request_port = envs.ZMQ_RECV_REQUEST_SERVER_PORT - self.recv_request_server = ZmqTcpServer(port=recv_request_port, mode=zmq.PULL) - send_response_port = envs.ZMQ_SEND_RESPONSE_SERVER_PORT - self.send_response_server = ZmqTcpServer(port=send_response_port, mode=zmq.ROUTER) - recv_control_cmd_ports = envs.ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") - self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[0], mode=zmq.ROUTER) - self.handle_control_cmd_thread = threading.Thread(target=self._handle_control_cmd, daemon=True) - self.handle_control_cmd_thread.start() - self.handle_control_cmd_result_thread = threading.Thread(target=self._handle_connect_rdma_results, daemon=True) - self.handle_control_cmd_result_thread.start() - if self.do_profile == 0 and ( self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" ): @@ -303,67 +296,6 @@ def start(self, api_server_pid=None): console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True - def _get_current_server_info(self): - """ - 获取服务当前资源信息 - """ - available_batch_size = min(self.cfg.max_prefill_batch, self.resource_manager.available_batch()) - - available_block_num = self.resource_manager.available_block_num() - server_info = { - "splitwise_role": self.cfg.splitwise_role, - "block_size": int(self.cfg.cache_config.block_size), - "block_num": int(available_block_num), - "dec_token_num": int(self.cfg.cache_config.dec_token_num), - "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, - "max_batch_size": int(available_batch_size), - "max_input_token_num": self.cfg.max_num_batched_tokens, - } - return server_info - - def _handle_control_cmd(self): - """ - Receive a multipart message from the control cmd socket. - """ - while self.running: - try: - task = self.recv_control_cmd_server.recv_control_cmd() - llm_logger.info(f"Recieve control task: {task}") - task_id_str = task["task_id"] - if task["cmd"] == "get_payload": - payload_info = self._get_current_server_info() - result = {"task_id": task_id_str, "result": payload_info} - llm_logger.info(f"Response for task: {task_id_str}") - self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) - - elif task["cmd"] == "get_metrics": - metrics_text = get_filtered_metrics( - [], - extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), - ) - result = {"task_id": task_id_str, "result": metrics_text} - llm_logger.info(f"Response for task: {task_id_str}") - self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) - elif task["cmd"] == "connect_rdma": - self.engine_worker_queue.put_connect_rdma_task(task) - - except Exception as e: - llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") - - def _handle_connect_rdma_results(self): - while True: - try: - result_data = self.engine_worker_queue.get_connect_rdma_task_response() - if result_data: - task_id_str = result_data["task_id"] - result = {"task_id": task_id_str, "result": result_data} - llm_logger.info(f"Response for task: {task_id_str}") - self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) - else: - time.sleep(0.001) - except Exception as e: - llm_logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") - def _zmq_send_generated_tokens(self): """ Recieve output for zmq @@ -376,10 +308,7 @@ def _zmq_send_generated_tokens(self): time.sleep(0.005) continue for request_id, contents in results.items(): - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: - self.send_response_server.send_multipart(request_id, contents) - else: - self.zmq_server.send_multipart(request_id, contents) + self.send_response_server.send_response(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") @@ -511,16 +440,10 @@ def _insert_zmq_task_to_scheduler(self): while self.running: try: block = True if len(added_requests) == 0 else False - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: - if not self.cfg.enable_mm: - err, data = self.recv_request_server.receive_json_once(block) - else: - err, data = self.recv_request_server.receive_pyobj_once(block) + if not self.cfg.enable_mm: + err, data = self.recv_request_server.receive_json_once(block) else: - if not self.cfg.enable_mm: - err, data = self.zmq_server.receive_json_once(block) - else: - err, data = self.zmq_server.receive_pyobj_once(block) + err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}") break @@ -568,10 +491,7 @@ def _insert_zmq_task_to_scheduler(self): ) # Since the request is not in scheduler # Send result by zmq directly - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: - self.send_response_server.send_multipart(request_id, error_result) - else: - self.zmq_server.send_multipart(request_id, error_result) + self.send_response_server.send_response(request_id, error_result) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " @@ -1090,15 +1010,12 @@ def _exit_sub_services(self): print(f"Error extracting sub services: {e}") self.engine_worker_queue.cleanup() - if hasattr(self, "zmq_server") and self.zmq_server is not None: - self.zmq_server.close() - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: - if hasattr(self, "send_response_server") and self.send_response_server is not None: - self.send_response_server.close() - if hasattr(self, "recv_request_server") and self.recv_request_server is not None: - self.recv_request_server.close() - if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: - self.recv_control_cmd_server.close() + if hasattr(self, "send_response_server") and self.send_response_server is not None: + self.send_response_server.close() + if hasattr(self, "recv_request_server") and self.recv_request_server is not None: + self.recv_request_server.close() + if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: + self.recv_control_cmd_server.close() if hasattr(self, "dp_processed"): for p in self.dp_processed: p.join() diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index be296a5acd..a8e3b4f46a 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -115,14 +115,7 @@ def __init__(self, cfg, local_data_parallel_id): self._finalizer = weakref.finalize(self, self._exit_sub_services) if envs.ENABLE_ENGINE_ZMQ_REMOTE_ACCESS: - recv_control_cmd_ports = envs.ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") - self.recv_control_cmd_server = ZmqTcpServer( - port=recv_control_cmd_ports[start_pos:end_pos], mode=zmq.ROUTER - ) - self.handle_control_cmd_thread = threading.Thread(target=self._handle_control_cmd, daemon=True) - self.handle_control_cmd_thread.start() - self.handle_control_cmd_result_thread = threading.Thread(target=self._handle_connect_rdma_results, daemon=True) - self.handle_control_cmd_result_thread.start() + self.external_adapter = ExternalModuleAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues=None, result_queue=None): """ @@ -169,67 +162,6 @@ def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues=None, console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True - def _get_current_server_info(self): - """ - 获取服务当前资源信息 - """ - available_batch_size = min(self.cfg.max_prefill_batch, self.resource_manager.available_batch()) - - available_block_num = self.resource_manager.available_block_num() - server_info = { - "splitwise_role": self.cfg.splitwise_role, - "block_size": int(self.cfg.cache_config.block_size), - "block_num": int(available_block_num), - "dec_token_num": int(self.cfg.cache_config.dec_token_num), - "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, - "max_batch_size": int(available_batch_size), - "max_input_token_num": self.cfg.max_num_batched_tokens, - } - return server_info - - def _handle_control_cmd(self): - """ - Receive a multipart message from the control cmd socket. - """ - while True: - try: - task = self.recv_control_cmd_server.recv_control_cmd() - llm_logger.info(f"Recieve control task: {task}") - task_id_str = task["task_id"] - if task["cmd"] == "get_payload": - payload_info = self._get_current_server_info() - result = {"task_id": task_id_str, "result": payload_info} - llm_logger.info(f"Response for task: {task_id_str}") - self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) - - elif task["cmd"] == "get_metrics": - metrics_text = get_filtered_metrics( - EXCLUDE_LABELS, - extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), - ) - result = {"task_id": task_id_str, "result": metrics_text} - llm_logger.info(f"Response for task: {task_id_str}") - self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) - elif task["cmd"] == "connect_rdma": - self.engine_worker_queue.put_connect_rdma_task(task) - - except Exception as e: - llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") - - def _handle_connect_rdma_results(self): - while True: - try: - result_data = self.engine_worker_queue.get_connect_rdma_task_response() - if result_data: - task_id_str = result_data["task_id"] - result = {"task_id": task_id_str, "result": result_data} - llm_logger.info(f"Response for task: {task_id_str}") - self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) - else: - time.sleep(0.001) - except Exception as e: - llm_logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") - def _insert_task_to_worker(self): """ diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 9be9eccb4a..d41bfd3064 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -20,7 +20,7 @@ import numpy as np from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import IPCSignal, ZmqClient +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.platforms import current_platform from fastdeploy.utils import EngineError, api_server_logger @@ -80,7 +80,7 @@ def create_zmq_client(self, model, mode): """ Create a ZMQ client. """ - self.zmq_client = ZmqClient(model, mode) + self.zmq_client = ZmqIpcClient(model, mode) self.zmq_client.connect() def format_and_add_data(self, prompts: dict): diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index ddffbbbd65..ea08af31a4 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -17,6 +17,7 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue from .ipc_signal import IPCSignal -from .zmq_client import ZmqClient +from .zmq_client import ZmqIpcClient +from .zmq_server import ZmqIpcServer, ZmqTcpServer -__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer"] +__all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"] diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 05e55929dd..28ee8f7f2b 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -17,6 +17,7 @@ import os import threading import time +from abc import ABC, abstractmethod import msgpack import zmq @@ -25,189 +26,77 @@ from fastdeploy.utils import llm_logger -class ZmqClient: +class ZmqClientBase(ABC): """ - ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ. + ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ. """ - def __init__(self, name, mode): - self.context = zmq.Context() - self.socket = self.context.socket(mode) - self.file_name = f"/dev/shm/{name}.socket" - self.router_path = f"/dev/shm/router_{name}.ipc" - - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND - - self.mutex = threading.Lock() - self.req_dict = dict() - self.router = None - self.poller = None - self.running = True - + def __init__(self): + pass + + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass + + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + + @abstractmethod def connect(self): """ Connect to the server using the file name specified in the constructor. """ - self.socket.connect(f"ipc://{self.file_name}") - - def start_server(self): - """ - Start the server using the file name specified in the constructor. - """ - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.socket.setsockopt(zmq.SNDTIMEO, -1) - self.socket.bind(f"ipc://{self.file_name}") - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) + pass - def create_router(self): - """ - Create a ROUTER socket and bind it to the specified router path. - """ - self.router = self.context.socket(zmq.ROUTER) - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.router.setsockopt(zmq.SNDTIMEO, -1) - self.router.bind(f"ipc://{self.router_path}") def send_json(self, data): """ Send a JSON-serializable object over the socket. """ + self._ensure_socket() self.socket.send_json(data) def recv_json(self): """ Receive a JSON-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_json() def send_pyobj(self, data): """ Send a Pickle-serializable object over the socket. """ + self._ensure_socket() self.socket.send_pyobj(data) def recv_pyobj(self): """ Receive a Pickle-serializable object from the socket. """ + self._ensure_socket() return self.socket.recv_pyobj() + - def pack_aggregated_data(self, data): - """ - Aggregate multiple responses into one and send them to the client. - """ - result = data[0] - if len(data) > 1: - for response in data[1:]: - result.add(response) - result = msgpack.packb([result.to_dict()]) - return result - - def send_multipart(self, req_id, data): - """ - Send a multipart message to the router socket. - """ - if self.router is None: - raise RuntimeError("Router socket not created. Call create_router() first.") - - while self.running: - with self.mutex: - if req_id not in self.req_dict: - try: - client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK) - req_id_str = request_id.decode("utf-8") - self.req_dict[req_id_str] = client - except zmq.Again: - time.sleep(0.001) - continue - else: - break - - try: - start_send = time.time() - if self.aggregate_send: - result = self.pack_aggregated_data(data) - else: - result = msgpack.packb([response.to_dict() for response in data]) - self.router.send_multipart([self.req_dict[req_id], b"", result]) - llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") - - except Exception as e: - llm_logger.error(f"Send result to zmq client failed: {e}") - - if data[-1].finished: - with self.mutex: - self.req_dict.pop(req_id, None) - llm_logger.info(f"send_multipart finished, req_id: {req_id}") - - def receive_json_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_json(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - llm_logger.warning(f"{e}") - return str(e), None - - def receive_pyobj_once(self, block=False): - """ - Receive a single message from the socket. - """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None - try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_pyobj(flags=flags) - except zmq.Again: - return None, None - except Exception as e: - self.close() - llm_logger.warning(f"{e}") - return str(e), None - - def _clear_ipc(self, name): - """ - Remove the IPC file with the given name. - """ - if os.path.exists(name): - try: - os.remove(name) - except OSError as e: - llm_logger.warning(f"Failed to remove IPC file {name} - {e}") - def close(self): - """ - Close the socket and context, and remove the IPC files. - """ - if not self.running: - return - - self.running = False - llm_logger.info("Closing ZMQ connection...") - try: - if hasattr(self, "socket") and not self.socket.closed: - self.socket.close() - - if self.router is not None and not self.router.closed: - self.router.close() +class ZmqIpcClient(ZmqClientBase): + def __init__(self, name, mode): + self.name = name + self.mode = mode + self.file_name = f"/dev/shm/{name}.socket" + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.context = zmq.Context() + return self.context.socket(self.mode) + + def connect(self): + self._ensure_socket() + self.socket.connect(f"ipc://{self.file_name}") - if not self.context.closed: - self.context.term() + - self._clear_ipc(self.file_name) - self._clear_ipc(self.router_path) - except Exception as e: - llm_logger.warning(f"Failed to close ZMQ connection - {e}") - return - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 0f90148abc..68a44a735e 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -16,6 +16,7 @@ import threading import time +from abc import ABC, abstractmethod import msgpack import zmq @@ -24,54 +25,23 @@ from fastdeploy.utils import llm_logger -class ZmqTcpServer: +class ZmqServerBase(ABC): """ - ZmqTcpServer, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + ZmqServerBase """ + def __init__(self): + pass - def __init__(self, port, mode): - self.context = zmq.Context() - self.socket = self.context.socket(mode) - self.mode = mode - self.port = port - self.socket.setsockopt(zmq.SNDTIMEO, -1) - self.socket.bind(f"tcp://*:{self.port}") - self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) - self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.aggregate_send = envs.FD_USE_AGGREGATE_SEND - - self.mutex = threading.Lock() - self.req_dict = dict() - self.poller = None - self.running = True - if self.mode == zmq.PULL: - self.poller = zmq.Poller() - self.poller.register(self.socket, zmq.POLLIN) - - def send_json(self, data): - """ - Send a JSON-serializable object over the socket. - """ - self.socket.send_json(data) - - def recv_json(self): - """ - Receive a JSON-serializable object from the socket. - """ - return self.socket.recv_json() - - def send_pyobj(self, data): - """ - Send a Pickle-serializable object over the socket. - """ - self.socket.send_pyobj(data) - - def recv_pyobj(self): - """ - Receive a Pickle-serializable object from the socket. - """ - return self.socket.recv_pyobj() - + @abstractmethod + def _create_socket(self): + """Abstract method to create and return a ZeroMQ socket.""" + pass + + def _ensure_socket(self): + """Ensure the socket is created before use.""" + if self.socket is None: + self.socket = self._create_socket() + def pack_aggregated_data(self, data): """ Aggregate multiple responses into one and send them to the client. @@ -82,41 +52,46 @@ def pack_aggregated_data(self, data): result.add(response) result = msgpack.packb([result.to_dict()]) return result - - def recv_control_cmd(self): - while self.running: - try: - client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) - task = msgpack.unpackb(task_data) - task_id_str = task["task_id"] - except zmq.Again: - time.sleep(0.001) - continue - with self.mutex: - self.req_dict[task_id_str] = client - return task - - def response_for_control_cmd(self, task_id, result): + + def receive_json_once(self, block=False): """ - Send a multipart message to the control cmd socket. + Receive a single message from the socket. """ - if self.socket is None: - raise RuntimeError("Router socket not created.") + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None try: - result = msgpack.packb(result) - self.socket.send_multipart([self.req_dict[task_id], b"", result]) - + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_json(flags=flags) + except zmq.Again: + return None, None except Exception as e: - llm_logger.error(f"Send result to zmq client failed: {e}") - - with self.mutex: - self.req_dict.pop(task_id, None) - llm_logger.info(f"response contrl cmd finished, task_id: {task_id}") + self.close() + llm_logger.warning(f"{e}") + return str(e), None - def send_multipart(self, req_id, data): + def receive_pyobj_once(self, block=False): + """ + Receive a single message from the socket. + """ + self._ensure_socket() + if self.socket is None or self.socket.closed: + return "zmp socket has closed", None + try: + flags = zmq.NOBLOCK if not block else 0 + return None, self.socket.recv_pyobj(flags=flags) + except zmq.Again: + return None, None + except Exception as e: + self.close() + llm_logger.warning(f"{e}") + return str(e), None + + def send_response(self, req_id, data): """ - Send a multipart message to the router socket. + Send generated token result to client. """ + self._ensure_socket() if self.socket is None: raise RuntimeError("Router socket not created. Call create_router() first.") @@ -149,42 +124,131 @@ def send_multipart(self, req_id, data): with self.mutex: self.req_dict.pop(req_id, None) llm_logger.info(f"send_multipart finished, req_id: {req_id}") + + @abstractmethod + def close(self): + pass - def receive_json_once(self, block=False): + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class ZmqIpcServer(ZmqServerBase): + """ + ZmqIpcServer, used when ENABLE_EXTERNAL_MODULE_ACCESS=0 + """ + + def __init__(self, name, mode): + self.name = name + self.mode = mode + if mode == zmq.PULL: + self.file_name = f"/dev/shm/{name}.socket" + elif mode == zmq.ROUTER: + self.file_name = f"/dev/shm/router_{name}.ipc" + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + self.mutex = threading.Lock() + self.req_dict = dict() + self.running = True + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.context = zmq.Context() + self.context.socket(self.mode) + self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.router.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"ipc://{self.file_name}") + + + def _clear_ipc(self, name): """ - Receive a single message from the socket. + Remove the IPC file with the given name. """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None + if os.path.exists(name): + try: + os.remove(name) + except OSError as e: + llm_logger.warning(f"Failed to remove IPC file {name} - {e}") + + def close(self): + """ + Close the socket and context, and remove the IPC files. + """ + if not self.running: + return + + self.running = False + llm_logger.info("Closing ZMQ connection...") try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_json(flags=flags) - except zmq.Again: - return None, None + if self.socket is not None and not self.socket.closed: + self.socket.close() + if not self.context.closed: + self.context.term() + self._clear_ipc(self.file_name) except Exception as e: - self.close() - llm_logger.warning(f"{e}") - return str(e), None + llm_logger.warning(f"Failed to close ZMQ connection - {e}") + return - def receive_pyobj_once(self, block=False): + +class ZmqTcpServer(ZmqServerBase): + """ + ZmqTcpServer, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + """ + + def __init__(self, port, mode): + self.mode = mode + self.port = port + self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) + self.aggregate_send = envs.FD_USE_AGGREGATE_SEND + + self.mutex = threading.Lock() + self.req_dict = dict() + self.running = True + + def _create_socket(self): + """create and return a ZeroMQ socket.""" + self.context = zmq.Context() + self.socket = self.context.socket(mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) + self.socket.bind(f"tcp://*:{self.port}") + + def recv_control_cmd(self): """ - Receive a single message from the socket. + Recieve control command from client """ - if self.socket is None or self.socket.closed: - return "zmp socket has closed", None + while self.running: + try: + client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) + task = msgpack.unpackb(task_data) + task_id_str = task["task_id"] + except zmq.Again: + time.sleep(0.001) + continue + with self.mutex: + self.req_dict[task_id_str] = client + return task + + def response_for_control_cmd(self, task_id, result): + """ + Send command result back to client. + """ + if self.socket is None: + raise RuntimeError("Router socket not created.") try: - flags = zmq.NOBLOCK if not block else 0 - return None, self.socket.recv_pyobj(flags=flags) - except zmq.Again: - return None, None + result = msgpack.packb(result) + self.socket.send_multipart([self.req_dict[task_id], b"", result]) + except Exception as e: - self.close() - llm_logger.warning(f"{e}") - return str(e), None + llm_logger.error(f"Send result to zmq client failed: {e}") + with self.mutex: + self.req_dict.pop(task_id, None) + llm_logger.info(f"response contrl cmd finished, task_id: {task_id}") + def close(self): """ - Close the socket and context, and remove the IPC files. + Close the socket and context. """ if not self.running: return @@ -192,12 +256,8 @@ def close(self): self.running = False llm_logger.info("Closing ZMQ connection...") try: - if hasattr(self, "socket") and not self.socket.closed: - self.socket.close() - if self.socket is not None and not self.socket.closed: self.socket.close() - if not self.context.closed: self.context.term() @@ -205,5 +265,3 @@ def close(self): llm_logger.warning(f"Failed to close ZMQ connection - {e}") return - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() \ No newline at end of file diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py new file mode 100644 index 0000000000..940acf4d69 --- /dev/null +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -0,0 +1,97 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# **Note**: Just for internal use +import zmq +import threading + +from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics +from fastdeploy.inter_communicator import ZmqTcpServer +from fastdeploy.utils import envs, llm_logger + +class ExternalModuleAdapter: + def __int__(self, cfg, engine, dp_rank): + self.cfg = cfg + self.engine = engine + self.dp_rank = dp_rank + recv_control_cmd_ports = envs.ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER) + self.recv_external_instruct_thread = threading.Thread(target=self._recv_external_module_control_instruct, daemon=True) + self.recv_external_instruct_thread.start() + self.response_external_instruct_thread = threading.Thread(target=self._response_external_module_control_instruct, daemon=True) + self.response_external_instruct_thread.start() + + + def get_current_server_info(self): + """ + 获取服务当前资源信息 + """ + available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) + + available_block_num = self.engine.resource_manager.available_block_num() + server_info = { + "splitwise_role": self.cfg.splitwise_role, + "block_size": int(self.cfg.cache_config.block_size), + "block_num": int(available_block_num), + "dec_token_num": int(self.cfg.cache_config.dec_token_num), + "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, + "max_batch_size": int(available_batch_size), + "max_input_token_num": self.cfg.max_num_batched_tokens, + } + return server_info + + def _recv_external_module_control_instruct(self): + """ + Receive a multipart message from the control cmd socket. + """ + while True: + try: + task = self.recv_control_cmd_server.recv_control_cmd() + llm_logger.info(f"Recieve control task: {task}") + task_id_str = task["task_id"] + if task["cmd"] == "get_payload": + payload_info = self._get_current_server_info() + result = {"task_id": task_id_str, "result": payload_info} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + + elif task["cmd"] == "get_metrics": + metrics_text = get_filtered_metrics( + [], + extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), + ) + result = {"task_id": task_id_str, "result": metrics_text} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + elif task["cmd"] == "connect_rdma": + self.engine_worker_queue.put_connect_rdma_task(task) + + except Exception as e: + llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + + def _response_external_module_control_instruct(self): + while True: + try: + result_data = self.engine_worker_queue.get_connect_rdma_task_response() + if result_data: + task_id_str = result_data["task_id"] + result = {"task_id": task_id_str, "result": result_data} + llm_logger.info(f"Response for task: {task_id_str}") + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) + else: + time.sleep(0.001) + except Exception as e: + llm_logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") From 1df42fe4537b8904b736d9d63799f3b3908823d5 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 16:31:34 +0800 Subject: [PATCH 06/19] refactor code to make it more clear --- fastdeploy/engine/engine.py | 16 ++++++++-------- fastdeploy/engine/expert_service.py | 11 ++++++----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ee4544b9ef..7fa662d120 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -255,16 +255,16 @@ def start(self, api_server_pid=None): role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info - request_queues = None - result_queue = None + request_queues_for_dp_ipc = None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp + result_queue_for_dp_ipc = None if self.cfg.scheduler_config.name == "splitwise": self.scheduler.start(role, host_ip, disaggregate) elif self.cfg.scheduler_config.name == 'dp': - request_queues = [] - result_queue = multiprocessing.Queue() + request_queues_for_dp_ipc = [] + result_queue_for_dp_ipc = multiprocessing.Queue() for i in range(self.cfg.parallel_config.data_parallel_size): - request_queues.append(multiprocessing.Queue()) - self.scheduler.start(self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues, result_queue) + request_queues_for_dp_ipc.append(multiprocessing.Queue()) + self.scheduler.start(self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc) time.sleep(1) @@ -282,8 +282,8 @@ def start(self, api_server_pid=None): self.cfg, i + self.cfg.node_rank * self.cfg.worker_num_per_node, self.ipc_signal_suffix, - request_queues, - result_queue, + request_queues_for_dp_ipc, + result_queue_for_dp_ipc, ), ) ) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index a8e3b4f46a..9f4585d2d0 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -27,11 +27,12 @@ import numpy as np from fastdeploy.engine.resource_manager import ResourceManager -from fastdeploy.inter_communicator import EngineWorkerQueue, ZmqTcpServer +from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, llm_logger +from fastdeploy.splitwise.internal_adapter_utils import ExternalModuleAdapter from fastdeploy.metrics.metrics import EXCLUDE_LABELS, get_filtered_metrics, main_process_metrics @@ -117,7 +118,7 @@ def __init__(self, cfg, local_data_parallel_id): if envs.ENABLE_ENGINE_ZMQ_REMOTE_ACCESS: self.external_adapter = ExternalModuleAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) - def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues=None, result_queue=None): + def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread @@ -154,7 +155,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues=None, disaggregate = self.cfg.disaggregate_info if self.cfg.scheduler_config.name == 'dp': assert (request_queues is not None) and (result_queue is not None) - self.scheduler.start(local_data_parallel_id, request_queues, result_queue) + self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) elif self.cfg.scheduler_config.name == 'splitwise': self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() @@ -366,13 +367,13 @@ def _exit_sub_services(self): self.zmq_server.close() -def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix, request_queues=None, result_queue=None): +def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) try: - expert_service.start(ipc_signal_suffix, local_data_parallel_id, request_queues, result_queue) + expert_service.start(ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) expert_service.split_connector.start_receiver() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}") From e169959b12b656c8395217e6209869096fbf48b5 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 16:36:10 +0800 Subject: [PATCH 07/19] refactor code to make it more clear --- fastdeploy/engine/expert_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 9f4585d2d0..47c606329f 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -154,7 +154,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info if self.cfg.scheduler_config.name == 'dp': - assert (request_queues is not None) and (result_queue is not None) + assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None) self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) elif self.cfg.scheduler_config.name == 'splitwise': self.scheduler.start(role, host_ip, disaggregate) From 23d570eacec50480e58dba2b29ce4cfbad47a36c Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 16:42:38 +0800 Subject: [PATCH 08/19] refactor code to make it more clear --- fastdeploy/inter_communicator/zmq_server.py | 3 ++- fastdeploy/splitwise/internal_adapter_utils.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 68a44a735e..81571c8f51 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -17,6 +17,7 @@ import threading import time from abc import ABC, abstractmethod +import os import msgpack import zmq @@ -208,7 +209,7 @@ def __init__(self, port, mode): def _create_socket(self): """create and return a ZeroMQ socket.""" self.context = zmq.Context() - self.socket = self.context.socket(mode) + self.socket = self.context.socket(self.mode) self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) self.socket.setsockopt(zmq.SNDTIMEO, -1) self.socket.bind(f"tcp://*:{self.port}") diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index 940acf4d69..7fc5cd4810 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -21,6 +21,7 @@ from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics from fastdeploy.inter_communicator import ZmqTcpServer from fastdeploy.utils import envs, llm_logger +import traceback class ExternalModuleAdapter: def __int__(self, cfg, engine, dp_rank): From a12214a5f0a32cc8d60132290cb005ee775c4cfc Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 17:25:26 +0800 Subject: [PATCH 09/19] fix according to review --- fastdeploy/engine/engine.py | 12 ++++++------ fastdeploy/engine/expert_service.py | 6 +++--- fastdeploy/envs.py | 14 +++++++------- fastdeploy/inter_communicator/zmq_server.py | 12 +++++++----- fastdeploy/splitwise/internal_adapter_utils.py | 9 +++++---- 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 7fa662d120..1a010cc464 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -55,7 +55,7 @@ from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector -from fastdeploy.splitwise.internal_adapter_utils import ExternalModuleAdapter +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -182,10 +182,10 @@ def start(self, api_server_pid=None): self.data_processor = self.input_processor.create_processor() if api_server_pid is not None: - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: - self.recv_request_server = ZmqTcpServer(port=envs.ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) - self.send_response_server = ZmqTcpServer(port=envs.ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) - self.external_adapter = ExternalModuleAdapter(cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) + self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) + self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node) else: self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) @@ -432,7 +432,7 @@ def _insert_zmq_task_to_scheduler(self): if self.api_server_pid is None: return - if envs.ENABLE_EXTERNAL_MODULE_ACCESS: + if envs.FD_ENABLE_INTERNAL_ADAPTER: if self.cfg.splitwise_role == "decode": return diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 47c606329f..c1a2dda1e5 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -32,7 +32,7 @@ from fastdeploy.output.token_processor import TokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, llm_logger -from fastdeploy.splitwise.internal_adapter_utils import ExternalModuleAdapter +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.metrics.metrics import EXCLUDE_LABELS, get_filtered_metrics, main_process_metrics @@ -115,8 +115,8 @@ def __init__(self, cfg, local_data_parallel_id): ) self._finalizer = weakref.finalize(self, self._exit_sub_services) - if envs.ENABLE_ENGINE_ZMQ_REMOTE_ACCESS: - self.external_adapter = ExternalModuleAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None): """ diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 881fb86d0f..7ae016d303 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -81,13 +81,13 @@ # enable kv cache block scheduler v1 (no need for kv_cache_ratio) "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), # enable external module to access LLMEngine. - "ENABLE_EXTERNAL_MODULE_ACCESS": lambda: int(os.getenv("ENABLE_EXTERNAL_MODULE_ACCESS", "0")), - # LLMEngine recieve requests port, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 - "ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), - # LLMEngine send response port, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 - "ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), - # LLMEngine recieve control command port, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 - "ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), + "FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")), + # LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), + # LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"), + # LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1 + "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), } diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 81571c8f51..50930a1f2a 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -136,7 +136,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class ZmqIpcServer(ZmqServerBase): """ - ZmqIpcServer, used when ENABLE_EXTERNAL_MODULE_ACCESS=0 + ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0 """ def __init__(self, name, mode): @@ -155,10 +155,11 @@ def __init__(self, name, mode): def _create_socket(self): """create and return a ZeroMQ socket.""" self.context = zmq.Context() - self.context.socket(self.mode) - self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) - self.router.setsockopt(zmq.SNDTIMEO, -1) + self.socket = self.context.socket(self.mode) + self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.socket.setsockopt(zmq.SNDTIMEO, -1) self.socket.bind(f"ipc://{self.file_name}") + return self.socket def _clear_ipc(self, name): @@ -193,7 +194,7 @@ def close(self): class ZmqTcpServer(ZmqServerBase): """ - ZmqTcpServer, used when ENABLE_EXTERNAL_MODULE_ACCESS=1 + ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1 """ def __init__(self, port, mode): @@ -213,6 +214,7 @@ def _create_socket(self): self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) self.socket.setsockopt(zmq.SNDTIMEO, -1) self.socket.bind(f"tcp://*:{self.port}") + return self.socket def recv_control_cmd(self): """ diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index 7fc5cd4810..98be0c1420 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -17,18 +17,19 @@ # **Note**: Just for internal use import zmq import threading +import time from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics from fastdeploy.inter_communicator import ZmqTcpServer from fastdeploy.utils import envs, llm_logger import traceback -class ExternalModuleAdapter: - def __int__(self, cfg, engine, dp_rank): +class InternalAdapter: + def __init__(self, cfg, engine, dp_rank): self.cfg = cfg self.engine = engine self.dp_rank = dp_rank - recv_control_cmd_ports = envs.ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") + recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER) self.recv_external_instruct_thread = threading.Thread(target=self._recv_external_module_control_instruct, daemon=True) self.recv_external_instruct_thread.start() @@ -36,7 +37,7 @@ def __int__(self, cfg, engine, dp_rank): self.response_external_instruct_thread.start() - def get_current_server_info(self): + def _get_current_server_info(self): """ 获取服务当前资源信息 """ From 7b0e78597a077edf117b0bca384edddb51ac437e Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 17:27:17 +0800 Subject: [PATCH 10/19] fix according to review --- fastdeploy/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 7ae016d303..bf318b47e0 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -80,7 +80,7 @@ "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), # enable kv cache block scheduler v1 (no need for kv_cache_ratio) "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), - # enable external module to access LLMEngine. + # enable internal module to access LLMEngine. "FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")), # LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1 "FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"), From d86518104aa02cadcb546e1b78ca85bfd474de6d Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 17:49:03 +0800 Subject: [PATCH 11/19] fix according to review --- fastdeploy/inter_communicator/zmq_server.py | 2 +- fastdeploy/scheduler/dp_scheduler.py | 5 +++-- fastdeploy/splitwise/internal_adapter_utils.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 50930a1f2a..f1de4c66df 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -247,7 +247,7 @@ def response_for_control_cmd(self, task_id, result): with self.mutex: self.req_dict.pop(task_id, None) - llm_logger.info(f"response contrl cmd finished, task_id: {task_id}") + llm_logger.info(f"response control cmd finished, task_id: {task_id}") def close(self): """ diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 62ed25b742..74c2ff2975 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -15,8 +15,7 @@ """ import threading from multiprocessing import Queue -from typing import Dict, List -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import time from fastdeploy.engine.request import Request, RequestOutput @@ -141,6 +140,8 @@ def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue): def put_requests(self, requests: List[Dict]): results = [] for request in requests: + if not hasattr(request, 'dp_rank'): + raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}") self.request_queues[request.dp_rank].put(request) results.append((request.request_id, None)) return results diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index 98be0c1420..093702338f 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -79,7 +79,7 @@ def _recv_external_module_control_instruct(self): llm_logger.info(f"Response for task: {task_id_str}") self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) elif task["cmd"] == "connect_rdma": - self.engine_worker_queue.put_connect_rdma_task(task) + self.engine.engine_worker_queue.put_connect_rdma_task(task) except Exception as e: llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") @@ -87,7 +87,7 @@ def _recv_external_module_control_instruct(self): def _response_external_module_control_instruct(self): while True: try: - result_data = self.engine_worker_queue.get_connect_rdma_task_response() + result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response() if result_data: task_id_str = result_data["task_id"] result = {"task_id": task_id_str, "result": result_data} From 3bfa50234324e8395bdda9fa425b2d44edc338c3 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 1 Aug 2025 17:56:51 +0800 Subject: [PATCH 12/19] fix according to review --- fastdeploy/engine/request.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index acf717547a..f88d24152b 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -71,6 +71,7 @@ def __init__( guided_json_object: Optional[bool] = None, enable_thinking: Optional[bool] = True, trace_carrier: dict = dict(), + dp_rank: Optional[int] = None ) -> None: self.request_id = request_id self.prompt = prompt @@ -119,6 +120,7 @@ def __init__( self.task_type = RequestType.PREFILL self.idx = None self.need_prefill_tokens = self.prompt_token_ids_len + self.dp_rank = dp_rank @classmethod def from_dict(cls, d: dict): @@ -151,6 +153,7 @@ def from_dict(cls, d: dict): guided_json_object=d.get("guided_json_object", None), enable_thinking=d.get("enable_thinking", True), trace_carrier=d.get("trace_carrier", {}), + dp_rank=d.get("dp_rank", None) ) @property From 48b081f0994247a7504658fe71065e45d5586a25 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Sun, 3 Aug 2025 20:55:42 +0800 Subject: [PATCH 13/19] fix according to review --- fastdeploy/engine/engine.py | 107 ++++++++++---------- fastdeploy/engine/expert_service.py | 25 +++-- fastdeploy/inter_communicator/zmq_client.py | 23 ++--- fastdeploy/inter_communicator/zmq_server.py | 31 +++--- fastdeploy/scheduler/config.py | 22 ++-- 5 files changed, 107 insertions(+), 101 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 1a010cc464..ecce9c5058 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -48,14 +48,14 @@ EngineWorkerQueue, IPCSignal, ZmqIpcServer, - ZmqTcpServer + ZmqTcpServer, ) -from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics +from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor -from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter +from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -163,7 +163,6 @@ def __init__(self, cfg): self.cfg.guided_decoding_backend, disable_any_whitespace=self.cfg.disable_any_whitespace, ) - def start(self, api_server_pid=None): """ @@ -185,12 +184,61 @@ def start(self, api_server_pid=None): if envs.FD_ENABLE_INTERNAL_ADAPTER: self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) - self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node) + self.external_adapter = InternalAdapter( + cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node + ) else: self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) time.sleep(3) - + + self.cfg.init_cache_info() + + role = self.cfg.splitwise_role + host_ip = self.cfg.host_ip + disaggregate = self.cfg.disaggregate_info + request_queues_for_dp_ipc = ( + None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp + ) + result_queue_for_dp_ipc = None + if self.cfg.scheduler_config.name == "splitwise": + self.scheduler.start(role, host_ip, disaggregate) + elif self.cfg.scheduler_config.name == "dp": + request_queues_for_dp_ipc = [] + result_queue_for_dp_ipc = multiprocessing.Queue() + for i in range(self.cfg.parallel_config.data_parallel_size): + request_queues_for_dp_ipc.append(multiprocessing.Queue()) + self.scheduler.start( + self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc + ) + + time.sleep(1) + + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + self.dp_processed = [] + for i in range( + 1, + self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, + ): + time.sleep(1) + self.dp_processed.append( + multiprocessing.Process( + target=start_expert_service, + args=( + self.cfg, + i + self.cfg.node_rank * self.cfg.worker_num_per_node, + self.ipc_signal_suffix, + request_queues_for_dp_ipc, + result_queue_for_dp_ipc, + ), + ) + ) + llm_logger.info( + f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" + + f" data parallel id {i}" + ) + self.dp_processed[-1].start() + if self.do_profile == 0 and ( self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" ): @@ -250,52 +298,9 @@ def start(self, api_server_pid=None): self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.start() - self.cfg.init_cache_info() - - role = self.cfg.splitwise_role - host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info - request_queues_for_dp_ipc = None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp - result_queue_for_dp_ipc = None - if self.cfg.scheduler_config.name == "splitwise": - self.scheduler.start(role, host_ip, disaggregate) - elif self.cfg.scheduler_config.name == 'dp': - request_queues_for_dp_ipc = [] - result_queue_for_dp_ipc = multiprocessing.Queue() - for i in range(self.cfg.parallel_config.data_parallel_size): - request_queues_for_dp_ipc.append(multiprocessing.Queue()) - self.scheduler.start(self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc) - - time.sleep(1) - - if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: - self.dp_processed = [] - for i in range( - 1, - self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, - ): - time.sleep(1) - self.dp_processed.append( - multiprocessing.Process( - target=start_expert_service, - args=( - self.cfg, - i + self.cfg.node_rank * self.cfg.worker_num_per_node, - self.ipc_signal_suffix, - request_queues_for_dp_ipc, - result_queue_for_dp_ipc, - ), - ) - ) - llm_logger.info( - f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" - + f" data parallel id {i}" - ) - self.dp_processed[-1].start() - console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True - + def _zmq_send_generated_tokens(self): """ Recieve output for zmq @@ -431,7 +436,7 @@ def _fetch_request(): def _insert_zmq_task_to_scheduler(self): if self.api_server_pid is None: return - + if envs.FD_ENABLE_INTERNAL_ADAPTER: if self.cfg.splitwise_role == "decode": return diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index c1a2dda1e5..6b3c014760 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -22,7 +22,6 @@ import time import traceback import weakref -import zmq import numpy as np @@ -30,10 +29,9 @@ from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.output.token_processor import TokenProcessor +from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.utils import EngineError, console_logger, envs, llm_logger -from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter -from fastdeploy.metrics.metrics import EXCLUDE_LABELS, get_filtered_metrics, main_process_metrics class ExpertService: @@ -63,7 +61,7 @@ def __init__(self, cfg, local_data_parallel_id): self.scheduler = cfg.scheduler_config.scheduler() - if self.cfg.scheduler_config.name == 'splitwise': + if self.cfg.scheduler_config.name == "splitwise": self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}") self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id @@ -118,7 +116,9 @@ def __init__(self, cfg, local_data_parallel_id): if envs.FD_ENABLE_INTERNAL_ADAPTER: self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id) - def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None): + def start( + self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None + ): """ Initializes the engine and starts its sub-services. If `api_server_pid` is defined, will launch a thread @@ -133,7 +133,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.tensor_parallel_size, device_ids=self.cfg.local_device_ids, - pod_ip=self.cfg.pod_ips[0], + pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}", ) @@ -153,16 +153,15 @@ def start(self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp role = self.cfg.splitwise_role host_ip = self.cfg.host_ip disaggregate = self.cfg.disaggregate_info - if self.cfg.scheduler_config.name == 'dp': + if self.cfg.scheduler_config.name == "dp": assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None) self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) - elif self.cfg.scheduler_config.name == 'splitwise': + elif self.cfg.scheduler_config.name == "splitwise": self.scheduler.start(role, host_ip, disaggregate) self.cfg.print() console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") return True - def _insert_task_to_worker(self): """ @@ -367,13 +366,17 @@ def _exit_sub_services(self): self.zmq_server.close() -def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None): +def start_expert_service( + cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None +): """ Start expert service """ expert_service = ExpertService(cfg, local_data_parallel_id) try: - expert_service.start(ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc) + expert_service.start( + ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc + ) expert_service.split_connector.start_receiver() except Exception as e: llm_logger.exception(f"Expert service failed to start: {e}") diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 28ee8f7f2b..13242f2a20 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -14,17 +14,10 @@ # limitations under the License. """ -import os -import threading -import time from abc import ABC, abstractmethod -import msgpack import zmq -from fastdeploy import envs -from fastdeploy.utils import llm_logger - class ZmqClientBase(ABC): """ @@ -33,12 +26,12 @@ class ZmqClientBase(ABC): def __init__(self): pass - + @abstractmethod def _create_socket(self): """Abstract method to create and return a ZeroMQ socket.""" pass - + def _ensure_socket(self): """Ensure the socket is created before use.""" if self.socket is None: @@ -51,7 +44,6 @@ def connect(self): """ pass - def send_json(self, data): """ Send a JSON-serializable object over the socket. @@ -79,7 +71,6 @@ def recv_pyobj(self): """ self._ensure_socket() return self.socket.recv_pyobj() - class ZmqIpcClient(ZmqClientBase): @@ -87,16 +78,14 @@ def __init__(self, name, mode): self.name = name self.mode = mode self.file_name = f"/dev/shm/{name}.socket" - + self.context = zmq.Context() + self.socket = self.context.socket(self.mode) + def _create_socket(self): """create and return a ZeroMQ socket.""" self.context = zmq.Context() return self.context.socket(self.mode) - + def connect(self): self._ensure_socket() self.socket.connect(f"ipc://{self.file_name}") - - - - diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index f1de4c66df..c22c044ea1 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -14,10 +14,10 @@ # limitations under the License. """ +import os import threading import time from abc import ABC, abstractmethod -import os import msgpack import zmq @@ -30,19 +30,20 @@ class ZmqServerBase(ABC): """ ZmqServerBase """ + def __init__(self): - pass + self.socket = None @abstractmethod def _create_socket(self): """Abstract method to create and return a ZeroMQ socket.""" pass - + def _ensure_socket(self): """Ensure the socket is created before use.""" if self.socket is None: self.socket = self._create_socket() - + def pack_aggregated_data(self, data): """ Aggregate multiple responses into one and send them to the client. @@ -53,7 +54,7 @@ def pack_aggregated_data(self, data): result.add(response) result = msgpack.packb([result.to_dict()]) return result - + def receive_json_once(self, block=False): """ Receive a single message from the socket. @@ -87,7 +88,7 @@ def receive_pyobj_once(self, block=False): self.close() llm_logger.warning(f"{e}") return str(e), None - + def send_response(self, req_id, data): """ Send generated token result to client. @@ -125,14 +126,14 @@ def send_response(self, req_id, data): with self.mutex: self.req_dict.pop(req_id, None) llm_logger.info(f"send_multipart finished, req_id: {req_id}") - + @abstractmethod def close(self): pass def __exit__(self, exc_type, exc_val, exc_tb): self.close() - + class ZmqIpcServer(ZmqServerBase): """ @@ -151,17 +152,17 @@ def __init__(self, name, mode): self.mutex = threading.Lock() self.req_dict = dict() self.running = True - + self.context = zmq.Context() + self._create_socket() + def _create_socket(self): """create and return a ZeroMQ socket.""" - self.context = zmq.Context() self.socket = self.context.socket(self.mode) self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) self.socket.setsockopt(zmq.SNDTIMEO, -1) self.socket.bind(f"ipc://{self.file_name}") return self.socket - def _clear_ipc(self, name): """ Remove the IPC file with the given name. @@ -206,10 +207,11 @@ def __init__(self, port, mode): self.mutex = threading.Lock() self.req_dict = dict() self.running = True + self.context = zmq.Context() + self._create_socket() def _create_socket(self): """create and return a ZeroMQ socket.""" - self.context = zmq.Context() self.socket = self.context.socket(self.mode) self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) self.socket.setsockopt(zmq.SNDTIMEO, -1) @@ -220,6 +222,7 @@ def recv_control_cmd(self): """ Recieve control command from client """ + self._ensure_socket() while self.running: try: client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK) @@ -236,6 +239,7 @@ def response_for_control_cmd(self, task_id, result): """ Send command result back to client. """ + self._ensure_socket() if self.socket is None: raise RuntimeError("Router socket not created.") try: @@ -248,7 +252,7 @@ def response_for_control_cmd(self, task_id, result): with self.mutex: self.req_dict.pop(task_id, None) llm_logger.info(f"response control cmd finished, task_id: {task_id}") - + def close(self): """ Close the socket and context. @@ -267,4 +271,3 @@ def close(self): except Exception as e: llm_logger.warning(f"Failed to close ZMQ connection - {e}") return - diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index a12e2349b5..c831b0f44a 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -18,7 +18,7 @@ from fastdeploy.utils import llm_logger -from .dp_scheduler import DPLocalScheduler +from .dp_scheduler import DPScheduler from .global_scheduler import GlobalScheduler from .local_scheduler import LocalScheduler from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig @@ -90,8 +90,6 @@ def print(self): llm_logger.info("=============================================================") - - class DPLocalSchedulerConfig(LocalSchedulerConfig): """ Configuration class for DPLocalScheduler. @@ -110,7 +108,7 @@ def __init__( max_num_partial_prefills: int = 1, max_long_partial_prefills: int = 1, long_prefill_token_threshold: int = 0, - splitwise_role: str = 'prefill', + splitwise_role: str = "prefill", **kwargs, ): """ @@ -282,7 +280,7 @@ def __init__(self, name="local", **kwargs): if name == "splitwise": self.config = SplitWiseSchedulerConfig(**kwargs) - + if name == "dp": self.config = DPLocalSchedulerConfig(**kwargs) @@ -293,7 +291,7 @@ def check(self): Raises: Exception: If invalid scheduler type is specified """ - if self.name not in ["local", "global", "splitwise"]: + if self.name not in ["local", "global", "splitwise", "dp"]: raise Exception(f"Unknown scheduler type {self.name}") self.config.check() @@ -330,9 +328,17 @@ def scheduler(self): if self.name == "splitwise": return SplitWiseScheduler(self.config) - + if self.name == "dp": - return DPLocalScheduler(self.config) + return DPScheduler( + max_size=self.config.max_size, + ttl=self.config.ttl, + enable_chunked_prefill=self.config.enable_chunked_prefill, + max_num_partial_prefills=self.config.max_num_partial_prefills, + max_long_partial_prefills=self.config.max_long_partial_prefills, + long_prefill_token_threshold=self.config.long_prefill_token_threshold, + splitwise_role=self.config.splitwise_role, + ) return LocalScheduler( max_size=self.config.max_size, From b8d342f67f4b21511f8fe168e71f9a0a23c21981 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 3 Aug 2025 22:58:22 +0800 Subject: [PATCH 14/19] fix according to review --- .../cache_manager/prefix_cache_manager.py | 2 + fastdeploy/inter_communicator/zmq_server.py | 2 +- fastdeploy/scheduler/dp_scheduler.py | 20 +++++----- .../splitwise/internal_adapter_utils.py | 38 +++++++++++-------- 4 files changed, 37 insertions(+), 25 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 0ac34ad6ac..198925cd6f 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -60,6 +60,7 @@ def __init__( self.enable_splitwise = 0 self.splitwise_role = splitwise_role + self.config = config self.cache_config = config.cache_config self.speculative_config = config.speculative_config self.local_data_parallel_id = local_data_parallel_id @@ -185,6 +186,7 @@ def launch_cache_manager( + f" --engine_pid {pid_suffix}" + f" --protocol {cache_config.cache_transfer_protocol}" + f" --local_data_parallel_id {self.local_data_parallel_id}" + + f" --data_parallel_size {self.config.parallel_config.data_parallel_size}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1" diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index c22c044ea1..f4ee8be313 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -32,7 +32,7 @@ class ZmqServerBase(ABC): """ def __init__(self): - self.socket = None + pass @abstractmethod def _create_socket(self): diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 74c2ff2975..d55a687905 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ + import threading +import time from multiprocessing import Queue from typing import Dict, List, Optional -import time from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse @@ -33,7 +34,7 @@ def __init__( max_num_partial_prefills: int, max_long_partial_prefills: int, long_prefill_token_threshold: int, - splitwise_role: str = 'prefill' + splitwise_role: str = "prefill", ): super().__init__( max_size, @@ -75,7 +76,7 @@ def _recycle(self, request_id: Optional[str] = None): if request_id is not None: self.requests.pop(request_id, None) self.responses.pop(request_id, None) - if self.splitwise_role == 'decode': + if self.splitwise_role == "decode": return self.ids.pop(self.ids.index(request_id)) self.ids_read_cursor -= 1 @@ -107,8 +108,6 @@ def _recycle(self, request_id: Optional[str] = None): self.ids_read_cursor -= len(expired_ids) - - class DPScheduler: def __init__( self, @@ -118,7 +117,7 @@ def __init__( max_num_partial_prefills: int, max_long_partial_prefills: int, long_prefill_token_threshold: int, - splitwise_role: str = 'prefill' + splitwise_role: str = "prefill", ): self._scheduler = DPLocalScheduler( max_size, @@ -127,7 +126,7 @@ def __init__( max_num_partial_prefills, max_long_partial_prefills, long_prefill_token_threshold, - splitwise_role + splitwise_role, ) def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue): @@ -140,7 +139,7 @@ def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue): def put_requests(self, requests: List[Dict]): results = [] for request in requests: - if not hasattr(request, 'dp_rank'): + if not hasattr(request, "dp_rank"): raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}") self.request_queues[request.dp_rank].put(request) results.append((request.request_id, None)) @@ -170,8 +169,11 @@ def get_requests( available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch ) + def get_unhandled_request_num(self): + return len(self._scheduler.requests) + def put_results(self, results: List[RequestOutput]): self._scheduler.put_results(results) def get_results(self) -> Dict[str, List[RequestOutput]]: - return self.result_queue.get() \ No newline at end of file + return self.result_queue.get() diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index 093702338f..27bf49e22c 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -14,15 +14,19 @@ # limitations under the License. """ -# **Note**: Just for internal use -import zmq import threading import time +import traceback + +# **Note**: Just for internal use +import zmq -from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics from fastdeploy.inter_communicator import ZmqTcpServer -from fastdeploy.utils import envs, llm_logger -import traceback +from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics +from fastdeploy.utils import envs, get_logger + +logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log") + class InternalAdapter: def __init__(self, cfg, engine, dp_rank): @@ -31,12 +35,15 @@ def __init__(self, cfg, engine, dp_rank): self.dp_rank = dp_rank recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",") self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER) - self.recv_external_instruct_thread = threading.Thread(target=self._recv_external_module_control_instruct, daemon=True) + self.recv_external_instruct_thread = threading.Thread( + target=self._recv_external_module_control_instruct, daemon=True + ) self.recv_external_instruct_thread.start() - self.response_external_instruct_thread = threading.Thread(target=self._response_external_module_control_instruct, daemon=True) + self.response_external_instruct_thread = threading.Thread( + target=self._response_external_module_control_instruct, daemon=True + ) self.response_external_instruct_thread.start() - def _get_current_server_info(self): """ 获取服务当前资源信息 @@ -52,9 +59,10 @@ def _get_current_server_info(self): "available_resource": 1.0 * available_block_num / self.cfg.cache_config.total_block_num, "max_batch_size": int(available_batch_size), "max_input_token_num": self.cfg.max_num_batched_tokens, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), } return server_info - + def _recv_external_module_control_instruct(self): """ Receive a multipart message from the control cmd socket. @@ -62,12 +70,12 @@ def _recv_external_module_control_instruct(self): while True: try: task = self.recv_control_cmd_server.recv_control_cmd() - llm_logger.info(f"Recieve control task: {task}") + logger.info(f"Recieve control task: {task}") task_id_str = task["task_id"] if task["cmd"] == "get_payload": payload_info = self._get_current_server_info() result = {"task_id": task_id_str, "result": payload_info} - llm_logger.info(f"Response for task: {task_id_str}") + logger.info(f"Response for task: {task_id_str}") self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) elif task["cmd"] == "get_metrics": @@ -76,13 +84,13 @@ def _recv_external_module_control_instruct(self): extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1), ) result = {"task_id": task_id_str, "result": metrics_text} - llm_logger.info(f"Response for task: {task_id_str}") + logger.info(f"Response for task: {task_id_str}") self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) elif task["cmd"] == "connect_rdma": self.engine.engine_worker_queue.put_connect_rdma_task(task) except Exception as e: - llm_logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") + logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") def _response_external_module_control_instruct(self): while True: @@ -91,9 +99,9 @@ def _response_external_module_control_instruct(self): if result_data: task_id_str = result_data["task_id"] result = {"task_id": task_id_str, "result": result_data} - llm_logger.info(f"Response for task: {task_id_str}") + logger.info(f"Response for task: {task_id_str}") self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) else: time.sleep(0.001) except Exception as e: - llm_logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") + logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}") From 58d62e38ab7f1c55fb3707ecd335961103de762a Mon Sep 17 00:00:00 2001 From: root Date: Sun, 3 Aug 2025 23:34:13 +0800 Subject: [PATCH 15/19] fix bug --- fastdeploy/output/token_processor.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 27fda99870..3938db3624 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -412,7 +412,11 @@ def _process_sampling_with_logprob_batch_output(self): self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) break - if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + if ( + not is_prefill + or self.cfg.scheduler_config.name == "splitwise" + or self.cfg.scheduler_config.name == "dp" + ): batch_result.append(result) self.postprocess(batch_result) @@ -531,7 +535,11 @@ def _process_batch_output(self): self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) break - if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + if ( + not is_prefill + or self.cfg.scheduler_config.name == "splitwise" + or self.cfg.scheduler_config.name == "dp" + ): batch_result.append(result) self.postprocess(batch_result) From 228941a854d3d659b25dabcda714c717f15127f3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 4 Aug 2025 14:21:51 +0800 Subject: [PATCH 16/19] fix bug --- fastdeploy/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ecce9c5058..6ed5505094 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -293,7 +293,7 @@ def start(self, api_server_pid=None): # 单机逻辑 self.engine_worker_queue.available_prefill_instances.put(1) self.split_mode_get_tasks() - if self.cfg.scheduler_config.name == "splitwise": + if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp": self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.start() From 9a4eb5460652164cda66768e78d2c7270a7e8e52 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Mon, 4 Aug 2025 15:13:46 +0800 Subject: [PATCH 17/19] fix bug --- fastdeploy/splitwise/internal_adapter_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index 27bf49e22c..db3ea520d4 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -46,7 +46,7 @@ def __init__(self, cfg, engine, dp_rank): def _get_current_server_info(self): """ - 获取服务当前资源信息 + Get resources information """ available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) From fec7657fa339dc3a8ce0bc2ec9e1854801e3bb6f Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Mon, 4 Aug 2025 19:30:55 +0800 Subject: [PATCH 18/19] merge --- fastdeploy/cache_manager/prefix_cache_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 198925cd6f..0ac34ad6ac 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -60,7 +60,6 @@ def __init__( self.enable_splitwise = 0 self.splitwise_role = splitwise_role - self.config = config self.cache_config = config.cache_config self.speculative_config = config.speculative_config self.local_data_parallel_id = local_data_parallel_id @@ -186,7 +185,6 @@ def launch_cache_manager( + f" --engine_pid {pid_suffix}" + f" --protocol {cache_config.cache_transfer_protocol}" + f" --local_data_parallel_id {self.local_data_parallel_id}" - + f" --data_parallel_size {self.config.parallel_config.data_parallel_size}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1" From 0c16f8351d3fd3b7d1f4b51194b845825ac11393 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Fri, 22 Aug 2025 13:15:33 +0800 Subject: [PATCH 19/19] Fix merge --- fastdeploy/engine/engine.py | 47 ------------------------------------- 1 file changed, 47 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 49241f83d2..6966815f9f 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -197,53 +197,6 @@ def start(self, api_server_pid=None): self.recv_result_handle_thread.start() time.sleep(3) - self.cfg.init_cache_info() - - role = self.cfg.splitwise_role - host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info - request_queues_for_dp_ipc = ( - None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp - ) - result_queue_for_dp_ipc = None - if self.cfg.scheduler_config.name == "splitwise": - self.scheduler.start(role, host_ip, disaggregate) - elif self.cfg.scheduler_config.name == "dp": - request_queues_for_dp_ipc = [] - result_queue_for_dp_ipc = multiprocessing.Queue() - for i in range(self.cfg.parallel_config.data_parallel_size): - request_queues_for_dp_ipc.append(multiprocessing.Queue()) - self.scheduler.start( - self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc - ) - - time.sleep(1) - - if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: - self.dp_processed = [] - for i in range( - 1, - self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, - ): - time.sleep(1) - self.dp_processed.append( - multiprocessing.Process( - target=start_expert_service, - args=( - self.cfg, - i + self.cfg.node_rank * self.cfg.worker_num_per_node, - self.ipc_signal_suffix, - request_queues_for_dp_ipc, - result_queue_for_dp_ipc, - ), - ) - ) - llm_logger.info( - f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" - + f" data parallel id {i}" - ) - self.dp_processed[-1].start() - if self.do_profile == 0 and ( self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed" ):