Skip to content

Commit fc8e2c4

Browse files
authored
chore: turn on mypy, add py.typed (#90)
* Enable mypy * Add py.typed
1 parent 38566ea commit fc8e2c4

File tree

8 files changed

+55
-26
lines changed

8 files changed

+55
-26
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=langgraph_swarm/
2828
lint lint_diff:
2929
[ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) --diff
3030
[ "$(PYTHON_FILES)" = "" ] || uv run ruff check $(PYTHON_FILES) --diff
31-
# [ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)
31+
[ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)
3232

3333
format format_diff:
3434
[ "$(PYTHON_FILES)" = "" ] || uv run ruff check --fix $(PYTHON_FILES)

langgraph_swarm/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
from langgraph_swarm.handoff import create_handoff_tool
22
from langgraph_swarm.swarm import SwarmState, add_active_agent_router, create_swarm
33

4-
__all__ = ["SwarmState", "add_active_agent_router", "create_handoff_tool", "create_swarm"]
4+
__all__ = [
5+
"SwarmState",
6+
"add_active_agent_router",
7+
"create_handoff_tool",
8+
"create_swarm",
9+
]

langgraph_swarm/handoff.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def create_handoff_tool(
4747
def handoff_to_agent(
4848
state: Annotated[dict, InjectedState],
4949
tool_call_id: Annotated[str, InjectedToolCallId],
50-
):
50+
) -> Command:
5151
tool_message = ToolMessage(
5252
content=f"Successfully transferred to {agent_name}",
5353
name=name,
@@ -56,14 +56,19 @@ def handoff_to_agent(
5656
return Command(
5757
goto=agent_name,
5858
graph=Command.PARENT,
59-
update={"messages": state["messages"] + [tool_message], "active_agent": agent_name},
59+
update={
60+
"messages": state["messages"] + [tool_message],
61+
"active_agent": agent_name,
62+
},
6063
)
6164

6265
handoff_to_agent.metadata = {METADATA_KEY_HANDOFF_DESTINATION: agent_name}
6366
return handoff_to_agent
6467

6568

66-
def get_handoff_destinations(agent: CompiledStateGraph, tool_node_name: str = "tools") -> list[str]:
69+
def get_handoff_destinations(
70+
agent: CompiledStateGraph, tool_node_name: str = "tools"
71+
) -> list[str]:
6772
"""Get a list of destinations from agent's handoff tools."""
6873
nodes = agent.get_graph().nodes
6974
if tool_node_name not in nodes:
@@ -77,5 +82,6 @@ def get_handoff_destinations(agent: CompiledStateGraph, tool_node_name: str = "t
7782
return [
7883
tool.metadata[METADATA_KEY_HANDOFF_DESTINATION]
7984
for tool in tools
80-
if tool.metadata is not None and METADATA_KEY_HANDOFF_DESTINATION in tool.metadata
85+
if tool.metadata is not None
86+
and METADATA_KEY_HANDOFF_DESTINATION in tool.metadata
8187
]

langgraph_swarm/swarm.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Optional, Union, get_args, get_origin
1+
from typing import Literal, Optional, Union, cast, get_args, get_origin
22

33
from langgraph.graph import START, MessagesState, StateGraph
44
from langgraph.pregel import Pregel
@@ -30,7 +30,8 @@ def _update_state_schema_agent_names(
3030
# Check if the annotation is str or Optional[str]
3131
is_str_type = active_agent_annotation is str
3232
is_optional_str = (
33-
get_origin(active_agent_annotation) is Union and get_args(active_agent_annotation)[0] is str
33+
get_origin(active_agent_annotation) is Union
34+
and get_args(active_agent_annotation)[0] is str
3435
)
3536

3637
# We only update if the 'active_agent' is a str or Optional[str]
@@ -48,7 +49,7 @@ def _update_state_schema_agent_names(
4849

4950
# If it was Optional[str], make it Optional[Literal[...]]
5051
if is_optional_str:
51-
updated_schema.__annotations__["active_agent"] = Optional[literal_type]
52+
updated_schema.__annotations__["active_agent"] = Optional[literal_type] # noqa: UP045
5253
else:
5354
updated_schema.__annotations__["active_agent"] = literal_type
5455

@@ -135,8 +136,8 @@ def add(a: int, b: int) -> int:
135136
msg,
136137
)
137138

138-
def route_to_active_agent(state: dict):
139-
return state.get("active_agent", default_active_agent)
139+
def route_to_active_agent(state: dict) -> str:
140+
return cast("str", state.get("active_agent", default_active_agent))
140141

141142
builder.add_conditional_edges(START, route_to_active_agent, path_map=route_to)
142143
return builder
@@ -223,8 +224,13 @@ def add(a: int, b: int) -> int:
223224
for agent in agents:
224225
builder.add_node(
225226
agent.name,
226-
agent,
227-
destinations=tuple(get_handoff_destinations(agent)),
227+
# We need to update the type signatures in add_node to match
228+
# the fact that more flexible Pregel objects are allowed.
229+
agent, # type: ignore[arg-type]
230+
destinations=tuple(
231+
# Need to update implementation to support Pregel objects
232+
get_handoff_destinations(agent) # type: ignore[arg-type]
233+
),
228234
)
229235

230236
return builder

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ python_files = ["test_*.py"]
3737
python_functions = ["test_*"]
3838

3939
[tool.ruff]
40-
line-length = 100
40+
line-length = 88
4141
target-version = "py310"
4242

4343
[tool.ruff.lint]

tests/test_import.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,2 @@
11
def test_import() -> None:
22
"""Test that the code can be imported."""
3-
from langgraph_swarm import ( # noqa: F401
4-
add_active_agent_router,
5-
create_handoff_tool,
6-
create_swarm,
7-
)

tests/test_swarm.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
from collections.abc import Callable, Sequence
2+
from typing import TYPE_CHECKING, Any
3+
14
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
25
from langchain_core.language_models.chat_models import BaseChatModel
36
from langchain_core.messages import AIMessage, BaseMessage
47
from langchain_core.outputs import ChatGeneration, ChatResult
8+
from langchain_core.tools import BaseTool
59
from langgraph.checkpoint.memory import MemorySaver
610
from langgraph.prebuilt import create_react_agent
711

812
from langgraph_swarm import create_handoff_tool, create_swarm
913

14+
if TYPE_CHECKING:
15+
from langchain_core.runnables.config import RunnableConfig
16+
1017

1118
class FakeChatModel(BaseChatModel):
1219
idx: int = 0
@@ -21,13 +28,19 @@ def _generate(
2128
messages: list[BaseMessage],
2229
stop: list[str] | None = None,
2330
run_manager: CallbackManagerForLLMRun | None = None,
24-
**kwargs,
31+
**kwargs: Any,
2532
) -> ChatResult:
2633
generation = ChatGeneration(message=self.responses[self.idx])
2734
self.idx += 1
2835
return ChatResult(generations=[generation])
2936

30-
def bind_tools(self, tools: list[any]) -> "FakeChatModel":
37+
def bind_tools(
38+
self,
39+
tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
40+
*,
41+
tool_choice: str | None = None,
42+
**kwargs: Any,
43+
) -> "FakeChatModel":
3144
return self
3245

3346

@@ -80,7 +93,7 @@ def test_basic_swarm() -> None:
8093
),
8194
]
8295

83-
model = FakeChatModel(responses=recorded_messages)
96+
model = FakeChatModel(responses=recorded_messages) # type: ignore[arg-type]
8497

8598
def add(a: int, b: int) -> int:
8699
"""Add two numbers."""
@@ -109,9 +122,11 @@ def add(a: int, b: int) -> int:
109122
workflow = create_swarm([alice, bob], default_active_agent="Alice")
110123
app = workflow.compile(checkpointer=checkpointer)
111124

112-
config = {"configurable": {"thread_id": "1"}}
125+
config: RunnableConfig = {"configurable": {"thread_id": "1"}}
113126
turn_1 = app.invoke(
114-
{"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]},
127+
{ # type: ignore[arg-type]
128+
"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]
129+
},
115130
config,
116131
)
117132

@@ -122,7 +137,9 @@ def add(a: int, b: int) -> int:
122137
assert turn_1["active_agent"] == "Bob"
123138

124139
turn_2 = app.invoke(
125-
{"messages": [{"role": "user", "content": "what's 5 + 7?"}]},
140+
{ # type: ignore[arg-type]
141+
"messages": [{"role": "user", "content": "what's 5 + 7?"}]
142+
},
126143
config,
127144
)
128145

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)