Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _validate_tool_inputs(self):
"columns": self.columns,
"workspace_client": self.workspace_client,
"include_score": self.include_score,
"reranker": self.reranker,
}
dbvs = DatabricksVectorSearch(**kwargs)
self._vector_store = dbvs
Expand Down
36 changes: 36 additions & 0 deletions integrations/langchain/src/databricks_langchain/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
from databricks.sdk import WorkspaceClient
from databricks.vector_search.reranker import DatabricksReranker, Reranker
from databricks_ai_bridge.utils.vector_search import (
IndexDetails,
RetrieverSchema,
Expand Down Expand Up @@ -74,6 +75,10 @@ class DatabricksVectorSearch(VectorStore):
Allows you to pass in values like ``service_principal_client_id``
and ``service_principal_client_secret`` to allow for
service principal authentication instead of personal access token authentication.
reranker: Optional reranker to apply on the top results. Pass an instance of
``databricks.vector_search.reranker.DatabricksReranker`` with
``columns_to_rerank=[...]``. The reranker reorders the initial results using
the specified text columns.

**Instantiate**:

Expand Down Expand Up @@ -107,6 +112,17 @@ class DatabricksVectorSearch(VectorStore):
text_column="document_content",
)

If you want Databricks to rerank the results, you can provide a reranker when initializing the vector store:

.. code-block:: python

from databricks.vector_search.reranker import DatabricksReranker

vector_store = DatabricksVectorSearch(
index_name="<your-index-name>",
reranker=DatabricksReranker(columns_to_rerank=["column1", "column2"]),
)

**Add Documents**:

