Skip to content

Commit 0d285bb

Browse files
committed
Merge pull request 'fix: update message handling for resend functionality' (#410) from fix-abort-then-resend into development
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/410
2 parents 78f95a1 + e56bfba commit 0d285bb

File tree

2 files changed

+119
-9
lines changed

2 files changed

+119
-9
lines changed

dive_mcp_host/host/chat.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ async def _run_in_context(self) -> AsyncGenerator[Self, None]:
115115
finally:
116116
self._agent = None
117117

118-
async def _get_updates_for_resend(
118+
async def _remove_messages_for_resend(
119119
self,
120120
resend: list[BaseMessage],
121121
update: list[BaseMessage],
122-
) -> list[BaseMessage]:
123-
if not self._checkpointer:
124-
return update
122+
) -> None:
123+
if not self._checkpointer or not (
124+
self.active_agent.checkpointer
125+
and isinstance(self.active_agent.checkpointer, BaseCheckpointSaver)
126+
):
127+
return
125128
resend_map = {msg.id: msg for msg in resend}
126129
to_update = [i for i in update if i.id not in resend_map]
127130
if state := await self.active_agent.aget_state(
@@ -134,15 +137,25 @@ async def _get_updates_for_resend(
134137
):
135138
drop_after = False
136139
if not state.values:
137-
return to_update
140+
return
138141

139142
for msg in cast(MessagesState, state.values)["messages"]:
140143
assert msg.id is not None # all messages from the agent have an ID
141144
if msg.id in resend_map:
142145
drop_after = True
143146
elif drop_after:
144147
to_update.append(RemoveMessage(msg.id))
145-
return to_update
148+
if to_update:
149+
await self.active_agent.aupdate_state(
150+
RunnableConfig(
151+
configurable={
152+
"thread_id": self._chat_id,
153+
"user_id": self._user_id,
154+
},
155+
),
156+
{"messages": to_update},
157+
)
158+
return
146159
raise ThreadNotFoundError(self._chat_id)
147160

148161
def query(
@@ -180,9 +193,7 @@ async def _stream_response() -> AsyncGenerator[dict[str, Any] | Any, None]:
180193
isinstance(msg, BaseMessage) and msg.id for msg in query_msgs
181194
):
182195
raise MessageTypeError("Resending messages must has an ID")
183-
query_msgs += await self._get_updates_for_resend(
184-
query_msgs, modify or []
185-
)
196+
await self._remove_messages_for_resend(query_msgs, modify or [])
186197
elif modify:
187198
query_msgs = [*query_msgs, *modify]
188199
signal = asyncio.Event()

tests/test_host.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import os
4+
import time
45
from contextlib import AbstractAsyncContextManager
56
from typing import Any, cast
67
from unittest import mock
@@ -15,8 +16,10 @@
1516
ToolCall,
1617
ToolMessage,
1718
)
19+
from langgraph.checkpoint.memory import InMemorySaver
1820
from pydantic import AnyUrl, SecretStr
1921

22+
from dive_mcp_host.host.chat import Chat
2023
from dive_mcp_host.host.conf import CheckpointerConfig, HostConfig
2124
from dive_mcp_host.host.conf.llm import LLMConfig
2225
from dive_mcp_host.host.custom_events import ToolCallProgress
@@ -561,3 +564,99 @@ async def test_custom_event(
561564
assert isinstance(i[1][1], ToolCallProgress)
562565
done = True
563566
assert done
567+
568+
569+
@pytest.mark.asyncio
570+
async def test_resend_after_abort( # noqa: C901
571+
echo_tool_stdio_config: dict[str, ServerConfig],
572+
) -> None:
573+
"""Test that resend after abort works."""
574+
config = HostConfig(
575+
llm=LLMConfig(
576+
model="fake",
577+
model_provider="dive",
578+
),
579+
mcp_servers=echo_tool_stdio_config,
580+
)
581+
582+
async with DiveMcpHost(config) as mcp_host:
583+
mcp_host._checkpointer = InMemorySaver()
584+
fake_responses = [
585+
AIMessage(
586+
content="Call echo tool",
587+
tool_calls=[
588+
ToolCall(
589+
name="echo",
590+
args={"message": "Hello, world!", "delay_ms": 2000},
591+
id="phase1-tool-1",
592+
type="tool_call",
593+
),
594+
],
595+
id="phase1-1",
596+
),
597+
AIMessage(
598+
content="Bye",
599+
id="phase1-2",
600+
),
601+
]
602+
603+
ts = time.time()
604+
got_msg = False
605+
606+
async def _abort_task(chat: Chat) -> None:
607+
while True:
608+
if got_msg and time.time() - ts > 0.5:
609+
chat.abort()
610+
break
611+
await asyncio.sleep(0.1)
612+
613+
fake_model = cast("FakeMessageToolModel", mcp_host.model)
614+
fake_model.responses = fake_responses
615+
await mcp_host.tools_initialized_event.wait()
616+
chat = mcp_host.chat(chat_id="chat_id")
617+
async with chat:
618+
task = asyncio.create_task(_abort_task(chat))
619+
async for r in chat.query(
620+
HumanMessage(content="Hello, world!", id="H1"),
621+
stream_mode=["messages", "values", "updates"],
622+
):
623+
if r[0] == "messages": # type: ignore
624+
got_msg = True
625+
ts = time.time()
626+
await task
627+
628+
resend = [HumanMessage(content="Resend message!", id="H1")]
629+
fake_responses = [
630+
AIMessage(
631+
content="Call echo tool",
632+
tool_calls=[
633+
ToolCall(
634+
name="echo",
635+
args={"message": "Hello, world!"},
636+
id="phase2-tool-1",
637+
type="tool_call",
638+
),
639+
],
640+
id="phase2-1",
641+
),
642+
AIMessage(
643+
content="Bye",
644+
id="phase2-2",
645+
),
646+
]
647+
fake_model.i = 0
648+
fake_model.responses = fake_responses
649+
chat = mcp_host.chat(chat_id="chat_id")
650+
async with chat:
651+
async for r in chat.query(
652+
resend, # type: ignore
653+
is_resend=True,
654+
):
655+
r = cast(tuple[str, list[BaseMessage]], r)
656+
if r[0] == "messages":
657+
for m in r[1]:
658+
assert m.id
659+
if isinstance(m, HumanMessage):
660+
assert m.id == "H1"
661+
elif isinstance(m, AIMessage):
662+
assert m.id.startswith("phase2-")

0 commit comments

Comments
 (0)