-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pgvector - Embedding Retriever (#320)
* squash * squash * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi <[email protected]> * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi <[email protected]> * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi <[email protected]> * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi <[email protected]> * fix fmt * adjust docstrings * Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py Co-authored-by: Massimiliano Pippi <[email protected]> * Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py Co-authored-by: Massimiliano Pippi <[email protected]> * improve docstrings * fmt --------- Co-authored-by: Massimiliano Pippi <[email protected]>
- Loading branch information
Showing
3 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
6 changes: 6 additions & 0 deletions
6
integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from .embedding_retriever import PgvectorEmbeddingRetriever | ||
|
||
__all__ = ["PgvectorEmbeddingRetriever"] |
104 changes: 104 additions & 0 deletions
104
.../pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.dataclasses import Document | ||
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore | ||
from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS | ||
|
||
|
||
@component | ||
class PgvectorEmbeddingRetriever: | ||
""" | ||
Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings. | ||
Needs to be connected to the PgvectorDocumentStore. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
document_store: PgvectorDocumentStore, | ||
filters: Optional[Dict[str, Any]] = None, | ||
top_k: int = 10, | ||
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, | ||
): | ||
""" | ||
Create the PgvectorEmbeddingRetriever component. | ||
:param document_store: An instance of PgvectorDocumentStore. | ||
:param filters: Filters applied to the retrieved Documents. Defaults to None. | ||
:param top_k: Maximum number of Documents to return, defaults to 10. | ||
:param vector_function: The similarity function to use when searching for similar embeddings. | ||
Defaults to the one set in the `document_store` instance. | ||
"cosine_similarity" and "inner_product" are similarity functions and | ||
higher scores indicate greater similarity between the documents. | ||
"l2_distance" returns the straight-line distance between vectors, | ||
and the most similar documents are the ones with the smallest score. | ||
Important: if the document store is using the "hnsw" search strategy, the vector function | ||
should match the one utilized during index creation to take advantage of the index. | ||
:type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | ||
:raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore. | ||
""" | ||
if not isinstance(document_store, PgvectorDocumentStore): | ||
msg = "document_store must be an instance of PgvectorDocumentStore" | ||
raise ValueError(msg) | ||
|
||
if vector_function and vector_function not in VALID_VECTOR_FUNCTIONS: | ||
msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}" | ||
raise ValueError(msg) | ||
|
||
self.document_store = document_store | ||
self.filters = filters or {} | ||
self.top_k = top_k | ||
self.vector_function = vector_function or document_store.vector_function | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return default_to_dict( | ||
self, | ||
filters=self.filters, | ||
top_k=self.top_k, | ||
vector_function=self.vector_function, | ||
document_store=self.document_store.to_dict(), | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": | ||
data["init_parameters"]["document_store"] = default_from_dict( | ||
PgvectorDocumentStore, data["init_parameters"]["document_store"] | ||
) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run( | ||
self, | ||
query_embedding: List[float], | ||
filters: Optional[Dict[str, Any]] = None, | ||
top_k: Optional[int] = None, | ||
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, | ||
): | ||
""" | ||
Retrieve documents from the PgvectorDocumentStore, based on their embeddings. | ||
:param query_embedding: Embedding of the query. | ||
:param filters: Filters applied to the retrieved Documents. | ||
:param top_k: Maximum number of Documents to return. | ||
:param vector_function: The similarity function to use when searching for similar embeddings. | ||
:type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | ||
:return: List of Documents similar to `query_embedding`. | ||
""" | ||
filters = filters or self.filters | ||
top_k = top_k or self.top_k | ||
vector_function = vector_function or self.vector_function | ||
|
||
docs = self.document_store._embedding_retrieval( | ||
query_embedding=query_embedding, | ||
filters=filters, | ||
top_k=top_k, | ||
vector_function=vector_function, | ||
) | ||
return {"documents": docs} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from unittest.mock import Mock | ||
|
||
from haystack.dataclasses import Document | ||
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever | ||
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore | ||
|
||
|
||
class TestRetriever: | ||
def test_init_default(self, document_store: PgvectorDocumentStore): | ||
retriever = PgvectorEmbeddingRetriever(document_store=document_store) | ||
assert retriever.document_store == document_store | ||
assert retriever.filters == {} | ||
assert retriever.top_k == 10 | ||
assert retriever.vector_function == document_store.vector_function | ||
|
||
def test_init(self, document_store: PgvectorDocumentStore): | ||
retriever = PgvectorEmbeddingRetriever( | ||
document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" | ||
) | ||
assert retriever.document_store == document_store | ||
assert retriever.filters == {"field": "value"} | ||
assert retriever.top_k == 5 | ||
assert retriever.vector_function == "l2_distance" | ||
|
||
def test_to_dict(self, document_store: PgvectorDocumentStore): | ||
retriever = PgvectorEmbeddingRetriever( | ||
document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" | ||
) | ||
res = retriever.to_dict() | ||
t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" | ||
assert res == { | ||
"type": t, | ||
"init_parameters": { | ||
"document_store": { | ||
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", | ||
"init_parameters": { | ||
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", | ||
"table_name": "haystack_test_to_dict", | ||
"embedding_dimension": 768, | ||
"vector_function": "cosine_similarity", | ||
"recreate_table": True, | ||
"search_strategy": "exact_nearest_neighbor", | ||
"hnsw_recreate_index_if_exists": False, | ||
"hnsw_index_creation_kwargs": {}, | ||
"hnsw_ef_search": None, | ||
}, | ||
}, | ||
"filters": {"field": "value"}, | ||
"top_k": 5, | ||
"vector_function": "l2_distance", | ||
}, | ||
} | ||
|
||
def test_from_dict(self): | ||
t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" | ||
data = { | ||
"type": t, | ||
"init_parameters": { | ||
"document_store": { | ||
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", | ||
"init_parameters": { | ||
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", | ||
"table_name": "haystack_test_to_dict", | ||
"embedding_dimension": 768, | ||
"vector_function": "cosine_similarity", | ||
"recreate_table": True, | ||
"search_strategy": "exact_nearest_neighbor", | ||
"hnsw_recreate_index_if_exists": False, | ||
"hnsw_index_creation_kwargs": {}, | ||
"hnsw_ef_search": None, | ||
}, | ||
}, | ||
"filters": {"field": "value"}, | ||
"top_k": 5, | ||
"vector_function": "l2_distance", | ||
}, | ||
} | ||
|
||
retriever = PgvectorEmbeddingRetriever.from_dict(data) | ||
document_store = retriever.document_store | ||
|
||
assert isinstance(document_store, PgvectorDocumentStore) | ||
assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" | ||
assert document_store.table_name == "haystack_test_to_dict" | ||
assert document_store.embedding_dimension == 768 | ||
assert document_store.vector_function == "cosine_similarity" | ||
assert document_store.recreate_table | ||
assert document_store.search_strategy == "exact_nearest_neighbor" | ||
assert not document_store.hnsw_recreate_index_if_exists | ||
assert document_store.hnsw_index_creation_kwargs == {} | ||
assert document_store.hnsw_ef_search is None | ||
|
||
assert retriever.filters == {"field": "value"} | ||
assert retriever.top_k == 5 | ||
assert retriever.vector_function == "l2_distance" | ||
|
||
def test_run(self): | ||
mock_store = Mock(spec=PgvectorDocumentStore) | ||
doc = Document(content="Test doc", embedding=[0.1, 0.2]) | ||
mock_store._embedding_retrieval.return_value = [doc] | ||
|
||
retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") | ||
res = retriever.run(query_embedding=[0.3, 0.5]) | ||
|
||
mock_store._embedding_retrieval.assert_called_once_with( | ||
query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" | ||
) | ||
|
||
assert res == {"documents": [doc]} |