Skip to content

Commit

Permalink
fix: Mongo - Fallback to default filter policy when deserializing r…
Browse files Browse the repository at this point in the history
…etrievers without the init parameter (#899)

* Add defensive check for filter_policy deserialization

* black test

* Fix ruff

* Black tests
  • Loading branch information
vblagoje authored Jul 17, 2024
1 parent b1201e0 commit d943f4e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever":
data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"])
# Pipelines serialized with old versions of the component might not
# have the filter_policy field.
if filter_policy := data["init_parameters"].get("filter_policy"):
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
36 changes: 36 additions & 0 deletions integrations/mongodb_atlas/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,42 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client
assert retriever.top_k == 5
assert retriever.filter_policy == FilterPolicy.REPLACE

def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears unused but is required
monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str")

data = {
"type": "haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever.MongoDBAtlasEmbeddingRetriever", # noqa: E501
"init_parameters": {
"document_store": {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501
"init_parameters": {
"mongo_connection_string": {
"env_vars": ["MONGO_CONNECTION_STRING"],
"strict": True,
"type": "env_var",
},
"database_name": "haystack_integration_test",
"collection_name": "test_embeddings_collection",
"vector_search_index": "cosine_index",
},
},
"filters": {"field": "value"},
"top_k": 5,
},
}

retriever = MongoDBAtlasEmbeddingRetriever.from_dict(data)
document_store = retriever.document_store

assert isinstance(document_store, MongoDBAtlasDocumentStore)
assert isinstance(document_store.mongo_connection_string, EnvVarSecret)
assert document_store.database_name == "haystack_integration_test"
assert document_store.collection_name == "test_embeddings_collection"
assert document_store.vector_search_index == "cosine_index"
assert retriever.filters == {"field": "value"}
assert retriever.top_k == 5
assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE

def test_run(self):
mock_store = Mock(spec=MongoDBAtlasDocumentStore)
doc = Document(content="Test doc", embedding=[0.1, 0.2])
Expand Down

0 comments on commit d943f4e

Please sign in to comment.