diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index e3f5062fe..432b86d4c 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -13,7 +13,28 @@ class MongoDBAtlasEmbeddingRetriever: """ Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity. - Needs to be connected to the MongoDBAtlasDocumentStore. + The similarity is dependent on the vector_search_index used in the MongoDBAtlasDocumentStore and the chosen metric + during the creation of the index (i.e. cosine, dot product, or euclidean). See MongoDBAtlasDocumentStore for more + information. + + Usage example: + ```python + import numpy as np + from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever + + store = MongoDBAtlasDocumentStore(database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index") + retriever = MongoDBAtlasEmbeddingRetriever(document_store=store) + + results = retriever.run(query_embedding=np.random.random(768).tolist()) + print(results["documents"]) + ``` + + The example above retrieves the 10 most similar documents to a random query embedding from the + MongoDBAtlasDocumentStore. Note that dimensions of the query_embedding must match the dimensions of the embeddings + stored in the MongoDBAtlasDocumentStore. """ def __init__( @@ -29,6 +50,8 @@ def __init__( :param document_store: An instance of MongoDBAtlasDocumentStore. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. + + :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. """ if not isinstance(document_store, MongoDBAtlasDocumentStore): msg = "document_store must be an instance of MongoDBAtlasDocumentStore" @@ -40,7 +63,10 @@ def __init__( def to_dict(self) -> Dict[str, Any]: """ - Serializes this component into a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -52,10 +78,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever": """ - Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever.to_dict()` into a - `MongoDBAtlasEmbeddingRetriever` instance. + Deserializes the component from a dictionary. - :param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever.to_dict()` + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( data["init_parameters"]["document_store"] @@ -70,18 +98,19 @@ def run( top_k: Optional[int] = None, ) -> Dict[str, List[Document]]: """ - Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings. + Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided embedding similarity. :param query_embedding: Embedding of the query. :param filters: Filters applied to the retrieved Documents. Overrides the value specified at initialization. :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. - :returns: List of Documents similar to `query_embedding`. + :returns: A dictionary with the following keys: + - `documents`: List of Documents most similar to the given `query_embedding` """ filters = filters or self.filters top_k = top_k or self.top_k - docs = self.document_store.embedding_retrieval( - query_embedding_np=query_embedding, + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, filters=filters, top_k=top_k, ) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index f76a31eb0..0d7116b3a 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -19,6 +19,39 @@ class MongoDBAtlasDocumentStore: + """ + MongoDBAtlasDocumentStore is a DocumentStore implementation that uses [MongoDB Atlas](https://www.mongodb.com/atlas/database). + service that is easy to deploy, operate, and scale. + + To connect to MongoDB Atlas, you need to provide a connection string in the format: + "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". + + This connection string can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button, selecting + Python as the driver, and copying the connection string. The connection string can be provided as an environment + variable `MONGO_CONNECTION_STRING` or directly as a parameter to the `MongoDBAtlasDocumentStore` constructor. + + After providing the connection string, you'll need to specify the `database_name` and `collection_name` to use. + Most likely that you'll create these via the MongoDB Atlas web UI but one can also create them via the MongoDB + Python driver. Creating databases and collections is beyond the scope of MongoDBAtlasDocumentStore. The primary + purpose of this document store is to read and write documents to an existing collection. + + The last parameter users needs to provide is a `vector_search_index` - used for vector search operations. This index + can support a chosen metric (i.e. cosine, dot product, or euclidean) and can be created in the Atlas web UI. + + For more details on MongoDB Atlas, see the official + MongoDB Atlas [documentation](https://www.mongodb.com/docs/atlas/getting-started/) + + Usage example: + ```python + from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + store = MongoDBAtlasDocumentStore(database_name="your_existing_db", + collection_name="your_existing_collection", + vector_search_index="your_existing_index") + print(store.count_documents()) + ``` + """ + def __init__( self, *, @@ -30,8 +63,6 @@ def __init__( """ Creates a new MongoDBAtlasDocumentStore instance. - This Document Store uses MongoDB Atlas as a backend (https://www.mongodb.com/docs/atlas/getting-started/). - :param mongo_connection_string: MongoDB Atlas connection string in the format: "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button. @@ -41,7 +72,10 @@ def __init__( this collection needs to have a vector search index set up on the `embedding` field. :param vector_search_index: The name of the vector search index to use for vector search operations. Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ - See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index + For more details refer to MongoDB + Atlas [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index) + + :raises ValueError: If the collection name contains invalid characters. """ if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)): msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' @@ -66,7 +100,10 @@ def __init__( def to_dict(self) -> Dict[str, Any]: """ - Utility function that serializes this Document Store's configuration into a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -79,7 +116,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore": """ - Utility function that deserializes this Document Store's configuration from a dictionary. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ deserialize_secrets_inplace(data["init_parameters"], keys=["mongo_connection_string"]) return default_from_dict(cls, data) @@ -87,6 +129,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasDocumentStore": def count_documents(self) -> int: """ Returns how many documents are present in the document store. + + :returns: The number of documents in the document store. """ return self.collection.count_documents({}) @@ -95,7 +139,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc Returns the documents that match the filters provided. For a detailed specification of the filters, - refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering). + refer to the Haystack [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering). :param filters: The filters to apply. It returns only the documents that match the filters. :returns: A list of Documents that match the given filters. @@ -108,12 +152,13 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ - Writes documents into to PgvectorDocumentStore. + Writes documents into the MongoDB Atlas collection. :param documents: A list of Documents to write to the document store. :param policy: The duplicate policy to use when writing documents. - :raises DuplicateDocumentError: If a document with the same id already exists in the document store + :raises DuplicateDocumentError: If a document with the same ID already exists in the document store and the policy is set to DuplicatePolicy.FAIL (or not specified). + :raises ValueError: If the documents are not of type Document. :returns: The number of documents written to the document store. """ @@ -156,7 +201,7 @@ def delete_documents(self, document_ids: List[str]) -> None: return self.collection.delete_many(filter={"id": {"$in": document_ids}}) - def embedding_retrieval( + def _embedding_retrieval( self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, @@ -168,6 +213,9 @@ def embedding_retrieval( :param query_embedding: Embedding of the query :param filters: Optional filters. :param top_k: How many documents to return. + :returns: A list of Documents that are most similar to the given `query_embedding` + :raises ValueError: If `query_embedding` is empty. + :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails. """ if not query_embedding: msg = "Query embedding must not be empty" @@ -203,10 +251,10 @@ def embedding_retrieval( msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" raise DocumentStoreError(msg) from e - documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents] + documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] return documents - def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: + def _mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ Converts the dictionary coming out of MongoDB into a Haystack document diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index aa7790bc7..54bbdedfd 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -21,7 +21,7 @@ def test_embedding_retrieval_cosine_similarity(self): vector_search_index="cosine_index", ) query_embedding = [0.1] * 768 - results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document A" assert results[1].content == "Document B" @@ -34,7 +34,7 @@ def test_embedding_retrieval_dot_product(self): vector_search_index="dotProduct_index", ) query_embedding = [0.1] * 768 - results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document A" assert results[1].content == "Document B" @@ -47,7 +47,7 @@ def test_embedding_retrieval_euclidean(self): vector_search_index="euclidean_index", ) query_embedding = [0.1] * 768 - results = document_store.embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 assert results[0].content == "Document C" assert results[1].content == "Document B" @@ -61,7 +61,7 @@ def test_empty_query_embedding(self): ) query_embedding: List[float] = [] with pytest.raises(ValueError): - document_store.embedding_retrieval(query_embedding=query_embedding) + document_store._embedding_retrieval(query_embedding=query_embedding) def test_query_embedding_wrong_dimension(self): document_store = MongoDBAtlasDocumentStore( @@ -71,4 +71,4 @@ def test_query_embedding_wrong_dimension(self): ) query_embedding = [0.1] * 4 with pytest.raises(DocumentStoreError): - document_store.embedding_retrieval(query_embedding=query_embedding) + document_store._embedding_retrieval(query_embedding=query_embedding) diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 887d22b76..ec44513e2 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -97,11 +97,11 @@ def test_from_dict(self): def test_run(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) doc = Document(content="Test doc", embedding=[0.1, 0.2]) - mock_store.embedding_retrieval.return_value = [doc] + mock_store._embedding_retrieval.return_value = [doc] retriever = MongoDBAtlasEmbeddingRetriever(document_store=mock_store) res = retriever.run(query_embedding=[0.3, 0.5]) - mock_store.embedding_retrieval.assert_called_once_with(query_embedding_np=[0.3, 0.5], filters={}, top_k=10) + mock_store._embedding_retrieval.assert_called_once_with(query_embedding=[0.3, 0.5], filters={}, top_k=10) assert res == {"documents": [doc]}