Skip to content

Commit e007244

Browse files
committed
Merge pull request 'Fix: could not handle mcp name that includes '/'' (#457) from mcp-name-with-slash into development
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/457
2 parents e18a23b + 6fed160 commit e007244

File tree

5 files changed

+149
-12
lines changed

5 files changed

+149
-12
lines changed

dive_mcp_host/host/tools/log.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,27 @@ def __init__(self, name: str, log_dir: Path, rotation_files: int = 5) -> None:
276276
self._logger = getLogger(self._name)
277277
self._logger.setLevel(INFO)
278278
self._logger.propagate = False
279-
handler = TimedRotatingFileHandler(
280-
self._path,
281-
when="D",
282-
interval=1,
283-
backupCount=rotation_files,
284-
encoding="utf-8",
285-
)
286-
self._logger.addHandler(handler)
279+
self._error: Exception | None = None
280+
try:
281+
handler = TimedRotatingFileHandler(
282+
self._path,
283+
when="D",
284+
interval=1,
285+
backupCount=rotation_files,
286+
encoding="utf-8",
287+
)
288+
self._logger.addHandler(handler)
289+
except FileNotFoundError as e:
290+
logger.exception(
291+
"Create TimedRotatingFileHandler for %s failed, "
292+
"subsquent calls will be ignored",
293+
self._name,
294+
)
295+
self._error = e
287296

288297
async def __call__(self, log: LogMsg) -> None:
289-
self._logger.info(log.model_dump_json())
298+
if not self._error:
299+
self._logger.info(log.model_dump_json())
290300

291301

292302
class LogManager:

dive_mcp_host/httpd/routers/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async def list_tools( # noqa: PLR0912, C901
124124
return ToolsResult(success=True, message=None, tools=list(result.values()))
125125

126126

127-
@tools.get("/{server_name}/logs/stream")
127+
@tools.get("/logs/stream")
128128
async def stream_server_logs(
129129
server_name: str,
130130
stream_until: ClientState | None = None,

tests/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,45 @@ async def echo_tool_streamable_server(
136136
await proc.wait()
137137

138138

139+
@pytest_asyncio.fixture
140+
@asynccontextmanager
141+
async def echo_with_slash_tool_streamable_server(
142+
unused_tcp_port_factory: Callable[[], int],
143+
) -> AsyncGenerator[tuple[int, dict[str, ServerConfig]], None]:
144+
"""Start the echo tool SSE server."""
145+
port = unused_tcp_port_factory()
146+
proc = await asyncio.create_subprocess_exec(
147+
"python3",
148+
"-m",
149+
"dive_mcp_host.host.tools.echo",
150+
"--transport=streamable",
151+
"--host=localhost",
152+
f"--port={port}",
153+
)
154+
while True:
155+
try:
156+
_ = await httpx.AsyncClient().get(f"http://localhost:{port}/xxxx")
157+
break
158+
except httpx.HTTPStatusError:
159+
break
160+
except: # noqa: E722
161+
await asyncio.sleep(0.1)
162+
try:
163+
yield (
164+
port,
165+
{
166+
"echo/aaa/bbb/ccc": ServerConfig(
167+
name="echo/aaa/bbb/ccc",
168+
url=f"http://localhost:{port}/mcp",
169+
transport="streamable",
170+
)
171+
},
172+
)
173+
finally:
174+
proc.send_signal(signal.SIGKILL)
175+
await proc.wait()
176+
177+
139178
@pytest.fixture
140179
def log_config() -> LogConfig:
141180
"""Fixture for log Config."""

tests/httpd/routers/test_tools.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ def test_tools_cache_after_update(test_client):
373373
def test_stream_logs_notfound(test_client: tuple[TestClient, DiveHostAPI]):
374374
"""Test stream_logs function with not found server."""
375375
client, _ = test_client
376-
response = client.get("/api/tools/missing_server/logs/stream")
376+
response = client.get(
377+
"/api/tools/logs/stream", params={"server_name": "missing_server"}
378+
)
377379
for line in response.iter_lines():
378380
content = line.removeprefix("data: ")
379381
if content in ("[DONE]", ""):
@@ -413,8 +415,63 @@ def update_tools():
413415
with ThreadPoolExecutor(1) as executor:
414416
executor.submit(update_tools)
415417
response = client.get(
416-
"/api/tools/missing_server/logs/stream",
418+
"/api/tools/logs/stream",
419+
params={
420+
"server_name": "missing_server",
421+
"stop_on_notfound": False,
422+
"max_retries": 5,
423+
"stream_until": "running",
424+
},
425+
)
426+
responses: list[LogMsg] = []
427+
for line in response.iter_lines():
428+
content = line.removeprefix("data: ")
429+
if content in ("[DONE]", ""):
430+
continue
431+
432+
data = LogMsg.model_validate_json(content)
433+
responses.append(data)
434+
435+
assert len(responses) >= 3
436+
assert responses[-3].event == LogEvent.STREAMING_ERROR
437+
438+
assert responses[-2].event == LogEvent.STDERR
439+
assert responses[-2].client_state == ClientState.INIT
440+
441+
assert responses[-1].event == LogEvent.STATUS_CHANGE
442+
assert responses[-1].client_state == ClientState.RUNNING
443+
444+
445+
def test_stream_logs_name_with_slash(test_client: tuple[TestClient, DiveHostAPI]):
446+
"""Test stream_logs before log buffer is registered."""
447+
client, _ = test_client
448+
449+
def update_tools():
450+
_ = client.post(
451+
"/api/config/mcpserver",
452+
json={
453+
"mcpServers": {
454+
"name/with/slash": {
455+
"transport": "stdio",
456+
"enabled": True,
457+
"command": "python",
458+
"args": [
459+
"-m",
460+
"dive_mcp_host.host.tools.echo",
461+
"--transport=stdio",
462+
],
463+
}
464+
}
465+
},
466+
)
467+
assert response.status_code == status.HTTP_200_OK
468+
469+
with ThreadPoolExecutor(1) as executor:
470+
executor.submit(update_tools)
471+
response = client.get(
472+
"/api/tools/logs/stream",
417473
params={
474+
"server_name": "name/with/slash",
418475
"stop_on_notfound": False,
419476
"max_retries": 5,
420477
"stream_until": "running",

tests/test_tools.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,37 @@ async def test_tool_manager_streamable(
130130
assert json.loads(str(result.content)) == []
131131

132132

133+
@pytest.mark.asyncio
134+
async def test_tool_manager_tool_name_with_slash(
135+
echo_with_slash_tool_streamable_server: AbstractAsyncContextManager[
136+
tuple[int, dict[str, ServerConfig]]
137+
],
138+
log_config: LogConfig,
139+
) -> None:
140+
"""Test the tool manager."""
141+
async with (
142+
echo_with_slash_tool_streamable_server as (port, configs),
143+
ToolManager(configs, log_config) as tool_manager,
144+
):
145+
await tool_manager.initialized_event.wait()
146+
tools = tool_manager.langchain_tools()
147+
assert sorted([i.name for i in tools]) == ["echo", "ignore"]
148+
for tool in tools:
149+
result = await tool.ainvoke(
150+
ToolCall(
151+
name=tool.name,
152+
id="123",
153+
args={"message": "Hello, world!"},
154+
type="tool_call",
155+
),
156+
)
157+
assert isinstance(result, ToolMessage)
158+
if tool.name == "echo":
159+
assert json.loads(str(result.content))[0]["text"] == "Hello, world!"
160+
else:
161+
assert json.loads(str(result.content)) == []
162+
163+
133164
@pytest.mark.asyncio
134165
async def test_tool_manager_reload(
135166
echo_tool_stdio_config: dict[str, ServerConfig],

0 commit comments

Comments
 (0)