Skip to content

Commit 440c98e

Browse files
authored
Fix/issue 2695 (langchain-ai#3608)
## Background fixes langchain-ai#2695 ## Changes The `add_text` method uses the internal embedding function if one was passes to the `Weaviate` constructor. NOTE: the latest merge on the `Weaviate` class made the specification of a `weaviate_api_key` mandatory which might not be desirable for all users and connection methods (for example weaviate also support Embedded Weaviate which I am happy to add support to here if people think it's desirable). I wrapped the fetching of the api key into a try catch in order to allow the `weaviate_api_key` to be unspecified. Do let me know if this is unsatisfactory. ## Test Plan added test for `add_texts` method.
1 parent 6158125 commit 440c98e

8 files changed

+2021
-1988
lines changed

langchain/vectorstores/weaviate.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,18 @@ def _default_schema(index_name: str) -> Dict:
2727

2828
def _create_weaviate_client(**kwargs: Any) -> Any:
2929
client = kwargs.get("client")
30-
3130
if client is not None:
3231
return client
3332

3433
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
35-
weaviate_api_key = get_from_dict_or_env(
36-
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
37-
)
34+
35+
try:
36+
# the weaviate api key param should not be mandatory
37+
weaviate_api_key = get_from_dict_or_env(
38+
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
39+
)
40+
except ValueError:
41+
weaviate_api_key = None
3842

3943
try:
4044
import weaviate
@@ -117,9 +121,21 @@ def add_texts(
117121
data_properties[key] = metadatas[i][key]
118122

119123
_id = get_valid_uuid(uuid4())
120-
batch.add_data_object(
121-
data_object=data_properties, class_name=self._index_name, uuid=_id
122-
)
124+
125+
if self._embedding is not None:
126+
embeddings = self._embedding.embed_documents(list(doc))
127+
batch.add_data_object(
128+
data_object=data_properties,
129+
class_name=self._index_name,
130+
uuid=_id,
131+
vector=embeddings[0],
132+
)
133+
else:
134+
batch.add_data_object(
135+
data_object=data_properties,
136+
class_name=self._index_name,
137+
uuid=_id,
138+
)
123139
ids.append(_id)
124140
return ids
125141

tests/integration_tests/vectorstores/cassettes/test_weaviate/TestWeaviate.test_max_marginal_relevance_search.yaml

Lines changed: 336 additions & 336 deletions
Large diffs are not rendered by default.

tests/integration_tests/vectorstores/cassettes/test_weaviate/TestWeaviate.test_max_marginal_relevance_search_by_vector.yaml

Lines changed: 330 additions & 331 deletions
Large diffs are not rendered by default.

tests/integration_tests/vectorstores/cassettes/test_weaviate/TestWeaviate.test_max_marginal_relevance_search_with_filter.yaml

Lines changed: 342 additions & 343 deletions
Large diffs are not rendered by default.

tests/integration_tests/vectorstores/cassettes/test_weaviate/TestWeaviate.test_similarity_search_with_metadata.yaml

Lines changed: 323 additions & 323 deletions
Large diffs are not rendered by default.

tests/integration_tests/vectorstores/cassettes/test_weaviate/TestWeaviate.test_similarity_search_with_metadata_and_filter.yaml

Lines changed: 327 additions & 326 deletions
Large diffs are not rendered by default.

tests/integration_tests/vectorstores/cassettes/test_weaviate/TestWeaviate.test_similarity_search_without_metadata.yaml

Lines changed: 322 additions & 322 deletions
Large diffs are not rendered by default.

tests/integration_tests/vectorstores/test_weaviate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from langchain.docstore.document import Document
1010
from langchain.embeddings.openai import OpenAIEmbeddings
1111
from langchain.vectorstores.weaviate import Weaviate
12+
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
1213

1314
logging.basicConfig(level=logging.DEBUG)
1415

@@ -163,3 +164,20 @@ def test_max_marginal_relevance_search_with_filter(
163164
assert output == [
164165
Document(page_content="foo", metadata={"page": 0}),
165166
]
167+
168+
def test_add_texts_with_given_embedding(self, weaviate_url: str) -> None:
169+
texts = ["foo", "bar", "baz"]
170+
embedding = FakeEmbeddings()
171+
172+
docsearch = Weaviate.from_texts(
173+
texts, embedding=embedding, weaviate_url=weaviate_url
174+
)
175+
176+
docsearch.add_texts(["foo"])
177+
output = docsearch.similarity_search_by_vector(
178+
embedding.embed_query("foo"), k=2
179+
)
180+
assert output == [
181+
Document(page_content="foo"),
182+
Document(page_content="foo"),
183+
]

0 commit comments

Comments
 (0)