From 7b6752d91f053ac2152bcf37081903fc84dda958 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 28 May 2024 21:30:51 +0200 Subject: [PATCH 1/4] feat: defer the database connection to when it's needed --- .../mongodb_atlas/document_store.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 4cb5b8659..9003aa064 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -81,22 +81,28 @@ def __init__( msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' raise ValueError(msg) - resolved_connection_string = mongo_connection_string.resolve_value() + self.resolved_connection_string = mongo_connection_string.resolve_value() self.mongo_connection_string = mongo_connection_string self.database_name = database_name self.collection_name = collection_name self.vector_search_index = vector_search_index + self._connection: Optional[MongoClient] = None + + @property + def connection(self) -> MongoClient: + if self._connection is None: + self._connection = MongoClient( + self.resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) + database = self.connection[self.database_name] + + if self.collection_name not in database.list_collection_names(): + msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'." + raise ValueError(msg) + self.collection = database[self.collection_name] - self.connection: MongoClient = MongoClient( - resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = self.connection[self.database_name] - - if collection_name not in database.list_collection_names(): - msg = f"Collection '{collection_name}' does not exist in database '{database_name}'." - raise ValueError(msg) - self.collection = database[self.collection_name] + return self._connection def to_dict(self) -> Dict[str, Any]: """ From e7264f24729331693b2dcaf848f8e9e168e52bce Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 28 May 2024 21:41:02 +0200 Subject: [PATCH 2/4] lazy collection too --- .../document_stores/mongodb_atlas/document_store.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 9003aa064..83e93e269 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -12,6 +12,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.collection import Collection from pymongo.driver_info import DriverInfo from pymongo.errors import BulkWriteError @@ -88,6 +89,7 @@ def __init__( self.collection_name = collection_name self.vector_search_index = vector_search_index self._connection: Optional[MongoClient] = None + self._collection: Optional[Collection] = None @property def connection(self) -> MongoClient: @@ -95,14 +97,19 @@ def connection(self) -> MongoClient: self._connection = MongoClient( self.resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") ) + + return self._connection + + @property + def collection(self) -> Collection: + if self._collection is None: database = self.connection[self.database_name] if self.collection_name not in database.list_collection_names(): msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'." raise ValueError(msg) - self.collection = database[self.collection_name] - - return self._connection + self._collection = database[self.collection_name] + return self._collection def to_dict(self) -> Dict[str, Any]: """ From 7cbdaf74345e5ad110f625cb379e937885a8b074 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 28 May 2024 22:00:24 +0200 Subject: [PATCH 3/4] add test --- .../mongodb_atlas/tests/test_document_store.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 89810ec8b..a94227b93 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os from uuid import uuid4 +from unittest.mock import patch import pytest from haystack.dataclasses.document import ByteStream, Document @@ -16,13 +17,23 @@ from pymongo.driver_info import DriverInfo +@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient") +def test_init_is_lazy(_mock_client): + MongoDBAtlasDocumentStore( + mongo_connection_string=Secret.from_token("test"), + database_name="database_name", + collection_name="collection_name", + vector_search_index="cosine_index", + ) + _mock_client.assert_not_called() + + @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): - @pytest.fixture def document_store(self): database_name = "haystack_integration_test" From 862176b8d9f7009708291c42b254af633a4ff31c Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 29 May 2024 07:48:21 +0200 Subject: [PATCH 4/4] linting --- integrations/mongodb_atlas/tests/test_document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index a94227b93..453d9d16c 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -2,8 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 import os -from uuid import uuid4 from unittest.mock import patch +from uuid import uuid4 import pytest from haystack.dataclasses.document import ByteStream, Document