From b23ab153c86c8724edfb341fce680f5a38004162 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Mon, 15 Jul 2024 17:13:15 +0200 Subject: [PATCH] fix: `OpenSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#895) --- .../retrievers/opensearch/bm25_retriever.py | 6 +++++- .../opensearch/embedding_retriever.py | 6 +++++- .../opensearch/tests/test_bm25_retriever.py | 19 +++++++++++++++++++ .../tests/test_embedding_retriever.py | 17 +++++++++++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index 64aa57499..06a7afe85 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -119,7 +119,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever": data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 08afce44b..eba5596f2 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -118,7 +118,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever": data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict( data["init_parameters"]["document_store"] ) - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index 63d69b881..1cce2961c 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -96,6 +96,25 @@ def test_from_dict(_mock_opensearch_client): assert retriever._custom_query == {"some": "custom query"} assert retriever._raise_on_failure is False + # For backwards compatibility with older versions of the retriever without a filter policy + data = { + "type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "fuzziness": "AUTO", + "top_k": 10, + "scale_score": True, + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchBM25Retriever.from_dict(data) + assert retriever._filter_policy == FilterPolicy.REPLACE + def test_run(): mock_store = Mock(spec=OpenSearchDocumentStore) diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index cf5a5dded..38be08698 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -106,6 +106,23 @@ def test_from_dict(_mock_opensearch_client): assert retriever._raise_on_failure is False assert retriever._filter_policy == FilterPolicy.REPLACE + # For backwards compatibility with older versions of the retriever without a filter policy + data = { + "type": type_s, + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "custom_query": {"some": "custom query"}, + "raise_on_failure": False, + }, + } + retriever = OpenSearchEmbeddingRetriever.from_dict(data) + assert retriever._filter_policy == FilterPolicy.REPLACE + def test_run(): mock_store = Mock(spec=OpenSearchDocumentStore)