Skip to content

Commit 89437c8

Browse files
committed
Merge pull request 'feat: Select individual tools to disable / enable' (#398) from disable-tool into main
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/398
2 parents 00454b0 + f301547 commit 89437c8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+6016
-1648
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}
6161
AZURE_OPENAI_DEPLOYMENT_NAME: ${{ secrets.AZURE_OPENAI_DEPLOYMENT_NAME }}
6262
AZURE_OPENAI_API_VERSION: ${{ secrets.AZURE_OPENAI_API_VERSION }}
63+
OAP_TOKEN: ${{ vars.OAP_TOKEN }}
6364
run: |
6465
uv run --extra dev --frozen pytest
6566

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ db.sqlite-wal
4444
cache/
4545
upload/
4646
cli_config.json
47+
logs/

dive_mcp_host/cli/cli.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import argparse
44
from pathlib import Path
55

6-
from langchain_core.messages import HumanMessage
6+
from langchain_core.messages import AIMessage, HumanMessage
7+
from langchain_core.output_parsers import StrOutputParser
78

89
from dive_mcp_host.cli.cli_types import CLIArgs
910
from dive_mcp_host.host.conf import HostConfig
@@ -67,6 +68,7 @@ async def run() -> None:
6768
with Path(args.prompt_file).open("r") as f:
6869
system_prompt = f.read()
6970

71+
output_parser = StrOutputParser()
7072
async with DiveMcpHost(config) as mcp_host:
7173
print("Waiting for tools to initialize...")
7274
await mcp_host.tools_initialized_event.wait()
@@ -75,7 +77,15 @@ async def run() -> None:
7577
current_chat_id = chat.chat_id
7678
async with chat:
7779
async for response in chat.query(query, stream_mode="messages"):
78-
print(response[0].content, end="") # type: ignore
80+
assert isinstance(response, tuple)
81+
msg = response[0]
82+
if isinstance(msg, AIMessage):
83+
content = output_parser.invoke(msg)
84+
print(content, end="")
85+
continue
86+
print(f"\n\n==== Start Of {type(msg)} ===")
87+
print(msg)
88+
print(f"==== End Of {type(msg)} ===\n")
7989

8090
print()
8191
print(f"Chat ID: {current_chat_id}")

dive_mcp_host/env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from os import getenv
2+
from pathlib import Path
3+
4+
RESOURCE_DIR = Path(getenv("RESOURCE_DIR", Path.cwd()))
5+
DIVE_CONFIG_DIR = Path(getenv("DIVE_CONFIG_DIR", Path.cwd()))

dive_mcp_host/host/agents/agent_factory.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import StrEnum
12
from typing import Literal, Protocol
23

34
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage
@@ -11,6 +12,16 @@
1112
from dive_mcp_host.host.prompt import PromptType
1213

1314

15+
class ConfigurableKey(StrEnum):
16+
"""Enum for RunnableConfig.configurable keys."""
17+
18+
# Thread id is also known as chat_id
19+
THREAD_ID = "thread_id"
20+
USER_ID = "user_id"
21+
MAX_INPUT_TOKENS = "max_input_tokens"
22+
OVERSIZE_POLICY = "oversize_policy"
23+
24+
1425
# XXX is there any better way to do this?
1526
class AgentFactory[T: MessagesState](Protocol):
1627
"""A factory for creating agents.
@@ -62,10 +73,10 @@ def create_config(
6273
"""
6374
return {
6475
"configurable": {
65-
"thread_id": thread_id,
66-
"user_id": user_id,
67-
"max_input_tokens": max_input_tokens,
68-
"oversize_policy": oversize_policy,
76+
ConfigurableKey.THREAD_ID: thread_id,
77+
ConfigurableKey.USER_ID: user_id,
78+
ConfigurableKey.MAX_INPUT_TOKENS: max_input_tokens,
79+
ConfigurableKey.OVERSIZE_POLICY: oversize_policy,
6980
},
7081
"recursion_limit": 102,
7182
}

dive_mcp_host/host/agents/chat_agent.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
HumanMessage,
1414
RemoveMessage,
1515
SystemMessage,
16+
ToolCall,
1617
ToolMessage,
1718
)
1819
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
1920
from langchain_core.prompt_values import ChatPromptValue
2021
from langchain_core.prompts import ChatPromptTemplate
21-
from langchain_core.runnables import Runnable, RunnableConfig
22+
from langchain_core.runnables import Runnable, RunnableConfig, RunnablePassthrough
2223
from langchain_core.tools import BaseTool
2324
from langgraph.checkpoint.base import BaseCheckpointSaver, V
2425
from langgraph.graph import END, StateGraph
@@ -31,16 +32,18 @@
3132
from pydantic import BaseModel
3233

3334
from dive_mcp_host.host.agents.agent_factory import AgentFactory, initial_messages
35+
from dive_mcp_host.host.agents.file_in_additional_kwargs import FileMsgConverter
3436
from dive_mcp_host.host.agents.message_order import tool_call_order
3537
from dive_mcp_host.host.agents.tools_in_prompt import (
3638
convert_messages,
3739
extract_tool_calls,
3840
)
3941
from dive_mcp_host.host.helpers import today_datetime
4042
from dive_mcp_host.host.prompt import PromptType, tools_prompt
43+
from dive_mcp_host.host.store.base import StoreManagerProtocol
4144

