Skip to content

Commit

Permalink
MongoDB Atlas: filters (#542)
Browse files Browse the repository at this point in the history
* wip

* progress

* more tests

* improvements

* ignore missing imports in pyproject

* fix mypy

* show coverage

* rm code duplication
  • Loading branch information
anakin87 authored Mar 9, 2024
1 parent 5e1d8b0 commit 7283a5c
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 55 deletions.
13 changes: 6 additions & 7 deletions integrations/mongodb_atlas/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,27 +156,26 @@ ban-relative-imports = "parents"
"examples/**/*" = ["T201"]

[tool.coverage.run]
source_pkgs = ["src", "tests"]
source = ["haystack_integrations"]
branch = true
parallel = true
parallel = false


[tool.coverage.paths]
tests = ["tests", "*/mongodb-atlas-haystack/tests"]

[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing=true
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]


[[tool.mypy.overrides]]
module = [
"haystack.*",
"haystack_integrations.*",
"mongodb_atlas.*",
"psycopg.*",
"pymongo.*",
"pytest.*"
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
Create the MongoDBAtlasDocumentStore component.
:param document_store: An instance of MongoDBAtlasDocumentStore.
:param filters: Filters applied to the retrieved Documents.
:param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are
included in the configuration of the `vector_search_index`. The configuration must be done manually
in the Web UI of MongoDB Atlas.
:param top_k: Maximum number of Documents to return.
:raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore
from pymongo.driver_info import DriverInfo # type: ignore
from pymongo.errors import BulkWriteError # type: ignore
from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters
from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne
from pymongo.driver_info import DriverInfo
from pymongo.errors import BulkWriteError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,8 +144,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:param filters: The filters to apply. It returns only the documents that match the filters.
:returns: A list of Documents that match the given filters.
"""
mongo_filters = haystack_filters_to_mongo(filters)
documents = list(self.collection.find(mongo_filters))
filters = _normalize_filters(filters) if filters else None
documents = list(self.collection.find(filters))
for doc in documents:
doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it.
return [Document.from_dict(doc) for doc in documents]
Expand All @@ -170,7 +170,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
if policy == DuplicatePolicy.NONE:
policy = DuplicatePolicy.FAIL

mongo_documents = [doc.to_dict() for doc in documents]
mongo_documents = [doc.to_dict(flatten=False) for doc in documents]
operations: List[Union[UpdateOne, InsertOne, ReplaceOne]]
written_docs = len(documents)

Expand Down Expand Up @@ -221,7 +221,8 @@ def _embedding_retrieval(
msg = "Query embedding must not be empty"
raise ValueError(msg)

filters = haystack_filters_to_mongo(filters)
filters = _normalize_filters(filters) if filters else None

pipeline = [
{
"$vectorSearch": {
Expand All @@ -230,7 +231,7 @@ def _embedding_retrieval(
"queryVector": query_embedding,
"numCandidates": 100,
"limit": top_k,
# "filter": filters,
"filter": filters,
}
},
{
Expand All @@ -249,6 +250,11 @@ def _embedding_retrieval(
documents = list(self.collection.aggregate(pipeline))
except Exception as e:
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
if filters:
msg += (
"\nMake sure that the fields used in the filters are included "
"in the `vector_search_index` configuration"
)
raise DocumentStoreError(msg) from e

documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,152 @@
from typing import Any, Dict, Optional
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from typing import Any, Dict

from haystack.errors import FilterError
from haystack.utils.filters import convert
from pandas import DataFrame

def haystack_filters_to_mongo(filters: Optional[Dict[str, Any]]):
# TODO
if filters:
msg = "Filtering not yet implemented for MongoDBAtlasDocumentStore"
raise ValueError(msg)
return {}
UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame)


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Converts Haystack filters to MongoDB filters.
"""
if not isinstance(filters, dict):
msg = "Filters must be a dictionary"
raise FilterError(msg)

if "operator" not in filters and "conditions" not in filters:
filters = convert(filters)

if "field" in filters:
return _parse_comparison_condition(filters)
return _parse_logical_condition(filters)


def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "conditions" not in condition:
msg = f"'conditions' key missing in {condition}"
raise FilterError(msg)

# logical conditions can be nested, so we need to parse them recursively
conditions = []
for c in condition["conditions"]:
if "field" in c:
conditions.append(_parse_comparison_condition(c))
else:
conditions.append(_parse_logical_condition(c))

operator = condition["operator"]
if operator == "AND":
return {"$and": conditions}
elif operator == "OR":
return {"$or": conditions}
elif operator == "NOT":
# MongoDB doesn't support our NOT operator (logical NAND) directly.
# we combine $nor and $and to achieve the same effect.
return {"$nor": [{"$and": conditions}]}

msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'"
raise FilterError(msg)


def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]:
field: str = condition["field"]
if "operator" not in condition:
msg = f"'operator' key missing in {condition}"
raise FilterError(msg)
if "value" not in condition:
msg = f"'value' key missing in {condition}"
raise FilterError(msg)
operator: str = condition["operator"]
value: Any = condition["value"]

if isinstance(value, DataFrame):
value = value.to_json()

return COMPARISON_OPERATORS[operator](field, value)


def _equal(field: str, value: Any) -> Dict[str, Any]:
return {field: {"$eq": value}}


def _not_equal(field: str, value: Any) -> Dict[str, Any]:
return {field: {"$ne": value}}


def _validate_type_for_comparison(value: Any) -> None:
msg = f"Cant compare {type(value)} using operators '>', '>=', '<', '<='."
if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON):
raise FilterError(msg)
elif isinstance(value, str):
try:
datetime.fromisoformat(value)
except (ValueError, TypeError) as exc:
msg += "\nStrings are only comparable if they are ISO formatted dates."
raise FilterError(msg) from exc


