Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="tool call parser type",
)
parser.add_argument(
"--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
"--running_max_req_size", type=int, default=2048, help="the max size for forward requests in the same time"
)
parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes")
parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node")
Expand All @@ -137,6 +137,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
using the deepseekv2 model, set dp to be equal to the tp parameter. In other cases, please
do not set it and keep the default value as 1.""",
)
parser.add_argument(
"--dp_balancer",
type=str,
default="round_robin",
help="the dp balancer type, default is round_robin",
)
parser.add_argument(
"--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len"
)
Expand Down
9 changes: 9 additions & 0 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def get_req_list_for_dp(self, dp_index: int):
req_list.append(req)
return req_list

def get_all_dp_req_num(self) -> List[int]:
if self.dp_size_in_node == 1:
return [len(self.reqs)]

all_dp_req_num = [0 for _ in range(self.dp_size_in_node)]
for req in self.reqs:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num

def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
unfinished_req_ids = []
for req in self.reqs:
Expand Down
13 changes: 13 additions & 0 deletions lightllm/server/router/req_queue/dp_balancer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .dp_base_balancer import RoundRobinDpBalancer
from typing import List
from lightllm.server.router.req_queue.base_queue import BaseQueue
from .dp_balancer_for_pd import DpBalancerForPd


def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]):
if args.dp_balancer == "round_robin":
return DpBalancerForPd(dp_size_in_node, inner_queues)
if args.run_mode in ["prefill", "decode"]:
return DpBalancerForPd(dp_size_in_node, inner_queues)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_dp_balancer function returns DpBalancerForPd when args.dp_balancer is "round_robin". This seems incorrect as it should return RoundRobinDpBalancer in this case. This could lead to unexpected behavior. Consider swapping the return values for the round_robin case to ensure the correct balancer is used.

Suggested change
if args.dp_balancer == "round_robin":
return DpBalancerForPd(dp_size_in_node, inner_queues)
if args.run_mode in ["prefill", "decode"]:
return DpBalancerForPd(dp_size_in_node, inner_queues)
def get_dp_balancer(args, dp_size_in_node: int, inner_queues: List[BaseQueue]):
if args.dp_balancer == "round_robin":
return RoundRobinDpBalancer(dp_size_in_node, inner_queues)
if args.run_mode in ["prefill", "decode"]:
return DpBalancerForPd(dp_size_in_node, inner_queues)

else:
raise ValueError(f"Invalid dp balancer: {args.dp_balancer}")
63 changes: 63 additions & 0 deletions lightllm/server/router/req_queue/dp_balancer/dp_balancer_for_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Union
from lightllm.server.router.req_queue.base_queue import BaseQueue
from lightllm.server.router.batch import Batch, Req
from lightllm.utils.log_utils import init_logger
from .dp_base_balancer import DpBalancer

logger = init_logger(__name__)


class DpBalancerForPd(DpBalancer):
"""
This balancer is main to balance the batch size of each dp rank.
Because, for dp mode, if it exists a dp rank without any request, it will
padding a request and cause the waste of GPU compute resource.
"""

def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
super().__init__(dp_size_in_node, inner_queues)

def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
if len(reqs_waiting_for_dp_index) == 0:
return
# calculate the total load of each dp rank
if current_batch is not None:
all_dp_req_num = current_batch.get_all_dp_req_num()
total_load_per_dp = [
all_dp_req_num[i] + len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)
]
else:
total_load_per_dp = [len(self.inner_queues[i].waiting_req_list) for i in range(self.dp_size_in_node)]
for req_group in reqs_waiting_for_dp_index:
# calculate the length of this request group
if isinstance(req_group, list):
req_length = len(req_group)
else:
req_length = 1

# find the dp rank with minimum load
min_load = min(total_load_per_dp)
select_dp_indexes = [i for i in range(self.dp_size_in_node) if total_load_per_dp[i] == min_load]

# select the dp rank with the minimum load
if len(select_dp_indexes) == 1:
suggested_dp_index = select_dp_indexes[0]
else:
# if multiple dp ranks have the same minimum load, randomly select one
import random

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import random statement is located inside the assign_reqs_to_dp method. It's better to move this import statement to the top of the file for better code organization and readability.


suggested_dp_index = random.choice(select_dp_indexes)

# assign the request to the dp rank and update the load count
if not isinstance(req_group, list):
req_group = [req_group]

for req in req_group:
req.sample_params.suggested_dp_index = suggested_dp_index
self.inner_queues[suggested_dp_index].append(req)

# update the load count for this dp rank
total_load_per_dp[suggested_dp_index] += req_length

reqs_waiting_for_dp_index.clear()
return
65 changes: 65 additions & 0 deletions lightllm/server/router/req_queue/dp_balancer/dp_base_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import random
from abc import ABC, abstractmethod
from typing import List, Union
from lightllm.server.router.req_queue.base_queue import BaseQueue
from lightllm.server.router.batch import Batch, Req
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class DpBalancer(ABC):
"""
DP负载均衡器基类
定义了负载均衡策略的接口,子类可以实现不同的负载均衡算法
"""

def __init__(self, dp_size_in_node: int, inner_queues: List[BaseQueue]):
self.dp_size_in_node = dp_size_in_node
self.inner_queues = inner_queues
self.pre_select_dp_index = self.dp_size_in_node - 1

@abstractmethod
def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
pass


class RoundRobinDpBalancer(DpBalancer):
"""
轮询负载均衡器
在队列长度最小的DP中进行轮询选择
"""

def get_suggest_dp_index(
self,
) -> int:
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
select_dp_indexes = [
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
]

# 如果没有可选择的索引,随机选择一个
if not select_dp_indexes:
self.pre_select_dp_index = random.randint(0, self.dp_size_in_node - 1)
return self.pre_select_dp_index

# 轮询选择
for i in range(self.dp_size_in_node):
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
if next_dp_index in select_dp_indexes:
self.pre_select_dp_index = next_dp_index
return self.pre_select_dp_index

self.pre_select_dp_index = random.choice(select_dp_indexes)
return self.pre_select_dp_index

def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: List[Union[Req, List[Req]]]) -> None:
for req_group in reqs_waiting_for_dp_index:
suggested_dp_index = self.get_suggest_dp_index()
if not isinstance(req_group, list):
req_group = [req_group]
for req in req_group:
req.sample_params.suggested_dp_index = suggested_dp_index
self.inner_queues[suggested_dp_index].append(req)
reqs_waiting_for_dp_index.clear()
return
61 changes: 21 additions & 40 deletions lightllm/server/router/req_queue/dp_base_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from ..batch import Batch, Req
from lightllm.server.router.req_queue.base_queue import BaseQueue
from lightllm.server.router.req_queue.dp_balancer import get_dp_balancer
from lightllm.common.basemodel.infer_lock import g_router_lock
from lightllm.utils.log_utils import init_logger

Expand All @@ -12,14 +13,14 @@ class DpQueue:
def __init__(self, args, router, base_queue_class, dp_size_in_node) -> None:
self.dp_size_in_node = dp_size_in_node
self.base_queue_class = base_queue_class
self.pre_select_dp_index = self.dp_size_in_node - 1
from lightllm.server.router.manager import RouterManager

self.router: RouterManager = router
self.inner_queues: List[BaseQueue] = [
base_queue_class(args, router, dp_index, dp_size_in_node) for dp_index in range(self.dp_size_in_node)
]

self.dp_balancer = get_dp_balancer(args, dp_size_in_node, self.inner_queues)
self.reqs_waiting_for_dp_index = []
return

def get_dp_queue(self, dp_index: int):
Expand All @@ -31,10 +32,16 @@ def get_wait_req_num(self):

# @calculate_time(show=True, min_cost_ms=10)
def generate_new_batch(self, current_batch: Batch):
batches = [
self.inner_queues[dp_index].generate_new_batch(current_batch) for dp_index in range(self.dp_size_in_node)
]
return self._merge_batch(batches)
try:
self.dp_balancer.assign_reqs_to_dp(current_batch, self.reqs_waiting_for_dp_index)
batches = [
self.inner_queues[dp_index].generate_new_batch(current_batch)
for dp_index in range(self.dp_size_in_node)
]
return self._merge_batch(batches)
except Exception as e:
logger.error(f"generate new batch failed: {e}")
raise e

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When an exception occurs during batch generation, using raise e can obscure the original stack trace. Using a bare raise will preserve the original traceback, making debugging easier.

Suggested change
logger.error(f"generate new batch failed: {e}")
raise e
except Exception as e:
logger.error(f"generate new batch failed: {e}")
raise


def _merge_batch(self, dp_batches: List[Batch]):
merged_batch: Batch = None
Expand All @@ -48,28 +55,20 @@ def _merge_batch(self, dp_batches: List[Batch]):
def append(self, req: Req):
suggested_dp_index = req.sample_params.suggested_dp_index
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
suggested_dp_index = self._get_suggest_dp_index()
self.pre_select_dp_index = suggested_dp_index
req.sample_params.suggested_dp_index = suggested_dp_index
self.inner_queues[suggested_dp_index].append(req)
# 在调度时,统一分配请求id
self.reqs_waiting_for_dp_index.append(req)
else:
self.inner_queues[suggested_dp_index].append(req)
return

def extend(self, req_group: List[Req]):
# 同一个组的,要分配在同一个 dp 上,效率最高
index = self._get_suggest_dp_index()
for req in req_group:
suggested_dp_index = req.sample_params.suggested_dp_index
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
logger.warning(f"input req {req.request_id} dp index {suggested_dp_index} is invalid")
self.pre_select_dp_index = index
req.sample_params.suggested_dp_index = index
self.inner_queues[index].append(req)
else:
suggested_dp_index = req_group[0].sample_params.suggested_dp_index
if suggested_dp_index >= self.dp_size_in_node or suggested_dp_index < 0:
# 同一个组的,要分配在同一个 dp 上
self.reqs_waiting_for_dp_index.append(req_group)
else:
for req in req_group:
self.inner_queues[suggested_dp_index].append(req)

return

def is_busy(self):
Expand All @@ -87,21 +86,3 @@ def update_token_load(self, current_batch: Batch, force_update=False):
self.router.shared_token_load.set_estimated_peak_token_count(estimated_peak_token_count, dp_index)
self.router.shared_token_load.set_dynamic_max_load(dynamic_max_load, dp_index)
return

def _get_suggest_dp_index(self):
min_length = min(len(queue.waiting_req_list) for queue in self.inner_queues)
select_dp_indexes = [
i for i, queue in enumerate(self.inner_queues) if len(queue.waiting_req_list) == min_length
]

# multi thread safe keep
if not select_dp_indexes:
return random.randint(0, self.dp_size_in_node - 1)

# round_robin select.
for i in range(self.dp_size_in_node):
next_dp_index = (self.pre_select_dp_index + i + 1) % self.dp_size_in_node
if next_dp_index in select_dp_indexes:
return next_dp_index

return random.choice(select_dp_indexes)