Skip to content

Commit

Permalink
fix: ElasticSearch - Fallback to default filter policy when deseria…
Browse files Browse the repository at this point in the history
…lizing retrievers without the init parameter (#898)

* Add defensive check for filter_policy deserialization

* Add defensive check for filter_policy deserialization

* Add unit test

* Revert change in chroma

* Linter fix
  • Loading branch information
vblagoje authored Jul 16, 2024
1 parent 6349d15 commit b33505a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchBM25Retriever":
data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchEmbeddingRetriever":
data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
24 changes: 24 additions & 0 deletions integrations/elasticsearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ def test_from_dict(_mock_elasticsearch_client):
assert retriever._filter_policy == FilterPolicy.REPLACE


@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_from_dict_no_filter_policy(_mock_elasticsearch_client):
data = {
"type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever",
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
},
"filters": {},
"fuzziness": "AUTO",
"top_k": 10,
"scale_score": True,
},
}
retriever = ElasticsearchBM25Retriever.from_dict(data)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._fuzziness == "AUTO"
assert retriever._top_k == 10
assert retriever._scale_score
assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE


def test_run():
mock_store = Mock(spec=ElasticsearchDocumentStore)
mock_store._bm25_retrieval.return_value = [Document(content="Test doc")]
Expand Down
23 changes: 23 additions & 0 deletions integrations/elasticsearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,29 @@ def test_from_dict(_mock_elasticsearch_client):
assert retriever._num_candidates is None


@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_from_dict_no_filter_policy(_mock_elasticsearch_client):
t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever"
data = {
"type": t,
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
},
"filters": {},
"top_k": 10,
"num_candidates": None,
},
}
retriever = ElasticsearchEmbeddingRetriever.from_dict(data)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._num_candidates is None
assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE


def test_run():
mock_store = Mock(spec=ElasticsearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
Expand Down

0 comments on commit b33505a

Please sign in to comment.