Skip to content

Commit a877c62

Browse files
authored
feat: support image input to LLM clients (#653)
* feat: support image as input in ExtractByLlm * feat: add support for dynamic image MIME type detection * refactor: update infer initialization in image MIME type detection * feat: use serde for base64 image handling in Ollama * feat: polish system prompt to support image input handling * feat: update argument handling logic * feat: refactor base64 encoding to use prelude for image handling * feat: make image captioning support with Ollama integration optional * feat: add optional envvar for checking Ollama model * feat: update Ollama model handling to allow dynamic model specification
1 parent 03b472c commit a877c62

File tree

10 files changed

+690
-545
lines changed

10 files changed

+690
-545
lines changed

Cargo.lock

Lines changed: 498 additions & 509 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,5 @@ aws-config = "1.6.2"
115115
aws-sdk-s3 = "1.85.0"
116116
aws-sdk-sqs = "1.67.0"
117117
numpy = "0.25.0"
118+
infer = "0.19.0"
119+
serde_with = { version = "3.13.0", features = ["base64"] }

examples/image_search/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ We appreciate a star ⭐ at [CocoIndex Github](https://github.com/cocoindex-io/c
1313
- CLIP ViT-L/14 - Embeddings Model for images and query
1414
- Qdrant for Vector Storage
1515
- FastApi for backend
16+
- Ollama (Optional) for generating image captions using `gemma3`.
1617

1718
## Setup
1819
- Make sure Postgres and Qdrant are running
@@ -21,7 +22,16 @@ We appreciate a star ⭐ at [CocoIndex Github](https://github.com/cocoindex-io/c
2122
export COCOINDEX_DATABASE_URL="postgres://cocoindex:cocoindex@localhost/cocoindex"
2223
```
2324

24-
## Run
25+
## (Optional) Run Ollama
26+
27+
- This enables automatic image captioning
28+
```
29+
ollama pull gemma3
30+
ollama serve
31+
export OLLAMA_MODEL="gemma3" # Optional, for caption generation
32+
```
33+
34+
## Run the App
2535
- Install dependencies:
2636
```
2737
pip install -e .

examples/image_search/main.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from qdrant_client import QdrantClient
1616
from transformers import CLIPModel, CLIPProcessor
1717

18+
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/")
1819
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/")
1920
QDRANT_COLLECTION = "ImageSearch"
2021
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
@@ -69,12 +70,39 @@ def image_object_embedding_flow(
6970
)
7071
img_embeddings = data_scope.add_collector()
7172
with data_scope["images"].row() as img:
73+
ollama_model_name = os.getenv("OLLAMA_MODEL")
74+
if ollama_model_name is not None:
75+
# If an Ollama model is specified, generate an image caption
76+
img["caption"] = flow_builder.transform(
77+
cocoindex.functions.ExtractByLlm(
78+
llm_spec=cocoindex.llm.LlmSpec(
79+
api_type=cocoindex.LlmApiType.OLLAMA, model=ollama_model_name
80+
),
81+
instruction=(
82+
"Describe the image in one detailed sentence. "
83+
"Name all visible animal species, objects, and the main scene. "
84+
"Be specific about type, color, and notable features. "
85+
"Mention what each animal is doing."
86+
),
87+
output_type=str,
88+
),
89+
image=img["content"],
90+
)
7291
img["embedding"] = img["content"].transform(embed_image)
73-
img_embeddings.collect(
74-
id=cocoindex.GeneratedField.UUID,
75-
filename=img["filename"],
76-
embedding=img["embedding"],
77-
)
92+
93+
collect_fields = {
94+
"id": cocoindex.GeneratedField.UUID,
95+
"filename": img["filename"],
96+
"embedding": img["embedding"],
97+
}
98+
99+
if ollama_model_name is not None:
100+
print(f"Using Ollama model '{ollama_model_name}' for captioning.")
101+
collect_fields["caption"] = img["caption"]
102+
else:
103+
print(f"No Ollama model '{ollama_model_name}' found — skipping captioning.")
104+
105+
img_embeddings.collect(**collect_fields)
78106

79107
img_embeddings.export(
80108
"img_embeddings",
@@ -126,11 +154,18 @@ def search(
126154
collection_name=QDRANT_COLLECTION,
127155
query_vector=("embedding", query_embedding),
128156
limit=limit,
157+
with_payload=True,
129158
)
130159

131160
return {
132161
"results": [
133-
{"filename": result.payload["filename"], "score": result.score}
162+
{
163+
"filename": result.payload["filename"],
164+
"score": result.score,
165+
"caption": result.payload.get(
166+
"caption"
167+
), # Include caption if available
168+
}
134169
for result in search_results
135170
]
136171
}

src/llm/anthropic.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use crate::llm::{
2-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat, ToJsonSchemaOptions,
2+
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
3+
ToJsonSchemaOptions, detect_image_mime_type,
34
};
45
use anyhow::{Context, Result, bail};
56
use async_trait::async_trait;
7+
use base64::prelude::*;
68
use json5;
79
use serde_json::Value;
810

@@ -36,9 +38,31 @@ impl LlmGenerationClient for Client {
3638
&self,
3739
request: LlmGenerateRequest<'req>,
3840
) -> Result<LlmGenerateResponse> {
41+
let mut user_content_parts: Vec<serde_json::Value> = Vec::new();
42+
43+
// Add image part if present
44+
if let Some(image_bytes) = &request.image {
45+
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
46+
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
47+
user_content_parts.push(serde_json::json!({
48+
"type": "image",
49+
"source": {
50+
"type": "base64",
51+
"media_type": mime_type,
52+
"data": base64_image,
53+
}
54+
}));
55+
}
56+
57+
// Add text part
58+
user_content_parts.push(serde_json::json!({
59+
"type": "text",
60+
"text": request.user_prompt
61+
}));
62+
3963
let messages = vec![serde_json::json!({
4064
"role": "user",
41-
"content": request.user_prompt
65+
"content": user_content_parts
4266
})];
4367

4468
let mut payload = serde_json::json!({

src/llm/gemini.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use crate::prelude::*;
22

33
use crate::llm::{
44
LlmEmbeddingClient, LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
5-
ToJsonSchemaOptions,
5+
ToJsonSchemaOptions, detect_image_mime_type,
66
};
7+
use base64::prelude::*;
78
use phf::phf_map;
89
use serde_json::Value;
910
use urlencoding::encode;
@@ -70,10 +71,27 @@ impl LlmGenerationClient for Client {
7071
&self,
7172
request: LlmGenerateRequest<'req>,
7273
) -> Result<LlmGenerateResponse> {
73-
// Compose the prompt/messages
74+
let mut user_parts: Vec<serde_json::Value> = Vec::new();
75+
76+
// Add text part first
77+
user_parts.push(serde_json::json!({ "text": request.user_prompt }));
78+
79+
// Add image part if present
80+
if let Some(image_bytes) = &request.image {
81+
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
82+
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
83+
user_parts.push(serde_json::json!({
84+
"inlineData": {
85+
"mimeType": mime_type,
86+
"data": base64_image
87+
}
88+
}));
89+
}
90+
91+
// Compose the contents
7492
let contents = vec![serde_json::json!({
7593
"role": "user",
76-
"parts": [{ "text": request.user_prompt }]
94+
"parts": user_parts
7795
})];
7896

7997
// Prepare payload

src/llm/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use crate::prelude::*;
22

33
use crate::base::json_schema::ToJsonSchemaOptions;
4+
use infer::Infer;
45
use schemars::schema::SchemaObject;
56
use std::borrow::Cow;
67

8+
static INFER: LazyLock<Infer> = LazyLock::new(Infer::new);
9+
710
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
811
pub enum LlmApiType {
912
Ollama,
@@ -36,6 +39,7 @@ pub struct LlmGenerateRequest<'a> {
3639
pub model: &'a str,
3740
pub system_prompt: Option<Cow<'a, str>>,
3841
pub user_prompt: Cow<'a, str>,
42+
pub image: Option<Cow<'a, [u8]>>,
3943
pub output_format: Option<OutputFormat<'a>>,
4044
}
4145

@@ -141,3 +145,11 @@ pub fn new_llm_embedding_client(
141145
};
142146
Ok(client)
143147
}
148+
149+
pub fn detect_image_mime_type(bytes: &[u8]) -> Result<&'static str> {
150+
let infer = &*INFER;
151+
match infer.get(bytes) {
152+
Some(info) if info.mime_type().starts_with("image/") => Ok(info.mime_type()),
153+
_ => bail!("Unknown or unsupported image format"),
154+
}
155+
}

src/llm/ollama.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::prelude::*;
22

33
use super::LlmGenerationClient;
44
use schemars::schema::SchemaObject;
5+
use serde_with::{base64::Base64, serde_as};
56

67
pub struct Client {
78
generate_url: String,
@@ -14,10 +15,13 @@ enum OllamaFormat<'a> {
1415
JsonSchema(&'a SchemaObject),
1516
}
1617

18+
#[serde_as]
1719
#[derive(Debug, Serialize)]
1820
struct OllamaRequest<'a> {
1921
pub model: &'a str,
2022
pub prompt: &'a str,
23+
#[serde_as(as = "Option<Vec<Base64>>")]
24+
pub images: Option<Vec<&'a [u8]>>,
2125
pub format: Option<OllamaFormat<'a>>,
2226
pub system: Option<&'a str>,
2327
pub stream: Option<bool>,
@@ -52,6 +56,7 @@ impl LlmGenerationClient for Client {
5256
let req = OllamaRequest {
5357
model: request.model,
5458
prompt: request.user_prompt.as_ref(),
59+
images: request.image.as_deref().map(|img| vec![img.as_ref()]),
5560
format: request.output_format.as_ref().map(
5661
|super::OutputFormat::JsonSchema { schema, .. }| {
5762
OllamaFormat::JsonSchema(schema.as_ref())

src/llm/openai.rs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
use crate::api_bail;
22

3-
use super::{LlmEmbeddingClient, LlmGenerationClient};
3+
use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
44
use anyhow::Result;
55
use async_openai::{
66
Client as OpenAIClient,
77
config::OpenAIConfig,
88
types::{
9-
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
9+
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
10+
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
1011
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
11-
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
12-
CreateEmbeddingRequest, EmbeddingInput, ResponseFormat, ResponseFormatJsonSchema,
12+
ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
13+
CreateChatCompletionRequest, CreateEmbeddingRequest, EmbeddingInput, ImageDetail,
14+
ResponseFormat, ResponseFormatJsonSchema,
1315
},
1416
};
1517
use async_trait::async_trait;
18+
use base64::prelude::*;
1619
use phf::phf_map;
1720

1821
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
@@ -64,11 +67,32 @@ impl LlmGenerationClient for Client {
6467
}
6568

6669
// Add user message
70+
let user_message_content = match request.image {
71+
Some(img_bytes) => {
72+
let base64_image = BASE64_STANDARD.encode(img_bytes.as_ref());
73+
let mime_type = detect_image_mime_type(img_bytes.as_ref())?;
74+
let image_url = format!("data:{};base64,{}", mime_type, base64_image);
75+
ChatCompletionRequestUserMessageContent::Array(vec![
76+
ChatCompletionRequestUserMessageContentPart::Text(
77+
ChatCompletionRequestMessageContentPartText {
78+
text: request.user_prompt.into_owned(),
79+
},
80+
),
81+
ChatCompletionRequestUserMessageContentPart::ImageUrl(
82+
ChatCompletionRequestMessageContentPartImage {
83+
image_url: async_openai::types::ImageUrl {
84+
url: image_url,
85+
detail: Some(ImageDetail::Auto),
86+
},
87+
},
88+
),
89+
])
90+
}
91+
None => ChatCompletionRequestUserMessageContent::Text(request.user_prompt.into_owned()),
92+
};
6793
messages.push(ChatCompletionRequestMessage::User(
6894
ChatCompletionRequestUserMessage {
69-
content: ChatCompletionRequestUserMessageContent::Text(
70-
request.user_prompt.into_owned(),
71-
),
95+
content: user_message_content,
7296
..Default::default()
7397
},
7498
));

0 commit comments

Comments
 (0)