Skip to content

Commit

Permalink
feat: MongoDBAtlasEmbeddingRetriever (#427)
Browse files Browse the repository at this point in the history
* initial implementation

* vector index seems non functional

* tests are green for docstore

* tests green

* no parallel tests

* lint

* use different collections for write tests

* fix tests

* black

* docstring

* add doc fields

* black

* Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py

Co-authored-by: Madeesh Kannan <[email protected]>

* Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py

* Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py

* Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py

* Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py

* Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py

* Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py

Co-authored-by: Madeesh Kannan <[email protected]>

* Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py

Co-authored-by: Madeesh Kannan <[email protected]>

* Update integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py

* black

* docstring

* black

* mypy

* rename

* deserialization

* unused import

* Update integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py

Co-authored-by: Silvano Cerza <[email protected]>

* change import

* ruff

* remove numpy conversion

---------

Co-authored-by: Madeesh Kannan <[email protected]>
Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
3 people authored Feb 23, 2024
1 parent c8db92f commit 8f73a8b
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever

__all__ = ["MongoDBAtlasEmbeddingRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore


@component
class MongoDBAtlasEmbeddingRetriever:
"""
Retrieves documents from the MongoDBAtlasDocumentStore by embedding similarity.
Needs to be connected to the MongoDBAtlasDocumentStore.
"""

def __init__(
self,
*,
document_store: MongoDBAtlasDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create the MongoDBAtlasDocumentStore component.
:param document_store: An instance of MongoDBAtlasDocumentStore.
:param filters: Filters applied to the retrieved Documents.
:param top_k: Maximum number of Documents to return.
"""
if not isinstance(document_store, MongoDBAtlasDocumentStore):
msg = "document_store must be an instance of MongoDBAtlasDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this component into a dictionary.
"""
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever":
"""
Deserializes a dictionary created with `MongoDBAtlasEmbeddingRetriever.to_dict()` into a
`MongoDBAtlasEmbeddingRetriever` instance.
:param data: the dictionary returned by `MongoDBAtlasEmbeddingRetriever.to_dict()`
"""
data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on their embeddings.
:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents. Overrides the value specified at initialization.
:param top_k: Maximum number of Documents to return. Overrides the value specified at initialization.
:return: List of Documents similar to `query_embedding`.
"""
filters = filters or self.filters
top_k = top_k or self.top_k

docs = self.document_store.embedding_retrieval(
query_embedding_np=query_embedding,
filters=filters,
top_k=top_k,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from haystack import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.errors import DuplicateDocumentError
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo
Expand All @@ -25,7 +25,7 @@ def __init__(
mongo_connection_string: Secret = Secret.from_env_var("MONGO_CONNECTION_STRING"), # noqa: B008
database_name: str,
collection_name: str,
recreate_collection: bool = False,
vector_search_index: str,
):
"""
Creates a new MongoDBAtlasDocumentStore instance.
Expand All @@ -37,8 +37,11 @@ def __init__(
This can be obtained on the MongoDB Atlas Dashboard by clicking on the `CONNECT` button.
This value will be read automatically from the env var "MONGO_CONNECTION_STRING".
:param database_name: Name of the database to use.
:param collection_name: Name of the collection to use.
:param recreate_collection: Whether to recreate the collection when initializing the document store.
:param collection_name: Name of the collection to use. To use this document store for embedding retrieval,
this collection needs to have a vector search index set up on the `embedding` field.
:param vector_search_index: The name of the vector search index to use for vector search operations.
Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \
See https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index
"""
if collection_name and not bool(re.match(r"^[a-zA-Z0-9\-_]+$", collection_name)):
msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.'
Expand All @@ -49,21 +52,16 @@ def __init__(

self.database_name = database_name
self.collection_name = collection_name
self.recreate_collection = recreate_collection
self.vector_search_index = vector_search_index

self.connection: MongoClient = MongoClient(
resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = self.connection[self.database_name]

if self.recreate_collection and self.collection_name in database.list_collection_names():
database[self.collection_name].drop()

# Implicitly create the collection if it doesn't exist
if collection_name not in database.list_collection_names():
database.create_collection(self.collection_name)
database[self.collection_name].create_index("id", unique=True)

msg = f"Collection '{collection_name}' does not exist in database '{database_name}'."
raise ValueError(msg)
self.collection = database[self.collection_name]

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -75,7 +73,7 @@ def to_dict(self) -> Dict[str, Any]:
mongo_connection_string=self.mongo_connection_string.to_dict(),
database_name=self.database_name,
collection_name=self.collection_name,
recreate_collection=self.recreate_collection,
vector_search_index=self.vector_search_index,
)

@classmethod
Expand Down Expand Up @@ -157,3 +155,63 @@ def delete_documents(self, document_ids: List[str]) -> None:
if not document_ids:
return
self.collection.delete_many(filter={"id": {"$in": document_ids}})

def embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
) -> List[Document]:
"""
Find the documents that are most similar to the provided `query_embedding` by using a vector similarity metric.
:param query_embedding: Embedding of the query
:param filters: Optional filters.
:param top_k: How many documents to return.
"""
if not query_embedding:
msg = "Query embedding must not be empty"
raise ValueError(msg)

filters = haystack_filters_to_mongo(filters)
pipeline = [
{
"$vectorSearch": {
"index": self.vector_search_index,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": 100,
"limit": top_k,
# "filter": filters,
}
},
{
"$project": {
"_id": 0,
"content": 1,
"dataframe": 1,
"blob": 1,
"meta": 1,
"embedding": 1,
"score": {"$meta": "vectorSearchScore"},
}
},
]
try:
documents = list(self.collection.aggregate(pipeline))
except Exception as e:
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
raise DocumentStoreError(msg) from e

documents = [self.mongo_doc_to_haystack_doc(doc) for doc in documents]
return documents

def mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document:
"""
Converts the dictionary coming out of MongoDB into a Haystack document
:param mongo_doc: A dictionary representing a document as stored in MongoDB
:return: A Haystack Document object
"""
mongo_doc.pop("_id", None)
return Document.from_dict(mongo_doc)
56 changes: 32 additions & 24 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import patch
from uuid import uuid4

import pytest
Expand All @@ -13,24 +12,38 @@
from haystack.utils import Secret
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from pandas import DataFrame
from pymongo import MongoClient # type: ignore
from pymongo.driver_info import DriverInfo # type: ignore


@pytest.fixture
def document_store(request):
def document_store():
database_name = "haystack_integration_test"
collection_name = "test_collection_" + str(uuid4())

connection: MongoClient = MongoClient(
os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = connection[database_name]
if collection_name in database.list_collection_names():
database[collection_name].drop()
database.create_collection(collection_name)
database[collection_name].create_index("id", unique=True)

store = MongoDBAtlasDocumentStore(
database_name="haystack_integration_test",
collection_name=request.node.name + str(uuid4()),
database_name=database_name,
collection_name=collection_name,
vector_search_index="cosine_index",
)
yield store
store.collection.drop()
database[collection_name].drop()


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):

def test_write_documents(self, document_store: MongoDBAtlasDocumentStore):
docs = [Document(content="some text")]
assert document_store.write_documents(docs) == 1
Expand All @@ -51,13 +64,10 @@ def test_write_dataframe(self, document_store: MongoDBAtlasDocumentStore):
retrieved_docs = document_store.filter_documents()
assert retrieved_docs == docs

@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_to_dict(self, _):
document_store = MongoDBAtlasDocumentStore(
database_name="database_name",
collection_name="collection_name",
)
assert document_store.to_dict() == {
def test_to_dict(self, document_store):
serialized_store = document_store.to_dict()
assert serialized_store["init_parameters"].pop("collection_name").startswith("test_collection_")
assert serialized_store == {
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
"init_parameters": {
"mongo_connection_string": {
Expand All @@ -67,14 +77,12 @@ def test_to_dict(self, _):
"strict": True,
"type": "env_var",
},
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_collection": False,
"database_name": "haystack_integration_test",
"vector_search_index": "cosine_index",
},
}

@patch("haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient")
def test_from_dict(self, _):
def test_from_dict(self):
docstore = MongoDBAtlasDocumentStore.from_dict(
{
"type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore",
Expand All @@ -86,13 +94,13 @@ def test_from_dict(self, _):
"strict": True,
"type": "env_var",
},
"database_name": "database_name",
"collection_name": "collection_name",
"recreate_collection": True,
"database_name": "haystack_integration_test",
"collection_name": "test_embeddings_collection",
"vector_search_index": "cosine_index",
},
}
)
assert docstore.mongo_connection_string == Secret.from_env_var("MONGO_CONNECTION_STRING")
assert docstore.database_name == "database_name"
assert docstore.collection_name == "collection_name"
assert docstore.recreate_collection
assert docstore.database_name == "haystack_integration_test"
assert docstore.collection_name == "test_embeddings_collection"
assert docstore.vector_search_index == "cosine_index"
Loading

0 comments on commit 8f73a8b

Please sign in to comment.