Skip to content

Commit 0f1ccbf

Browse files
committed
Merge pull request 'chore: support gemini-2.5-flash-image-preview' (#473) from gemini-image into development
Reviewed-on: https://git.biggo.com/Funmula/dive-mcp-host/pulls/473
2 parents c898906 + df9a62a commit 0f1ccbf

File tree

4 files changed

+117
-3
lines changed

4 files changed

+117
-3
lines changed

dive_mcp_host/httpd/routers/utils.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from uuid import uuid4
1414

1515
from fastapi.responses import StreamingResponse
16+
from langchain_core.language_models import BaseChatModel
1617
from langchain_core.messages import (
1718
AIMessage,
1819
BaseMessage,
@@ -176,6 +177,67 @@ class ImageAndDocuments:
176177
documents: list[str] = field(default_factory=list)
177178

178179

180+
class ContentHandler:
181+
"""Some models will return more then just pure text in content response.
182+
183+
We need to have a customized handler for those special models.
184+
"""
185+
186+
def __init__(
187+
self,
188+
model: BaseChatModel,
189+
str_output_parser: StrOutputParser,
190+
) -> None:
191+
"""Initialize ContentHandler
192+
193+
Args:
194+
- model: To verify which model it is.
195+
- str_output_parser: Used for extracting text content from AIMessage.
196+
"""
197+
self._model = model
198+
self._str_output_parser = str_output_parser
199+
200+
def invoke(self, msg: AIMessage) -> str:
201+
"""Extract content from AIMessage."""
202+
result = self._text_content(msg)
203+
204+
if self._model.name in {"gemini-2.5-flash-image-preview"}:
205+
result = f"{result} {self._gemini_25_image(msg)}"
206+
207+
return result
208+
209+
def _text_content(self, msg: AIMessage) -> str:
210+
return self._str_output_parser.invoke(msg)
211+
212+
def _gemini_25_image(self, msg: AIMessage) -> str:
213+
"""Gemini will return base64 image content.
214+
215+
{
216+
"content": [
217+
"Here is a cuddly cat wearing a hat! ",
218+
{
219+
"type": "image_url",
220+
"image_url": {
221+
"url": ""
222+
}
223+
}
224+
]
225+
}
226+
227+
"""
228+
result = ""
229+
for content in msg.content:
230+
if (
231+
isinstance(content, dict)
232+
and (image_url := content.get("image_url"))
233+
and (url := image_url.get("url"))
234+
):
235+
markdown_image_tag = f"![image]({url})"
236+
result = f"{result} {markdown_image_tag}"
237+
238+
return result
239+
240+
179241
class ChatProcessor:
180242
"""Chat processor."""
181243

@@ -192,6 +254,9 @@ def __init__(
192254
self.store: StoreManagerProtocol = app.store
193255
self.dive_host: DiveMcpHost = app.dive_host["default"]
194256
self._str_output_parser = StrOutputParser()
257+
self._content_handler = ContentHandler(
258+
self.dive_host.model, self._str_output_parser
259+
)
195260
self.disable_dive_system_prompt = (
196261
app.model_config_manager.full_config.disable_dive_system_prompt
197262
if app.model_config_manager.full_config
@@ -485,7 +550,7 @@ def _prompt_cb(_: Any) -> list[BaseMessage]:
485550
raise RuntimeError("Unreachable")
486551

487552
async def _stream_text_msg(self, message: AIMessage) -> None:
488-
content = self._str_output_parser.invoke(message)
553+
content = self._content_handler.invoke(message)
489554
if content:
490555
await self.stream.write(StreamMessage(type="text", content=content))
491556
if message.response_metadata.get("stop_reason") == "max_tokens":

tests/httpd/test_chat_processor.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import pytest
66
import pytest_asyncio
77
from langchain_core.messages import AIMessage, HumanMessage
8+
from langchain_core.output_parsers import StrOutputParser
89

910
from dive_mcp_host.httpd.conf.httpd_service import ServiceManager
1011
from dive_mcp_host.httpd.conf.mcp_servers import Config
1112
from dive_mcp_host.httpd.conf.prompt import PromptKey
12-
from dive_mcp_host.httpd.routers.utils import ChatProcessor
13+
from dive_mcp_host.httpd.routers.utils import ChatProcessor, ContentHandler
1314
from dive_mcp_host.httpd.server import DiveHostAPI
14-
from dive_mcp_host.models.fake import FakeMessageToolModel # noqa: TC001
15+
from dive_mcp_host.models.fake import FakeMessageToolModel, load_model
1516
from tests.httpd.routers.conftest import config_files # noqa: F401
1617

1718

@@ -110,3 +111,24 @@ async def test_generate_title(processor: ChatProcessor):
110111
assert r == "Simple Greeting"
111112
r = await processor._generate_title("Hello, how are you?")
112113
assert r == "Simple Greeting 2"
114+
115+
116+
def test_content_handler_gemini_image():
117+
"""Check if content handler can extract what is needed."""
118+
model = load_model()
119+
model.name = "gemini-2.5-flash-image-preview"
120+
content_handler = ContentHandler(model, StrOutputParser())
121+
message = AIMessage(
122+
content=[
123+
"Here is a cuddly cat wearing a hat! ",
124+
{
125+
"type": "image_url",
126+
"image_url": {"url": ""},
127+
},
128+
]
129+
)
130+
content = content_handler.invoke(message)
131+
assert (
132+
content
133+
== "Here is a cuddly cat wearing a hat! ![image]()" # noqa: E501
134+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

tests/test_providers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,32 @@ async def test_host_google(echo_tool_stdio_config: dict[str, ServerConfig]) -> N
186186
await _run_the_test(config)
187187

188188

189+
@pytest.mark.asyncio
190+
async def test_host_google_image_gen(
191+
echo_tool_stdio_config: dict[str, ServerConfig],
192+
) -> None:
193+
"""Test the host context initialization."""
194+
if api_key := environ.get("GOOGLE_API_KEY"):
195+
config = HostConfig(
196+
llm=LLMConfig(
197+
model="gemini-2.5-flash-image-preview",
198+
model_provider="google-genai",
199+
api_key=SecretStr(api_key),
200+
configuration=LLMConfiguration(
201+
temperature=0.0,
202+
top_p=0,
203+
),
204+
tools_in_prompt=True,
205+
disable_streaming=True,
206+
),
207+
mcp_servers=echo_tool_stdio_config,
208+
)
209+
else:
210+
pytest.skip("need environment variable GOOGLE_API_KEY to run this test")
211+
212+
await _run_the_test(config)
213+
214+
189215
@pytest.mark.asyncio
190216
async def test_bedrock(echo_tool_stdio_config: dict[str, ServerConfig]) -> None:
191217
"""Test the host context initialization."""

0 commit comments

Comments
 (0)