|
1 | 1 | import asyncio
|
2 | 2 | import json
|
3 | 3 | import os
|
| 4 | +import time |
4 | 5 | from contextlib import AbstractAsyncContextManager
|
5 | 6 | from typing import Any, cast
|
6 | 7 | from unittest import mock
|
|
15 | 16 | ToolCall,
|
16 | 17 | ToolMessage,
|
17 | 18 | )
|
| 19 | +from langgraph.checkpoint.memory import InMemorySaver |
18 | 20 | from pydantic import AnyUrl, SecretStr
|
19 | 21 |
|
| 22 | +from dive_mcp_host.host.chat import Chat |
20 | 23 | from dive_mcp_host.host.conf import CheckpointerConfig, HostConfig
|
21 | 24 | from dive_mcp_host.host.conf.llm import LLMConfig
|
22 | 25 | from dive_mcp_host.host.custom_events import ToolCallProgress
|
@@ -561,3 +564,99 @@ async def test_custom_event(
|
561 | 564 | assert isinstance(i[1][1], ToolCallProgress)
|
562 | 565 | done = True
|
563 | 566 | assert done
|
| 567 | + |
| 568 | + |
| 569 | +@pytest.mark.asyncio |
| 570 | +async def test_resend_after_abort( # noqa: C901 |
| 571 | + echo_tool_stdio_config: dict[str, ServerConfig], |
| 572 | +) -> None: |
| 573 | + """Test that resend after abort works.""" |
| 574 | + config = HostConfig( |
| 575 | + llm=LLMConfig( |
| 576 | + model="fake", |
| 577 | + model_provider="dive", |
| 578 | + ), |
| 579 | + mcp_servers=echo_tool_stdio_config, |
| 580 | + ) |
| 581 | + |
| 582 | + async with DiveMcpHost(config) as mcp_host: |
| 583 | + mcp_host._checkpointer = InMemorySaver() |
| 584 | + fake_responses = [ |
| 585 | + AIMessage( |
| 586 | + content="Call echo tool", |
| 587 | + tool_calls=[ |
| 588 | + ToolCall( |
| 589 | + name="echo", |
| 590 | + args={"message": "Hello, world!", "delay_ms": 2000}, |
| 591 | + id="phase1-tool-1", |
| 592 | + type="tool_call", |
| 593 | + ), |
| 594 | + ], |
| 595 | + id="phase1-1", |
| 596 | + ), |
| 597 | + AIMessage( |
| 598 | + content="Bye", |
| 599 | + id="phase1-2", |
| 600 | + ), |
| 601 | + ] |
| 602 | + |
| 603 | + ts = time.time() |
| 604 | + got_msg = False |
| 605 | + |
| 606 | + async def _abort_task(chat: Chat) -> None: |
| 607 | + while True: |
| 608 | + if got_msg and time.time() - ts > 0.5: |
| 609 | + chat.abort() |
| 610 | + break |
| 611 | + await asyncio.sleep(0.1) |
| 612 | + |
| 613 | + fake_model = cast("FakeMessageToolModel", mcp_host.model) |
| 614 | + fake_model.responses = fake_responses |
| 615 | + await mcp_host.tools_initialized_event.wait() |
| 616 | + chat = mcp_host.chat(chat_id="chat_id") |
| 617 | + async with chat: |
| 618 | + task = asyncio.create_task(_abort_task(chat)) |
| 619 | + async for r in chat.query( |
| 620 | + HumanMessage(content="Hello, world!", id="H1"), |
| 621 | + stream_mode=["messages", "values", "updates"], |
| 622 | + ): |
| 623 | + if r[0] == "messages": # type: ignore |
| 624 | + got_msg = True |
| 625 | + ts = time.time() |
| 626 | + await task |
| 627 | + |
| 628 | + resend = [HumanMessage(content="Resend message!", id="H1")] |
| 629 | + fake_responses = [ |
| 630 | + AIMessage( |
| 631 | + content="Call echo tool", |
| 632 | + tool_calls=[ |
| 633 | + ToolCall( |
| 634 | + name="echo", |
| 635 | + args={"message": "Hello, world!"}, |
| 636 | + id="phase2-tool-1", |
| 637 | + type="tool_call", |
| 638 | + ), |
| 639 | + ], |
| 640 | + id="phase2-1", |
| 641 | + ), |
| 642 | + AIMessage( |
| 643 | + content="Bye", |
| 644 | + id="phase2-2", |
| 645 | + ), |
| 646 | + ] |
| 647 | + fake_model.i = 0 |
| 648 | + fake_model.responses = fake_responses |
| 649 | + chat = mcp_host.chat(chat_id="chat_id") |
| 650 | + async with chat: |
| 651 | + async for r in chat.query( |
| 652 | + resend, # type: ignore |
| 653 | + is_resend=True, |
| 654 | + ): |
| 655 | + r = cast(tuple[str, list[BaseMessage]], r) |
| 656 | + if r[0] == "messages": |
| 657 | + for m in r[1]: |
| 658 | + assert m.id |
| 659 | + if isinstance(m, HumanMessage): |
| 660 | + assert m.id == "H1" |
| 661 | + elif isinstance(m, AIMessage): |
| 662 | + assert m.id.startswith("phase2-") |
0 commit comments