From 8eda8ab3d28dd48c7f3e4af4ece8e3918344cb85 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 28 Jul 2025 16:29:07 +0800 Subject: [PATCH 01/20] dp balancer abstract --- lightllm/server/api_cli.py | 8 +++++- .../server/router/req_queue/dp_base_queue.py | 28 +++---------------- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..4aac3841a 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -112,7 +112,7 @@ def make_argument_parser() -> argparse.ArgumentParser: help="tool call parser type", ) parser.add_argument( - "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" + "--running_max_req_size", type=int, default=2048, help="the max size for forward requests in the same time" ) parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes") parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node") @@ -137,6 +137,12 @@ def make_argument_parser() -> argparse.ArgumentParser: using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please do not set it and keep the default value as 1.""", ) + parser.add_argument( + "--dp_balancer", + type=str, + default="round_robin", + help="the dp balancer type, default is round_robin", + ) parser.add_argument( "--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len" ) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 198495933..352f584fd 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -2,6 +2,7 @@ from typing import List from ..batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.log_utils import init_logger @@ -12,14 +13,13 @@ class DpQueue: def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.dp_size_in_node = dp_size_in_node self.base_queue_class = base_queue_class - self.pre_select_dp_index = self.dp_size_in_node - 1 from lightllm.server.router.manager import RouterManager self.router: RouterManager = router self.inner_queues: List[BaseQueue] = [ base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node) ] - + self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues) return def get_dp_queue(self, dp_index: int): @@ -49,8 +49,7 @@ def append(self, req: Req): suggested_dp_index = req.sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") - suggested_dp_index = self._get_suggest_dp_index() - self.pre_select_dp_index = suggested_dp_index + suggested_dp_index = self.dp_balancer.get_suggest_dp_index() req.sample_params.suggested_dp_index = suggested_dp_index self.inner_queues[suggested_dp_index].append(req) else: @@ -59,12 +58,11 @@ def append(self, req: Req): def extend(self, req_group: List[Req]): # 同一个组的,要分配在同一个 dp 上,效率最高 - index = self._get_suggest_dp_index() + index = self.dp_balancer.get_suggest_dp_index() for req in req_group: suggested_dp_index = req.sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") - self.pre_select_dp_index = index req.sample_params.suggested_dp_index = index self.inner_queues[index].append(req) else: @@ -87,21 +85,3 @@ def update_token_load(self, current_batch: Batch, force_update=False): self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index) self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index) return - - def _get_suggest_dp_index(self): - min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues) - select_dp_indexes = [ - i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length - ] - - # multi thread safe keep - if not select_dp_indexes: - return random.randint(0, self.dp_size_in_node - 1) - - # round_robin select. - for i in range(self.dp_size_in_node): - next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node - if next_dp_index in select_dp_indexes: - return next_dp_index - - return random.choice(select_dp_indexes) From 86df27cf636e5170eb55b40b8b8e70befd98d99c Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 28 Jul 2025 21:16:21 +0800 Subject: [PATCH 02/20] add dp balancer for dp --- lightllm/server/router/batch.py | 9 +++ .../router/req_queue/dp_balancer/__init__.py | 13 ++++ .../dp_balancer/dp_balancer_for_pd.py | 63 ++++++++++++++++++ .../req_queue/dp_balancer/dp_base_balancer.py | 65 +++++++++++++++++++ .../server/router/req_queue/dp_base_queue.py | 37 ++++++----- 5 files changed, 169 insertions(+), 18 deletions(-) create mode 100644 lightllm/server/router/req_queue/dp_balancer/__init__.py create mode 100644 lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py create mode 100644 lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 1336cd1dc..40529a3f5 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -40,6 +40,15 @@ def get_req_list_for_dp(self, dp_index: int): req_list.append(req) return req_list + def get_all_dp_req_num(self) -> List[int]: + if self.dp_size_in_node == 1: + return [len(self.reqs)] + + all_dp_req_num = [0 for _ in range(self.dp_size_in_node)] + for req in self.reqs: + all_dp_req_num[req.sample_params.suggested_dp_index] += 1 + return all_dp_req_num + def filter_out_finished_req(self, shm_req_manager: ShmReqManager): unfinished_req_ids = [] for req in self.reqs: diff --git a/lightllm/server/router/req_queue/dp_balancer/__init__.py b/lightllm/server/router/req_queue/dp_balancer/__init__.py new file mode 100644 index 000000000..3b17c6293 --- /dev/null +++ b/lightllm/server/router/req_queue/dp_balancer/__init__.py @@ -0,0 +1,13 @@ +from .dp_base_balancer import RoundRobinDpBalancer +from typing import List +from lightllm.server.router.req_queue.base_queue import BaseQueue +from .dp_balancer_for_pd import DpBalancerForPd + + +def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]): + if args.dp_balancer == "round_robin": + return DpBalancerForPd(dp_size_in_node, inner_queues) + if args.run_mode in ["prefill", "decode"]: + return DpBalancerForPd(dp_size_in_node, inner_queues) + else: + raise ValueError(f"Invalid dp balancer: {args.dp_balancer}") diff --git a/lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py b/lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py new file mode 100644 index 000000000..2f73e552c --- /dev/null +++ b/lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py @@ -0,0 +1,63 @@ +from typing import List, Union +from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.server.router.batch import Batch, Req +from lightllm.utils.log_utils import init_logger +from .dp_base_balancer import DpBalancer + +logger = init_logger(__name__) + + +class DpBalancerForPd(DpBalancer): + """ + This balancer is main to balance the batch size of each dp rank. + Because, for dp mode, if it exists a dp rank without any request, it will + padding a request and cause the waste of GPU compute resource. + """ + + def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): + super().__init__(dp_size_in_node, inner_queues) + + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + if len(reqs_waiting_for_dp_index) == 0: + return + # calculate the total load of each dp rank + if current_batch is not None: + all_dp_req_num = current_batch.get_all_dp_req_num() + total_load_per_dp = [ + all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node) + ] + else: + total_load_per_dp = [len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)] + for req_group in reqs_waiting_for_dp_index: + # calculate the length of this request group + if isinstance(req_group, list): + req_length = len(req_group) + else: + req_length = 1 + + # find the dp rank with minimum load + min_load = min(total_load_per_dp) + select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load] + + # select the dp rank with the minimum load + if len(select_dp_indexes) == 1: + suggested_dp_index = select_dp_indexes[0] + else: + # if multiple dp ranks have the same minimum load, randomly select one + import random + + suggested_dp_index = random.choice(select_dp_indexes) + + # assign the request to the dp rank and update the load count + if not isinstance(req_group, list): + req_group = [req_group] + + for req in req_group: + req.sample_params.suggested_dp_index = suggested_dp_index + self.inner_queues[suggested_dp_index].append(req) + + # update the load count for this dp rank + total_load_per_dp[suggested_dp_index] += req_length + + reqs_waiting_for_dp_index.clear() + return diff --git a/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py b/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py new file mode 100644 index 000000000..2e564b01a --- /dev/null +++ b/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py @@ -0,0 +1,65 @@ +import random +from abc import ABC, abstractmethod +from typing import List, Union +from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.server.router.batch import Batch, Req +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class DpBalancer(ABC): + """ + DP负载均衡器基类 + 定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法 + """ + + def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): + self.dp_size_in_node = dp_size_in_node + self.inner_queues = inner_queues + self.pre_select_dp_index = self.dp_size_in_node - 1 + + @abstractmethod + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + pass + + +class RoundRobinDpBalancer(DpBalancer): + """ + 轮询负载均衡器 + 在队列长度最小的DP中进行轮询选择 + """ + + def get_suggest_dp_index( + self, + ) -> int: + min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues) + select_dp_indexes = [ + i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length + ] + + # 如果没有可选择的索引,随机选择一个 + if not select_dp_indexes: + self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1) + return self.pre_select_dp_index + + # 轮询选择 + for i in range(self.dp_size_in_node): + next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node + if next_dp_index in select_dp_indexes: + self.pre_select_dp_index = next_dp_index + return self.pre_select_dp_index + + self.pre_select_dp_index = random.choice(select_dp_indexes) + return self.pre_select_dp_index + + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + for req_group in reqs_waiting_for_dp_index: + suggested_dp_index = self.get_suggest_dp_index() + if not isinstance(req_group, list): + req_group = [req_group] + for req in req_group: + req.sample_params.suggested_dp_index = suggested_dp_index + self.inner_queues[suggested_dp_index].append(req) + reqs_waiting_for_dp_index.clear() + return diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 352f584fd..f3b50625e 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -20,6 +20,7 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node) ] self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues) + self.reqs_waiting_for_dp_index = [] return def get_dp_queue(self, dp_index: int): @@ -31,10 +32,16 @@ def get_wait_req_num(self): # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): - batches = [ - self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node) - ] - return self._merge_batch(batches) + try: + self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index) + batches = [ + self.inner_queues[dp_index].generate_new_batch(current_batch) + for dp_index in range(self.dp_size_in_node) + ] + return self._merge_batch(batches) + except Exception as e: + logger.error(f"generate new batch failed: {e}") + raise e def _merge_batch(self, dp_batches: List[Batch]): merged_batch: Batch = None @@ -48,26 +55,20 @@ def _merge_batch(self, dp_batches: List[Batch]): def append(self, req: Req): suggested_dp_index = req.sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: - logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") - suggested_dp_index = self.dp_balancer.get_suggest_dp_index() - req.sample_params.suggested_dp_index = suggested_dp_index - self.inner_queues[suggested_dp_index].append(req) + # 在调度时,统一分配请求id + self.reqs_waiting_for_dp_index.append(req) else: self.inner_queues[suggested_dp_index].append(req) return def extend(self, req_group: List[Req]): - # 同一个组的,要分配在同一个 dp 上,效率最高 - index = self.dp_balancer.get_suggest_dp_index() - for req in req_group: - suggested_dp_index = req.sample_params.suggested_dp_index - if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: - logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid") - req.sample_params.suggested_dp_index = index - self.inner_queues[index].append(req) - else: + suggested_dp_index = req_group[0].sample_params.suggested_dp_index + if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: + # 同一个组的,要分配在同一个 dp 上 + self.reqs_waiting_for_dp_index.append(req_group) + else: + for req in req_group: self.inner_queues[suggested_dp_index].append(req) - return def is_busy(self): From 54cd9aca160458a48f827600c7af444c3d630379 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 30 Jul 2025 00:48:41 +0800 Subject: [PATCH 03/20] fix test --- .../benchmark/static_inference/model_infer.py | 114 +++++++++--------- 1 file changed, 59 insertions(+), 55 deletions(-) diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 73a99ff28..dead7ea92 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -79,16 +79,16 @@ def overlap_prefill( _0_b_seq_len = b_seq_len[: batch_size // 2] _o_b_ready_cache_len = b_ready_cache_len[: batch_size // 2] micro_batch1 = ModelInput( - _0_batch_size, - _0_total_token_num, - _0_max_len_in_batch, - _0_input_ids, - _0_mem_indexes, - _0_b_req_idx, - _0_b_seq_len, - True, - _o_b_ready_cache_len, - {}, + batch_size=_0_batch_size, + total_token_num=_0_total_token_num, + max_len_in_batch=_0_max_len_in_batch, + input_ids=_0_input_ids, + b_req_idx=_0_b_req_idx, + b_seq_len=_0_b_seq_len, + is_prefill=True, + b_ready_cache_len=_o_b_ready_cache_len, + multimodal_params={}, + mem_indexes_cpu=_0_mem_indexes, ) _1_batch_size = batch_size - batch_size // 2 @@ -101,16 +101,16 @@ def overlap_prefill( _1_b_ready_cache_len = b_ready_cache_len[batch_size // 2 :] micro_batch2 = ModelInput( - _1_batch_size, - _1_total_token_num, - _1_max_len_in_batch, - _1_input_ids, - _1_mem_indexes, - _1_b_req_idx, - _1_b_seq_len, - True, - _1_b_ready_cache_len, - {}, + batch_size=_1_batch_size, + total_token_num=_1_total_token_num, + max_len_in_batch=_1_max_len_in_batch, + input_ids=_1_input_ids, + b_req_idx=_1_b_req_idx, + b_seq_len=_1_b_seq_len, + is_prefill=True, + b_ready_cache_len=_1_b_ready_cache_len, + multimodal_params={}, + mem_indexes_cpu=_1_mem_indexes, ) output, output1 = model_part.microbatch_overlap_prefill(micro_batch1, micro_batch2) @@ -130,13 +130,13 @@ def overlap_decode( _0_b_req_idx = b_req_idx[: batch_size // 2] _0_b_seq_len = b_seq_len[: batch_size // 2] micro_batch1 = ModelInput( - _0_batch_size, - _0_total_token_num, - _0_max_len_in_batch, - _0_input_ids, - _0_mem_indexes, - _0_b_req_idx, - _0_b_seq_len, + batch_size=_0_batch_size, + total_token_num=_0_total_token_num, + max_len_in_batch=_0_max_len_in_batch, + input_ids=_0_input_ids, + b_req_idx=_0_b_req_idx, + b_seq_len=_0_b_seq_len, + mem_indexes_cpu=_0_mem_indexes, ) _1_batch_size = batch_size - batch_size // 2 @@ -148,13 +148,13 @@ def overlap_decode( _1_b_seq_len = b_seq_len[batch_size // 2 :] micro_batch2 = ModelInput( - _1_batch_size, - _1_total_token_num, - _1_max_len_in_batch, - _1_input_ids, - _1_mem_indexes, - _1_b_req_idx, - _1_b_seq_len, + batch_size=_1_batch_size, + total_token_num=_1_total_token_num, + max_len_in_batch=_1_max_len_in_batch, + input_ids=_1_input_ids, + b_req_idx=_1_b_req_idx, + b_seq_len=_1_b_seq_len, + mem_indexes_cpu=_1_mem_indexes, ) output, output1 = model_part.microbatch_overlap_decode(micro_batch1, micro_batch2) @@ -174,30 +174,34 @@ def prefill( total_token_num, b_ready_cache_len, ): + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") model_input = ModelInput( - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + mem_indexes_cpu=mem_indexes, is_prefill=True, - b_ready_cache_len=b_ready_cache_len, + b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len ) model_output = model_part.forward(model_input) return model_output.logits def decode(model_part, batch_size, max_len_in_batch, input_ids, mem_indexes, b_req_idx, b_seq_len, total_token_num): + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") model_input = ModelInput( - batch_size, - total_token_num, - max_len_in_batch, - input_ids, - mem_indexes, - b_req_idx, - b_seq_len, + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=max_len_in_batch, + input_ids=input_ids, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + b_mtp_index=b_mtp_index, + mem_indexes_cpu=mem_indexes, is_prefill=False, ) model_output = model_part.forward(model_input) @@ -222,7 +226,7 @@ def run_forward_once( ): test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) test_data = test_data.reshape(-1) - test_data = torch.from_numpy(test_data).cuda() + test_data = torch.from_numpy(test_data) import torch.distributed as dist dist.barrier() @@ -234,15 +238,15 @@ def run_forward_once( prefill_start_time = time.time() b_req_idx = torch.tensor( - [model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cpu" ) - b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu") for i in range(batch_size): b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]) rank_id = model_kvargs["rank_id"] @@ -303,7 +307,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]) max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, From ea0ada4c74c960505b1dcf26fb783848b6a12289 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 31 Jul 2025 17:47:25 +0800 Subject: [PATCH 04/20] update router --- lightllm/server/router/manager.py | 10 ++-------- .../router/model_infer/mode_backend/base_backend.py | 7 +++++++ .../server/router/req_queue/chunked_prefill/impl.py | 1 + 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 24b8a9ddb..02b1004cd 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -197,10 +197,6 @@ async def wait_to_model_ready(self): return def _get_schedule_time_interval(self): - if self.running_batch is None: - # 没有运行中的 batch 时,每 10ms 触发一次请求调度 - return 0.01 - # dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求 return self.schedule_time_interval @@ -370,9 +366,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): def _generate_new_batch(self): # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。 - new_batch = self.req_queue.generate_new_batch( - Batch.merge_two_batch(self.running_batch, self.schedule_new_batch) - ) + new_batch = self.req_queue.generate_new_batch(self.schedule_new_batch) self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) return @@ -469,7 +463,7 @@ async def _recv_new_reqs_and_schedule(self): if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: - if self._get_paused_req_num() == 0: + if self._get_paused_req_num() == 0 and self.shm_reqs_io_buffer.is_empty(): self._generate_new_batch() return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..2fe415e2b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -75,6 +75,7 @@ def init_model(self, kvargs): self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache + self.batch_max_tokens = self.args.batch_max_tokens self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1 @@ -391,6 +392,7 @@ def _get_classed_reqs( # 请求,其逻辑是不适合的。 pause_max_req_num = 2 wait_pause_count = 0 + prefill_tokens = 0 # 因为会使用到 radix cache 和 mem_manager 的计数信息 # 所以需要加锁保护。 @@ -439,6 +441,11 @@ def _get_classed_reqs( wait_pause_count += 1 else: token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill) + if prefill_tokens + token_num > self.batch_max_tokens: + # 跳过等下次prefill,避免oom + prefill_tokens = 0 + break + prefill_tokens += token_num if token_num <= can_alloc_token_num: prefill_reqs.append(req_obj) can_alloc_token_num -= token_num diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index f1dae4cac..9bd3946e8 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -69,6 +69,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch_first_router_need_tokens = ( 0 if current_batch is None else current_batch.get_batch_decode_need_tokens()[self.dp_index] ) + print(f"new_batch_first_router_need_tokens: {new_batch_first_router_need_tokens}") self._init_cache_list(current_batch, is_busy) can_run_list = [] From d3f12d0a8a93d53f64ad240249dc69281f45db08 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 12 Aug 2025 15:11:41 +0800 Subject: [PATCH 05/20] rename --- lightllm/server/api_cli.py | 5 +++-- lightllm/server/router/req_queue/dp_balancer/__init__.py | 8 ++++---- .../{dp_balancer_for_pd.py => dp_bs_balancer.py} | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) rename lightllm/server/router/req_queue/dp_balancer/{dp_balancer_for_pd.py => dp_bs_balancer.py} (98%) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4aac3841a..b809b7765 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -140,8 +140,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--dp_balancer", type=str, - default="round_robin", - help="the dp balancer type, default is round_robin", + default="bs_balancer", + choices=["round_robin", "bs_balancer"], + help="the dp balancer type, default is bs_balancer", ) parser.add_argument( "--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len" diff --git a/lightllm/server/router/req_queue/dp_balancer/__init__.py b/lightllm/server/router/req_queue/dp_balancer/__init__.py index 3b17c6293..9ca642610 100644 --- a/lightllm/server/router/req_queue/dp_balancer/__init__.py +++ b/lightllm/server/router/req_queue/dp_balancer/__init__.py @@ -1,13 +1,13 @@ from .dp_base_balancer import RoundRobinDpBalancer from typing import List from lightllm.server.router.req_queue.base_queue import BaseQueue -from .dp_balancer_for_pd import DpBalancerForPd +from .dp_balancer_bs import DpBsBalancer def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]): if args.dp_balancer == "round_robin": - return DpBalancerForPd(dp_size_in_node, inner_queues) - if args.run_mode in ["prefill", "decode"]: - return DpBalancerForPd(dp_size_in_node, inner_queues) + return RoundRobinDpBalancer(dp_size_in_node, inner_queues) + elif args.dp_balancer == "bs_balancer": + return DpBsBalancer(dp_size_in_node, inner_queues) else: raise ValueError(f"Invalid dp balancer: {args.dp_balancer}") diff --git a/lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py b/lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py similarity index 98% rename from lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py rename to lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py index 2f73e552c..f6efa9bac 100644 --- a/lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py +++ b/lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py @@ -7,7 +7,7 @@ logger = init_logger(__name__) -class DpBalancerForPd(DpBalancer): +class DpBsBalancer(DpBalancer): """ This balancer is main to balance the batch size of each dp rank. Because, for dp mode, if it exists a dp rank without any request, it will From dc1e2f08ece58f55984975d2da9a55afb603e868 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 12 Aug 2025 16:25:14 +0800 Subject: [PATCH 06/20] fix --- lightllm/server/router/req_queue/dp_balancer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/req_queue/dp_balancer/__init__.py b/lightllm/server/router/req_queue/dp_balancer/__init__.py index 9ca642610..0f73aad97 100644 --- a/lightllm/server/router/req_queue/dp_balancer/__init__.py +++ b/lightllm/server/router/req_queue/dp_balancer/__init__.py @@ -1,7 +1,7 @@ from .dp_base_balancer import RoundRobinDpBalancer from typing import List from lightllm.server.router.req_queue.base_queue import BaseQueue -from .dp_balancer_bs import DpBsBalancer +from .dp_bs_balancer import DpBsBalancer def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]): From fc1d5c691a78f8aba8e96a5419a420faf4742f5d Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:09:24 +0800 Subject: [PATCH 07/20] fix --- lightllm/server/api_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4eaa58520..ae9f7541d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -119,7 +119,7 @@ def make_argument_parser() -> argparse.ArgumentParser: help="tool call parser type", ) parser.add_argument( - "--running_max_req_size", type=int, default=2048, help="the max size for forward requests in the same time" + "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" ) parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes") parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node") From d5725d5c58ce7c805cf477f9690e98c3613e72db Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 25 Aug 2025 10:47:54 +0800 Subject: [PATCH 08/20] fix --- lightllm/server/router/req_queue/chunked_prefill/impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 9bd3946e8..f1dae4cac 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -69,7 +69,6 @@ def generate_new_batch(self, current_batch: Batch): new_batch_first_router_need_tokens = ( 0 if current_batch is None else current_batch.get_batch_decode_need_tokens()[self.dp_index] ) - print(f"new_batch_first_router_need_tokens: {new_batch_first_router_need_tokens}") self._init_cache_list(current_batch, is_busy) can_run_list = [] From a0a7526c465f7226e59c5cc918cfa7ae918030b7 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:05:05 +0800 Subject: [PATCH 09/20] fix --- .../server/router/model_infer/mode_backend/base_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 4daeea7b3..a3572ff09 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -446,9 +446,7 @@ def _get_classed_reqs( else: token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill) if prefill_tokens + token_num > self.batch_max_tokens: - # 跳过等下次prefill,避免oom - prefill_tokens = 0 - break + continue prefill_tokens += token_num if token_num <= can_alloc_token_num: prefill_reqs.append(req_obj) From 69c9bb813d4a94a2cb27bbeb7fd0d2556b0a3195 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:11:16 +0800 Subject: [PATCH 10/20] fix --- lightllm/server/router/model_infer/mode_backend/base_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a3572ff09..7c2311d56 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -447,8 +447,8 @@ def _get_classed_reqs( token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill) if prefill_tokens + token_num > self.batch_max_tokens: continue - prefill_tokens += token_num if token_num <= can_alloc_token_num: + prefill_tokens += token_num prefill_reqs.append(req_obj) can_alloc_token_num -= token_num else: From d038f96b4af859dd1ccabba74f20bf60f7322c68 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 25 Aug 2025 13:54:05 +0800 Subject: [PATCH 11/20] update --- lightllm/server/router/manager.py | 4 +++- .../server/router/req_queue/dp_balancer/dp_base_balancer.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index eb0a63d4e..62c61ff66 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -397,7 +397,9 @@ def _add_req(self, group_req_indexes: GroupReqIndexes): def _generate_new_batch(self): # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。 - new_batch = self.req_queue.generate_new_batch(self.schedule_new_batch) + new_batch = self.req_queue.generate_new_batch( + Batch.merge_two_batch(self.running_batch, self.schedule_new_batch) + ) self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch) return diff --git a/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py b/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py index 2e564b01a..8f872d0a9 100644 --- a/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py +++ b/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py @@ -17,7 +17,6 @@ class DpBalancer(ABC): def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): self.dp_size_in_node = dp_size_in_node self.inner_queues = inner_queues - self.pre_select_dp_index = self.dp_size_in_node - 1 @abstractmethod def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: @@ -30,6 +29,10 @@ class RoundRobinDpBalancer(DpBalancer): 在队列长度最小的DP中进行轮询选择 """ + def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): + super().__init__(dp_size_in_node, inner_queues) + self.pre_select_dp_index = self.dp_size_in_node - 1 + def get_suggest_dp_index( self, ) -> int: From f1e970770e33ebe469fa0e1ddf84d6eec0a7f6d2 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 25 Aug 2025 14:50:23 +0800 Subject: [PATCH 12/20] update router mananger --- lightllm/server/router/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 62c61ff66..60b86053a 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -496,7 +496,7 @@ async def _recv_new_reqs_and_schedule(self): if self.is_multinode_tp: self._multinode_tp_generate_new_batch() else: - if self._get_paused_req_num() == 0 and self.shm_reqs_io_buffer.is_empty(): + if self._get_paused_req_num() == 0: self._generate_new_batch() return From 0f517b0eb832f33015f723d001c5a74bc75075cc Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 16:28:51 +0800 Subject: [PATCH 13/20] update --- .../router/req_queue/dp_balancer/__init__.py | 4 ++-- .../router/req_queue/dp_balancer/base.py | 23 +++++++++++++++++++ .../dp_balancer/{dp_bs_balancer.py => bs.py} | 2 +- .../{dp_base_balancer.py => roundrobin.py} | 17 +------------- 4 files changed, 27 insertions(+), 19 deletions(-) create mode 100644 lightllm/server/router/req_queue/dp_balancer/base.py rename lightllm/server/router/req_queue/dp_balancer/{dp_bs_balancer.py => bs.py} (98%) rename lightllm/server/router/req_queue/dp_balancer/{dp_base_balancer.py => roundrobin.py} (80%) diff --git a/lightllm/server/router/req_queue/dp_balancer/__init__.py b/lightllm/server/router/req_queue/dp_balancer/__init__.py index 0f73aad97..34f994f8a 100644 --- a/lightllm/server/router/req_queue/dp_balancer/__init__.py +++ b/lightllm/server/router/req_queue/dp_balancer/__init__.py @@ -1,7 +1,7 @@ -from .dp_base_balancer import RoundRobinDpBalancer +from .roundrobin import RoundRobinDpBalancer from typing import List from lightllm.server.router.req_queue.base_queue import BaseQueue -from .dp_bs_balancer import DpBsBalancer +from .bs import DpBsBalancer def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]): diff --git a/lightllm/server/router/req_queue/dp_balancer/base.py b/lightllm/server/router/req_queue/dp_balancer/base.py new file mode 100644 index 000000000..7f439947c --- /dev/null +++ b/lightllm/server/router/req_queue/dp_balancer/base.py @@ -0,0 +1,23 @@ +import random +from abc import ABC, abstractmethod +from typing import List, Union +from lightllm.server.router.req_queue.base_queue import BaseQueue +from lightllm.server.router.batch import Batch, Req +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class DpBalancer(ABC): + """ + DP负载均衡器基类 + 定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法 + """ + + def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): + self.dp_size_in_node = dp_size_in_node + self.inner_queues = inner_queues + + @abstractmethod + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + pass diff --git a/lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py b/lightllm/server/router/req_queue/dp_balancer/bs.py similarity index 98% rename from lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py rename to lightllm/server/router/req_queue/dp_balancer/bs.py index f6efa9bac..730ed9b8a 100644 --- a/lightllm/server/router/req_queue/dp_balancer/dp_bs_balancer.py +++ b/lightllm/server/router/req_queue/dp_balancer/bs.py @@ -2,7 +2,7 @@ from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.server.router.batch import Batch, Req from lightllm.utils.log_utils import init_logger -from .dp_base_balancer import DpBalancer +from .base import DpBalancer logger = init_logger(__name__) diff --git a/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py b/lightllm/server/router/req_queue/dp_balancer/roundrobin.py similarity index 80% rename from lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py rename to lightllm/server/router/req_queue/dp_balancer/roundrobin.py index 8f872d0a9..c2bda3461 100644 --- a/lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py +++ b/lightllm/server/router/req_queue/dp_balancer/roundrobin.py @@ -1,28 +1,13 @@ import random -from abc import ABC, abstractmethod from typing import List, Union from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.server.router.batch import Batch, Req from lightllm.utils.log_utils import init_logger +from .base import DpBalancer logger = init_logger(__name__) -class DpBalancer(ABC): - """ - DP负载均衡器基类 - 定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法 - """ - - def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): - self.dp_size_in_node = dp_size_in_node - self.inner_queues = inner_queues - - @abstractmethod - def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: - pass - - class RoundRobinDpBalancer(DpBalancer): """ 轮询负载均衡器 From 89033a4b2886618317666449452fee4e2f3cfe80 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 16:48:55 +0800 Subject: [PATCH 14/20] update code --- .../server/router/req_queue/dp_balancer/bs.py | 28 +++++-------------- .../server/router/req_queue/dp_base_queue.py | 2 +- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/lightllm/server/router/req_queue/dp_balancer/bs.py b/lightllm/server/router/req_queue/dp_balancer/bs.py index 730ed9b8a..2d9d3ed50 100644 --- a/lightllm/server/router/req_queue/dp_balancer/bs.py +++ b/lightllm/server/router/req_queue/dp_balancer/bs.py @@ -1,3 +1,4 @@ +import random from typing import List, Union from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.server.router.batch import Batch, Req @@ -21,37 +22,22 @@ def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: Lis if len(reqs_waiting_for_dp_index) == 0: return # calculate the total load of each dp rank + all_dp_req_num = [0 for _ in range(self.dp_size_in_node)] if current_batch is not None: all_dp_req_num = current_batch.get_all_dp_req_num() - total_load_per_dp = [ - all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node) - ] - else: - total_load_per_dp = [len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)] + total_load_per_dp = [ + all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node) + ] for req_group in reqs_waiting_for_dp_index: # calculate the length of this request group - if isinstance(req_group, list): - req_length = len(req_group) - else: - req_length = 1 + req_length = len(req_group) # find the dp rank with minimum load min_load = min(total_load_per_dp) select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load] - - # select the dp rank with the minimum load - if len(select_dp_indexes) == 1: - suggested_dp_index = select_dp_indexes[0] - else: - # if multiple dp ranks have the same minimum load, randomly select one - import random - - suggested_dp_index = random.choice(select_dp_indexes) + suggested_dp_index = random.choice(select_dp_indexes) # assign the request to the dp rank and update the load count - if not isinstance(req_group, list): - req_group = [req_group] - for req in req_group: req.sample_params.suggested_dp_index = suggested_dp_index self.inner_queues[suggested_dp_index].append(req) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index f3b50625e..35476b68e 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -56,7 +56,7 @@ def append(self, req: Req): suggested_dp_index = req.sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: # 在调度时,统一分配请求id - self.reqs_waiting_for_dp_index.append(req) + self.reqs_waiting_for_dp_index.append([req]) else: self.inner_queues[suggested_dp_index].append(req) return From d8869fc417dd1bb0c82099294a6d33a0afcb670a Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 16:58:26 +0800 Subject: [PATCH 15/20] fix --- lightllm/server/router/req_queue/dp_balancer/base.py | 2 +- lightllm/server/router/req_queue/dp_balancer/bs.py | 2 +- lightllm/server/router/req_queue/dp_balancer/roundrobin.py | 4 +--- lightllm/server/router/req_queue/dp_base_queue.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/req_queue/dp_balancer/base.py b/lightllm/server/router/req_queue/dp_balancer/base.py index 7f439947c..4c1ef96ea 100644 --- a/lightllm/server/router/req_queue/dp_balancer/base.py +++ b/lightllm/server/router/req_queue/dp_balancer/base.py @@ -19,5 +19,5 @@ def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): self.inner_queues = inner_queues @abstractmethod - def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None: pass diff --git a/lightllm/server/router/req_queue/dp_balancer/bs.py b/lightllm/server/router/req_queue/dp_balancer/bs.py index 2d9d3ed50..bb7a92845 100644 --- a/lightllm/server/router/req_queue/dp_balancer/bs.py +++ b/lightllm/server/router/req_queue/dp_balancer/bs.py @@ -18,7 +18,7 @@ class DpBsBalancer(DpBalancer): def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]): super().__init__(dp_size_in_node, inner_queues) - def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None: if len(reqs_waiting_for_dp_index) == 0: return # calculate the total load of each dp rank diff --git a/lightllm/server/router/req_queue/dp_balancer/roundrobin.py b/lightllm/server/router/req_queue/dp_balancer/roundrobin.py index c2bda3461..79e493254 100644 --- a/lightllm/server/router/req_queue/dp_balancer/roundrobin.py +++ b/lightllm/server/router/req_queue/dp_balancer/roundrobin.py @@ -41,11 +41,9 @@ def get_suggest_dp_index( self.pre_select_dp_index = random.choice(select_dp_indexes) return self.pre_select_dp_index - def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None: + def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[List[Req]]) -> None: for req_group in reqs_waiting_for_dp_index: suggested_dp_index = self.get_suggest_dp_index() - if not isinstance(req_group, list): - req_group = [req_group] for req in req_group: req.sample_params.suggested_dp_index = suggested_dp_index self.inner_queues[suggested_dp_index].append(req) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 35476b68e..fd13e0553 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -20,7 +20,7 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node) ] self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues) - self.reqs_waiting_for_dp_index = [] + self.reqs_waiting_for_dp_index: List[List[Req]] = [] return def get_dp_queue(self, dp_index: int): From 65b07704f2793c1c1f424c3eefa0a4be71fa42e3 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 17:05:33 +0800 Subject: [PATCH 16/20] update --- lightllm/server/router/req_queue/dp_balancer/bs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lightllm/server/router/req_queue/dp_balancer/bs.py b/lightllm/server/router/req_queue/dp_balancer/bs.py index bb7a92845..d175978c9 100644 --- a/lightllm/server/router/req_queue/dp_balancer/bs.py +++ b/lightllm/server/router/req_queue/dp_balancer/bs.py @@ -29,9 +29,6 @@ def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: Lis all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node) ] for req_group in reqs_waiting_for_dp_index: - # calculate the length of this request group - req_length = len(req_group) - # find the dp rank with minimum load min_load = min(total_load_per_dp) select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load] @@ -41,9 +38,8 @@ def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: Lis for req in req_group: req.sample_params.suggested_dp_index = suggested_dp_index self.inner_queues[suggested_dp_index].append(req) - # update the load count for this dp rank - total_load_per_dp[suggested_dp_index] += req_length + total_load_per_dp[suggested_dp_index] += len(req_group) reqs_waiting_for_dp_index.clear() return From 63ac326f2e12eadf629cfbb324adc45322db8fe9 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 17:19:40 +0800 Subject: [PATCH 17/20] fix --- .../server/router/req_queue/base_queue.py | 5 ---- .../server/router/req_queue/dp_base_queue.py | 24 ++++--------------- 2 files changed, 5 insertions(+), 24 deletions(-) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index cf913e6bd..cf5435601 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -26,11 +26,6 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None: self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy self.router_max_new_token_len = args.router_max_new_token_len - def append(self, req: Req): - req.sample_params.suggested_dp_index = self.dp_index - self.waiting_req_list.append(req) - return - def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index fd13e0553..e39ea1209 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -32,16 +32,11 @@ def get_wait_req_num(self): # @calculate_time(show=True, min_cost_ms=10) def generate_new_batch(self, current_batch: Batch): - try: - self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index) - batches = [ - self.inner_queues[dp_index].generate_new_batch(current_batch) - for dp_index in range(self.dp_size_in_node) - ] - return self._merge_batch(batches) - except Exception as e: - logger.error(f"generate new batch failed: {e}") - raise e + self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index) + batches = [ + self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node) + ] + return self._merge_batch(batches) def _merge_batch(self, dp_batches: List[Batch]): merged_batch: Batch = None @@ -52,15 +47,6 @@ def _merge_batch(self, dp_batches: List[Batch]): merged_batch = iter_batch return merged_batch - def append(self, req: Req): - suggested_dp_index = req.sample_params.suggested_dp_index - if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: - # 在调度时,统一分配请求id - self.reqs_waiting_for_dp_index.append([req]) - else: - self.inner_queues[suggested_dp_index].append(req) - return - def extend(self, req_group: List[Req]): suggested_dp_index = req_group[0].sample_params.suggested_dp_index if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0: From 3f06f9bd3f6969876f04bd6fe2382df0e34f29cd Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 17:25:05 +0800 Subject: [PATCH 18/20] fix --- lightllm/server/router/req_queue/dp_balancer/bs.py | 2 +- lightllm/server/router/req_queue/dp_balancer/roundrobin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/router/req_queue/dp_balancer/bs.py b/lightllm/server/router/req_queue/dp_balancer/bs.py index d175978c9..a1fc2195b 100644 --- a/lightllm/server/router/req_queue/dp_balancer/bs.py +++ b/lightllm/server/router/req_queue/dp_balancer/bs.py @@ -37,7 +37,7 @@ def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: Lis # assign the request to the dp rank and update the load count for req in req_group: req.sample_params.suggested_dp_index = suggested_dp_index - self.inner_queues[suggested_dp_index].append(req) + self.inner_queues[suggested_dp_index].extend(req_group) # update the load count for this dp rank total_load_per_dp[suggested_dp_index] += len(req_group) diff --git a/lightllm/server/router/req_queue/dp_balancer/roundrobin.py b/lightllm/server/router/req_queue/dp_balancer/roundrobin.py index 79e493254..8f954fafd 100644 --- a/lightllm/server/router/req_queue/dp_balancer/roundrobin.py +++ b/lightllm/server/router/req_queue/dp_balancer/roundrobin.py @@ -46,6 +46,6 @@ def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: Lis suggested_dp_index = self.get_suggest_dp_index() for req in req_group: req.sample_params.suggested_dp_index = suggested_dp_index - self.inner_queues[suggested_dp_index].append(req) + self.inner_queues[suggested_dp_index].extend(req_group) reqs_waiting_for_dp_index.clear() return From 9f08153deafda4fd64c7746b2cc9d37e327b2200 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 17:33:39 +0800 Subject: [PATCH 19/20] fix --- lightllm/server/router/req_queue/dp_base_queue.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index e39ea1209..30669ba0b 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -53,8 +53,7 @@ def extend(self, req_group: List[Req]): # 同一个组的,要分配在同一个 dp 上 self.reqs_waiting_for_dp_index.append(req_group) else: - for req in req_group: - self.inner_queues[suggested_dp_index].append(req) + self.inner_queues[suggested_dp_index].extend(req_group) return def is_busy(self): From 58fb09114caee7b73b746856a93972904197725c Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 25 Aug 2025 17:59:21 +0800 Subject: [PATCH 20/20] fix --- lightllm/server/router/req_queue/dp_base_queue.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lightllm/server/router/req_queue/dp_base_queue.py b/lightllm/server/router/req_queue/dp_base_queue.py index 30669ba0b..a73823b8b 100644 --- a/lightllm/server/router/req_queue/dp_base_queue.py +++ b/lightllm/server/router/req_queue/dp_base_queue.py @@ -19,6 +19,10 @@ def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None: self.inner_queues: List[BaseQueue] = [ base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node) ] + # 在调度这放松,在推理时约束。 + # 避免prefill 模式下的情况下,推理完成了,调度没及时获取信息,导致调度bs 过小 + for queue in self.inner_queues: + queue.batch_max_tokens = int(args.batch_max_tokens * 2) self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues) self.reqs_waiting_for_dp_index: List[List[Req]] = [] return