Skip to content

Commit 2c9fd18

Browse files
committed
Merge branch 'development' of ssh://git.biggo.com:222/Funmula/dive-mcp-host into development
2 parents 9d2f2ad + cab43db commit 2c9fd18

File tree

8 files changed

+276
-16
lines changed

8 files changed

+276
-16
lines changed

dive_mcp_host/host/conf/llm.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LLMConfiguration(BaseModel):
8282

8383
base_url: str | None = Field(default=None, alias="baseURL")
8484
skip_tls_verify: bool | None = Field(default=None)
85-
temperature: float | None = Field(default=0)
85+
temperature: float | None = Field(default=None)
8686
top_p: float | None = Field(default=None)
8787

8888
model_config = pydantic_model_config
@@ -150,6 +150,28 @@ def dump_api_key(self, v: SecretStr | None) -> str | None:
150150
"""Serialize the api_key field to plain text."""
151151
return v.get_secret_value() if v else None
152152

153+
@model_validator(mode="after")
154+
def temperature_top_p(self) -> Self:
155+
"""Update default headers for large tokens."""
156+
if (
157+
"claude-opus-4-1" in self.model
158+
and self.configuration
159+
and self.configuration.temperature
160+
and self.configuration.top_p
161+
):
162+
self.configuration.top_p = None
163+
164+
if (
165+
"gpt-5" in self.model
166+
and self.configuration
167+
and (temperature := self.configuration.temperature)
168+
):
169+
if temperature > 0:
170+
self.configuration.temperature = 1
171+
else:
172+
self.configuration.temperature = None
173+
return self
174+
153175

154176
class LLMBedrockConfig(BaseLLMConfig):
155177
"""Configuration for Bedrock LLM models."""

dive_mcp_host/host/tools/log.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,27 @@ def __init__(self, name: str, log_dir: Path, rotation_files: int = 5) -> None:
276276
self._logger = getLogger(self._name)
277277
self._logger.setLevel(INFO)
278278
self._logger.propagate = False
279-
handler = TimedRotatingFileHandler(
280-
self._path,
281-
when="D",
282-
interval=1,
283-
backupCount=rotation_files,
284-
encoding="utf-8",
285-
)
286-
self._logger.addHandler(handler)
279+
self._error: Exception | None = None
280+
try:
281+
handler = TimedRotatingFileHandler(
282+
self._path,
283+
when="D",
284+
interval=1,
285+
backupCount=rotation_files,
286+
encoding="utf-8",
287+
)
288+
self._logger.addHandler(handler)
289+
except FileNotFoundError as e:
290+
logger.exception(
291+
"Create TimedRotatingFileHandler for %s failed, "
292+
"subsquent calls will be ignored",
293+
self._name,
294+
)
295+
self._error = e
287296

288297
async def __call__(self, log: LogMsg) -> None:
289-
self._logger.info(log.model_dump_json())
298+
if not self._error:
299+
self._logger.info(log.model_dump_json())
290300

291301

292302
class LogManager:

