diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index c0742d8bb..ab6f9b908 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -2,10 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore logger = logging.getLogger(__name__) @@ -22,6 +24,7 @@ def __init__( top_k: int = 10, scale_score: bool = False, all_terms_must_match: bool = False, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, custom_query: Optional[Dict[str, Any]] = None, raise_on_failure: bool = True, ): @@ -36,6 +39,7 @@ def __init__( This is useful when comparing documents across different indexes. Defaults to False. :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. This is useful when searching for short text where even one term can make a difference. Defaults to False. + :param filter_policy: Policy to determine how filters are applied. :param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder **An example custom_query:** @@ -76,6 +80,9 @@ def __init__( self._top_k = top_k self._scale_score = scale_score self._all_terms_must_match = all_terms_must_match + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure @@ -93,6 +100,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, scale_score=self._scale_score, document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, custom_query=self._custom_query, raise_on_failure=self._raise_on_failure, ) @@ -111,6 +119,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever": data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -128,7 +137,9 @@ def run( Retrieve documents using BM25 retrieval. :param query: The query string - :param filters: Optional filters to narrow down the search space. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at document store initialization. See init method docstring for more + details. :param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents. :param top_k: Maximum number of Documents to return. :param fuzziness: Fuzziness parameter for full-text queries. @@ -164,6 +175,8 @@ def run( - documents: List of retrieved Documents. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + if filters is None: filters = self._filters if all_terms_must_match is None: diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 4283a9558..a9f418db8 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -2,10 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore logger = logging.getLogger(__name__) @@ -25,6 +27,7 @@ def __init__( document_store: OpenSearchDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, custom_query: Optional[Dict[str, Any]] = None, raise_on_failure: bool = True, ): @@ -35,6 +38,7 @@ def __init__( :param filters: Filters applied to the retrieved Documents. Defaults to None. Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. :param top_k: Maximum number of Documents to return, defaults to 10 + :param filter_policy: Policy to determine how filters are applied. :param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder **An example custom_query:** @@ -77,6 +81,9 @@ def __init__( self._document_store = document_store self._filters = filters or {} self._top_k = top_k + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure @@ -92,6 +99,7 @@ def to_dict(self) -> Dict[str, Any]: filters=self._filters, top_k=self._top_k, document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, custom_query=self._custom_query, raise_on_failure=self._raise_on_failure, ) @@ -110,6 +118,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) @@ -124,7 +133,9 @@ def run( Retrieve documents using a vector similarity metric. :param query_embedding: Embedding of the query. - :param filters: Optional filters to narrow down the search space. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at document store initialization. See init method docstring for more + details. :param top_k: Maximum number of Documents to return. :param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder @@ -161,6 +172,8 @@ def run( Dictionary with key "documents" containing the retrieved Documents. - documents: List of Document similar to `query_embedding`. """ + filters = apply_filter_policy(self._filter_policy, self._filters, filters) + top_k = top_k or self._top_k if filters is None: filters = self._filters if top_k is None: diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index d0a77733c..63d69b881 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.opensearch import OpenSearchBM25Retriever from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES @@ -16,6 +18,13 @@ def test_init_default(): assert retriever._filters == {} assert retriever._top_k == 10 assert not retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = OpenSearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + OpenSearchBM25Retriever(document_store=mock_store, filter_policy="unknown") @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") @@ -52,6 +61,7 @@ def test_to_dict(_mock_opensearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": False, + "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": True, }, @@ -71,6 +81,7 @@ def test_from_dict(_mock_opensearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": True, + "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": False, }, @@ -81,6 +92,7 @@ def test_from_dict(_mock_opensearch_client): assert retriever._fuzziness == "AUTO" assert retriever._top_k == 10 assert retriever._scale_score + assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._custom_query == {"some": "custom query"} assert retriever._raise_on_failure is False diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 456abd180..cf5a5dded 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +import pytest from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy from haystack_integrations.components.retrievers.opensearch import OpenSearchEmbeddingRetriever from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES @@ -15,6 +17,13 @@ def test_init_default(): assert retriever._document_store == mock_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") @@ -65,6 +74,7 @@ def test_to_dict(_mock_opensearch_client): }, "filters": {}, "top_k": 10, + "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": True, }, @@ -83,6 +93,7 @@ def test_from_dict(_mock_opensearch_client): }, "filters": {}, "top_k": 10, + "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": False, }, @@ -93,6 +104,7 @@ def test_from_dict(_mock_opensearch_client): assert retriever._top_k == 10 assert retriever._custom_query == {"some": "custom query"} assert retriever._raise_on_failure is False + assert retriever._filter_policy == FilterPolicy.REPLACE def test_run():