13
13
HumanMessage ,
14
14
RemoveMessage ,
15
15
SystemMessage ,
16
+ ToolCall ,
16
17
ToolMessage ,
17
18
)
18
19
from langchain_core .messages .utils import count_tokens_approximately , trim_messages
19
20
from langchain_core .prompt_values import ChatPromptValue
20
21
from langchain_core .prompts import ChatPromptTemplate
21
- from langchain_core .runnables import Runnable , RunnableConfig
22
+ from langchain_core .runnables import Runnable , RunnableConfig , RunnablePassthrough
22
23
from langchain_core .tools import BaseTool
23
24
from langgraph .checkpoint .base import BaseCheckpointSaver , V
24
25
from langgraph .graph import END , StateGraph
31
32
from pydantic import BaseModel
32
33
33
34
from dive_mcp_host .host .agents .agent_factory import AgentFactory , initial_messages
35
+ from dive_mcp_host .host .agents .file_in_additional_kwargs import FileMsgConverter
34
36
from dive_mcp_host .host .agents .message_order import tool_call_order
35
37
from dive_mcp_host .host .agents .tools_in_prompt import (
36
38
convert_messages ,
37
39
extract_tool_calls ,
38
40
)
39
41
from dive_mcp_host .host .helpers import today_datetime
40
42
from dive_mcp_host .host .prompt import PromptType , tools_prompt
43
+ from dive_mcp_host .host .store .base import StoreManagerProtocol
41
44
42
- StructuredResponse = dict | BaseModel
43
- StructuredResponseSchema = dict | type [BaseModel ]
45
+ type StructuredResponse = dict | BaseModel
46
+ type StructuredResponseSchema = dict | type [BaseModel ]
44
47
45
48
46
49
class AgentState (MessagesState ):
@@ -92,6 +95,41 @@ def _func(state: AgentState | ChatPromptValue) -> list[BaseMessage]:
92
95
return prompt_runnable
93
96
94
97
98
+ class HackedToolNode (ToolNode ):
99
+ """hacked tool node to inject tool_call_id into the config.
100
+
101
+ This is a hack. If langgraph support tool_call_id, we will remove this class.
102
+ """
103
+
104
+ async def _arun_one (
105
+ self ,
106
+ call : ToolCall ,
107
+ input_type : Literal ["list" , "dict" , "tool_calls" ],
108
+ config : RunnableConfig ,
109
+ ) -> ToolMessage :
110
+ if "metadata" in config :
111
+ config ["metadata" ]["tool_call_id" ] = call ["id" ]
112
+ else :
113
+ config ["metadata" ] = {
114
+ "tool_call_id" : call ["id" ],
115
+ }
116
+ return await super ()._arun_one (call , input_type , config )
117
+
118
+ def _run_one (
119
+ self ,
120
+ call : ToolCall ,
121
+ input_type : Literal ["list" , "dict" , "tool_calls" ],
122
+ config : RunnableConfig ,
123
+ ) -> ToolMessage :
124
+ if "metadata" in config :
125
+ config ["metadata" ]["tool_call_id" ] = call ["id" ]
126
+ else :
127
+ config ["metadata" ] = {
128
+ "tool_call_id" : call ["id" ],
129
+ }
130
+ return super ()._run_one (call , input_type , config )
131
+
132
+
95
133
class ChatAgentFactory (AgentFactory [AgentState ]):
96
134
"""A factory for ChatAgents."""
97
135
@@ -100,6 +138,7 @@ def __init__(
100
138
model : BaseChatModel ,
101
139
tools : Sequence [BaseTool ] | ToolNode ,
102
140
tools_in_prompt : bool = False ,
141
+ store : StoreManagerProtocol | None = None ,
103
142
) -> None :
104
143
"""Initialize the chat agent factory."""
105
144
self ._model = model
@@ -110,6 +149,12 @@ def __init__(
110
149
StructuredResponseSchema | tuple [str , StructuredResponseSchema ] | None
111
150
) = None
112
151
152
+ self ._file_msg_converter = (
153
+ FileMsgConverter (model_provider = self ._model_class , store = store ).runnable
154
+ if store
155
+ else RunnablePassthrough ()
156
+ )
157
+
113
158
# changed when self._build_graph is called
114
159
self ._tool_classes : list [BaseTool ] = []
115
160
self ._should_return_direct : set [str ] = set ()
@@ -160,22 +205,28 @@ def _check_more_steps_needed(
160
205
)
161
206
)
162
207
163
- def _call_model (self , state : AgentState , config : RunnableConfig ) -> AgentState :
208
+ async def _call_model (
209
+ self , state : AgentState , config : RunnableConfig
210
+ ) -> AgentState :
211
+ # TODO: _validate_chat_history
164
212
if not self ._tools_in_prompt :
165
213
model = self ._model
166
214
if self ._tool_classes :
167
215
model = self ._model .bind_tools (self ._tool_classes )
168
- model_runnable = self ._prompt | drop_empty_messages | model
216
+ model_runnable = (
217
+ self ._prompt | self ._file_msg_converter | drop_empty_messages | model
218
+ )
169
219
else :
170
220
model_runnable = (
171
221
self ._prompt
172
222
| self ._tool_prompt
173
223
| convert_messages
224
+ | self ._file_msg_converter
174
225
| drop_empty_messages
175
226
| self ._model
176
227
)
177
228
178
- response = model_runnable .invoke (state , config )
229
+ response = await model_runnable .ainvoke (state , config )
179
230
if isinstance (response , AIMessage ):
180
231
response = extract_tool_calls (response )
181
232
if self ._check_more_steps_needed (state , response ):
@@ -254,7 +305,9 @@ def _build_graph(self) -> None:
254
305
graph .add_edge ("before_agent" , "agent" )
255
306
256
307
tool_node = (
257
- self ._tools if isinstance (self ._tools , ToolNode ) else ToolNode (self ._tools )
308
+ self ._tools
309
+ if isinstance (self ._tools , ToolNode )
310
+ else HackedToolNode (self ._tools )
258
311
)
259
312
self ._tool_classes = list (tool_node .tools_by_name .values ())
260
313
graph .add_node ("tools" , tool_node )
@@ -323,9 +376,15 @@ def get_chat_agent_factory(
323
376
model : BaseChatModel ,
324
377
tools : Sequence [BaseTool ] | ToolNode ,
325
378
tools_in_prompt : bool = False ,
379
+ store : StoreManagerProtocol | None = None ,
326
380
) -> ChatAgentFactory :
327
381
"""Get an agent factory."""
328
- return ChatAgentFactory (model , tools , tools_in_prompt )
382
+ return ChatAgentFactory (
383
+ model = model ,
384
+ tools = tools ,
385
+ tools_in_prompt = tools_in_prompt ,
386
+ store = store ,
387
+ )
329
388
330
389
331
390
@RunnableCallable
0 commit comments