Skip to content

Commit 51f6374

Browse files
authored
add support for pydantic state for the swarm state (#91)
* Adds support for using pydantic base model for the swarm state.
1 parent fc8e2c4 commit 51f6374

File tree

2 files changed

+153
-3
lines changed

2 files changed

+153
-3
lines changed

langgraph_swarm/handoff.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
11
import re
2-
from typing import Annotated
2+
from dataclasses import is_dataclass
3+
from typing import Annotated, Any
34

45
from langchain_core.messages import ToolMessage
56
from langchain_core.tools import BaseTool, InjectedToolCallId, tool
67
from langgraph.graph.state import CompiledStateGraph
78
from langgraph.prebuilt import InjectedState, ToolNode
89
from langgraph.types import Command
10+
from pydantic import BaseModel
11+
12+
13+
def _get_field(obj: Any, key: str) -> Any:
14+
"""Get a field from an object.
15+
16+
This function retrieves a field from a dictionary, dataclass, or Pydantic model.
17+
18+
Args:
19+
obj: The object from which to retrieve the field.
20+
key: The key or attribute name of the field to retrieve.
21+
22+
Returns:
23+
The value of the specified field.
24+
25+
"""
26+
if isinstance(obj, dict):
27+
return obj[key]
28+
if is_dataclass(obj) or isinstance(obj, BaseModel):
29+
return getattr(obj, key)
30+
msg = f"Unsupported type for state: {type(obj)}"
31+
raise TypeError(msg)
32+
933

1034
WHITESPACE_RE = re.compile(r"\s+")
1135
METADATA_KEY_HANDOFF_DESTINATION = "__handoff_destination"
@@ -45,7 +69,10 @@ def create_handoff_tool(
4569

4670
@tool(name, description=description)
4771
def handoff_to_agent(
48-
state: Annotated[dict, InjectedState],
72+
# Annotation is typed as Any instead of StateLike. StateLike
73+
# trigger validation issues from Pydantic / langchain_core interaction.
74+
# https://github.com/langchain-ai/langchain/issues/32067
75+
state: Annotated[Any, InjectedState],
4976
tool_call_id: Annotated[str, InjectedToolCallId],
5077
) -> Command:
5178
tool_message = ToolMessage(
@@ -57,7 +84,7 @@ def handoff_to_agent(
5784
goto=agent_name,
5885
graph=Command.PARENT,
5986
update={
60-
"messages": state["messages"] + [tool_message],
87+
"messages": [*_get_field(state, "messages"), tool_message],
6188
"active_agent": agent_name,
6289
},
6390
)

tests/test_swarm.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain_core.tools import BaseTool
99
from langgraph.checkpoint.memory import MemorySaver
1010
from langgraph.prebuilt import create_react_agent
11+
from langgraph.prebuilt.chat_agent_executor import AgentStatePydantic
1112

1213
from langgraph_swarm import create_handoff_tool, create_swarm
1314

@@ -149,3 +150,125 @@ def add(a: int, b: int) -> int:
149150
assert turn_2["messages"][-2].content == "12"
150151
assert turn_2["messages"][-1].content == recorded_messages[4].content
151152
assert turn_2["active_agent"] == "Alice"
153+
154+
155+
def test_basic_swarm_pydantic() -> None:
156+
"""Test a basic swarm with Pydantic state schema."""
157+
158+
class SwarmState(AgentStatePydantic):
159+
"""State schema for the multi-agent swarm."""
160+
161+
# NOTE: this state field is optional and is not expected to be provided by the
162+
# user.
163+
# If a user does provide it, the graph will start from the specified active
164+
# agent.
165+
# If active agent is typed as a `str`, we turn it into enum of all active agent
166+
# names.
167+
active_agent: str | None = None
168+
169+
recorded_messages = [
170+
AIMessage(
171+
content="",
172+
name="Alice",
173+
tool_calls=[
174+
{
175+
"name": "transfer_to_bob",
176+
"args": {},
177+
"id": "call_1LlFyjm6iIhDjdn7juWuPYr4",
178+
},
179+
],
180+
),
181+
AIMessage(
182+
content="Ahoy, matey! Bob the pirate be at yer service. What be ye needin' "
183+
"help with today on the high seas? Arrr!",
184+
name="Bob",
185+
),
186+
AIMessage(
187+
content="",
188+
name="Bob",
189+
tool_calls=[
190+
{
191+
"name": "transfer_to_alice",
192+
"args": {},
193+
"id": "call_T6pNmo2jTfZEK3a9avQ14f8Q",
194+
},
195+
],
196+
),
197+
AIMessage(
198+
content="",
199+
name="Alice",
200+
tool_calls=[
201+
{
202+
"name": "add",
203+
"args": {
204+
"a": 5,
205+
"b": 7,
206+
},
207+
"id": "call_4kLYO1amR2NfhAxfECkALCr1",
208+
},
209+
],
210+
),
211+
AIMessage(
212+
content="The sum of 5 and 7 is 12.",
213+
name="Alice",
214+
),
215+
]
216+
217+
model = FakeChatModel(responses=recorded_messages) # type: ignore[arg-type]
218+
219+
def add(a: int, b: int) -> int:
220+
"""Add two numbers."""
221+
return a + b
222+
223+
alice = create_react_agent(
224+
model,
225+
[add, create_handoff_tool(agent_name="Bob")],
226+
prompt="You are Alice, an addition expert.",
227+
name="Alice",
228+
state_schema=SwarmState,
229+
)
230+
231+
bob = create_react_agent(
232+
model,
233+
[
234+
create_handoff_tool(
235+
agent_name="Alice",
236+
description="Transfer to Alice, she can help with math",
237+
),
238+
],
239+
prompt="You are Bob, you speak like a pirate.",
240+
name="Bob",
241+
state_schema=SwarmState,
242+
)
243+
244+
checkpointer = MemorySaver()
245+
workflow = create_swarm([alice, bob], default_active_agent="Alice")
246+
app = workflow.compile(checkpointer=checkpointer)
247+
248+
config: RunnableConfig = {"configurable": {"thread_id": "1"}}
249+
turn_1 = app.invoke(
250+
{ # type: ignore[arg-type]
251+
"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]
252+
},
253+
config,
254+
)
255+
256+
# Verify turn 1 results
257+
assert len(turn_1["messages"]) == 4
258+
assert turn_1["messages"][-2].content == "Successfully transferred to Bob"
259+
assert turn_1["messages"][-1].content == recorded_messages[1].content
260+
assert turn_1["active_agent"] == "Bob"
261+
262+
turn_2 = app.invoke(
263+
{ # type: ignore[arg-type]
264+
"messages": [{"role": "user", "content": "what's 5 + 7?"}]
265+
},
266+
config,
267+
)
268+
269+
# Verify turn 2 results
270+
assert len(turn_2["messages"]) == 10
271+
assert turn_2["messages"][-4].content == "Successfully transferred to Alice"
272+
assert turn_2["messages"][-2].content == "12"
273+
assert turn_2["messages"][-1].content == recorded_messages[4].content
274+
assert turn_2["active_agent"] == "Alice"

0 commit comments

Comments
 (0)