13
13
from uuid import uuid4
14
14
15
15
from fastapi .responses import StreamingResponse
16
+ from langchain_core .language_models import BaseChatModel
16
17
from langchain_core .messages import (
17
18
AIMessage ,
18
19
BaseMessage ,
@@ -176,6 +177,67 @@ class ImageAndDocuments:
176
177
documents : list [str ] = field (default_factory = list )
177
178
178
179
180
+ class ContentHandler :
181
+ """Some models will return more then just pure text in content response.
182
+
183
+ We need to have a customized handler for those special models.
184
+ """
185
+
186
+ def __init__ (
187
+ self ,
188
+ model : BaseChatModel ,
189
+ str_output_parser : StrOutputParser ,
190
+ ) -> None :
191
+ """Initialize ContentHandler
192
+
193
+ Args:
194
+ - model: To verify which model it is.
195
+ - str_output_parser: Used for extracting text content from AIMessage.
196
+ """
197
+ self ._model = model
198
+ self ._str_output_parser = str_output_parser
199
+
200
+ def invoke (self , msg : AIMessage ) -> str :
201
+ """Extract content from AIMessage."""
202
+ result = self ._text_content (msg )
203
+
204
+ if self ._model .name in {"gemini-2.5-flash-image-preview" }:
205
+ result = f"{ result } { self ._gemini_25_image (msg )} "
206
+
207
+ return result
208
+
209
+ def _text_content (self , msg : AIMessage ) -> str :
210
+ return self ._str_output_parser .invoke (msg )
211
+
212
+ def _gemini_25_image (self , msg : AIMessage ) -> str :
213
+ """Gemini will return base64 image content.
214
+
215
+ {
216
+ "content": [
217
+ "Here is a cuddly cat wearing a hat! ",
218
+ {
219
+ "type": "image_url",
220
+ "image_url": {
221
+ "url": "data:image/png;base64,XXXXXXXX"
222
+ }
223
+ }
224
+ ]
225
+ }
226
+
227
+ """
228
+ result = ""
229
+ for content in msg .content :
230
+ if (
231
+ isinstance (content , dict )
232
+ and (image_url := content .get ("image_url" ))
233
+ and (url := image_url .get ("url" ))
234
+ ):
235
+ markdown_image_tag = f""
236
+ result = f"{ result } { markdown_image_tag } "
237
+
238
+ return result
239
+
240
+
179
241
class ChatProcessor :
180
242
"""Chat processor."""
181
243
@@ -192,6 +254,9 @@ def __init__(
192
254
self .store : StoreManagerProtocol = app .store
193
255
self .dive_host : DiveMcpHost = app .dive_host ["default" ]
194
256
self ._str_output_parser = StrOutputParser ()
257
+ self ._content_handler = ContentHandler (
258
+ self .dive_host .model , self ._str_output_parser
259
+ )
195
260
self .disable_dive_system_prompt = (
196
261
app .model_config_manager .full_config .disable_dive_system_prompt
197
262
if app .model_config_manager .full_config
@@ -485,7 +550,7 @@ def _prompt_cb(_: Any) -> list[BaseMessage]:
485
550
raise RuntimeError ("Unreachable" )
486
551
487
552
async def _stream_text_msg (self , message : AIMessage ) -> None :
488
- content = self ._str_output_parser .invoke (message )
553
+ content = self ._content_handler .invoke (message )
489
554
if content :
490
555
await self .stream .write (StreamMessage (type = "text" , content = content ))
491
556
if message .response_metadata .get ("stop_reason" ) == "max_tokens" :
0 commit comments