-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
What happened?
I'm trying to create a collection with a SentenceTransformerEmbeddingFunction
, and my use case is for text only, I don't need anything multimodal.
Here's the relevant code:
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
def main() -> None:
client = chromadb.EphemeralClient()
_collection: chromadb.Collection = client.get_or_create_collection(
"asdf",
embedding_function=SentenceTransformerEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2"
),
)
Now when I run mypy on this (even without --strict
), I get the following error:
main.py:10: error: Argument "embedding_function" to "get_or_create_collection" of "ClientAPI" has incompatible type "SentenceTransformerEmbeddingFunction"; expected "EmbeddingFunction[list[str] | list[ndarray[tuple[Any, ...], dtype[unsignedinteger[_32Bit | _64Bit] | signedinteger[_64Bit] | float64]]]] | None" [arg-type]
main.py:10: note: Following member(s) of "SentenceTransformerEmbeddingFunction" have conflicts:
main.py:10: note: Expected:
main.py:10: note: def __call__(self, input: list[str] | list[ndarray[tuple[Any, ...], dtype[unsignedinteger[_32Bit | _64Bit] | signedinteger[_64Bit] | float64]]]) -> list[ndarray[tuple[Any, ...], dtype[signedinteger[_32Bit] | floating[_32Bit]]]]
main.py:10: note: Got:
main.py:10: note: def __call__(self, input: list[str]) -> list[ndarray[tuple[Any, ...], dtype[signedinteger[_32Bit] | floating[_32Bit]]]]
main.py:10: note: Expected:
main.py:10: note: def build_from_config(config: dict[str, Any]) -> EmbeddingFunction[list[str] | list[ndarray[tuple[Any, ...], dtype[unsignedinteger[_32Bit | _64Bit] | signedinteger[_64Bit] | float64]]]]
main.py:10: note: Got:
main.py:10: note: def build_from_config(config: dict[str, Any]) -> EmbeddingFunction[list[str]]
main.py:10: note: <1 more conflict(s) not shown>
Found 1 error in 1 file (checked 1 source file)
The gist of it is that get_or_create_collection
expects an EmbeddingFunction[Embeddable]
where Embeddable = Union[Documents, Images]
, but I'm providing an EmbeddingFunction[Documents]
. And since EmbeddingFunction
is of course contravariant in its parameter due to it being the input to the function, this means that we can't use a text-only embedding function when creating a new collection this way.
For embedding functions like OpenCLIPEmbeddingFunction
this is not an issue, since they implement EmbeddingFunction[Embeddable]
, but it looks like the majority of the provided embedding functions are text-only.
I'm aware that the code type-checks when I change it to the following:
_collection: chromadb.Collection = client.get_or_create_collection(
"asdf",
configuration={
"embedding_function": SentenceTransformerEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
},
)
But that's only due to a type: ignore
inside CreateCollectionConfiguration
, the type of the configuration
argument.
Possible solutions
One possible fix would be to add the same type: ignore
also to the embedding_function
argument.
Another option would be to take what e.g. MistralEmbeddingFunction
does and expand that concept to everywhere. The Mistral embedder also only supports text, and it uses runtime checks to enforce that. My suggestion would be to change every embedding function to be of type EmbeddingFunction[Embeddable]
, remove the now obsolete type: ignore
annotations, and insert runtime checks wherever the embedder does not support images. This is preferable in my opinion because in the current solution we still get runtime errors when something doesn't match up, but these might appear much later deep inside whatever third-party api is being used. Having all embedding functions take Embeddable
as input would mean that the implementation any text-only embedder most likely does not type-check unless the implementation explicitly inserts a runtime check, which means that we can expect better runtime error messages in general.
A third option would be to make the Collection
type generic, but that would probably be quite a large change to the public api.
Versions
chromadb==1.0.16
mypy==1.17.1
Python 3.12.8
Arch Linux (x86_64, 6.15.8-arch1-2)