39
39
40
40
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
41
41
from mcp .shared .message import SessionMessage
42
+ from mcp .shared .session import RequestResponder
42
43
43
44
from dive_mcp_host .host .conf import ServerConfig
44
45
@@ -141,6 +142,28 @@ def __init__(
141
142
else :
142
143
raise InvalidMcpServerError (self .config .name , "Invalid server config" )
143
144
145
+ async def _message_handler (
146
+ self ,
147
+ message : RequestResponder [types .ServerRequest , types .ClientResult ]
148
+ | types .ServerNotification
149
+ | Exception ,
150
+ ) -> None :
151
+ """Used for handling mcp special responses.
152
+
153
+ Such as:
154
+ - Exception (Literal python exception)
155
+ - ProgressResult (ServerNotification) ... etc
156
+ """
157
+ logger .info (
158
+ "handling message for %s, type: %s, content: %s" ,
159
+ self .name ,
160
+ type (message ).__name__ ,
161
+ message ,
162
+ )
163
+
164
+ if isinstance (message , Exception ):
165
+ raise message
166
+
144
167
async def _init_tool_info (self , session : ClientSession ) -> None :
145
168
"""Initialize the session."""
146
169
async with asyncio .timeout (10 ):
@@ -331,7 +354,9 @@ async def _stdio_client_watcher(self) -> None: # noqa: C901, PLR0915
331
354
),
332
355
errlog = self ._stderr_log_proxy ,
333
356
) as (stream_read , stream_send , pid ),
334
- ClientSession (stream_read , stream_send ) as session ,
357
+ ClientSession (
358
+ stream_read , stream_send , message_handler = self ._message_handler
359
+ ) as session ,
335
360
):
336
361
self ._session = session
337
362
self ._pid = pid
@@ -363,6 +388,7 @@ async def _stdio_client_watcher(self) -> None: # noqa: C901, PLR0915
363
388
httpx .ConnectError ,
364
389
httpx .InvalidURL ,
365
390
httpx .TooManyRedirects ,
391
+ httpx .ConnectTimeout ,
366
392
) as eg :
367
393
err_msg = (
368
394
f"Client initialization error for { self .name } : { eg .exceptions } "
@@ -550,7 +576,6 @@ def _http_get_client(
550
576
key : value .get_secret_value ()
551
577
for key , value in self .config .headers .items ()
552
578
},
553
- sse_read_timeout = 0.1 ,
554
579
)
555
580
if self .config .transport == "websocket" :
556
581
return websocket_client (
@@ -564,7 +589,7 @@ async def _http_init_client(self) -> None:
564
589
"""Initialize the HTTP client."""
565
590
async with (
566
591
self ._http_get_client () as streams ,
567
- ClientSession (* streams ) as session ,
592
+ ClientSession (* streams , message_handler = self . _message_handler ) as session ,
568
593
):
569
594
await self ._init_tool_info (session )
570
595
@@ -639,7 +664,9 @@ async def session_ctx() -> AsyncGenerator[ClientSession, None]:
639
664
"""
640
665
async with (
641
666
self ._http_get_client () as streams ,
642
- ClientSession (* streams ) as session ,
667
+ ClientSession (
668
+ * streams , message_handler = self ._message_handler
669
+ ) as session ,
643
670
self ._session_wrapper (),
644
671
):
645
672
await session .initialize ()
@@ -750,7 +777,9 @@ async def session_ctx() -> AsyncGenerator[ClientSession, None]:
750
777
"""
751
778
async with (
752
779
self ._http_get_client () as streams ,
753
- ClientSession (* streams ) as session ,
780
+ ClientSession (
781
+ * streams , message_handler = self ._message_handler
782
+ ) as session ,
754
783
self ._session_wrapper (
755
784
restart_client = lambda e : isinstance (e , httpx .ConnectError )
756
785
),
0 commit comments