Skip to content

Commit 394c391

Browse files
committed
Merge branch 'development' of ssh://git.biggo.com:222/Funmula/dive-mcp-host into development
2 parents 89ac5d5 + 4b61320 commit 394c391

File tree

11 files changed

+1322
-867
lines changed

11 files changed

+1322
-867
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ cache/
4545
upload/
4646
cli_config.json
4747
logs/
48+
oap_config.json

dive_mcp_host/host/conf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ServerConfig(BaseModel):
5858
ProxyUrl | None,
5959
BeforeValidator(_rewrite_socks),
6060
] = None
61+
initial_timeout: float = 10
6162

6263
@field_serializer("headers", when_used="json")
6364
def dump_headers(self, v: dict[str, SecretStr] | None) -> dict[str, str] | None:

dive_mcp_host/host/tools/hack/stdio_server.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@
44
import subprocess
55
import sys
66
from collections.abc import AsyncGenerator
7-
from contextlib import asynccontextmanager
7+
from contextlib import asynccontextmanager, suppress
88
from pathlib import Path
99

1010
import anyio
1111
import anyio.abc
1212
import anyio.lowlevel
13+
from anyio.abc import Process
1314
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1415
from anyio.streams.text import TextReceiveStream
1516
from mcp import types
1617
from mcp.client.stdio import StdioServerParameters, get_default_environment
17-
from mcp.client.stdio.win32 import (
18+
from mcp.os.posix.utilities import terminate_posix_process_tree
19+
from mcp.os.win32.utilities import (
20+
FallbackProcess,
1821
get_windows_executable_command,
19-
terminate_windows_process,
22+
terminate_windows_process_tree,
2023
)
2124
from mcp.shared.message import SessionMessage
2225

@@ -43,6 +46,8 @@
4346
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
4447
)
4548

49+
PROCESS_TERMINATION_TIMEOUT = 2.0
50+
4651

