Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add filter_policy to opensearch integration #822

Merged
merged 13 commits into from
Jul 5, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore

logger = logging.getLogger(__name__)
Expand All @@ -22,6 +24,7 @@ def __init__(
top_k: int = 10,
scale_score: bool = False,
all_terms_must_match: bool = False,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
custom_query: Optional[Dict[str, Any]] = None,
raise_on_failure: bool = True,
):
Expand All @@ -36,6 +39,7 @@ def __init__(
This is useful when comparing documents across different indexes. Defaults to False.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
This is useful when searching for short text where even one term can make a difference. Defaults to False.
:param filter_policy: Policy to determine how filters are applied.
:param custom_query: The query containing a mandatory `$query` and an optional `$filters` placeholder

**An example custom_query:**
Expand Down Expand Up @@ -76,6 +80,9 @@ def __init__(
self._top_k = top_k
self._scale_score = scale_score
self._all_terms_must_match = all_terms_must_match
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._custom_query = custom_query
self._raise_on_failure = raise_on_failure

Expand All @@ -93,6 +100,7 @@ def to_dict(self) -> Dict[str, Any]:
top_k=self._top_k,
scale_score=self._scale_score,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
custom_query=self._custom_query,
raise_on_failure=self._raise_on_failure,
)
Expand All @@ -111,6 +119,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever":
data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand All @@ -128,7 +137,9 @@ def run(
Retrieve documents using BM25 retrieval.

:param query: The query string
:param filters: Optional filters to narrow down the search space.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at document store initialization. See init method docstring for more
details.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
:param top_k: Maximum number of Documents to return.
:param fuzziness: Fuzziness parameter for full-text queries.
Expand Down Expand Up @@ -164,6 +175,8 @@ def run(
- documents: List of retrieved Documents.

"""
filters = apply_filter_policy(self._filter_policy, self._filters, filters)

if filters is None:
filters = self._filters
if all_terms_must_match is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack.document_stores.types.filter_policy import apply_filter_policy
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +27,7 @@ def __init__(
document_store: OpenSearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
custom_query: Optional[Dict[str, Any]] = None,
raise_on_failure: bool = True,
):
Expand All @@ -35,6 +38,7 @@ def __init__(
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param filter_policy: Policy to determine how filters are applied.
:param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder

**An example custom_query:**
Expand Down Expand Up @@ -77,6 +81,9 @@ def __init__(
self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k
self._filter_policy = (
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
)
self._custom_query = custom_query
self._raise_on_failure = raise_on_failure

Expand All @@ -92,6 +99,7 @@ def to_dict(self) -> Dict[str, Any]:
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
filter_policy=self._filter_policy.value,
custom_query=self._custom_query,
raise_on_failure=self._raise_on_failure,
)
Expand All @@ -110,6 +118,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever":
data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand All @@ -124,7 +133,9 @@ def run(
Retrieve documents using a vector similarity metric.

:param query_embedding: Embedding of the query.
:param filters: Optional filters to narrow down the search space.
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
the `filter_policy` chosen at document store initialization. See init method docstring for more
details.
:param top_k: Maximum number of Documents to return.
:param custom_query: The query containing a mandatory `$query_embedding` and an optional `$filters` placeholder

Expand Down Expand Up @@ -161,6 +172,8 @@ def run(
Dictionary with key "documents" containing the retrieved Documents.
- documents: List of Document similar to `query_embedding`.
"""
filters = apply_filter_policy(self._filter_policy, self._filters, filters)
top_k = top_k or self._top_k
if filters is None:
filters = self._filters
if top_k is None:
Expand Down
12 changes: 12 additions & 0 deletions integrations/opensearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch

import pytest
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack_integrations.components.retrievers.opensearch import OpenSearchBM25Retriever
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore
from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES
Expand All @@ -16,6 +18,13 @@ def test_init_default():
assert retriever._filters == {}
assert retriever._top_k == 10
assert not retriever._scale_score
assert retriever._filter_policy == FilterPolicy.REPLACE

retriever = OpenSearchBM25Retriever(document_store=mock_store, filter_policy="replace")
assert retriever._filter_policy == FilterPolicy.REPLACE

with pytest.raises(ValueError):
OpenSearchBM25Retriever(document_store=mock_store, filter_policy="unknown")


@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch")
Expand Down Expand Up @@ -52,6 +61,7 @@ def test_to_dict(_mock_opensearch_client):
"fuzziness": "AUTO",
"top_k": 10,
"scale_score": False,
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": True,
},
Expand All @@ -71,6 +81,7 @@ def test_from_dict(_mock_opensearch_client):
"fuzziness": "AUTO",
"top_k": 10,
"scale_score": True,
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": False,
},
Expand All @@ -81,6 +92,7 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._fuzziness == "AUTO"
assert retriever._top_k == 10
assert retriever._scale_score
assert retriever._filter_policy == FilterPolicy.REPLACE
assert retriever._custom_query == {"some": "custom query"}
assert retriever._raise_on_failure is False

Expand Down
12 changes: 12 additions & 0 deletions integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch

import pytest
from haystack.dataclasses import Document
from haystack.document_stores.types import FilterPolicy
from haystack_integrations.components.retrievers.opensearch import OpenSearchEmbeddingRetriever
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore
from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES
Expand All @@ -15,6 +17,13 @@ def test_init_default():
assert retriever._document_store == mock_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._filter_policy == FilterPolicy.REPLACE

retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace")
assert retriever._filter_policy == FilterPolicy.REPLACE

with pytest.raises(ValueError):
OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown")


@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch")
Expand Down Expand Up @@ -65,6 +74,7 @@ def test_to_dict(_mock_opensearch_client):
},
"filters": {},
"top_k": 10,
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": True,
},
Expand All @@ -83,6 +93,7 @@ def test_from_dict(_mock_opensearch_client):
},
"filters": {},
"top_k": 10,
"filter_policy": "replace",
"custom_query": {"some": "custom query"},
"raise_on_failure": False,
},
Expand All @@ -93,6 +104,7 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._top_k == 10
assert retriever._custom_query == {"some": "custom query"}
assert retriever._raise_on_failure is False
assert retriever._filter_policy == FilterPolicy.REPLACE


def test_run():
Expand Down