Skip to content

Commit b4f17bc

Browse files
committed
Merge pull request 'Backport tool enable / disable from main' (#439) from backport-tool-selection into development
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/439
2 parents c29e5f3 + 89437c8 commit b4f17bc

File tree

10 files changed

+414
-67
lines changed

10 files changed

+414
-67
lines changed

dive_mcp_host/host/tools/mcp_server.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,23 @@
6565
logger = getLogger(__name__)
6666

6767

68+
class ToolInfo(types.Tool):
69+
"""Custom tool info with extra info."""
70+
71+
enable: bool
72+
73+
@classmethod
74+
def from_tool(cls, tool: types.Tool, enable: bool) -> Self:
75+
"""Create from mcp Tool type."""
76+
return cls(**tool.model_dump(), enable=enable)
77+
78+
6879
class McpServerInfo(BaseModel):
6980
"""MCP server capability and tool list."""
7081

7182
name: str
7283
"""The name of the MCP server."""
73-
tools: list[types.Tool]
84+
tools: list[ToolInfo]
7485
"""The tools provided by the MCP server."""
7586
initialize_result: types.InitializeResult | None
7687
"""The result of the initialize method.
@@ -239,19 +250,35 @@ def log_buffer(self) -> LogBuffer:
239250
@property
240251
def server_info(self) -> McpServerInfo:
241252
"""Get the server info."""
253+
tools: list[ToolInfo] = []
254+
if self._tool_results:
255+
for tool in self._tool_results.tools:
256+
enable: bool = True
257+
if tool.name in self.config.exclude_tools:
258+
enable = False
259+
tools.append(ToolInfo.from_tool(tool, enable))
260+
242261
return McpServerInfo(
243262
name=self.name,
244263
initialize_result=self._initialize_result,
245-
tools=self._tool_results.tools if self._tool_results is not None else [],
264+
tools=tools,
246265
client_status=self._client_status,
247266
error=self._exception,
248267
)
249268

269+
def _get_enabled_tools(self) -> list[McpTool]:
270+
result: list[McpTool] = []
271+
for tool in self._mcp_tools:
272+
if tool.name in self.config.exclude_tools:
273+
continue
274+
result.append(tool)
275+
return result
276+
250277
@property
251278
def mcp_tools(self) -> list[McpTool]:
252279
"""Get the tools."""
253280
if self._client_status == ClientState.RUNNING:
254-
return self._mcp_tools
281+
return self._get_enabled_tools()
255282
return []
256283

257284
def session(

dive_mcp_host/httpd/conf/mcp_servers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@ class MCPServerConfig(BaseModel):
2727
) = "stdio"
2828
enabled: bool = True
2929
command: str | None = None
30-
args: list[str] | None = None
31-
env: dict[str, str] | None = None
30+
args: list[str] | None = Field(default_factory=list)
31+
env: dict[str, str] | None = Field(default_factory=dict)
3232
url: str | None = None
33-
headers: dict[str, SecretStr] | None = None
3433
extra_data: dict[str, Any] | None = Field(default=None, alias="extraData")
3534
proxy: ProxyUrl | None = None
35+
headers: dict[str, SecretStr] | None = Field(default_factory=dict)
36+
exclude_tools: list[str] = Field(default_factory=list)
37+
38+
def model_post_init(self, _: Any) -> None:
39+
"""Post-initialization hook."""
40+
if self.transport in ["sse", "websocket"]:
41+
if self.url is None:
42+
raise ValueError("url is required for sse and websocket transport")
43+
elif self.transport == "stdio" and self.command is None:
44+
raise ValueError("command is required for stdio transport")
3645

3746
@field_serializer("headers", when_used="json")
3847
def dump_headers(self, v: dict[str, SecretStr] | None) -> dict[str, str] | None:
@@ -43,7 +52,9 @@ def dump_headers(self, v: dict[str, SecretStr] | None) -> dict[str, str] | None:
4352
class Config(BaseModel):
4453
"""Model of mcp_config.json."""
4554

46-
mcp_servers: dict[str, MCPServerConfig] = Field(alias="mcpServers")
55+
mcp_servers: dict[str, MCPServerConfig] = Field(
56+
alias="mcpServers", default_factory=dict
57+
)
4758

4859

4960
type McpServerConfigCallback = Callable[[Config], Config | Coroutine[Any, Any, Config]]

dive_mcp_host/httpd/routers/config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
from fastapi import APIRouter, Depends, Request
44
from pydantic import BaseModel, Field
55

6-
from dive_mcp_host.httpd.conf.mcp_servers import Config
6+
from dive_mcp_host.httpd.conf.mcp_servers import Config as McpServers
77
from dive_mcp_host.httpd.dependencies import get_app
88
from dive_mcp_host.httpd.server import DiveHostAPI
99

1010
from .models import (
1111
EmbedConfig,
1212
McpServerError,
13-
McpServers,
1413
ModelFullConfigs,
1514
ModelInterfaceDefinition,
1615
ModelSettingsDefinition,
@@ -79,24 +78,26 @@ async def get_mcp_server(
7978
)
8079

8180

81+
# Frontend prefers to use this API for all MCP server config interactions.
82+
# Doesn't matter if they only want to change a small thing in a single MCP server.
83+
# Just overwrite the entire config every time.
8284
@config.post("/mcpserver")
8385
async def post_mcp_server(
84-
servers: McpServers,
86+
new_config: McpServers,
8587
app: DiveHostAPI = Depends(get_app),
8688
force: bool = False,
8789
) -> SaveConfigResult:
8890
"""Save MCP server configurations.
8991
9092
Args:
91-
servers (McpServers): The server configurations to save.
93+
new_config (McpServers): The server configurations to save.
9294
app (DiveHostAPI): The DiveHostAPI instance.
9395
force (bool): If True, reload all mcp servers even if they are not changed.
9496
9597
Returns:
9698
SaveConfigResult: Result of the save operation with any errors.
9799
"""
98100
# Update conifg
99-
new_config = Config.model_validate(servers.model_dump(by_alias=True))
100101
if not await app.mcp_server_config_manager.update_all_configs(new_config):
101102
raise ValueError("Failed to update MCP server configurations")
102103

dive_mcp_host/httpd/routers/models.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from enum import StrEnum
2-
from typing import Annotated, Any, Literal, Self, TypeVar
2+
from typing import Any, Literal, Self, TypeVar
33

44
from pydantic import (
55
BaseModel,
6-
BeforeValidator,
76
ConfigDict,
87
Field,
98
RootModel,
@@ -29,45 +28,6 @@ class ResultResponse(BaseModel):
2928
message: str | None = None
3029

3130

32-
Transport = Literal["stdio", "sse", "websocket", "streamable"]
33-
34-
35-
class McpServerConfig(BaseModel):
36-
"""MCP server configuration with transport and connection settings."""
37-
38-
transport: Annotated[
39-
Transport, BeforeValidator(lambda v: "stdio" if v == "command" else v)
40-
]
41-
enabled: bool | None
42-
command: str | None = None
43-
args: list[str] | None = Field(default_factory=list)
44-
env: dict[str, str] | None = Field(default_factory=dict)
45-
url: str | None = None
46-
headers: dict[str, SecretStr] | None = Field(default_factory=dict)
47-
extra_data: dict[str, Any] | None = Field(default=None, alias="extraData")
48-
49-
def model_post_init(self, _: Any) -> None:
50-
"""Post-initialization hook."""
51-
if self.transport in ["sse", "websocket"]:
52-
if self.url is None:
53-
raise ValueError("url is required for sse and websocket transport")
54-
elif self.transport == "stdio" and self.command is None:
55-
raise ValueError("command is required for stdio transport")
56-
57-
@field_serializer("headers", when_used="json")
58-
def dump_api_key(self, v: dict[str, SecretStr] | None) -> dict[str, str] | None:
59-
"""Serialize the api_key field to plain text."""
60-
return {k: v.get_secret_value() for k, v in v.items()} if v else None
61-
62-
63-
class McpServers(BaseModel):
64-
"""Collection of MCP server configurations."""
65-
66-
mcp_servers: dict[str, McpServerConfig] = Field(
67-
alias="mcpServers", default_factory=dict
68-
)
69-
70-
7131
class McpServerError(BaseModel):
7232
"""Represents an error from an MCP server."""
7333

@@ -127,6 +87,7 @@ class SimpleToolInfo(BaseModel):
12787

12888
name: str
12989
description: str
90+
enabled: bool = True
13091

13192

13293
class McpTool(BaseModel):

dive_mcp_host/httpd/routers/tools.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import ValidationError
66

77
from dive_mcp_host.host.tools.model_types import ClientState
8+
from dive_mcp_host.httpd.conf.mcp_servers import Config
89
from dive_mcp_host.httpd.dependencies import get_app
910
from dive_mcp_host.httpd.routers.models import (
1011
McpTool,
@@ -42,7 +43,7 @@ async def initialized(
4243

4344

4445
@tools.get("/")
45-
async def list_tools(
46+
async def list_tools( # noqa: PLR0912, C901
4647
app: DiveHostAPI = Depends(get_app),
4748
) -> ToolsResult:
4849
"""Lists all available MCP tools.
@@ -54,16 +55,20 @@ async def list_tools(
5455

5556
# get full list of servers from config
5657
if (config := await app.mcp_server_config_manager.get_current_config()) is not None:
57-
all_servers = set(config.mcp_servers.keys())
58+
all_server_configs = config
5859
else:
59-
all_servers = set()
60+
all_server_configs = Config()
6061

6162
# get tools from dive host
6263
for server_name, server_info in app.dive_host["default"].mcp_server_info.items():
6364
result[server_name] = McpTool(
6465
name=server_name,
6566
tools=[
66-
SimpleToolInfo(name=tool.name, description=tool.description or "")
67+
SimpleToolInfo(
68+
name=tool.name,
69+
description=tool.description or "",
70+
enabled=tool.enable,
71+
)
6772
for tool in server_info.tools
6873
],
6974
description="",
@@ -73,7 +78,7 @@ async def list_tools(
7378
)
7479
logger.debug("active mcp servers: %s", result.keys())
7580
# find missing servers
76-
missing_servers = all_servers - set(result.keys())
81+
missing_servers = set(all_server_configs.mcp_servers.keys()) - set(result.keys())
7782
logger.debug("disabled mcp servers: %s", missing_servers)
7883

7984
# get missing from local cache
@@ -90,6 +95,15 @@ async def list_tools(
9095
for server_name in missing_servers:
9196
if server_info := cached_tools.root.get(server_name, None):
9297
server_info.enabled = False
98+
99+
# Sync tool 'enabled' with 'exclude_tools'
100+
if mcp_config := all_server_configs.mcp_servers.get(server_name):
101+
for tool in server_info.tools:
102+
if tool.name in mcp_config.exclude_tools:
103+
tool.enabled = False
104+
else:
105+
tool.enabled = True
106+
93107
result[server_name] = server_info
94108
else:
95109
logger.warning("Server %s not found in cached tools", server_name)

dive_mcp_host/httpd/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ async def load_host_config(self) -> HostConfig:
235235
transport=server_config.transport or "stdio",
236236
headers=server_config.headers or {},
237237
proxy=server_config.proxy or None,
238+
exclude_tools=server_config.exclude_tools,
238239
)
239240

240241
logger.debug("got %s mcp servers in config", len(mcp_servers))

0 commit comments

Comments
 (0)