Skip to content

Commit 8d515cb

Browse files
committed
feat: Add max_rpm configuration and rate limiting to LLMClient for improved request management
1 parent ae23432 commit 8d515cb

File tree

5 files changed

+127
-2
lines changed

5 files changed

+127
-2
lines changed

sources/gc-qa-rag-etl/.config.development.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"llm": {
88
"api_key": "",
99
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
10-
"model_name": "qwen-plus"
10+
"model_name": "qwen-plus",
11+
"max_rpm": 100
1112
},
1213
"embedding": {
1314
"api_key": ""

sources/gc-qa-rag-etl/.config.production.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"llm": {
88
"api_key": "",
99
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
10-
"model_name": "qwen-plus"
10+
"model_name": "qwen-plus",
11+
"max_rpm": 100
1112
},
1213
"embedding": {
1314
"api_key": ""

sources/gc-qa-rag-etl/etlapp/common/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class LlmConfig:
1919
api_key: str
2020
api_base: str
2121
model_name: str
22+
max_rpm: int = 100 # 每分钟最大请求数,默认100
2223

2324

2425
@dataclass
@@ -65,6 +66,7 @@ def from_environment(cls, environment: str) -> "Config":
6566
api_key=config_raw["llm"]["api_key"],
6667
api_base=config_raw["llm"]["api_base"],
6768
model_name=config_raw["llm"]["model_name"],
69+
max_rpm=config_raw["llm"].get("max_rpm", 60), # 默认60 RPM
6870
),
6971
embedding=EmbeddingConfig(api_key=config_raw["embedding"]["api_key"]),
7072
vector_db=VectorDbConfig(host=config_raw["vector_db"]["host"]),

sources/gc-qa-rag-etl/etlapp/common/llm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from openai import OpenAI
22
from etlapp.common.config import app_config
3+
from etlapp.common.rate_limiter import RateLimiter
34
from typing import List, Dict
45

56

@@ -9,6 +10,7 @@ def __init__(
910
api_key: str = app_config.llm.api_key,
1011
api_base: str = app_config.llm.api_base,
1112
model_name: str = app_config.llm.model_name,
13+
max_rpm: int = app_config.llm.max_rpm,
1214
system_prompt: str = "你是一个乐于解答各种问题的助手。",
1315
temperature: float = 0.7,
1416
top_p: float = 0.7,
@@ -18,8 +20,13 @@ def __init__(
1820
self.system_prompt = system_prompt
1921
self.temperature = temperature
2022
self.top_p = top_p
23+
# 初始化限流器
24+
self.rate_limiter = RateLimiter(max_requests=max_rpm, window_seconds=60)
2125

2226
def _create_completion(self, messages: List[Dict[str, str]]) -> str:
27+
# 在发送请求前进行限流
28+
self.rate_limiter.wait_and_acquire()
29+
2330
completion = self.client.chat.completions.create(
2431
model=self.model_name,
2532
messages=messages,
@@ -37,6 +44,23 @@ def chat(self, content: str) -> str:
3744

3845
def chat_with_messages(self, messages: List[Dict[str, str]]) -> str:
3946
return self._create_completion(messages)
47+
48+
def get_rate_limit_status(self) -> dict:
49+
"""
50+
获取当前限流状态
51+
52+
Returns:
53+
dict: 包含剩余请求数和重置时间的状态信息
54+
"""
55+
remaining = self.rate_limiter.get_remaining_requests()
56+
reset_time = self.rate_limiter.get_reset_time()
57+
58+
return {
59+
"remaining_requests": remaining,
60+
"reset_time": reset_time,
61+
"max_rpm": self.rate_limiter.max_requests,
62+
"window_seconds": self.rate_limiter.window_seconds
63+
}
4064

4165

4266
# Create a default instance
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import threading
2+
import time
3+
from collections import deque
4+
from typing import Optional
5+
6+
7+
class RateLimiter:
8+
"""
9+
线程安全的速率限制器,支持RPM(每分钟请求数)限制
10+
"""
11+
12+
def __init__(self, max_requests: int, window_seconds: int = 60):
13+
"""
14+
初始化速率限制器
15+
16+
Args:
17+
max_requests: 在指定时间窗口内允许的最大请求数
18+
window_seconds: 时间窗口大小(秒),默认60秒(1分钟)
19+
"""
20+
self.max_requests = max_requests
21+
self.window_seconds = window_seconds
22+
self.requests = deque()
23+
self._lock = threading.Lock()
24+
25+
def acquire(self, timeout: Optional[float] = None) -> bool:
26+
"""
27+
尝试获取请求许可
28+
29+
Args:
30+
timeout: 超时时间(秒),None表示无限等待
31+
32+
Returns:
33+
bool: 是否成功获取许可
34+
"""
35+
start_time = time.time()
36+
37+
while True:
38+
with self._lock:
39+
current_time = time.time()
40+
41+
# 清理过期的请求记录
42+
while self.requests and current_time - self.requests[0] > self.window_seconds:
43+
self.requests.popleft()
44+
45+
# 检查是否可以发送请求
46+
if len(self.requests) < self.max_requests:
47+
self.requests.append(current_time)
48+
return True
49+
50+
# 检查超时
51+
if timeout is not None and time.time() - start_time >= timeout:
52+
return False
53+
54+
# 等待一小段时间再重试
55+
time.sleep(0.1)
56+
57+
def wait_and_acquire(self) -> None:
58+
"""
59+
等待直到可以获取请求许可(阻塞式)
60+
"""
61+
self.acquire(timeout=None)
62+
63+
def get_remaining_requests(self) -> int:
64+
"""
65+
获取当前时间窗口内剩余的请求数
66+
67+
Returns:
68+
int: 剩余请求数
69+
"""
70+
with self._lock:
71+
current_time = time.time()
72+
73+
# 清理过期的请求记录
74+
while self.requests and current_time - self.requests[0] > self.window_seconds:
75+
self.requests.popleft()
76+
77+
return max(0, self.max_requests - len(self.requests))
78+
79+
def get_reset_time(self) -> Optional[float]:
80+
"""
81+
获取下次可以发送请求的时间戳
82+
83+
Returns:
84+
Optional[float]: 下次可发送请求的时间戳,None表示立即可发送
85+
"""
86+
with self._lock:
87+
current_time = time.time()
88+
89+
# 清理过期的请求记录
90+
while self.requests and current_time - self.requests[0] > self.window_seconds:
91+
self.requests.popleft()
92+
93+
if len(self.requests) < self.max_requests:
94+
return None
95+
96+
# 返回最早请求过期的时间
97+
return self.requests[0] + self.window_seconds

0 commit comments

Comments
 (0)