1
+ from collections .abc import Callable , Sequence
2
+ from typing import TYPE_CHECKING , Any
3
+
1
4
from langchain_core .callbacks .manager import CallbackManagerForLLMRun
2
5
from langchain_core .language_models .chat_models import BaseChatModel
3
6
from langchain_core .messages import AIMessage , BaseMessage
4
7
from langchain_core .outputs import ChatGeneration , ChatResult
8
+ from langchain_core .tools import BaseTool
5
9
from langgraph .checkpoint .memory import MemorySaver
6
10
from langgraph .prebuilt import create_react_agent
7
11
8
12
from langgraph_swarm import create_handoff_tool , create_swarm
9
13
14
+ if TYPE_CHECKING :
15
+ from langchain_core .runnables .config import RunnableConfig
16
+
10
17
11
18
class FakeChatModel (BaseChatModel ):
12
19
idx : int = 0
@@ -21,13 +28,19 @@ def _generate(
21
28
messages : list [BaseMessage ],
22
29
stop : list [str ] | None = None ,
23
30
run_manager : CallbackManagerForLLMRun | None = None ,
24
- ** kwargs ,
31
+ ** kwargs : Any ,
25
32
) -> ChatResult :
26
33
generation = ChatGeneration (message = self .responses [self .idx ])
27
34
self .idx += 1
28
35
return ChatResult (generations = [generation ])
29
36
30
- def bind_tools (self , tools : list [any ]) -> "FakeChatModel" :
37
+ def bind_tools (
38
+ self ,
39
+ tools : Sequence [dict [str , Any ] | type | Callable [..., Any ] | BaseTool ],
40
+ * ,
41
+ tool_choice : str | None = None ,
42
+ ** kwargs : Any ,
43
+ ) -> "FakeChatModel" :
31
44
return self
32
45
33
46
@@ -80,7 +93,7 @@ def test_basic_swarm() -> None:
80
93
),
81
94
]
82
95
83
- model = FakeChatModel (responses = recorded_messages )
96
+ model = FakeChatModel (responses = recorded_messages ) # type: ignore[arg-type]
84
97
85
98
def add (a : int , b : int ) -> int :
86
99
"""Add two numbers."""
@@ -109,9 +122,11 @@ def add(a: int, b: int) -> int:
109
122
workflow = create_swarm ([alice , bob ], default_active_agent = "Alice" )
110
123
app = workflow .compile (checkpointer = checkpointer )
111
124
112
- config = {"configurable" : {"thread_id" : "1" }}
125
+ config : RunnableConfig = {"configurable" : {"thread_id" : "1" }}
113
126
turn_1 = app .invoke (
114
- {"messages" : [{"role" : "user" , "content" : "i'd like to speak to Bob" }]},
127
+ { # type: ignore[arg-type]
128
+ "messages" : [{"role" : "user" , "content" : "i'd like to speak to Bob" }]
129
+ },
115
130
config ,
116
131
)
117
132
@@ -122,7 +137,9 @@ def add(a: int, b: int) -> int:
122
137
assert turn_1 ["active_agent" ] == "Bob"
123
138
124
139
turn_2 = app .invoke (
125
- {"messages" : [{"role" : "user" , "content" : "what's 5 + 7?" }]},
140
+ { # type: ignore[arg-type]
141
+ "messages" : [{"role" : "user" , "content" : "what's 5 + 7?" }]
142
+ },
126
143
config ,
127
144
)
128
145
0 commit comments