4752
@asynccontextmanager
4853
async def stdio_client( # noqa: C901, PLR0915
@@ -156,12 +161,29 @@ async def stdin_writer() -> None:
156161
logger.error("Error, closing process %s: %s", process.pid, exc)
157162
raise
158163
finally:
159-
# Clean up process to prevent any dangling orphaned processes
160-
logger.info("Terminated process %s", process.pid)
161-
# Some process never terminates, so we need to kill it.
162-
await terminate_windows_process(process)
163-
status = await process.wait()
164-
logger.info("Process %s exited with status %s", process.pid, status)
164+
# MCP spec: stdio shutdown sequence
165+
# 1. Close input stream to server
166+
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
167+
# 3. Send SIGKILL if still not exited
168+
if process.stdin:
169+
with suppress(Exception):
170+
await process.stdin.aclose()
171+
try:
172+
# Give the process time to exit gracefully after stdin closes
173+
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
174+
await process.wait()
175+
except TimeoutError:
176+
# Process didn't exit from stdin closure, use platform-specific
177+
# termination
178+
# which handles SIGTERM -> SIGKILL escalation
179+
await _terminate_process_tree(process)
180+
except ProcessLookupError:
181+
# Process already exited, which is fine
182+
pass
183+
await read_stream.aclose()
184+
await write_stream.aclose()
185+
await read_stream_writer.aclose()
186+
await write_stream_reader.aclose()
165187
logger.error("Process %s closed", "xx")
166188

167189

@@ -189,3 +211,23 @@ async def _create_platform_compatible_process(
189211
logger.info("launched process: %s, pid: %s", command, process.pid)
190212

191213
return process
214+
215+
216+
async def _terminate_process_tree(
217+
process: Process | FallbackProcess, timeout_seconds: float = 2.0
218+
) -> None:
219+
"""Terminate a process and all its children using platform-specific methods.
220+
221+
Unix: Uses os.killpg() for atomic process group termination
222+
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
223+
224+
Args:
225+
process: The process to terminate
226+
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
227+
"""
228+
if sys.platform == "win32":
229+
await terminate_windows_process_tree(process, timeout_seconds)
230+
else:
231+
# FallbackProcess should only be used for Windows compatibility
232+
assert isinstance(process, Process)
233+
await terminate_posix_process_tree(process, timeout_seconds)

dive_mcp_host/host/tools/local_http_server.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
from collections.abc import AsyncGenerator
66
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
7+
from time import time
78
from typing import Any
89

910
import httpx
@@ -30,7 +31,6 @@ async def local_http_server( # noqa: C901, PLR0913, PLR0915
3031
command: str | None = None,
3132
args: list[str] | None = None,
3233
env: dict[str, str] | None = None,
33-
max_connection_retries: int = 10,
3434
headers: dict[str, Any] | None = None,
3535
) -> AsyncGenerator[tuple[InitializeResult, ListToolsResult, int], None]:
3636
"""Create a local MCP server client.
@@ -65,7 +65,13 @@ def _sse_client(
6565
value = headers[key]
6666
if isinstance(value, SecretStr):
6767
headers[key] = value.get_secret_value()
68-
return sse_client(url=url, headers=headers)
68+
logger.debug("Connecting sse client with timeout: %s", config.initial_timeout)
69+
return sse_client(
70+
url=url,
71+
headers=headers,
72+
sse_read_timeout=config.initial_timeout,
73+
timeout=config.initial_timeout,
74+
)
6975

7076
get_client = _sse_client if config.transport == "sse" else websocket_client
7177
logger.debug("Starting local MCP server %s with command: %s", config.name, command)
@@ -80,7 +86,6 @@ def _sse_client(
8086
):
8187
logger.error("Failed to start subprocess for %s", config.name)
8288
raise RuntimeError("failed to start subprocess")
83-
retried = 0
8489

8590
# it tooks time to start the server, so we need to retry
8691
async def _read_stdout(
@@ -114,38 +119,40 @@ async def _read_stderr(
114119
name="read-stdout",
115120
)
116121

122+
start_time = time()
117123
try:
118-
while retried < max_connection_retries:
119-
await asyncio.sleep(0.3 if retried == 0 else 1)
120-
logger.debug(
121-
"Attempting to connect to server %s (attempt %d/%d)",
122-
config.name,
123-
retried + 1,
124-
max_connection_retries,
125-
)
124+
logger.debug(
125+
"Server %s initalizing with timeout: %s",
126+
config.name,
127+
config.initial_timeout,
128+
)
129+
while (time() - start_time) < config.initial_timeout:
126130
with suppress(TimeoutError, httpx.HTTPError):
127131
async with (
128132
get_client(url=config.url) as streams,
129133
ClientSession(*streams) as session,
130134
):
131-
async with asyncio.timeout(10):
135+
async with asyncio.timeout(config.initial_timeout):
132136
initialize_result = await session.initialize()
133137
tools = await session.list_tools()
134138
logger.info(
135139
"Successfully connected to server %s, got tools: %s",
136140
config.name,
137141
tools,
138142
)
139-
break
140-
retried += 1
143+
logger.info("Connected to the server %s", config.name)
144+
yield initialize_result, tools, subprocess.pid
145+
break
146+
await asyncio.sleep(0.3)
141147
else:
142-
raise InvalidMcpServerError(config.name)
143-
logger.info(
144-
"Connected to the server %s after %d attempts", config.name, retried
145-
)
146-
yield initialize_result, tools, subprocess.pid
148+
logger.warning(
149+
"Connected to the server %s failed after %s seconds",
150+
config.name,
151+
config.initial_timeout,
152+
)
153+
raise InvalidMcpServerError(config.name, reason="failed to initalize")
147154
finally:
148-
with suppress(TimeoutError):
155+
with suppress(TimeoutError, ProcessLookupError, asyncio.CancelledError):
149156
logger.debug("Terminating subprocess for %s", config.name)
150157
read_stderr_task.cancel()
151158
read_stdout_task.cancel()
@@ -158,7 +165,7 @@ async def _read_stderr(
158165
subprocess = None
159166
if subprocess:
160167
logger.info("Timeout to terminate mcp-server %s. Kill it.", config.name)
161-
with suppress(TimeoutError):
168+
with suppress(TimeoutError, ProcessLookupError, asyncio.CancelledError):
162169
read_stderr_task.cancel()
163170
read_stdout_task.cancel()
164171
subprocess.kill()

dive_mcp_host/host/tools/mcp_server.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class McpServer(ContextProtocol):
111111

112112
RETRY_LIMIT: int = 3
113113
KEEP_ALIVE_INTERVAL: float = 60
114-
RESTART_INTERVAL: float = 3
114+
RESTART_INTERVAL: float = 1
115115

116116
def __init__(
117117
self,
@@ -214,11 +214,16 @@ async def _message_handler(
214214

215215
async def _init_tool_info(self, session: ClientSession) -> None:
216216
"""Initialize the session."""
217-
async with asyncio.timeout(10):
217+
logger.debug(
218+
"Client %s initalizing with timeout: %s",
219+
self.name,
220+
self.config.initial_timeout,
221+
)
222+
async with asyncio.timeout(self.config.initial_timeout):
218223
# When using stdio, the initialize call may block indefinitely
219224
self._initialize_result = await session.initialize()
220225
logger.debug(
221-
"Client %s initializing, result: %s",
226+
"Client %s initialize result: %s",
222227
self.name,
223228
self._initialize_result,
224229
)
@@ -428,15 +433,19 @@ async def _session_ctx_mgr_wrapper(
428433
async with self._cond:
429434
self._cond.notify_all()
430435

431-
async def _stdio_client_watcher(self) -> None: # noqa: C901, PLR0915
436+
async def _stdio_client_watcher(self) -> None: # noqa: C901, PLR0915, PLR0912
432437
"""Client watcher task.
433438
434439
Restart the client if need.
435440
Only this watcher can set the client status to RUNNING / FAILED.
436441
"""
437442
env = os.environ.copy()
438443
env.update(self.config.env)
439-
while True:
444+
start_time = time.time()
445+
while (
446+
self._retries == 0
447+
or (time.time() - start_time) < self.config.initial_timeout
448+
):
440449
should_break = False
441450
try:
442451
logger.debug("Attempting to initialize client %s", self.name)
@@ -466,6 +475,11 @@ async def _stdio_client_watcher(self) -> None: # noqa: C901, PLR0915
466475
self.name,
467476
self._client_status,
468477
)
478+
if self._client_status == ClientState.RESTARTING:
479+
self._retries = 0
480+
start_time = time.time()
481+
continue
482+
return
469483
except* ProcessLookupError as eg:
470484
# this raised when a stdio process is exited
471485
# and the initialize call is timeout
@@ -522,25 +536,25 @@ async def _stdio_client_watcher(self) -> None: # noqa: C901, PLR0915
522536
if self._client_status == ClientState.CLOSED:
523537
logger.info("Client %s closed, stopping watcher", self.name)
524538
return
525-
if self._retries >= self.RETRY_LIMIT or should_break:
526-
logger.warning(
527-
"client for [%s] failed after %d retries %s",
528-
self.name,
529-
self._retries,
530-
self._exception,
531-
)
532-
async with self._cond:
533-
if self._client_status != ClientState.CLOSED:
534-
await self.__change_state(ClientState.FAILED, None, False)
535-
return
539+
if should_break:
540+
break
536541
logger.debug(
537-
"Retrying client initialization for %s (attempt %d/%d)",
542+
"Retrying client initialization for %s (attempt %d)",
538543
self.name,
539544
self._retries,
540-
self.RETRY_LIMIT,
541545
)
542546
await asyncio.sleep(self.RESTART_INTERVAL)
543547

548+
logger.warning(
549+
"client for [%s] failed after %d retries %s",
550+
self.name,
551+
self._retries,
552+
self._exception,
553+
)
554+
async with self._cond:
555+
if self._client_status != ClientState.CLOSED:
556+
await self.__change_state(ClientState.FAILED, None, False)
557+
544558
async def _stdio_setup(self) -> None:
545559
"""Setup the stdio client."""
546560
self._server_task = asyncio.create_task(
@@ -714,7 +728,10 @@ def _http_get_client(
714728
async def _http_init_client(self) -> None:
715729
"""Initialize the HTTP client."""
716730
async with (
717-
self._http_get_client(sse_read_timeout=30) as streams,
731+
self._http_get_client(
732+
sse_read_timeout=self.config.initial_timeout,
733+
timeout=self.config.initial_timeout,
734+
) as streams,
718735
ClientSession(
719736
*[streams[0], streams[1]], message_handler=self._message_handler
720737
) as session,
@@ -724,7 +741,11 @@ async def _http_init_client(self) -> None:
724741
async def _http_setup(self) -> None:
725742
"""Setup the http client."""
726743
self._retries = 0
727-
for _ in range(self.RETRY_LIMIT):
744+
start_time = time.time()
745+
while (
746+
self._retries == 0
747+
or (time.time() - start_time) < self.config.initial_timeout
748+
):
728749
should_break = False
729750
try:
730751
await self._http_init_client()

dive_mcp_host/httpd/conf/mcp_servers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
from pathlib import Path
77
from typing import Annotated, Any, Literal
88

9-
from pydantic import BaseModel, BeforeValidator, Field, SecretStr, field_serializer
9+
from pydantic import (
10+
BaseModel,
11+
BeforeValidator,
12+
ConfigDict,
13+
Field,
14+
SecretStr,
15+
field_serializer,
16+
)
1017

1118
from dive_mcp_host.env import DIVE_CONFIG_DIR
1219
from dive_mcp_host.host.conf import ProxyUrl
@@ -34,6 +41,13 @@ class MCPServerConfig(BaseModel):
3441
proxy: ProxyUrl | None = None
3542
headers: dict[str, SecretStr] | None = Field(default_factory=dict)
3643
exclude_tools: list[str] = Field(default_factory=list)
44+
initial_timeout: float = Field(default=10, ge=10, alias="initialTimeout")
45+
46+
model_config = ConfigDict(
47+
validate_by_name=True,
48+
validate_by_alias=True,
49+
serialize_by_alias=True,
50+
)
3751

3852
def model_post_init(self, _: Any) -> None:
3953
"""Post-initialization hook."""
@@ -56,6 +70,12 @@ class Config(BaseModel):
5670
alias="mcpServers", default_factory=dict
5771
)
5872

73+
model_config = ConfigDict(
74+
validate_by_name=True,
75+
validate_by_alias=True,
76+
serialize_by_alias=True,
77+
)
78+
5979

6080
type McpServerConfigCallback = Callable[[Config], Config | Coroutine[Any, Any, Config]]
6181
UpdateAllConfigsHookName = "httpd.config.mcp_servers.update_all_configs"

dive_mcp_host/httpd/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ async def load_host_config(self) -> HostConfig:
236236
headers=server_config.headers or {},
237237
proxy=server_config.proxy or None,
238238
exclude_tools=server_config.exclude_tools,
239+
initial_timeout=server_config.initial_timeout,
239240
)
240241

241242
logger.debug("got %s mcp servers in config", len(mcp_servers))

0 commit comments

Comments
 (0)