Skip to content

Commit

Permalink
fix: serialization for custom_query in OpenSearch retrievers (#851)
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel authored Jun 26, 2024
1 parent 49e323f commit f170ab4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]:
top_k=self._top_k,
scale_score=self._scale_score,
document_store=self._document_store.to_dict(),
custom_query=self._custom_query,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def to_dict(self) -> Dict[str, Any]:
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
custom_query=self._custom_query,
)

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion integrations/opensearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_init_default():
@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch")
def test_to_dict(_mock_opensearch_client):
document_store = OpenSearchDocumentStore(hosts="some fake host")
retriever = OpenSearchBM25Retriever(document_store=document_store)
retriever = OpenSearchBM25Retriever(document_store=document_store, custom_query={"some": "custom query"})
res = retriever.to_dict()
assert res == {
"type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever",
Expand Down Expand Up @@ -52,6 +52,7 @@ def test_to_dict(_mock_opensearch_client):
"fuzziness": "AUTO",
"top_k": 10,
"scale_score": False,
"custom_query": {"some": "custom query"},
},
}

Expand All @@ -69,6 +70,7 @@ def test_from_dict(_mock_opensearch_client):
"fuzziness": "AUTO",
"top_k": 10,
"scale_score": True,
"custom_query": {"some": "custom query"},
},
}
retriever = OpenSearchBM25Retriever.from_dict(data)
Expand All @@ -77,6 +79,7 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._fuzziness == "AUTO"
assert retriever._top_k == 10
assert retriever._scale_score
assert retriever._custom_query == {"some": "custom query"}


def test_run():
Expand Down
5 changes: 4 additions & 1 deletion integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_init_default():
@patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch")
def test_to_dict(_mock_opensearch_client):
document_store = OpenSearchDocumentStore(hosts="some fake host")
retriever = OpenSearchEmbeddingRetriever(document_store=document_store)
retriever = OpenSearchEmbeddingRetriever(document_store=document_store, custom_query={"some": "custom query"})
res = retriever.to_dict()
type_s = "haystack_integrations.components.retrievers.opensearch.embedding_retriever.OpenSearchEmbeddingRetriever"
assert res == {
Expand Down Expand Up @@ -65,6 +65,7 @@ def test_to_dict(_mock_opensearch_client):
},
"filters": {},
"top_k": 10,
"custom_query": {"some": "custom query"},
},
}

Expand All @@ -81,12 +82,14 @@ def test_from_dict(_mock_opensearch_client):
},
"filters": {},
"top_k": 10,
"custom_query": {"some": "custom query"},
},
}
retriever = OpenSearchEmbeddingRetriever.from_dict(data)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._custom_query == {"some": "custom query"}


def test_run():
Expand Down

0 comments on commit f170ab4

Please sign in to comment.