Skip to content

Commit

Permalink
fix: OpenSearch - Fallback to default filter policy when deserializ…
Browse files Browse the repository at this point in the history
…ing retrievers without the init parameter (#895)
  • Loading branch information
shadeMe authored Jul 15, 2024
1 parent 20221ac commit b23ab15
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever":
data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever":
data["init_parameters"]["document_store"] = OpenSearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])

# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if "filter_policy" in data["init_parameters"]:
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
19 changes: 19 additions & 0 deletions integrations/opensearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._custom_query == {"some": "custom query"}
assert retriever._raise_on_failure is False

# For backwards compatibility with older versions of the retriever without a filter policy
data = {
"type": "haystack_integrations.components.retrievers.opensearch.bm25_retriever.OpenSearchBM25Retriever",
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore",
},
"filters": {},
"fuzziness": "AUTO",
"top_k": 10,
"scale_score": True,
"custom_query": {"some": "custom query"},
"raise_on_failure": False,
},
}
retriever = OpenSearchBM25Retriever.from_dict(data)
assert retriever._filter_policy == FilterPolicy.REPLACE


def test_run():
mock_store = Mock(spec=OpenSearchDocumentStore)
Expand Down
17 changes: 17 additions & 0 deletions integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ def test_from_dict(_mock_opensearch_client):
assert retriever._raise_on_failure is False
assert retriever._filter_policy == FilterPolicy.REPLACE

# For backwards compatibility with older versions of the retriever without a filter policy
data = {
"type": type_s,
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore",
},
"filters": {},
"top_k": 10,
"custom_query": {"some": "custom query"},
"raise_on_failure": False,
},
}
retriever = OpenSearchEmbeddingRetriever.from_dict(data)
assert retriever._filter_policy == FilterPolicy.REPLACE


def test_run():
mock_store = Mock(spec=OpenSearchDocumentStore)
Expand Down

0 comments on commit b23ab15

Please sign in to comment.