42-
StructuredResponse = dict | BaseModel
43-
StructuredResponseSchema = dict | type[BaseModel]
45+
type StructuredResponse = dict | BaseModel
46+
type StructuredResponseSchema = dict | type[BaseModel]
4447

4548

4649
class AgentState(MessagesState):
@@ -92,6 +95,41 @@ def _func(state: AgentState | ChatPromptValue) -> list[BaseMessage]:
9295
return prompt_runnable
9396

9497

98+
class HackedToolNode(ToolNode):
99+
"""hacked tool node to inject tool_call_id into the config.
100+
101+
This is a hack. If langgraph support tool_call_id, we will remove this class.
102+
"""
103+
104+
async def _arun_one(
105+
self,
106+
call: ToolCall,
107+
input_type: Literal["list", "dict", "tool_calls"],
108+
config: RunnableConfig,
109+
) -> ToolMessage:
110+
if "metadata" in config:
111+
config["metadata"]["tool_call_id"] = call["id"]
112+
else:
113+
config["metadata"] = {
114+
"tool_call_id": call["id"],
115+
}
116+
return await super()._arun_one(call, input_type, config)
117+
118+
def _run_one(
119+
self,
120+
call: ToolCall,
121+
input_type: Literal["list", "dict", "tool_calls"],
122+
config: RunnableConfig,
123+
) -> ToolMessage:
124+
if "metadata" in config:
125+
config["metadata"]["tool_call_id"] = call["id"]
126+
else:
127+
config["metadata"] = {
128+
"tool_call_id": call["id"],
129+
}
130+
return super()._run_one(call, input_type, config)
131+
132+
95133
class ChatAgentFactory(AgentFactory[AgentState]):
96134
"""A factory for ChatAgents."""
97135

@@ -100,6 +138,7 @@ def __init__(
100138
model: BaseChatModel,
101139
tools: Sequence[BaseTool] | ToolNode,
102140
tools_in_prompt: bool = False,
141+
store: StoreManagerProtocol | None = None,
103142
) -> None:
104143
"""Initialize the chat agent factory."""
105144
self._model = model
@@ -110,6 +149,12 @@ def __init__(
110149
StructuredResponseSchema | tuple[str, StructuredResponseSchema] | None
111150
) = None
112151

152+
self._file_msg_converter = (
153+
FileMsgConverter(model_provider=self._model_class, store=store).runnable
154+
if store
155+
else RunnablePassthrough()
156+
)
157+
113158
# changed when self._build_graph is called
114159
self._tool_classes: list[BaseTool] = []
115160
self._should_return_direct: set[str] = set()
@@ -160,22 +205,28 @@ def _check_more_steps_needed(
160205
)
161206
)
162207

163-
def _call_model(self, state: AgentState, config: RunnableConfig) -> AgentState:
208+
async def _call_model(
209+
self, state: AgentState, config: RunnableConfig
210+
) -> AgentState:
211+
# TODO: _validate_chat_history
164212
if not self._tools_in_prompt:
165213
model = self._model
166214
if self._tool_classes:
167215
model = self._model.bind_tools(self._tool_classes)
168-
model_runnable = self._prompt | drop_empty_messages | model
216+
model_runnable = (
217+
self._prompt | self._file_msg_converter | drop_empty_messages | model
218+
)
169219
else:
170220
model_runnable = (
171221
self._prompt
172222
| self._tool_prompt
173223
| convert_messages
224+
| self._file_msg_converter
174225
| drop_empty_messages
175226
| self._model
176227
)
177228

178-
response = model_runnable.invoke(state, config)
229+
response = await model_runnable.ainvoke(state, config)
179230
if isinstance(response, AIMessage):
180231
response = extract_tool_calls(response)
181232
if self._check_more_steps_needed(state, response):
@@ -254,7 +305,9 @@ def _build_graph(self) -> None:
254305
graph.add_edge("before_agent", "agent")
255306

256307
tool_node = (
257-
self._tools if isinstance(self._tools, ToolNode) else ToolNode(self._tools)
308+
self._tools
309+
if isinstance(self._tools, ToolNode)
310+
else HackedToolNode(self._tools)
258311
)
259312
self._tool_classes = list(tool_node.tools_by_name.values())
260313
graph.add_node("tools", tool_node)
@@ -323,9 +376,15 @@ def get_chat_agent_factory(
323376
model: BaseChatModel,
324377
tools: Sequence[BaseTool] | ToolNode,
325378
tools_in_prompt: bool = False,
379+
store: StoreManagerProtocol | None = None,
326380
) -> ChatAgentFactory:
327381
"""Get an agent factory."""
328-
return ChatAgentFactory(model, tools, tools_in_prompt)
382+
return ChatAgentFactory(
383+
model=model,
384+
tools=tools,
385+
tools_in_prompt=tools_in_prompt,
386+
store=store,
387+
)
329388

330389

331390
@RunnableCallable

0 commit comments

Comments
 (0)