diff --git a/integrations/weaviate/pydoc/config.yml b/integrations/weaviate/pydoc/config.yml index fa59e6874..84334c2e6 100644 --- a/integrations/weaviate/pydoc/config.yml +++ b/integrations/weaviate/pydoc/config.yml @@ -1,9 +1,12 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../src] - modules: [ - "haystack_integrations.document_stores.weaviate.document_store", - ] + modules: + [ + "haystack_integrations.document_stores.weaviate.document_store", + "haystack_integrations.components.retrievers.weaviate.bm25_retriever", + "haystack_integrations.components.retrievers.weaviate.embedding_retriever", + ] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py new file mode 100644 index 000000000..34bfd0c7d --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/__init__.py @@ -0,0 +1,4 @@ +from .bm25_retriever import WeaviateBM25Retriever +from .embedding_retriever import WeaviateEmbeddingRetriever + +__all__ = ["WeaviateBM25Retriever", "WeaviateEmbeddingRetriever"] diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py new file mode 100644 index 000000000..6c27378cf --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +@component +class WeaviateBM25Retriever: + """ + Retriever that uses BM25 to find the most promising documents for a given query. + """ + + def __init__( + self, + *, + document_store: WeaviateDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ): + """ + Create a new instance of WeaviateBM25Retriever. + + :param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever. + :param filters: Custom filters applied when running the retriever, defaults to None + :param top_k: Maximum number of documents to return, defaults to 10 + """ + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever": + data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + filters = filters or self._filters + top_k = top_k or self._top_k + return self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k) diff --git a/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py new file mode 100644 index 000000000..b8a163b56 --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +@component +class WeaviateEmbeddingRetriever: + """ + A retriever that uses Weaviate's vector search to find similar documents based on the embeddings of the query. + """ + + def __init__( + self, + *, + document_store: WeaviateDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ): + """ + Create a new instance of WeaviateEmbeddingRetriever. + Raises ValueError if both `distance` and `certainty` are provided. + See the official Weaviate documentation to learn more about the `distance` and `certainty` parameters: + https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables + + :param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever. + :param filters: Custom filters applied when running the retriever, defaults to None + :param top_k: Maximum number of documents to return, defaults to 10 + :param distance: The maximum allowed distance between Documents' embeddings, defaults to None + :param certainty: Normalized distance between the result item and the search vector, defaults to None + """ + if distance is not None and certainty is not None: + msg = "Can't use 'distance' and 'certainty' parameters together" + raise ValueError(msg) + + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + self._distance = distance + self._certainty = certainty + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + distance=self._distance, + certainty=self._certainty, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever": + data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ): + filters = filters or self._filters + top_k = top_k or self._top_k + distance = distance or self._distance + certainty = certainty or self._certainty + return self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + distance=distance, + certainty=certainty, + ) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index b7aba3716..38f0b38cd 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -419,3 +419,66 @@ def delete_documents(self, document_ids: List[str]) -> None: "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], }, ) + + def _bm25_retrieval( + self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None + ) -> List[Document]: + collection_name = self._collection_settings["class"] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] + + query_builder = ( + self._client.query.get(collection_name, properties=properties) + .with_bm25(query=query, properties=["content"]) + .with_additional(["vector"]) + ) + + if filters: + query_builder = query_builder.with_where(convert_filters(filters)) + + if top_k: + query_builder = query_builder.with_limit(top_k) + + result = query_builder.do() + + return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + + def _embedding_retrieval( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + distance: Optional[float] = None, + certainty: Optional[float] = None, + ) -> List[Document]: + if distance is not None and certainty is not None: + msg = "Can't use 'distance' and 'certainty' parameters together" + raise ValueError(msg) + + collection_name = self._collection_settings["class"] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] + + near_vector: Dict[str, Union[float, List[float]]] = { + "vector": query_embedding, + } + if distance is not None: + near_vector["distance"] = distance + + if certainty is not None: + near_vector["certainty"] = certainty + + query_builder = ( + self._client.query.get(collection_name, properties=properties) + .with_near_vector(near_vector) + .with_additional(["vector"]) + ) + + if filters: + query_builder = query_builder.with_where(convert_filters(filters)) + + if top_k: + query_builder = query_builder.with_limit(top_k) + + result = query_builder.do() + return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py new file mode 100644 index 000000000..83f90735b --- /dev/null +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -0,0 +1,102 @@ +from unittest.mock import Mock, patch + +from haystack_integrations.components.retrievers.weaviate import WeaviateBM25Retriever +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +def test_init_default(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + retriever = WeaviateBM25Retriever(document_store=mock_document_store) + assert retriever._document_store == mock_document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_to_dict(_mock_weaviate): + document_store = WeaviateDocumentStore() + retriever = WeaviateBM25Retriever(document_store=document_store) + assert retriever.to_dict() == { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict(_mock_weaviate): + retriever = WeaviateBM25Retriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + + +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") +def test_run(mock_document_store): + retriever = WeaviateBM25Retriever(document_store=mock_document_store) + query = "some query" + filters = {"field": "content", "operator": "==", "value": "Some text"} + retriever.run(query=query, filters=filters, top_k=5) + mock_document_store._bm25_retrieval.assert_called_once_with(query=query, filters=filters, top_k=5) diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 2322a9484..359af3670 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -481,3 +481,141 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab @pytest.mark.skip(reason="Weaviate for some reason is not returning what we expect") def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) + + def test_bm25_retrieval(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + result = document_store._bm25_retrieval("functional Haskell") + assert len(result) == 5 + assert "functional" in result[0].content + assert "functional" in result[1].content + assert "functional" in result[2].content + assert "functional" in result[3].content + assert "functional" in result[4].content + + def test_bm25_retrieval_with_filters(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + filters = {"field": "content", "operator": "==", "value": "Haskell"} + result = document_store._bm25_retrieval("functional Haskell", filters=filters) + assert len(result) == 1 + assert "Haskell is a functional programming language" == result[0].content + + def test_bm25_retrieval_with_topk(self, document_store): + document_store.write_documents( + [ + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), + ] + ) + result = document_store._bm25_retrieval("functional Haskell", top_k=3) + assert len(result) == 3 + assert "functional" in result[0].content + assert "functional" in result[1].content + assert "functional" in result[2].content + + def test_embedding_retrieval(self, document_store): + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[0.00001, 0.00001, 0.00001, 0.00002], + ), + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + ] + ) + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0]) + assert len(result) == 3 + assert "The document" == result[0].content + assert "Another document" == result[1].content + assert "Yet another document" == result[2].content + + def test_embedding_retrieval_with_filters(self, document_store): + document_store.write_documents( + [ + Document( + content="Yet another document", + embedding=[0.00001, 0.00001, 0.00001, 0.00002], + ), + Document(content="The document I want", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + ] + ) + filters = {"field": "content", "operator": "==", "value": "The document I want"} + result = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], filters=filters) + assert len(result) == 1 + assert "The document I want" == result[0].content + + def test_embedding_retrieval_with_topk(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], top_k=2) + assert len(results) == 2 + assert results[0].content == "The document" + assert results[1].content == "Another document" + + def test_embedding_retrieval_with_distance(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[1.0, 1.0, 1.0, 1.0], distance=0.0) + assert len(results) == 1 + assert results[0].content == "The document" + + def test_embedding_retrieval_with_certainty(self, document_store): + docs = [ + Document(content="The document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="Another document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Yet another document", embedding=[0.00001, 0.00001, 0.00001, 0.00002]), + ] + document_store.write_documents(docs) + results = document_store._embedding_retrieval(query_embedding=[0.8, 0.8, 0.8, 1.0], certainty=1.0) + assert len(results) == 1 + assert results[0].content == "Another document" + + def test_embedding_retrieval_with_distance_and_certainty(self, document_store): + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py new file mode 100644 index 000000000..7f07d8a24 --- /dev/null +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -0,0 +1,119 @@ +from unittest.mock import Mock, patch + +import pytest +from haystack_integrations.components.retrievers.weaviate import WeaviateEmbeddingRetriever +from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore + + +def test_init_default(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) + assert retriever._document_store == mock_document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +def test_init_with_distance_and_certainty(): + mock_document_store = Mock(spec=WeaviateDocumentStore) + with pytest.raises(ValueError): + WeaviateEmbeddingRetriever(document_store=mock_document_store, distance=0.1, certainty=0.8) + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_to_dict(_mock_weaviate): + document_store = WeaviateDocumentStore() + retriever = WeaviateEmbeddingRetriever(document_store=document_store) + assert retriever.to_dict() == { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + + +@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") +def test_from_dict(_mock_weaviate): + retriever = WeaviateEmbeddingRetriever.from_dict( + { + "type": "haystack_integrations.components.retrievers.weaviate.embedding_retriever.WeaviateEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "distance": None, + "certainty": None, + "document_store": { + "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": None, + "collection_settings": { + "class": "Default", + "invertedIndexConfig": {"indexNullState": True}, + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, + "auth_client_secret": None, + "timeout_config": (10, 60), + "proxies": None, + "trust_env": False, + "additional_headers": None, + "startup_period": 5, + "embedded_options": None, + "additional_config": None, + }, + }, + }, + } + ) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._distance is None + assert retriever._certainty is None + + +@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore") +def test_run(mock_document_store): + retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store) + query_embedding = [0.1, 0.1, 0.1, 0.1] + filters = {"field": "content", "operator": "==", "value": "Some text"} + retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1) + mock_document_store._embedding_retrieval.assert_called_once_with( + query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1 + )