def _greater_than(field: str, value: Any) -> Dict[str, Any]:
_validate_type_for_comparison(value)
return {field: {"$gt": value}}


def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]:
if value is None:
# we want {field: {"$gte": null}} to return an empty result
# $gte with null values in MongoDB returns a non-empty result, while $gt aligns with our expectations
return {field: {"$gt": value}}

_validate_type_for_comparison(value)
return {field: {"$gte": value}}


def _less_than(field: str, value: Any) -> Dict[str, Any]:
_validate_type_for_comparison(value)
return {field: {"$lt": value}}


def _less_than_equal(field: str, value: Any) -> Dict[str, Any]:
if value is None:
# we want {field: {"$lte": null}} to return an empty result
# $lte with null values in MongoDB returns a non-empty result, while $lt aligns with our expectations
return {field: {"$lt": value}}
_validate_type_for_comparison(value)

return {field: {"$lte": value}}


def _not_in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone"
raise FilterError(msg)

return {field: {"$nin": value}}


def _in(field: str, value: Any) -> Dict[str, Any]:
if not isinstance(value, list):
msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone"
raise FilterError(msg)

return {field: {"$in": value}}


COMPARISON_OPERATORS = {
"==": _equal,
"!=": _not_equal,
">": _greater_than,
">=": _greater_than_equal,
"<": _less_than,
"<=": _less_than_equal,
"in": _in,
"not in": _not_in,
}
89 changes: 62 additions & 27 deletions integrations/mongodb_atlas/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,43 @@
from haystack.dataclasses.document import ByteStream, Document
from haystack.document_stores.errors import DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest
from haystack.testing.document_store import DocumentStoreBaseTests
from haystack.utils import Secret
from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore
from pandas import DataFrame
from pymongo import MongoClient # type: ignore
from pymongo.driver_info import DriverInfo # type: ignore


@pytest.fixture
def document_store():
database_name = "haystack_integration_test"
collection_name = "test_collection_" + str(uuid4())

connection: MongoClient = MongoClient(
os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = connection[database_name]
if collection_name in database.list_collection_names():
database[collection_name].drop()
database.create_collection(collection_name)
database[collection_name].create_index("id", unique=True)

store = MongoDBAtlasDocumentStore(
database_name=database_name,
collection_name=collection_name,
vector_search_index="cosine_index",
)
yield store
database[collection_name].drop()
from pymongo import MongoClient
from pymongo.driver_info import DriverInfo


@pytest.mark.skipif(
"MONGO_CONNECTION_STRING" not in os.environ,
reason="No MongoDB Atlas connection string provided",
)
class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):
@pytest.mark.integration
class TestDocumentStore(DocumentStoreBaseTests):

@pytest.fixture
def document_store(self):
database_name = "haystack_integration_test"
collection_name = "test_collection_" + str(uuid4())

connection: MongoClient = MongoClient(
os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
)
database = connection[database_name]
if collection_name in database.list_collection_names():
database[collection_name].drop()
database.create_collection(collection_name)
database[collection_name].create_index("id", unique=True)

store = MongoDBAtlasDocumentStore(
database_name=database_name,
collection_name=collection_name,
vector_search_index="cosine_index",
)
yield store
database[collection_name].drop()

def test_write_documents(self, document_store: MongoDBAtlasDocumentStore):
docs = [Document(content="some text")]
assert document_store.write_documents(docs) == 1
Expand Down Expand Up @@ -104,3 +105,37 @@ def test_from_dict(self):
assert docstore.database_name == "haystack_integration_test"
assert docstore.collection_name == "test_embeddings_collection"
assert docstore.vector_search_index == "cosine_index"

def test_complex_filter(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
filters = {
"operator": "OR",
"conditions": [
{
"operator": "AND",
"conditions": [
{"field": "meta.number", "operator": "==", "value": 100},
{"field": "meta.chapter", "operator": "==", "value": "intro"},
],
},
{
"operator": "AND",
"conditions": [
{"field": "meta.page", "operator": "==", "value": "90"},
{"field": "meta.chapter", "operator": "==", "value": "conclusion"},
],
},
],
}

result = document_store.filter_documents(filters=filters)

self.assert_documents_are_equal(
result,
[
d
for d in filterable_docs
if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro")
or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion")
],
)
Loading

0 comments on commit 7283a5c

Please sign in to comment.