Skip to content

[Bug]: text-only embedding functions can't be used typesafely #5241

@Pascal-So

Description

@Pascal-So

What happened?

I'm trying to create a collection with a SentenceTransformerEmbeddingFunction, and my use case is for text only, I don't need anything multimodal.

Here's the relevant code:

import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction


def main() -> None:
    client = chromadb.EphemeralClient()

    _collection: chromadb.Collection = client.get_or_create_collection(
        "asdf",
        embedding_function=SentenceTransformerEmbeddingFunction(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        ),
    )

Now when I run mypy on this (even without --strict), I get the following error:

main.py:10: error: Argument "embedding_function" to "get_or_create_collection" of "ClientAPI" has incompatible type "SentenceTransformerEmbeddingFunction"; expected "EmbeddingFunction[list[str] | list[ndarray[tuple[Any, ...], dtype[unsignedinteger[_32Bit | _64Bit] | signedinteger[_64Bit] | float64]]]] | None"  [arg-type]
main.py:10: note: Following member(s) of "SentenceTransformerEmbeddingFunction" have conflicts:
main.py:10: note:     Expected:
main.py:10: note:         def __call__(self, input: list[str] | list[ndarray[tuple[Any, ...], dtype[unsignedinteger[_32Bit | _64Bit] | signedinteger[_64Bit] | float64]]]) -> list[ndarray[tuple[Any, ...], dtype[signedinteger[_32Bit] | floating[_32Bit]]]]
main.py:10: note:     Got:
main.py:10: note:         def __call__(self, input: list[str]) -> list[ndarray[tuple[Any, ...], dtype[signedinteger[_32Bit] | floating[_32Bit]]]]
main.py:10: note:     Expected:
main.py:10: note:         def build_from_config(config: dict[str, Any]) -> EmbeddingFunction[list[str] | list[ndarray[tuple[Any, ...], dtype[unsignedinteger[_32Bit | _64Bit] | signedinteger[_64Bit] | float64]]]]
main.py:10: note:     Got:
main.py:10: note:         def build_from_config(config: dict[str, Any]) -> EmbeddingFunction[list[str]]
main.py:10: note:     <1 more conflict(s) not shown>
Found 1 error in 1 file (checked 1 source file)

The gist of it is that get_or_create_collection expects an EmbeddingFunction[Embeddable] where Embeddable = Union[Documents, Images], but I'm providing an EmbeddingFunction[Documents]. And since EmbeddingFunction is of course contravariant in its parameter due to it being the input to the function, this means that we can't use a text-only embedding function when creating a new collection this way.

For embedding functions like OpenCLIPEmbeddingFunction this is not an issue, since they implement EmbeddingFunction[Embeddable], but it looks like the majority of the provided embedding functions are text-only.

I'm aware that the code type-checks when I change it to the following:

    _collection: chromadb.Collection = client.get_or_create_collection(
        "asdf",
        configuration={
            "embedding_function": SentenceTransformerEmbeddingFunction(
                model_name="sentence-transformers/all-MiniLM-L6-v2"
            )
        },
    )

But that's only due to a type: ignore inside CreateCollectionConfiguration, the type of the configuration argument.

Possible solutions

One possible fix would be to add the same type: ignore also to the embedding_function argument.

Another option would be to take what e.g. MistralEmbeddingFunction does and expand that concept to everywhere. The Mistral embedder also only supports text, and it uses runtime checks to enforce that. My suggestion would be to change every embedding function to be of type EmbeddingFunction[Embeddable], remove the now obsolete type: ignore annotations, and insert runtime checks wherever the embedder does not support images. This is preferable in my opinion because in the current solution we still get runtime errors when something doesn't match up, but these might appear much later deep inside whatever third-party api is being used. Having all embedding functions take Embeddable as input would mean that the implementation any text-only embedder most likely does not type-check unless the implementation explicitly inserts a runtime check, which means that we can expect better runtime error messages in general.

A third option would be to make the Collection type generic, but that would probably be quite a large change to the public api.

Versions

chromadb==1.0.16
mypy==1.17.1

Python 3.12.8

Arch Linux (x86_64, 6.15.8-arch1-2)

Relevant log output

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions