6
6
from collections .abc import AsyncGenerator , AsyncIterator , Callable , Coroutine
7
7
from contextlib import AsyncExitStack , suppress
8
8
from dataclasses import asdict , dataclass , field
9
+ from hashlib import md5
9
10
from itertools import batched
10
11
from pathlib import Path
11
12
from typing import TYPE_CHECKING , Any , Literal , Self
12
13
from urllib .parse import urlparse
13
14
from uuid import uuid4
14
15
15
16
from fastapi .responses import StreamingResponse
16
- from langchain_core .language_models import BaseChatModel
17
17
from langchain_core .messages import (
18
18
AIMessage ,
19
19
BaseMessage ,
33
33
from dive_mcp_host .host .agents .message_order import FAKE_TOOL_RESPONSE
34
34
from dive_mcp_host .host .custom_events import ToolCallProgress
35
35
from dive_mcp_host .host .errors import LogBufferNotFoundError
36
- from dive_mcp_host .host .store .base import FileType
36
+ from dive_mcp_host .host .store .base import FileType , StoreManagerProtocol
37
37
from dive_mcp_host .host .tools .log import LogEvent , LogManager , LogMsg
38
38
from dive_mcp_host .host .tools .model_types import ClientState
39
39
from dive_mcp_host .httpd .conf .prompt import PromptKey
58
58
59
59
if TYPE_CHECKING :
60
60
from dive_mcp_host .host .host import DiveMcpHost
61
- from dive_mcp_host .host .store .base import StoreManagerProtocol
62
61
from dive_mcp_host .httpd .middlewares .general import DiveUser
63
62
64
63
title_prompt = """You are a title generator from the user input.
@@ -185,31 +184,54 @@ class ContentHandler:
185
184
186
185
def __init__ (
187
186
self ,
188
- model : BaseChatModel ,
189
- str_output_parser : StrOutputParser ,
187
+ store : StoreManagerProtocol ,
190
188
) -> None :
191
- """Initialize ContentHandler
192
-
193
- Args:
194
- - model: To verify which model it is.
195
- - str_output_parser: Used for extracting text content from AIMessage.
196
- """
197
- self ._model = model
198
- self ._str_output_parser = str_output_parser
189
+ """Initialize ContentHandler."""
190
+ self ._store = store
191
+ self ._str_output_parser = StrOutputParser ()
192
+ # Cache that contains the md5 hash and file path / urls for the file.
193
+ # Prevents dupicate save / uploads.
194
+ self ._cache : dict [str , list [str ]] = {}
199
195
200
- def invoke (self , msg : AIMessage ) -> str :
201
- """Extract content from AIMessage ."""
196
+ async def invoke (self , msg : AIMessage ) -> str :
197
+ """Extract various types of content ."""
202
198
result = self ._text_content (msg )
199
+ model_name = msg .response_metadata .get ("model_name" )
203
200
204
- if self . _model . name in {"gemini-2.5-flash-image-preview" }:
205
- result = f"{ result } { self ._gemini_25_image (msg )} "
201
+ if model_name in {"gemini-2.5-flash-image-preview" }:
202
+ result = f"{ result } { await self ._gemini_25_image (msg )} "
206
203
207
204
return result
208
205
209
206
def _text_content (self , msg : AIMessage ) -> str :
210
207
return self ._str_output_parser .invoke (msg )
211
208
212
- def _gemini_25_image (self , msg : AIMessage ) -> str :
209
+ async def _save_with_cache (self , data : str ) -> list [str ]:
210
+ """Prevents duplicate save and uploads.
211
+
212
+ Returns:
213
+ Saved locations, 'local file path' or 'url'
214
+ """
215
+ md5_hash = md5 (data .encode (), usedforsecurity = False ).hexdigest ()
216
+ locations = self ._cache .get (md5_hash )
217
+ if not locations :
218
+ locations = await self ._store .save_base64_image (data )
219
+ self ._cache [md5_hash ] = locations
220
+ return locations
221
+
222
+ def _retrive_optimal_location (self , locations : list [str ]) -> str :
223
+ """Prioritize urls, prevents broken image in case we need to sync
224
+ user chat history some day.
225
+ """ # noqa: D205
226
+ url = locations [0 ]
227
+ for item in locations [1 :]:
228
+ if self ._store .is_url (item ):
229
+ url = item
230
+ if self ._store .is_local_file (url ):
231
+ url = f"file://{ url } "
232
+ return url
233
+
234
+ async def _gemini_25_image (self , msg : AIMessage ) -> str :
213
235
"""Gemini will return base64 image content.
214
236
215
237
{
@@ -230,10 +252,14 @@ def _gemini_25_image(self, msg: AIMessage) -> str:
230
252
if (
231
253
isinstance (content , dict )
232
254
and (image_url := content .get ("image_url" ))
233
- and (url := image_url .get ("url" ))
255
+ and (inline_base64 := image_url .get ("url" ))
234
256
):
235
- markdown_image_tag = f""
236
- result = f"{ result } { markdown_image_tag } "
257
+ base64_data : str = inline_base64 .split ("," )[- 1 ]
258
+ assert isinstance (base64_data , str ), "base64_data must be string"
259
+ locations = await self ._save_with_cache (base64_data )
260
+ url = self ._retrive_optimal_location (locations )
261
+ image_tag = f""
262
+ result = f"{ result } { image_tag } "
237
263
238
264
return result
239
265
@@ -254,9 +280,7 @@ def __init__(
254
280
self .store : StoreManagerProtocol = app .store
255
281
self .dive_host : DiveMcpHost = app .dive_host ["default" ]
256
282
self ._str_output_parser = StrOutputParser ()
257
- self ._content_handler = ContentHandler (
258
- self .dive_host .model , self ._str_output_parser
259
- )
283
+ self ._content_handler = ContentHandler (self .store )
260
284
self .disable_dive_system_prompt = (
261
285
app .model_config_manager .full_config .disable_dive_system_prompt
262
286
if app .model_config_manager .full_config
@@ -395,7 +419,7 @@ async def handle_chat( # noqa: C901, PLR0912, PLR0915
395
419
total_run_time = duration ,
396
420
)
397
421
result = (
398
- self ._str_output_parser .invoke (message )
422
+ await self ._content_handler .invoke (message )
399
423
if message .content
400
424
else ""
401
425
)
@@ -550,7 +574,7 @@ def _prompt_cb(_: Any) -> list[BaseMessage]:
550
574
raise RuntimeError ("Unreachable" )
551
575
552
576
async def _stream_text_msg (self , message : AIMessage ) -> None :
553
- content = self ._content_handler .invoke (message )
577
+ content = await self ._content_handler .invoke (message )
554
578
if content :
555
579
await self .stream .write (StreamMessage (type = "text" , content = content ))
556
580
if message .response_metadata .get ("stop_reason" ) == "max_tokens" :
0 commit comments