From f170ab434711d24a66c9c1a63575c9968e54824d Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Wed, 26 Jun 2024 19:13:32 +0200 Subject: [PATCH] fix: serialization for custom_query in OpenSearch retrievers (#851) --- .../components/retrievers/opensearch/bm25_retriever.py | 1 + .../components/retrievers/opensearch/embedding_retriever.py | 1 + integrations/opensearch/tests/test_bm25_retriever.py | 5 ++++- integrations/opensearch/tests/test_embedding_retriever.py | 5 ++++- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py index 949826230..add55eb60 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/bm25_retriever.py @@ -86,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, scale_score=self._scale_score, document_store=self._document_store.to_dict(), + custom_query=self._custom_query, ) @classmethod diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index 81688601f..a51028a96 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -85,6 +85,7 @@ def to_dict(self) -> Dict[str, Any]: filters=self._filters, top_k=self._top_k, document_store=self._document_store.to_dict(), + custom_query=self._custom_query, ) @classmethod diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index a3e3e3acd..2b6961210 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -21,7 +21,7 @@ def test_init_default(): @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict(_mock_opensearch_client): document_store = OpenSearchDocumentStore(hosts="some fake host") - retriever = OpenSearchBM25Retriever(document_store=document_store) + retriever = OpenSearchBM25Retriever(document_store=document_store, custom_query={"some": "custom query"}) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever", @@ -52,6 +52,7 @@ def test_to_dict(_mock_opensearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": False, + "custom_query": {"some": "custom query"}, }, } @@ -69,6 +70,7 @@ def test_from_dict(_mock_opensearch_client): "fuzziness": "AUTO", "top_k": 10, "scale_score": True, + "custom_query": {"some": "custom query"}, }, } retriever = OpenSearchBM25Retriever.from_dict(data) @@ -77,6 +79,7 @@ def test_from_dict(_mock_opensearch_client): assert retriever._fuzziness == "AUTO" assert retriever._top_k == 10 assert retriever._scale_score + assert retriever._custom_query == {"some": "custom query"} def test_run(): diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 325ff16d3..0432ee9e3 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -20,7 +20,7 @@ def test_init_default(): @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") def test_to_dict(_mock_opensearch_client): document_store = OpenSearchDocumentStore(hosts="some fake host") - retriever = OpenSearchEmbeddingRetriever(document_store=document_store) + retriever = OpenSearchEmbeddingRetriever(document_store=document_store, custom_query={"some": "custom query"}) res = retriever.to_dict() type_s = "haystack_integrations.components.retrievers.opensearch.embedding_retriever.OpenSearchEmbeddingRetriever" assert res == { @@ -65,6 +65,7 @@ def test_to_dict(_mock_opensearch_client): }, "filters": {}, "top_k": 10, + "custom_query": {"some": "custom query"}, }, } @@ -81,12 +82,14 @@ def test_from_dict(_mock_opensearch_client): }, "filters": {}, "top_k": 10, + "custom_query": {"some": "custom query"}, }, } retriever = OpenSearchEmbeddingRetriever.from_dict(data) assert retriever._document_store assert retriever._filters == {} assert retriever._top_k == 10 + assert retriever._custom_query == {"some": "custom query"} def test_run():