dive_mcp_host/httpd/routers/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async def list_tools( # noqa: PLR0912, C901
124124
return ToolsResult(success=True, message=None, tools=list(result.values()))
125125

126126

127-
@tools.get("/{server_name}/logs/stream")
127+
@tools.get("/logs/stream")
128128
async def stream_server_logs(
129129
server_name: str,
130130
stream_until: ClientState | None = None,

tests/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,45 @@ async def echo_tool_streamable_server(
136136
await proc.wait()
137137

138138

139+
@pytest_asyncio.fixture
140+
@asynccontextmanager
141+
async def echo_with_slash_tool_streamable_server(
142+
unused_tcp_port_factory: Callable[[], int],
143+
) -> AsyncGenerator[tuple[int, dict[str, ServerConfig]], None]:
144+
"""Start the echo tool SSE server."""
145+
port = unused_tcp_port_factory()
146+
proc = await asyncio.create_subprocess_exec(
147+
"python3",
148+
"-m",
149+
"dive_mcp_host.host.tools.echo",
150+
"--transport=streamable",
151+
"--host=localhost",
152+
f"--port={port}",
153+
)
154+
while True:
155+
try:
156+
_ = await httpx.AsyncClient().get(f"http://localhost:{port}/xxxx")
157+
break
158+
except httpx.HTTPStatusError:
159+
break
160+
except: # noqa: E722
161+
await asyncio.sleep(0.1)
162+
try:
163+
yield (
164+
port,
165+
{
166+
"echo/aaa/bbb/ccc": ServerConfig(
167+
name="echo/aaa/bbb/ccc",
168+
url=f"http://localhost:{port}/mcp",
169+
transport="streamable",
170+
)
171+
},
172+
)
173+
finally:
174+
proc.send_signal(signal.SIGKILL)
175+
await proc.wait()
176+
177+
139178
@pytest.fixture
140179
def log_config() -> LogConfig:
141180
"""Fixture for log Config."""

tests/httpd/routers/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def test_get_model(test_client):
285285
"configs": {
286286
"dive": {
287287
"configuration": {
288-
"temperature": 0.0,
288+
"temperature": None,
289289
"topP": None,
290290
},
291291
"model": "fake",
@@ -369,7 +369,7 @@ def test_post_model(test_client: tuple[TestClient, "DiveHostAPI"]):
369369
"modelProvider": "dive",
370370
"maxTokens": None,
371371
"configuration": {
372-
"temperature": 0.0,
372+
"temperature": None,
373373
"topP": None,
374374
},
375375
},
@@ -456,7 +456,7 @@ def test_post_model_embedding(test_client):
456456
"apiKey": None,
457457
"configuration": {
458458
"baseURL": None,
459-
"temperature": 0.0,
459+
"temperature": None,
460460
"topP": None,
461461
},
462462
"active": True,

tests/httpd/routers/test_tools.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ def test_tools_cache_after_update(test_client):
373373
def test_stream_logs_notfound(test_client: tuple[TestClient, DiveHostAPI]):
374374
"""Test stream_logs function with not found server."""
375375
client, _ = test_client
376-
response = client.get("/api/tools/missing_server/logs/stream")
376+
response = client.get(
377+
"/api/tools/logs/stream", params={"server_name": "missing_server"}
378+
)
377379
for line in response.iter_lines():
378380
content = line.removeprefix("data: ")
379381
if content in ("[DONE]", ""):
@@ -413,8 +415,63 @@ def update_tools():
413415
with ThreadPoolExecutor(1) as executor:
414416
executor.submit(update_tools)
415417
response = client.get(
416-
"/api/tools/missing_server/logs/stream",
418+
"/api/tools/logs/stream",
419+
params={
420+
"server_name": "missing_server",
421+
"stop_on_notfound": False,
422+
"max_retries": 5,
423+
"stream_until": "running",
424+
},
425+
)
426+
responses: list[LogMsg] = []
427+
for line in response.iter_lines():
428+
content = line.removeprefix("data: ")
429+
if content in ("[DONE]", ""):
430+
continue
431+
432+
data = LogMsg.model_validate_json(content)
433+
responses.append(data)
434+
435+
assert len(responses) >= 3
436+
assert responses[-3].event == LogEvent.STREAMING_ERROR
437+
438+
assert responses[-2].event == LogEvent.STDERR
439+
assert responses[-2].client_state == ClientState.INIT
440+
441+
assert responses[-1].event == LogEvent.STATUS_CHANGE
442+
assert responses[-1].client_state == ClientState.RUNNING
443+
444+
445+
def test_stream_logs_name_with_slash(test_client: tuple[TestClient, DiveHostAPI]):
446+
"""Test stream_logs before log buffer is registered."""
447+
client, _ = test_client
448+
449+
def update_tools():
450+
_ = client.post(
451+
"/api/config/mcpserver",
452+
json={
453+
"mcpServers": {
454+
"name/with/slash": {
455+
"transport": "stdio",
456+
"enabled": True,
457+
"command": "python",
458+
"args": [
459+
"-m",
460+
"dive_mcp_host.host.tools.echo",
461+
"--transport=stdio",
462+
],
463+
}
464+
}
465+
},
466+
)
467+
assert response.status_code == status.HTTP_200_OK
468+
469+
with ThreadPoolExecutor(1) as executor:
470+
executor.submit(update_tools)
471+
response = client.get(
472+
"/api/tools/logs/stream",
417473
params={
474+
"server_name": "name/with/slash",
418475
"stop_on_notfound": False,
419476
"max_retries": 5,
420477
"stream_until": "running",

tests/test_models.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,104 @@ def test_ollama_skip_tls_verify() -> None:
230230
)
231231

232232
assert model._client._client._transport._pool._ssl_context.verify_mode == CERT_NONE # type: ignore
233+
234+
235+
def test_opus_4_1_temperature_top_p() -> None:
236+
"""Test the temperature and top_p."""
237+
raw_config: dict = {
238+
"modelProvider": "anthropic",
239+
"model": "claude-opus-4-1",
240+
}
241+
242+
# test opus 4.1
243+
simple_1 = raw_config.copy()
244+
llm_config = LLMConfig.model_validate(simple_1)
245+
246+
# with temperature and top_p
247+
simple_2 = raw_config.copy()
248+
simple_2["configuration"] = {"temperature": 0.5, "top_p": 0.5}
249+
llm_config = LLMConfig.model_validate(simple_2)
250+
# should be subset of llm_config.to_load_model_kwargs()
251+
kwargs = llm_config.to_load_model_kwargs()
252+
assert "top_p" not in kwargs
253+
assert "temperature" in kwargs
254+
assert kwargs["temperature"] == 0.5
255+
256+
# test only temperature
257+
simple_3 = raw_config.copy()
258+
simple_3["configuration"] = {"temperature": 0.5}
259+
llm_config = LLMConfig.model_validate(simple_3)
260+
kwargs = llm_config.to_load_model_kwargs()
261+
assert "top_p" not in kwargs
262+
assert "temperature" in kwargs
263+
assert kwargs["temperature"] == 0.5
264+
265+
# test only top_p
266+
simple_4 = raw_config.copy()
267+
simple_4["configuration"] = {"top_p": 0.5}
268+
llm_config = LLMConfig.model_validate(simple_4)
269+
kwargs = llm_config.to_load_model_kwargs()
270+
assert "top_p" in kwargs
271+
assert "temperature" not in kwargs
272+
assert kwargs["top_p"] == 0.5
273+
274+
# oap provider
275+
simple_5 = simple_2.copy()
276+
simple_5["modelProvider"] = "oap"
277+
llm_config = LLMConfig.model_validate(simple_5)
278+
kwargs = llm_config.to_load_model_kwargs()
279+
assert "top_p" not in kwargs
280+
assert "temperature" in kwargs
281+
assert kwargs["temperature"] == 0.5
282+
283+
# test general llm config
284+
simple_6 = simple_2.copy()
285+
simple_6["model"] = "gpt-4o"
286+
llm_config = LLMConfig.model_validate(simple_6)
287+
kwargs = llm_config.to_load_model_kwargs()
288+
assert "top_p" in kwargs
289+
assert "temperature" in kwargs
290+
assert kwargs["temperature"] == 0.5
291+
assert kwargs["top_p"] == 0.5
292+
293+
294+
def test_gpt_5_temperature() -> None:
295+
"""Test the GPT-5 temperature."""
296+
raw_config: dict = {
297+
"modelProvider": "openai",
298+
"model": "gpt-5",
299+
}
300+
301+
simple_1 = raw_config.copy()
302+
llm_config = LLMConfig.model_validate(simple_1)
303+
kwargs = llm_config.to_load_model_kwargs()
304+
assert "temperature" not in kwargs
305+
306+
simple_2 = raw_config.copy()
307+
simple_2["configuration"] = {"temperature": 0.5}
308+
llm_config = LLMConfig.model_validate(simple_2)
309+
kwargs = llm_config.to_load_model_kwargs()
310+
assert "temperature" in kwargs
311+
assert kwargs["temperature"] == 1
312+
313+
simple_3 = raw_config.copy()
314+
simple_3["configuration"] = {"temperature": 0}
315+
llm_config = LLMConfig.model_validate(simple_3)
316+
kwargs = llm_config.to_load_model_kwargs()
317+
assert "temperature" not in kwargs
318+
319+
# test oap provider
320+
simple_4 = simple_2.copy()
321+
simple_4["modelProvider"] = "oap"
322+
llm_config = LLMConfig.model_validate(simple_4)
323+
kwargs = llm_config.to_load_model_kwargs()
324+
assert "temperature" in kwargs
325+
assert kwargs["temperature"] == 1
326+
327+
# test general llm config
328+
simple_5 = simple_2.copy()
329+
simple_5["model"] = "gpt-4o"
330+
llm_config = LLMConfig.model_validate(simple_5)
331+
kwargs = llm_config.to_load_model_kwargs()
332+
assert "temperature" in kwargs
333+
assert kwargs["temperature"] == 0.5

tests/test_tools.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,37 @@ async def test_tool_manager_streamable(
130130
assert json.loads(str(result.content)) == []
131131

132132

133+
@pytest.mark.asyncio
134+
async def test_tool_manager_tool_name_with_slash(
135+
echo_with_slash_tool_streamable_server: AbstractAsyncContextManager[
136+
tuple[int, dict[str, ServerConfig]]
137+
],
138+
log_config: LogConfig,
139+
) -> None:
140+
"""Test the tool manager."""
141+
async with (
142+
echo_with_slash_tool_streamable_server as (port, configs),
143+
ToolManager(configs, log_config) as tool_manager,
144+
):
145+
await tool_manager.initialized_event.wait()
146+
tools = tool_manager.langchain_tools()
147+
assert sorted([i.name for i in tools]) == ["echo", "ignore"]
148+
for tool in tools:
149+
result = await tool.ainvoke(
150+
ToolCall(
151+
name=tool.name,
152+
id="123",
153+
args={"message": "Hello, world!"},
154+
type="tool_call",
155+
),
156+
)
157+
assert isinstance(result, ToolMessage)
158+
if tool.name == "echo":
159+
assert json.loads(str(result.content))[0]["text"] == "Hello, world!"
160+
else:
161+
assert json.loads(str(result.content)) == []
162+
163+
133164
@pytest.mark.asyncio
134165
async def test_tool_manager_reload(
135166
echo_tool_stdio_config: dict[str, ServerConfig],

0 commit comments

Comments
 (0)