Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add filter_policy to pinecone integration #821

Merged
merged 13 commits into from
Jul 5, 2024
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy

from haystack_integrations.document_stores.pinecone import PineconeDocumentStore

Expand Down Expand Up @@ -55,11 +57,13 @@ def __init__(
document_store: PineconeDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
):
"""
:param document_store: The Pinecone Document Store.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
:param filter_policy: Policy to determine how filters are applied.

:raises ValueError: If `document_store` is not an instance of `PineconeDocumentStore`.
"""
Expand All @@ -70,6 +74,9 @@ def __init__(
self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k
self.filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -81,6 +88,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
filters=self.filters,
top_k=self.top_k,
filter_policy=self.filter_policy.value,
document_store=self.document_store.to_dict(),
)

Expand All @@ -96,19 +104,34 @@ def from_dict(cls, data: Dict[str, Any]) -> "PineconeEmbeddingRetriever":
data["init_parameters"]["document_store"] = PineconeDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
):
"""
Retrieve documents from the `PineconeDocumentStore`, based on their dense embeddings.

:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at document store initialization. See init method docstring for more
details.
:param top_k: Maximum number of `Document`s to return.

:returns: List of Document similar to `query_embedding`.
"""
filters = apply_filter_policy(self.filter_policy, self.filters, filters)

top_k = top_k or self.top_k

docs = self.document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self.filters,
top_k=self.top_k,
filters=filters,
top_k=top_k,
)
return {"documents": docs}
12 changes: 12 additions & 0 deletions integrations/pinecone/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch

import pytest
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack.utils import Secret

from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever
Expand All @@ -16,6 +18,13 @@ def test_init_default():
assert retriever.document_store == mock_store
assert retriever.filters == {}
assert retriever.top_k == 10
assert retriever.filter_policy == FilterPolicy.REPLACE

retriever = PineconeEmbeddingRetriever(document_store=mock_store, filter_policy="replace")
assert retriever.filter_policy == FilterPolicy.REPLACE

with pytest.raises(ValueError):
PineconeEmbeddingRetriever(document_store=mock_store, filter_policy="invalid")


@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone")
Expand Down Expand Up @@ -53,6 +62,7 @@ def test_to_dict(mock_pinecone, monkeypatch):
},
"filters": {},
"top_k": 10,
"filter_policy": "replace",
},
}

Expand Down Expand Up @@ -82,6 +92,7 @@ def test_from_dict(mock_pinecone, monkeypatch):
},
"filters": {},
"top_k": 10,
"filter_policy": "replace",
},
}

Expand All @@ -100,6 +111,7 @@ def test_from_dict(mock_pinecone, monkeypatch):

assert retriever.filters == {}
assert retriever.top_k == 10
assert retriever.filter_policy == FilterPolicy.REPLACE


def test_run():
Expand Down