Skip to content

Commit a07d2a7

Browse files
committed
Merge pull request 'feat: support gemini-2.5-flash-image-preview' (#474) from gemini-image into main
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/474
2 parents adfa0df + 1fb5126 commit a07d2a7

File tree

9 files changed

+170
-54
lines changed

9 files changed

+170
-54
lines changed

dive_mcp_host/host/store/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ async def get_file(self, file_path: str) -> bytes:
5656
class StoreManagerProtocol(ContextProtocol, Protocol):
5757
"""Protocol for store manager operations."""
5858

59+
async def save_base64_image(self, data: str, extension: str = "png") -> list[str]:
60+
"""Save base64 image.
61+
62+
Args:
63+
data: Image in base64
64+
extension: File extension
65+
66+
Returns:
67+
List of paths / urls
68+
"""
69+
...
70+
5971
async def upload_files(
6072
self, files: list[UploadFile | str]
6173
) -> tuple[list[str], list[str]]:

dive_mcp_host/httpd/_main.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ def main() -> None:
4343
if args.cors_origin:
4444
service_config_manager.current_setting.cors_origin = args.cors_origin
4545

46-
service_config_manager.current_setting.logging_config["root"]["level"] = (
47-
args.log_level
48-
)
49-
service_config_manager.current_setting.logging_config["loggers"]["dive_mcp_host"][
50-
"level"
51-
] = args.log_level
46+
if args.log_level:
47+
service_config_manager.current_setting.logging_config["root"]["level"] = (
48+
args.log_level
49+
)
50+
service_config_manager.current_setting.logging_config["loggers"][
51+
"dive_mcp_host"
52+
]["level"] = args.log_level
5253

5354
if args.log_dir:
5455
log_dir = Path(args.log_dir)

dive_mcp_host/httpd/conf/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class Arguments(BaseModel):
102102
description="Directory to write log files.",
103103
)
104104

105-
log_level: str = Field(
106-
default="INFO",
105+
log_level: str | None = Field(
106+
default=None,
107107
description="Log level to use.",
108108
)
109109

dive_mcp_host/httpd/conf/httpd_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class ServiceConfig(BaseModel):
7272
}
7373
},
7474
"root": {"level": "INFO", "handlers": ["default"]},
75-
"loggers": {"dive_mcp_host": {"level": "DEBUG"}},
75+
"loggers": {"dive_mcp_host": {"level": "INFO"}},
7676
}
7777

7878

dive_mcp_host/httpd/routers/utils.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
77
from contextlib import AsyncExitStack, suppress
88
from dataclasses import asdict, dataclass, field
9+
from hashlib import md5
910
from itertools import batched
1011
from pathlib import Path
1112
from typing import TYPE_CHECKING, Any, Literal, Self
1213
from urllib.parse import urlparse
1314
from uuid import uuid4
1415

