Skip to content

feat: support image input to LLM clients #653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,007 changes: 498 additions & 509 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,5 @@ aws-config = "1.6.2"
aws-sdk-s3 = "1.85.0"
aws-sdk-sqs = "1.67.0"
numpy = "0.25.0"
infer = "0.19.0"
serde_with = { version = "3.13.0", features = ["base64"] }
12 changes: 11 additions & 1 deletion examples/image_search/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ We appreciate a star ⭐ at [CocoIndex Github](https://github.com/cocoindex-io/c
- CLIP ViT-L/14 - Embeddings Model for images and query
- Qdrant for Vector Storage
- FastApi for backend
- Ollama (Optional) for generating image captions using `gemma3`.

## Setup
- Make sure Postgres and Qdrant are running
Expand All @@ -21,7 +22,16 @@ We appreciate a star ⭐ at [CocoIndex Github](https://github.com/cocoindex-io/c
export COCOINDEX_DATABASE_URL="postgres://cocoindex:cocoindex@localhost/cocoindex"
```

## Run
## (Optional) Run Ollama

- This enables automatic image captioning
```
ollama pull gemma3
ollama serve
export OLLAMA_MODEL="gemma3" # Optional, for caption generation
```

## Run the App
- Install dependencies:
```
pip install -e .
Expand Down
47 changes: 41 additions & 6 deletions examples/image_search/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from qdrant_client import QdrantClient
from transformers import CLIPModel, CLIPProcessor

OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/")
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6334/")
QDRANT_COLLECTION = "ImageSearch"
CLIP_MODEL_NAME = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -69,12 +70,39 @@ def image_object_embedding_flow(
)
img_embeddings = data_scope.add_collector()
with data_scope["images"].row() as img:
ollama_model_name = os.getenv("OLLAMA_MODEL")
if ollama_model_name is not None:
# If an Ollama model is specified, generate an image caption
img["caption"] = flow_builder.transform(
cocoindex.functions.ExtractByLlm(
llm_spec=cocoindex.llm.LlmSpec(
api_type=cocoindex.LlmApiType.OLLAMA, model=ollama_model_name
),
instruction=(
"Describe the image in one detailed sentence. "
"Name all visible animal species, objects, and the main scene. "
"Be specific about type, color, and notable features. "
"Mention what each animal is doing."
),
output_type=str,
),
image=img["content"],
)
img["embedding"] = img["content"].transform(embed_image)
img_embeddings.collect(
id=cocoindex.GeneratedField.UUID,
filename=img["filename"],
embedding=img["embedding"],
)

collect_fields = {
"id": cocoindex.GeneratedField.UUID,
"filename": img["filename"],
"embedding": img["embedding"],
}

if ollama_model_name is not None:
print(f"Using Ollama model '{ollama_model_name}' for captioning.")
collect_fields["caption"] = img["caption"]
else:
print(f"No Ollama model '{ollama_model_name}' found — skipping captioning.")

img_embeddings.collect(**collect_fields)

img_embeddings.export(
"img_embeddings",
Expand Down Expand Up @@ -126,11 +154,18 @@ def search(
collection_name=QDRANT_COLLECTION,
query_vector=("embedding", query_embedding),
limit=limit,
with_payload=True,
)

return {
"results": [
{"filename": result.payload["filename"], "score": result.score}
{
"filename": result.payload["filename"],
"score": result.score,
"caption": result.payload.get(
"caption"
), # Include caption if available
}
for result in search_results
]
}
28 changes: 26 additions & 2 deletions src/llm/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::llm::{
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat, ToJsonSchemaOptions,
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
ToJsonSchemaOptions, detect_image_mime_type,
};
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use base64::prelude::*;
use json5;
use serde_json::Value;

Expand Down Expand Up @@ -36,9 +38,31 @@ impl LlmGenerationClient for Client {
&self,
request: LlmGenerateRequest<'req>,
) -> Result<LlmGenerateResponse> {
let mut user_content_parts: Vec<serde_json::Value> = Vec::new();

// Add image part if present
if let Some(image_bytes) = &request.image {
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
user_content_parts.push(serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_image,
}
}));
}

// Add text part
user_content_parts.push(serde_json::json!({
"type": "text",
"text": request.user_prompt
}));

let messages = vec![serde_json::json!({
"role": "user",
"content": request.user_prompt
"content": user_content_parts
})];

let mut payload = serde_json::json!({
Expand Down
24 changes: 21 additions & 3 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use crate::prelude::*;

use crate::llm::{
LlmEmbeddingClient, LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
ToJsonSchemaOptions,
ToJsonSchemaOptions, detect_image_mime_type,
};
use base64::prelude::*;
use phf::phf_map;
use serde_json::Value;
use urlencoding::encode;
Expand Down Expand Up @@ -70,10 +71,27 @@ impl LlmGenerationClient for Client {
&self,
request: LlmGenerateRequest<'req>,
) -> Result<LlmGenerateResponse> {
// Compose the prompt/messages
let mut user_parts: Vec<serde_json::Value> = Vec::new();

// Add text part first
user_parts.push(serde_json::json!({ "text": request.user_prompt }));

// Add image part if present
if let Some(image_bytes) = &request.image {
let base64_image = BASE64_STANDARD.encode(image_bytes.as_ref());
let mime_type = detect_image_mime_type(image_bytes.as_ref())?;
user_parts.push(serde_json::json!({
"inlineData": {
"mimeType": mime_type,
"data": base64_image
}
}));
}

// Compose the contents
let contents = vec![serde_json::json!({
"role": "user",
"parts": [{ "text": request.user_prompt }]
"parts": user_parts
})];

// Prepare payload
Expand Down
12 changes: 12 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::prelude::*;

use crate::base::json_schema::ToJsonSchemaOptions;
use infer::Infer;
use schemars::schema::SchemaObject;
use std::borrow::Cow;

static INFER: LazyLock<Infer> = LazyLock::new(Infer::new);

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum LlmApiType {
Ollama,
Expand Down Expand Up @@ -36,6 +39,7 @@ pub struct LlmGenerateRequest<'a> {
pub model: &'a str,
pub system_prompt: Option<Cow<'a, str>>,
pub user_prompt: Cow<'a, str>,
pub image: Option<Cow<'a, [u8]>>,
pub output_format: Option<OutputFormat<'a>>,
}

Expand Down Expand Up @@ -141,3 +145,11 @@ pub fn new_llm_embedding_client(
};
Ok(client)
}

pub fn detect_image_mime_type(bytes: &[u8]) -> Result<&'static str> {
let infer = &*INFER;
match infer.get(bytes) {
Some(info) if info.mime_type().starts_with("image/") => Ok(info.mime_type()),
_ => bail!("Unknown or unsupported image format"),
}
}
5 changes: 5 additions & 0 deletions src/llm/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::prelude::*;

use super::LlmGenerationClient;
use schemars::schema::SchemaObject;
use serde_with::{base64::Base64, serde_as};

pub struct Client {
generate_url: String,
Expand All @@ -14,10 +15,13 @@ enum OllamaFormat<'a> {
JsonSchema(&'a SchemaObject),
}

#[serde_as]
#[derive(Debug, Serialize)]
struct OllamaRequest<'a> {
pub model: &'a str,
pub prompt: &'a str,
#[serde_as(as = "Option<Vec<Base64>>")]
pub images: Option<Vec<&'a [u8]>>,
pub format: Option<OllamaFormat<'a>>,
pub system: Option<&'a str>,
pub stream: Option<bool>,
Expand Down Expand Up @@ -52,6 +56,7 @@ impl LlmGenerationClient for Client {
let req = OllamaRequest {
model: request.model,
prompt: request.user_prompt.as_ref(),
images: request.image.as_deref().map(|img| vec![img.as_ref()]),
format: request.output_format.as_ref().map(
|super::OutputFormat::JsonSchema { schema, .. }| {
OllamaFormat::JsonSchema(schema.as_ref())
Expand Down
38 changes: 31 additions & 7 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use crate::api_bail;

use super::{LlmEmbeddingClient, LlmGenerationClient};
use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
use anyhow::Result;
use async_openai::{
Client as OpenAIClient,
config::OpenAIConfig,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
CreateEmbeddingRequest, EmbeddingInput, ResponseFormat, ResponseFormatJsonSchema,
ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
CreateChatCompletionRequest, CreateEmbeddingRequest, EmbeddingInput, ImageDetail,
ResponseFormat, ResponseFormatJsonSchema,
},
};
use async_trait::async_trait;
use base64::prelude::*;
use phf::phf_map;

static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
Expand Down Expand Up @@ -64,11 +67,32 @@ impl LlmGenerationClient for Client {
}

// Add user message
let user_message_content = match request.image {
Some(img_bytes) => {
let base64_image = BASE64_STANDARD.encode(img_bytes.as_ref());
let mime_type = detect_image_mime_type(img_bytes.as_ref())?;
let image_url = format!("data:{};base64,{}", mime_type, base64_image);
ChatCompletionRequestUserMessageContent::Array(vec![
ChatCompletionRequestUserMessageContentPart::Text(
ChatCompletionRequestMessageContentPartText {
text: request.user_prompt.into_owned(),
},
),
ChatCompletionRequestUserMessageContentPart::ImageUrl(
ChatCompletionRequestMessageContentPartImage {
image_url: async_openai::types::ImageUrl {
url: image_url,
detail: Some(ImageDetail::Auto),
},
},
),
])
}
None => ChatCompletionRequestUserMessageContent::Text(request.user_prompt.into_owned()),
};
messages.push(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(
request.user_prompt.into_owned(),
),
content: user_message_content,
..Default::default()
},
));
Expand Down
Loading