Skip to content

Commit c9c6fac

Browse files
committed
feat: add dummy tool result for aborted tool calls
1 parent faee3f1 commit c9c6fac

File tree

4 files changed

+144
-3
lines changed

4 files changed

+144
-3
lines changed

dive_mcp_host/host/agents/chat_agent.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pydantic import BaseModel
3232

3333
from dive_mcp_host.host.agents.agent_factory import AgentFactory, initial_messages
34+
from dive_mcp_host.host.agents.message_order import tool_call_order
3435
from dive_mcp_host.host.agents.tools_in_prompt import (
3536
convert_messages,
3637
extract_tool_calls,
@@ -160,7 +161,6 @@ def _check_more_steps_needed(
160161
)
161162

162163
def _call_model(self, state: AgentState, config: RunnableConfig) -> AgentState:
163-
# TODO: _validate_chat_history
164164
if not self._tools_in_prompt:
165165
model = self._model
166166
if self._tool_classes:
@@ -204,8 +204,13 @@ def _before_agent(self, state: AgentState, config: RunnableConfig) -> AgentState
204204
configurable = config.get("configurable", {})
205205
max_input_tokens: int | None = configurable.get("max_input_tokens")
206206
oversize_policy: Literal["window"] | None = configurable.get("oversize_policy")
207+
208+
new_messages: list[BaseMessage] = []
209+
new_messages.extend(tool_call_order(state["messages"]))
210+
207211
if max_input_tokens is None or oversize_policy is None:
208-
return cast(AgentState, {"messages": []})
212+
return cast(AgentState, {"messages": new_messages})
213+
209214
if oversize_policy == "window":
210215
messages: list[BaseMessage] = trim_messages(
211216
state["messages"],
@@ -217,7 +222,8 @@ def _before_agent(self, state: AgentState, config: RunnableConfig) -> AgentState
217222
for m in state["messages"]
218223
if m not in messages
219224
]
220-
return cast(AgentState, {"messages": remove_messages})
225+
new_messages.extend(remove_messages)
226+
return cast(AgentState, {"messages": new_messages})
221227

222228
return cast(AgentState, {"messages": []})
223229

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from logging import getLogger
2+
from uuid import uuid4
3+
4+
from langchain_core.messages import (
5+
AIMessage,
6+
AnyMessage,
7+
BaseMessage,
8+
RemoveMessage,
9+
ToolMessage,
10+
)
11+
12+
from dive_mcp_host.log import TRACE
13+
14+
logger = getLogger(__name__)
15+
16+
17+
FAKE_TOOL_RESPONSE = "FAKE_TOOL_RESPONSE"
18+
19+
20+
def _has_tool_call(msg: AnyMessage | None) -> bool:
21+
return msg is not None and isinstance(msg, AIMessage) and len(msg.tool_calls) > 0
22+
23+
24+
def _not_tool_result(msg: AnyMessage) -> bool:
25+
return not isinstance(msg, ToolMessage)
26+
27+
28+
def tool_call_order(messages: list[AnyMessage]) -> list[BaseMessage]:
29+
"""Guarantee tool call tool result pair.
30+
31+
Providers like Anthropic requires each tool call to have their
32+
corresponding tool result.
33+
"""
34+
logger.log(TRACE, "Examine tool call order. msgs: %s", messages)
35+
36+
new_msgs: list[BaseMessage] = []
37+
remove_msgs: list[RemoveMessage] = []
38+
found_error: bool = False
39+
prev_msg: BaseMessage | None = None
40+
41+
for index, msg in enumerate(messages):
42+
if _has_tool_call(prev_msg) and _not_tool_result(msg):
43+
assert isinstance(prev_msg, AIMessage), "Could only be AIMessage"
44+
logger.warning(
45+
"Found tool call that doesn't have tool result as next message: %s",
46+
prev_msg.model_dump_json(),
47+
)
48+
49+
# Add tool results for each tool call
50+
for tool_call in prev_msg.tool_calls:
51+
new_msgs.append(
52+
ToolMessage(
53+
content="Previous tool call was not processed",
54+
tool_call_id=tool_call["id"],
55+
response_metadata={FAKE_TOOL_RESPONSE: True},
56+
id=uuid4().hex,
57+
),
58+
)
59+
60+
# Because we will rearrange all messages after this tool_call,
61+
# we will remove the original messages.
62+
if not found_error:
63+
found_error = True
64+
remove_msgs = [
65+
RemoveMessage(id=msg.id) for msg in messages[index:] if msg.id
66+
]
67+
68+
# Add original messages back (with new id)
69+
if found_error:
70+
new = type(msg)(**msg.model_dump())
71+
new.id = uuid4().hex
72+
new_msgs.append(new)
73+
74+
prev_msg = msg
75+
76+
result = remove_msgs + new_msgs
77+
logger.log(TRACE, "Tool call order result: %s", result)
78+
logger.debug(
79+
"tool call order result, fake tool result needed: %s"
80+
", new_msgs: %s, remove_msgs: %s",
81+
found_error,
82+
len(new_msgs),
83+
len(remove_msgs),
84+
)
85+
return result

dive_mcp_host/httpd/routers/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pydantic import BaseModel
1616
from starlette.datastructures import State
1717

18+
from dive_mcp_host.host.agents.message_order import FAKE_TOOL_RESPONSE
1819
from dive_mcp_host.host.errors import LogBufferNotFoundError
1920
from dive_mcp_host.host.tools.log import LogEvent, LogManager, LogMsg
2021
from dive_mcp_host.host.tools.model_types import ClientState
@@ -489,6 +490,13 @@ async def _handle_response( # noqa: C901, PLR0912
489490
await self._stream_text_msg(message)
490491
elif isinstance(message, ToolMessage):
491492
logger.log(TRACE, "got tool message: %s", message.model_dump_json())
493+
if message.response_metadata.get(FAKE_TOOL_RESPONSE, False):
494+
logger.log(
495+
TRACE,
496+
"ignore fake tool response: %s",
497+
message.model_dump_json(),
498+
)
499+
continue
492500
await self._stream_tool_result_msg(message)
493501
else:
494502
# idk what is this

tests/test_message_order.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from langchain_core.messages import AIMessage, HumanMessage, RemoveMessage, ToolMessage
2+
3+
from dive_mcp_host.host.agents.message_order import FAKE_TOOL_RESPONSE, tool_call_order
4+
5+
6+
def test_msg_order():
7+
"""Test if message order correction is successful."""
8+
tool_call_id = "toolu_012N5cw28KM9QfRLeYdik5V6"
9+
messages = [
10+
HumanMessage(content="Hi, please generate a image of xxx for me.", id="1"),
11+
AIMessage(
12+
content="Sure, I will us xxx to generate and image of xxx for you.",
13+
tool_calls=[
14+
{
15+
"name": "xxx",
16+
"args": {
17+
"prompt": "A xxx",
18+
},
19+
"id": "toolu_012N5cw28KM9QfRLeYdik5V6",
20+
"type": "tool_call",
21+
}
22+
],
23+
id="2",
24+
),
25+
HumanMessage(content="Hi, again", id="3"),
26+
]
27+
result = tool_call_order(messages)
28+
assert len(result) == 3
29+
30+
# Remove messages after the tool call
31+
assert isinstance(result[0], RemoveMessage)
32+
33+
# Insert ToolMessage
34+
assert isinstance(result[1], ToolMessage)
35+
assert result[1].tool_call_id == tool_call_id
36+
assert result[1].id
37+
assert result[1].response_metadata[FAKE_TOOL_RESPONSE]
38+
39+
# Other messages behind ToolMessage
40+
assert isinstance(result[2], HumanMessage)
41+
assert result[2].content == messages[2].content
42+
assert result[2].id

0 commit comments

Comments
 (0)