1516
from fastapi.responses import StreamingResponse
16-
from langchain_core.language_models import BaseChatModel
1717
from langchain_core.messages import (
1818
AIMessage,
1919
BaseMessage,
@@ -33,7 +33,7 @@
3333
from dive_mcp_host.host.agents.message_order import FAKE_TOOL_RESPONSE
3434
from dive_mcp_host.host.custom_events import ToolCallProgress
3535
from dive_mcp_host.host.errors import LogBufferNotFoundError
36-
from dive_mcp_host.host.store.base import FileType
36+
from dive_mcp_host.host.store.base import FileType, StoreManagerProtocol
3737
from dive_mcp_host.host.tools.log import LogEvent, LogManager, LogMsg
3838
from dive_mcp_host.host.tools.model_types import ClientState
3939
from dive_mcp_host.httpd.conf.prompt import PromptKey
@@ -58,7 +58,6 @@
5858

5959
if TYPE_CHECKING:
6060
from dive_mcp_host.host.host import DiveMcpHost
61-
from dive_mcp_host.host.store.base import StoreManagerProtocol
6261
from dive_mcp_host.httpd.middlewares.general import DiveUser
6362

6463
title_prompt = """You are a title generator from the user input.
@@ -185,31 +184,54 @@ class ContentHandler:
185184

186185
def __init__(
187186
self,
188-
model: BaseChatModel,
189-
str_output_parser: StrOutputParser,
187+
store: StoreManagerProtocol,
190188
) -> None:
191-
"""Initialize ContentHandler
192-
193-
Args:
194-
- model: To verify which model it is.
195-
- str_output_parser: Used for extracting text content from AIMessage.
196-
"""
197-
self._model = model
198-
self._str_output_parser = str_output_parser
189+
"""Initialize ContentHandler."""
190+
self._store = store
191+
self._str_output_parser = StrOutputParser()
192+
# Cache that contains the md5 hash and file path / urls for the file.
193+
# Prevents dupicate save / uploads.
194+
self._cache: dict[str, list[str]] = {}
199195

200-
def invoke(self, msg: AIMessage) -> str:
201-
"""Extract content from AIMessage."""
196+
async def invoke(self, msg: AIMessage) -> str:
197+
"""Extract various types of content."""
202198
result = self._text_content(msg)
199+
model_name = msg.response_metadata.get("model_name")
203200

204-
if self._model.name in {"gemini-2.5-flash-image-preview"}:
205-
result = f"{result} {self._gemini_25_image(msg)}"
201+
if model_name in {"gemini-2.5-flash-image-preview"}:
202+
result = f"{result} {await self._gemini_25_image(msg)}"
206203

207204
return result
208205

209206
def _text_content(self, msg: AIMessage) -> str:
210207
return self._str_output_parser.invoke(msg)
211208

212-
def _gemini_25_image(self, msg: AIMessage) -> str:
209+
async def _save_with_cache(self, data: str) -> list[str]:
210+
"""Prevents duplicate save and uploads.
211+
212+
Returns:
213+
Saved locations, 'local file path' or 'url'
214+
"""
215+
md5_hash = md5(data.encode(), usedforsecurity=False).hexdigest()
216+
locations = self._cache.get(md5_hash)
217+
if not locations:
218+
locations = await self._store.save_base64_image(data)
219+
self._cache[md5_hash] = locations
220+
return locations
221+
222+
def _retrive_optimal_location(self, locations: list[str]) -> str:
223+
"""Prioritize urls, prevents broken image in case we need to sync
224+
user chat history some day.
225+
""" # noqa: D205
226+
url = locations[0]
227+
for item in locations[1:]:
228+
if self._store.is_url(item):
229+
url = item
230+
if self._store.is_local_file(url):
231+
url = f"file://{url}"
232+
return url
233+
234+
async def _gemini_25_image(self, msg: AIMessage) -> str:
213235
"""Gemini will return base64 image content.
214236
215237
{
@@ -230,10 +252,14 @@ def _gemini_25_image(self, msg: AIMessage) -> str:
230252
if (
231253
isinstance(content, dict)
232254
and (image_url := content.get("image_url"))
233-
and (url := image_url.get("url"))
255+
and (inline_base64 := image_url.get("url"))
234256
):
235-
markdown_image_tag = f"![image]({url})"
236-
result = f"{result} {markdown_image_tag}"
257+
base64_data: str = inline_base64.split(",")[-1]
258+
assert isinstance(base64_data, str), "base64_data must be string"
259+
locations = await self._save_with_cache(base64_data)
260+
url = self._retrive_optimal_location(locations)
261+
image_tag = f"![image]({url})"
262+
result = f"{result} {image_tag}"
237263

238264
return result
239265

@@ -254,9 +280,7 @@ def __init__(
254280
self.store: StoreManagerProtocol = app.store
255281
self.dive_host: DiveMcpHost = app.dive_host["default"]
256282
self._str_output_parser = StrOutputParser()
257-
self._content_handler = ContentHandler(
258-
self.dive_host.model, self._str_output_parser
259-
)
283+
self._content_handler = ContentHandler(self.store)
260284
self.disable_dive_system_prompt = (
261285
app.model_config_manager.full_config.disable_dive_system_prompt
262286
if app.model_config_manager.full_config
@@ -395,7 +419,7 @@ async def handle_chat( # noqa: C901, PLR0912, PLR0915
395419
total_run_time=duration,
396420
)
397421
result = (
398-
self._str_output_parser.invoke(message)
422+
await self._content_handler.invoke(message)
399423
if message.content
400424
else ""
401425
)
@@ -550,7 +574,7 @@ def _prompt_cb(_: Any) -> list[BaseMessage]:
550574
raise RuntimeError("Unreachable")
551575

552576
async def _stream_text_msg(self, message: AIMessage) -> None:
553-
content = self._content_handler.invoke(message)
577+
content = await self._content_handler.invoke(message)
554578
if content:
555579
await self.stream.write(StreamMessage(type="text", content=content))
556580
if message.response_metadata.get("stop_reason") == "max_tokens":

dive_mcp_host/httpd/store/local.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,26 @@ def __init__(self, root_dir: Path = RESOURCE_DIR) -> None:
2222
upload_dir.mkdir(parents=True, exist_ok=True)
2323
self.upload_dir = upload_dir
2424

25+
def save_base64_image(self, base64_str: str, extension: str = "png") -> str:
26+
"""Save base64 image to file.
27+
28+
Args:
29+
base64_str: Image in base64
30+
extension: File extension
31+
32+
Returns:
33+
File path to image file
34+
"""
35+
base64_bytes = BytesIO(base64.b64decode(base64_str))
36+
pil_image = Image.open(base64_bytes)
37+
file_name = f"{self._gen_rand_str()}.{extension}"
38+
file_path = self.upload_dir / file_name
39+
pil_image.save(file_path)
40+
return str(file_path)
41+
42+
def _gen_rand_str(self) -> str:
43+
return f"{int(time.time() * 1000)}-{randint(0, int(1e9))}" # noqa: S311
44+
2545
async def save_file(
2646
self,
2747
file: UploadFile | str,
@@ -35,7 +55,7 @@ async def save_file(
3555

3656
ext = Path(file.filename).suffix
3757

38-
tmp_name = f"{int(time.time() * 1000)}-{randint(0, int(1e9))}{ext}" # noqa: S311
58+
tmp_name = f"{self._gen_rand_str()}{ext}"
3959
tmp_file = self.upload_dir.joinpath(tmp_name)
4060

4161
hash_md5 = md5(usedforsecurity=False)

dive_mcp_host/httpd/store/manager.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,30 @@ async def _run_in_context(self) -> AsyncGenerator[Self, None]:
7676
self._storages.append(store)
7777
yield self
7878

79+
async def save_base64_image(self, data: str, extension: str = "png") -> list[str]:
80+
"""Save base64 image.
81+
82+
Args:
83+
data: Image in base64
84+
extension: File extension
85+
86+
Returns:
87+
List of paths / urls
88+
"""
89+
path = self._local_store.save_base64_image(data, extension)
90+
additional_paths = await self._run_storage_callbacks(path)
91+
return [path, *additional_paths]
92+
93+
async def _run_storage_callbacks(self, file: UploadFile | str) -> list[str]:
94+
tasks: list[asyncio.Task] = []
95+
async with asyncio.TaskGroup() as tg:
96+
for store in self._storages:
97+
tasks.append(tg.create_task(store.save_file(file)))
98+
return [i.result() for i in tasks if i.result()]
99+
79100
async def save_files(
80-
self, files: list[UploadFile | str]
101+
self,
102+
files: list[UploadFile | str],
81103
) -> list[tuple[FileType, list[str]]]:
82104
"""Save files to the stores.
83105
@@ -95,11 +117,7 @@ async def save_files(
95117
continue
96118
paths = [path]
97119
if self._storage_callbacks:
98-
tasks: list[asyncio.Task] = []
99-
async with asyncio.TaskGroup() as tg:
100-
for store in self._storages:
101-
tasks.append(tg.create_task(store.save_file(file)))
102-
additional_paths = [i.result() for i in tasks if i.result()]
120+
additional_paths = await self._run_storage_callbacks(file)
103121
paths.extend(additional_paths)
104122
all_paths.append((FileType.from_file_path(path), paths))
105123
return all_paths

tests/httpd/test_chat_processor.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
1+
import tempfile
12
import uuid
23
from collections.abc import AsyncGenerator
3-
from typing import Any, cast
4+
from hashlib import md5
5+
from typing import TYPE_CHECKING, Any, cast
6+
from unittest.mock import AsyncMock
47

58
import pytest
69
import pytest_asyncio
710
from langchain_core.messages import AIMessage, HumanMessage
8-
from langchain_core.output_parsers import StrOutputParser
911

1012
from dive_mcp_host.httpd.conf.httpd_service import ServiceManager
1113
from dive_mcp_host.httpd.conf.mcp_servers import Config
1214
from dive_mcp_host.httpd.conf.prompt import PromptKey
1315
from dive_mcp_host.httpd.routers.utils import ChatProcessor, ContentHandler
1416
from dive_mcp_host.httpd.server import DiveHostAPI
15-
from dive_mcp_host.models.fake import FakeMessageToolModel, load_model
17+
from dive_mcp_host.httpd.store.manager import StoreManager
1618
from tests.httpd.routers.conftest import config_files # noqa: F401
1719

20+
if TYPE_CHECKING:
21+
from dive_mcp_host.models.fake import FakeMessageToolModel
22+
1823

1924
@pytest_asyncio.fixture
2025
async def server(config_files) -> AsyncGenerator[DiveHostAPI, None]: # noqa: F811
@@ -113,22 +118,59 @@ async def test_generate_title(processor: ChatProcessor):
113118
assert r == "Simple Greeting 2"
114119

115120

116-
def test_content_handler_gemini_image():
121+
@pytest.mark.asyncio
122+
async def test_content_handler_gemini_image_with_url():
117123
"""Check if content handler can extract what is needed."""
118-
model = load_model()
119-
model.name = "gemini-2.5-flash-image-preview"
120-
content_handler = ContentHandler(model, StrOutputParser())
124+
store = StoreManager()
125+
store.save_base64_image = AsyncMock(
126+
return_value=["/some/path", "http://someurl.com"]
127+
)
128+
content_handler = ContentHandler(store)
121129
message = AIMessage(
122130
content=[
123131
"Here is a cuddly cat wearing a hat! ",
124132
{
125133
"type": "image_url",
126134
"image_url": {"url": ""},
127135
},
128-
]
136+
],
137+
response_metadata={"model_name": "gemini-2.5-flash-image-preview"},
129138
)
130-
content = content_handler.invoke(message)
139+
content = await content_handler.invoke(message)
131140
assert (
132-
content
133-
== "Here is a cuddly cat wearing a hat! ![image]()" # noqa: E501
141+
content == "Here is a cuddly cat wearing a hat! ![image](http://someurl.com)"
134142
)
143+
144+
# Cache should exist
145+
md5_hash = md5(b"XXXXXXXX", usedforsecurity=False).hexdigest()
146+
assert md5_hash in content_handler._cache
147+
assert content_handler._cache[md5_hash] == ["/some/path", "http://someurl.com"]
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_content_handler_gemini_image_with_local_file():
152+
"""Make sure local file also works."""
153+
with tempfile.NamedTemporaryFile(prefix="dummyfile-") as f:
154+
store = StoreManager()
155+
store.save_base64_image = AsyncMock(return_value=[f.name])
156+
content_handler = ContentHandler(store)
157+
message = AIMessage(
158+
content=[
159+
"Here is a cuddly cat wearing a hat! ",
160+
{
161+
"type": "image_url",
162+
"image_url": {"url": ""},
163+
},
164+
],
165+
response_metadata={"model_name": "gemini-2.5-flash-image-preview"},
166+
)
167+
content = await content_handler.invoke(message)
168+
assert (
169+
content
170+
== f"Here is a cuddly cat wearing a hat! ![image](file://{f.name})"
171+
)
172+
173+
# Cache should exist
174+
md5_hash = md5(b"XXXXXXXX", usedforsecurity=False).hexdigest()
175+
assert md5_hash in content_handler._cache
176+
assert content_handler._cache[md5_hash] == [f.name]

tests/httpd/test_content_handler.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)