.. code-block:: python
Expand Down Expand Up @@ -229,6 +245,7 @@ def __init__(
workspace_client: Optional[WorkspaceClient] = None,
client_args: Optional[Dict[str, Any]] = None,
include_score: bool = False,
reranker: Optional[DatabricksReranker] = None,
):
if not isinstance(index_name, str):
raise ValueError(
Expand Down Expand Up @@ -296,6 +313,7 @@ def __init__(
other_columns=self._columns,
)
self._include_score = include_score
self._reranker = reranker

@property
def embeddings(self) -> Optional[Embeddings]:
Expand Down Expand Up @@ -406,6 +424,7 @@ def similarity_search(
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
reranker: Optional[Reranker] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Expand All @@ -415,6 +434,7 @@ def similarity_search(
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
reranker: Allows reranking the results. Defaults to None.
kwargs: Additional keyword arguments to pass to `databricks.vector_search.client.VectorSearchIndex.similarity_search`. `See
documentation <https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search>`_
to see the full set of supported keyword arguments
Expand All @@ -427,6 +447,7 @@ def similarity_search(
k=k,
filter=filter,
query_type=query_type,
reranker=reranker,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
Expand All @@ -445,6 +466,7 @@ def similarity_search_with_score(
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
reranker: Optional[Reranker] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query, along with scores.
Expand All @@ -454,6 +476,7 @@ def similarity_search_with_score(
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
reranker: Allows reranking the results. Defaults to None.
kwargs: Additional keyword arguments to pass to `databricks.vector_search.client.VectorSearchIndex.similarity_search`. `See
documentation <https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search>`_
to see the full set of supported keyword arguments
Expand Down Expand Up @@ -482,6 +505,7 @@ def similarity_search_with_score(
"filters": filter,
"num_results": k,
"query_type": query_type,
"reranker": reranker or self._reranker,
}
)
search_resp = self.index.similarity_search(**kwargs)
Expand Down Expand Up @@ -516,6 +540,7 @@ def similarity_search_by_vector(
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
reranker: Optional[Reranker] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Expand All @@ -525,6 +550,7 @@ def similarity_search_by_vector(
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
reranker: Allows reranking the results. Defaults to None.
kwargs: Additional keyword arguments to pass to `databricks.vector_search.client.VectorSearchIndex.similarity_search`. `See
documentation <https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search>`_
to see the full set of supported keyword arguments
Expand All @@ -541,6 +567,7 @@ def similarity_search_by_vector(
filter=filter,
query_type=query_type,
query=query,
reranker=reranker,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
Expand All @@ -562,6 +589,7 @@ def similarity_search_by_vector_with_score(
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
reranker: Optional[Reranker] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector, along with scores.
Expand All @@ -575,6 +603,7 @@ def similarity_search_by_vector_with_score(
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
reranker: Allows reranking the results. Defaults to None.
kwargs: Additional keyword arguments to pass to `databricks.vector_search.client.VectorSearchIndex.similarity_search`. `See
documentation <https://api-docs.databricks.com/python/vector-search/databricks.vector_search.html#databricks.vector_search.index.VectorSearchIndex.similarity_search>`_
to see the full set of supported keyword arguments
Expand Down Expand Up @@ -605,6 +634,7 @@ def similarity_search_by_vector_with_score(
filters=filter,
num_results=k,
query_type=query_type,
reranker=reranker or self._reranker,
**kwargs,
)
return parse_vector_search_response(
Expand All @@ -623,6 +653,7 @@ def max_marginal_relevance_search(
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
reranker: Optional[Reranker] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Expand All @@ -644,6 +675,7 @@ def max_marginal_relevance_search(
Defaults to 0.5.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
reranker: Allows reranking the results. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
Expand All @@ -665,6 +697,7 @@ def max_marginal_relevance_search(
lambda_mult=lambda_mult,
filter=filter,
query_type=query_type,
reranker=reranker,
)
return docs

Expand Down Expand Up @@ -699,6 +732,7 @@ def max_marginal_relevance_search_by_vector(
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
reranker: Optional[Reranker] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Expand All @@ -720,6 +754,7 @@ def max_marginal_relevance_search_by_vector(
Defaults to 0.5.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
reranker: Allows reranking the results. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
Expand All @@ -736,6 +771,7 @@ def max_marginal_relevance_search_by_vector(
filters=filter,
num_results=fetch_k,
query_type=query_type,
reranker=reranker or self._reranker,
**kwargs,
)

Expand Down
73 changes: 70 additions & 3 deletions integrations/langchain/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from databricks.vector_search.client import VectorSearchIndex # type: ignore
from databricks.vector_search.reranker import DatabricksReranker, Reranker
from databricks_ai_bridge.test_utils.vector_search import ( # noqa: F401
ALL_INDEX_NAMES,
DELTA_SYNC_INDEX,
Expand All @@ -22,11 +23,14 @@


def init_vector_search(
index_name: str, columns: Optional[List[str]] = None
index_name: str,
columns: Optional[List[str]] = None,
reranker: Optional[Reranker] = None,
) -> DatabricksVectorSearch:
kwargs: Dict[str, Any] = {
"index_name": index_name,
"columns": columns,
"reranker": reranker,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
Expand Down Expand Up @@ -256,6 +260,7 @@ def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
filters=filters,
num_results=limit,
query_type=query_type,
reranker=None,
)
else:
vectorsearch.index.similarity_search.assert_called_once_with(
Expand All @@ -265,6 +270,7 @@ def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
filters=filters,
num_results=limit,
query_type=query_type,
reranker=None,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
Expand All @@ -289,6 +295,7 @@ def test_similarity_search_hybrid(index_name: str) -> None:
filters=filters,
num_results=limit,
query_type="HYBRID",
reranker=None,
)
else:
vectorsearch.index.similarity_search.assert_called_once_with(
Expand All @@ -298,6 +305,7 @@ def test_similarity_search_hybrid(index_name: str) -> None:
filters=filters,
num_results=limit,
query_type="HYBRID",
reranker=None,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
Expand All @@ -309,6 +317,7 @@ def test_similarity_search_passing_kwargs() -> None:
query = "foo"
filters = {"some filter": True}
query_type = "ANN"
reranker = DatabricksReranker(columns_to_rerank=["id", "text", "text_vector"])

search_result = vectorsearch.similarity_search(
query,
Expand All @@ -318,6 +327,7 @@ def test_similarity_search_passing_kwargs() -> None:
score_threshold=0.5,
num_results=10,
random_parameters="not included",
reranker=reranker,
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
Expand All @@ -327,6 +337,7 @@ def test_similarity_search_passing_kwargs() -> None:
query_type=query_type,
num_results=5, # maintained
score_threshold=0.5, # passed
reranker=reranker,
)


Expand All @@ -346,12 +357,60 @@ def test_mmr_search(
query = INPUT_TEXTS[0]
filters = {"some filter": True}
limit = 1
reranker = DatabricksReranker(columns_to_rerank=["id", "text", "text_vector"])

search_result = vectorsearch.max_marginal_relevance_search(query, k=limit, filters=filters)
search_result = vectorsearch.max_marginal_relevance_search(
query, k=limit, filters=filters, reranker=reranker
)
assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]]
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]


@pytest.mark.parametrize(
"reranker", [None, DatabricksReranker(columns_to_rerank=["id", "text", "text_vector"])]
)
def test_reranker_similarity_search_with_score(reranker: Optional[DatabricksReranker]):
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX, reranker=reranker)

query = INPUT_TEXTS[0]
filters = {"some filter": True}
limit = 1
search_result = vectorsearch.similarity_search_with_score(query, k=limit, filter=filters)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=None,
query_vector=EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type=None,
reranker=reranker,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for (d, _score) in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for (d, _score) in search_result])


@pytest.mark.parametrize(
"reranker", [None, DatabricksReranker(columns_to_rerank=["id", "text", "text_vector"])]
)
def test_reranker_backward_compatibility(reranker: Optional[DatabricksReranker]):
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)

query = INPUT_TEXTS[0]
filters = {"some filter": True}
limit = 1
vectorsearch.similarity_search_with_score(query, k=limit, filter=filters, reranker=reranker)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=None,
query_vector=EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type=None,
reranker=reranker,
)


@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_mmr_parameters(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
Expand Down Expand Up @@ -437,6 +496,7 @@ def test_similarity_search_by_vector(index_name: str, query_type: Optional[str])
num_results=limit,
query_type=query_type,
query_text=None,
reranker=None,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
Expand All @@ -449,9 +509,15 @@ def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
reranker = DatabricksReranker(columns_to_rerank=["id", "text"])

search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
query_embedding,
k=limit,
filter=filters,
query_type="HYBRID",
query="foo",
reranker=reranker,
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
Expand All @@ -460,6 +526,7 @@ def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
num_results=limit,
query_type="HYBRID",
query_text="foo",
reranker=reranker,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
# Allow kwargs to override the default values upon invocation
num_results = kwargs.pop("num_results", self.num_results)
query_type = kwargs.pop("query_type", self.query_type)
reranker = kwargs.pop("reranker", self.reranker)

# Ensure that we don't have duplicate keys
kwargs.update(
Expand All @@ -130,6 +131,7 @@ def get_query_text_vector(query: str) -> Tuple[Optional[str], Optional[List[floa
"filters": combined_filters,
"num_results": num_results,
"query_type": query_type,
"reranker": reranker,
}
)
search_resp = self._index.similarity_search(**kwargs)
Expand Down
Loading