From 7283a5c9009e68cb6199be0100f7d89868b5be8c Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Sat, 9 Mar 2024 11:28:19 +0100 Subject: [PATCH] MongoDB Atlas: filters (#542) * wip * progress * more tests * improvements * ignore missing imports in pyproject * fix mypy * show coverage * rm code duplication --- integrations/mongodb_atlas/pyproject.toml | 13 +- .../mongodb_atlas/embedding_retriever.py | 4 +- .../mongodb_atlas/document_store.py | 24 ++- .../document_stores/mongodb_atlas/errors.py | 4 - .../document_stores/mongodb_atlas/filters.py | 157 +++++++++++++++++- .../tests/test_document_store.py | 89 +++++++--- .../tests/test_embedding_retrieval.py | 32 ++++ 7 files changed, 268 insertions(+), 55 deletions(-) delete mode 100644 integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index 0021884ad..6e6b55dfe 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -156,27 +156,26 @@ ban-relative-imports = "parents" "examples/**/*" = ["T201"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -tests = ["tests", "*/mongodb-atlas-haystack/tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "haystack.*", "haystack_integrations.*", - "mongodb_atlas.*", - "psycopg.*", + "pymongo.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 432b86d4c..ffad97789 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -48,7 +48,9 @@ def __init__( Create the MongoDBAtlasDocumentStore component. :param document_store: An instance of MongoDBAtlasDocumentStore. - :param filters: Filters applied to the retrieved Documents. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `vector_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. :param top_k: Maximum number of Documents to return. :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 27cb853db..c9e8f1dae 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -10,10 +10,10 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo -from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore -from pymongo.errors import BulkWriteError # type: ignore +from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters +from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.driver_info import DriverInfo +from pymongo.errors import BulkWriteError logger = logging.getLogger(__name__) @@ -144,8 +144,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: The filters to apply. It returns only the documents that match the filters. :returns: A list of Documents that match the given filters. """ - mongo_filters = haystack_filters_to_mongo(filters) - documents = list(self.collection.find(mongo_filters)) + filters = _normalize_filters(filters) if filters else None + documents = list(self.collection.find(filters)) for doc in documents: doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. return [Document.from_dict(doc) for doc in documents] @@ -170,7 +170,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if policy == DuplicatePolicy.NONE: policy = DuplicatePolicy.FAIL - mongo_documents = [doc.to_dict() for doc in documents] + mongo_documents = [doc.to_dict(flatten=False) for doc in documents] operations: List[Union[UpdateOne, InsertOne, ReplaceOne]] written_docs = len(documents) @@ -221,7 +221,8 @@ def _embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - filters = haystack_filters_to_mongo(filters) + filters = _normalize_filters(filters) if filters else None + pipeline = [ { "$vectorSearch": { @@ -230,7 +231,7 @@ def _embedding_retrieval( "queryVector": query_embedding, "numCandidates": 100, "limit": top_k, - # "filter": filters, + "filter": filters, } }, { @@ -249,6 +250,11 @@ def _embedding_retrieval( documents = list(self.collection.aggregate(pipeline)) except Exception as e: msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" + if filters: + msg += ( + "\nMake sure that the fields used in the filters are included " + "in the `vector_search_index` configuration" + ) raise DocumentStoreError(msg) from e documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py deleted file mode 100644 index 132156bd0..000000000 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py +++ /dev/null @@ -1,4 +0,0 @@ -class MongoDBAtlasDocumentStoreError(Exception): - """Exception for issues that occur in a MongoDBAtlas document store""" - - pass diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index f03ca88c0..4583d6cd3 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -1,9 +1,152 @@ -from typing import Any, Dict, Optional +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from typing import Any, Dict +from haystack.errors import FilterError +from haystack.utils.filters import convert +from pandas import DataFrame -def haystack_filters_to_mongo(filters: Optional[Dict[str, Any]]): - # TODO - if filters: - msg = "Filtering not yet implemented for MongoDBAtlasDocumentStore" - raise ValueError(msg) - return {} +UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame) + + +def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts Haystack filters to MongoDB filters. + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + conditions.append(_parse_comparison_condition(c)) + else: + conditions.append(_parse_logical_condition(c)) + + operator = condition["operator"] + if operator == "AND": + return {"$and": conditions} + elif operator == "OR": + return {"$or": conditions} + elif operator == "NOT": + # MongoDB doesn't support our NOT operator (logical NAND) directly. + # we combine $nor and $and to achieve the same effect. + return {"$nor": [{"$and": conditions}]} + + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'" + raise FilterError(msg) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + value: Any = condition["value"] + + if isinstance(value, DataFrame): + value = value.to_json() + + return COMPARISON_OPERATORS[operator](field, value) + + +def _equal(field: str, value: Any) -> Dict[str, Any]: + return {field: {"$eq": value}} + + +def _not_equal(field: str, value: Any) -> Dict[str, Any]: + return {field: {"$ne": value}} + + +def _validate_type_for_comparison(value: Any) -> None: + msg = f"Cant compare {type(value)} using operators '>', '>=', '<', '<='." + if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): + raise FilterError(msg) + elif isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg += "\nStrings are only comparable if they are ISO formatted dates." + raise FilterError(msg) from exc + + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + _validate_type_for_comparison(value) + return {field: {"$gt": value}} + + +def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # we want {field: {"$gte": null}} to return an empty result + # $gte with null values in MongoDB returns a non-empty result, while $gt aligns with our expectations + return {field: {"$gt": value}} + + _validate_type_for_comparison(value) + return {field: {"$gte": value}} + + +def _less_than(field: str, value: Any) -> Dict[str, Any]: + _validate_type_for_comparison(value) + return {field: {"$lt": value}} + + +def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # we want {field: {"$lte": null}} to return an empty result + # $lte with null values in MongoDB returns a non-empty result, while $lt aligns with our expectations + return {field: {"$lt": value}} + _validate_type_for_comparison(value) + + return {field: {"$lte": value}} + + +def _not_in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + return {field: {"$nin": value}} + + +def _in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + return {field: {"$in": value}} + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 39a4465c1..89810ec8b 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -8,42 +8,43 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack.testing.document_store import DocumentStoreBaseTests from haystack.utils import Secret from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame -from pymongo import MongoClient # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore - - -@pytest.fixture -def document_store(): - database_name = "haystack_integration_test" - collection_name = "test_collection_" + str(uuid4()) - - connection: MongoClient = MongoClient( - os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = connection[database_name] - if collection_name in database.list_collection_names(): - database[collection_name].drop() - database.create_collection(collection_name) - database[collection_name].create_index("id", unique=True) - - store = MongoDBAtlasDocumentStore( - database_name=database_name, - collection_name=collection_name, - vector_search_index="cosine_index", - ) - yield store - database[collection_name].drop() +from pymongo import MongoClient +from pymongo.driver_info import DriverInfo @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) -class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): +@pytest.mark.integration +class TestDocumentStore(DocumentStoreBaseTests): + + @pytest.fixture + def document_store(self): + database_name = "haystack_integration_test" + collection_name = "test_collection_" + str(uuid4()) + + connection: MongoClient = MongoClient( + os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) + database = connection[database_name] + if collection_name in database.list_collection_names(): + database[collection_name].drop() + database.create_collection(collection_name) + database[collection_name].create_index("id", unique=True) + + store = MongoDBAtlasDocumentStore( + database_name=database_name, + collection_name=collection_name, + vector_search_index="cosine_index", + ) + yield store + database[collection_name].drop() + def test_write_documents(self, document_store: MongoDBAtlasDocumentStore): docs = [Document(content="some text")] assert document_store.write_documents(docs) == 1 @@ -104,3 +105,37 @@ def test_from_dict(self): assert docstore.database_name == "haystack_integration_test" assert docstore.collection_name == "test_embeddings_collection" assert docstore.vector_search_index == "cosine_index" + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 54bbdedfd..a03c735e0 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -13,6 +13,7 @@ "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) +@pytest.mark.integration class TestEmbeddingRetrieval: def test_embedding_retrieval_cosine_similarity(self): document_store = MongoDBAtlasDocumentStore( @@ -72,3 +73,34 @@ def test_query_embedding_wrong_dimension(self): query_embedding = [0.1] * 4 with pytest.raises(DocumentStoreError): document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_embedding_retrieval_with_filters(self): + """ + Note: we can combine embedding retrieval with filters + becuse the `cosine_index` vector_search_index was created with the `content` field as the filter field. + { + "fields": [ + { + "type": "vector", + "path": "embedding", + "numDimensions": 768, + "similarity": "cosine" + }, + { + "type": "filter", + "path": "content" + } + ] + } + """ + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) + query_embedding = [0.1] * 768 + filters = {"field": "content", "operator": "!=", "value": "Document A"} + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters=filters) + assert len(results) == 2 + for doc in results: + assert doc.content != "Document A"