From 3454815095b539558cdda083c6d51f76ed2b12ea Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 1 Feb 2024 17:01:26 +0100 Subject: [PATCH] Pgvector - Embedding Retriever (#320) * squash * squash * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * fix fmt * adjust docstrings * Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py Co-authored-by: Massimiliano Pippi * improve docstrings * fmt --------- Co-authored-by: Massimiliano Pippi --- .../retrievers/pgvector/__init__.py | 6 + .../pgvector/embedding_retriever.py | 104 ++++++++++++++++ integrations/pgvector/tests/test_retriever.py | 112 ++++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py create mode 100644 integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py create mode 100644 integrations/pgvector/tests/test_retriever.py diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py new file mode 100644 index 000000000..ec0cf0dc4 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .embedding_retriever import PgvectorEmbeddingRetriever + +__all__ = ["PgvectorEmbeddingRetriever"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py new file mode 100644 index 000000000..26807e9bd --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS + + +@component +class PgvectorEmbeddingRetriever: + """ + Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings. + + Needs to be connected to the PgvectorDocumentStore. + """ + + def __init__( + self, + *, + document_store: PgvectorDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Create the PgvectorEmbeddingRetriever component. + + :param document_store: An instance of PgvectorDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + :param vector_function: The similarity function to use when searching for similar embeddings. + Defaults to the one set in the `document_store` instance. + "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: if the document store is using the "hnsw" search strategy, the vector function + should match the one utilized during index creation to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + + :raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore. + """ + if not isinstance(document_store, PgvectorDocumentStore): + msg = "document_store must be an instance of PgvectorDocumentStore" + raise ValueError(msg) + + if vector_function and vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.vector_function = vector_function or document_store.vector_function + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + vector_function=self.vector_function, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + PgvectorDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Retrieve documents from the PgvectorDocumentStore, based on their embeddings. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param vector_function: The similarity function to use when searching for similar embeddings. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + :return: List of Documents similar to `query_embedding`. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + vector_function = vector_function or self.vector_function + + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + vector_function=vector_function, + ) + return {"documents": docs} diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py new file mode 100644 index 000000000..cca6bbc9f --- /dev/null +++ b/integrations/pgvector/tests/test_retriever.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock + +from haystack.dataclasses import Document +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +class TestRetriever: + def test_init_default(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever(document_store=document_store) + assert retriever.document_store == document_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.vector_function == document_store.vector_function + + def test_init(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + assert retriever.document_store == document_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_to_dict(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + def test_from_dict(self): + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + retriever = PgvectorEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_ef_search is None + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._embedding_retrieval.return_value = [doc] + + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") + res = retriever.run(query_embedding=[0.3, 0.5]) + + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" + ) + + assert res == {"documents": [doc]}