|
8 | 8 | from langchain_core.tools import BaseTool
|
9 | 9 | from langgraph.checkpoint.memory import MemorySaver
|
10 | 10 | from langgraph.prebuilt import create_react_agent
|
| 11 | +from langgraph.prebuilt.chat_agent_executor import AgentStatePydantic |
11 | 12 |
|
12 | 13 | from langgraph_swarm import create_handoff_tool, create_swarm
|
13 | 14 |
|
@@ -149,3 +150,125 @@ def add(a: int, b: int) -> int:
|
149 | 150 | assert turn_2["messages"][-2].content == "12"
|
150 | 151 | assert turn_2["messages"][-1].content == recorded_messages[4].content
|
151 | 152 | assert turn_2["active_agent"] == "Alice"
|
| 153 | + |
| 154 | + |
| 155 | +def test_basic_swarm_pydantic() -> None: |
| 156 | + """Test a basic swarm with Pydantic state schema.""" |
| 157 | + |
| 158 | + class SwarmState(AgentStatePydantic): |
| 159 | + """State schema for the multi-agent swarm.""" |
| 160 | + |
| 161 | + # NOTE: this state field is optional and is not expected to be provided by the |
| 162 | + # user. |
| 163 | + # If a user does provide it, the graph will start from the specified active |
| 164 | + # agent. |
| 165 | + # If active agent is typed as a `str`, we turn it into enum of all active agent |
| 166 | + # names. |
| 167 | + active_agent: str | None = None |
| 168 | + |
| 169 | + recorded_messages = [ |
| 170 | + AIMessage( |
| 171 | + content="", |
| 172 | + name="Alice", |
| 173 | + tool_calls=[ |
| 174 | + { |
| 175 | + "name": "transfer_to_bob", |
| 176 | + "args": {}, |
| 177 | + "id": "call_1LlFyjm6iIhDjdn7juWuPYr4", |
| 178 | + }, |
| 179 | + ], |
| 180 | + ), |
| 181 | + AIMessage( |
| 182 | + content="Ahoy, matey! Bob the pirate be at yer service. What be ye needin' " |
| 183 | + "help with today on the high seas? Arrr!", |
| 184 | + name="Bob", |
| 185 | + ), |
| 186 | + AIMessage( |
| 187 | + content="", |
| 188 | + name="Bob", |
| 189 | + tool_calls=[ |
| 190 | + { |
| 191 | + "name": "transfer_to_alice", |
| 192 | + "args": {}, |
| 193 | + "id": "call_T6pNmo2jTfZEK3a9avQ14f8Q", |
| 194 | + }, |
| 195 | + ], |
| 196 | + ), |
| 197 | + AIMessage( |
| 198 | + content="", |
| 199 | + name="Alice", |
| 200 | + tool_calls=[ |
| 201 | + { |
| 202 | + "name": "add", |
| 203 | + "args": { |
| 204 | + "a": 5, |
| 205 | + "b": 7, |
| 206 | + }, |
| 207 | + "id": "call_4kLYO1amR2NfhAxfECkALCr1", |
| 208 | + }, |
| 209 | + ], |
| 210 | + ), |
| 211 | + AIMessage( |
| 212 | + content="The sum of 5 and 7 is 12.", |
| 213 | + name="Alice", |
| 214 | + ), |
| 215 | + ] |
| 216 | + |
| 217 | + model = FakeChatModel(responses=recorded_messages) # type: ignore[arg-type] |
| 218 | + |
| 219 | + def add(a: int, b: int) -> int: |
| 220 | + """Add two numbers.""" |
| 221 | + return a + b |
| 222 | + |
| 223 | + alice = create_react_agent( |
| 224 | + model, |
| 225 | + [add, create_handoff_tool(agent_name="Bob")], |
| 226 | + prompt="You are Alice, an addition expert.", |
| 227 | + name="Alice", |
| 228 | + state_schema=SwarmState, |
| 229 | + ) |
| 230 | + |
| 231 | + bob = create_react_agent( |
| 232 | + model, |
| 233 | + [ |
| 234 | + create_handoff_tool( |
| 235 | + agent_name="Alice", |
| 236 | + description="Transfer to Alice, she can help with math", |
| 237 | + ), |
| 238 | + ], |
| 239 | + prompt="You are Bob, you speak like a pirate.", |
| 240 | + name="Bob", |
| 241 | + state_schema=SwarmState, |
| 242 | + ) |
| 243 | + |
| 244 | + checkpointer = MemorySaver() |
| 245 | + workflow = create_swarm([alice, bob], default_active_agent="Alice") |
| 246 | + app = workflow.compile(checkpointer=checkpointer) |
| 247 | + |
| 248 | + config: RunnableConfig = {"configurable": {"thread_id": "1"}} |
| 249 | + turn_1 = app.invoke( |
| 250 | + { # type: ignore[arg-type] |
| 251 | + "messages": [{"role": "user", "content": "i'd like to speak to Bob"}] |
| 252 | + }, |
| 253 | + config, |
| 254 | + ) |
| 255 | + |
| 256 | + # Verify turn 1 results |
| 257 | + assert len(turn_1["messages"]) == 4 |
| 258 | + assert turn_1["messages"][-2].content == "Successfully transferred to Bob" |
| 259 | + assert turn_1["messages"][-1].content == recorded_messages[1].content |
| 260 | + assert turn_1["active_agent"] == "Bob" |
| 261 | + |
| 262 | + turn_2 = app.invoke( |
| 263 | + { # type: ignore[arg-type] |
| 264 | + "messages": [{"role": "user", "content": "what's 5 + 7?"}] |
| 265 | + }, |
| 266 | + config, |
| 267 | + ) |
| 268 | + |
| 269 | + # Verify turn 2 results |
| 270 | + assert len(turn_2["messages"]) == 10 |
| 271 | + assert turn_2["messages"][-4].content == "Successfully transferred to Alice" |
| 272 | + assert turn_2["messages"][-2].content == "12" |
| 273 | + assert turn_2["messages"][-1].content == recorded_messages[4].content |
| 274 | + assert turn_2["active_agent"] == "Alice" |
0 commit comments