From ea482ffd46df0844db018a2f7b922fc3413c53e8 Mon Sep 17 00:00:00 2001 From: alperkaya Date: Wed, 16 Oct 2024 15:40:40 +0200 Subject: [PATCH 1/3] initial version --- .../retrievers/mongodb_atlas/__init__.py | 3 +- .../mongodb_atlas/fulltext_retriever.py | 106 ++++++++++++++++++ .../mongodb_atlas/document_store.py | 47 ++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py index fed0a4c28..b551eade8 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -1,3 +1,4 @@ from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever import MongoDBAtlasFullTextRetriever -__all__ = ["MongoDBAtlasEmbeddingRetriever"] +__all__ = ["MongoDBAtlasEmbeddingRetriever", "MongoDBAtlasFullTextRetriever"] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py new file mode 100644 index 000000000..373185f37 --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py @@ -0,0 +1,106 @@ +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@component +class MongoDBAtlasFullTextRetriever: + + def __init__( + self, + *, + document_store: MongoDBAtlasDocumentStore, + search_path: Union[str, List[str]] = "content", + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the MongoDBAtlasFullTextRetriever component. + + :param document_store: An instance of MongoDBAtlasDocumentStore. + :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `vector_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. + :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. + :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. + """ + + if not isinstance(document_store, MongoDBAtlasDocumentStore): + msg = "document_store must be an instance of MongoDBAtlasDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.search_path = search_path + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + filter_policy=self.filter_policy.value, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if filter_policy := data["init_parameters"].get("filter_policy"): + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ) -> Dict[str, List[Document]]: + """ + Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided query. + + :param query: Text query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. + :returns: A dictionary with the following keys: + - `documents`: List of Documents most similar to the given `query` + """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k + + docs = self.document_store._fulltext_retrieval( + query=query, filters=filters, top_k=top_k, search_path=self.search_path + ) + return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 79caa15f8..3a4a240b6 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -226,6 +226,53 @@ def delete_documents(self, document_ids: List[str]) -> None: return self.collection.delete_many(filter={"id": {"$in": document_ids}}) + def _fulltext_retrieval( + self, + query: str, + search_path: Union[str, List[str]] = "content", + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> List[Document]: + """ + Find the documents that are exact match provided `query`. + + :param query: The text to search in the document store. + :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. + :param filters: Optional filters. + :param top_k: How many documents to return. + :returns: A list of Documents matching the full-text search query. + :raises ValueError: If `query` is empty. + :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails. + """ + if not query: + msg = "query must not be empty" + raise ValueError(msg) + + filters = _normalize_filters(filters) if filters else {} + + pipeline = [ + { + "$search": { + "index": self.vector_search_index, + "text": { + "query": query, + "path": search_path, + }, + } + }, + {"$match": filters if filters else {}}, + {"$limit": top_k}, + {"$project": {"_id": 0, "content": 1, "meta": 1, "score": {"$meta": "searchScore"}}}, + ] + try: + documents = list(self.collection.aggregate(pipeline)) + except Exception as e: + msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" + raise DocumentStoreError(msg) from e + + documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] + return documents + def _embedding_retrieval( self, query_embedding: List[float], From fb68881ebc80893dfdd419001902ba1db17bfc53 Mon Sep 17 00:00:00 2001 From: alperkaya Date: Wed, 16 Oct 2024 15:44:00 +0200 Subject: [PATCH 2/3] unit testcase added --- .../tests/test_full_text_retriever.py | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 integrations/mongodb_atlas/tests/test_full_text_retriever.py diff --git a/integrations/mongodb_atlas/tests/test_full_text_retriever.py b/integrations/mongodb_atlas/tests/test_full_text_retriever.py new file mode 100644 index 000000000..b7b6c8710 --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_full_text_retriever.py @@ -0,0 +1,165 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.utils.auth import EnvVarSecret + +from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +class TestFullTextRetriever: + @pytest.fixture + def mock_client(self): + with patch( + "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient" + ) as mock_mongo_client: + mock_connection = MagicMock() + mock_database = MagicMock() + mock_collection_names = MagicMock(return_value=["test_collection"]) + mock_database.list_collection_names = mock_collection_names + mock_connection.__getitem__.return_value = mock_database + mock_mongo_client.return_value = mock_connection + yield mock_mongo_client + + def test_init_default(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="merge") + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="wrong_policy") + + def test_init(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + top_k=5, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_init_filter_policy_merge(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + top_k=5, + filter_policy=FilterPolicy.MERGE, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.MERGE + + def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_collection", + vector_search_index="default", + ) + + retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_collection", + "vector_search_index": "default", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + data = { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_collection", + "vector_search_index": "default", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + retriever = MongoDBAtlasFullTextRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_collection" + assert document_store.vector_search_index == "default" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_run(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Test doc") + mock_store._fulltext_retrieval.return_value = [doc] + + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, search_path="desc") + res = retriever.run(query="text") + + mock_store._fulltext_retrieval.assert_called_once_with(query="text", filters={}, top_k=10, search_path="desc") + + assert res == {"documents": [doc]} + + def test_run_merge_policy_filter(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Test doc") + mock_store._fulltext_retrieval.return_value = [doc] + + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + filter_policy=FilterPolicy.MERGE, + ) + res = retriever.run(query="text", filters={"field": "meta.some_field", "operator": "==", "value": "Test"}) + mock_store._fulltext_retrieval.assert_called_once_with( + query="text", + filters={"field": "meta.some_field", "operator": "==", "value": "Test"}, + top_k=10, + search_path="content", + ) + + assert res == {"documents": [doc]} From e749bdd85b292b1f553c97ab097941e4ce40d7ec Mon Sep 17 00:00:00 2001 From: alperkaya Date: Thu, 17 Oct 2024 11:31:50 +0200 Subject: [PATCH 3/3] add new testcases and remove filter --- .../mongodb_atlas/fulltext_retriever.py | 27 +---- .../mongodb_atlas/document_store.py | 5 - .../tests/test_full_text_retriever.py | 64 +---------- .../tests/test_fulltext_retrieval.py | 100 ++++++++++++++++++ 4 files changed, 104 insertions(+), 92 deletions(-) create mode 100644 integrations/mongodb_atlas/tests/test_fulltext_retrieval.py diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py index 373185f37..a98c8f1b4 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/fulltext_retriever.py @@ -2,8 +2,6 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document -from haystack.document_stores.types import FilterPolicy -from haystack.document_stores.types.filter_policy import apply_filter_policy from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore @@ -16,20 +14,14 @@ def __init__( *, document_store: MongoDBAtlasDocumentStore, search_path: Union[str, List[str]] = "content", - filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, ): """ Create the MongoDBAtlasFullTextRetriever component. :param document_store: An instance of MongoDBAtlasDocumentStore. :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. - :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are - included in the configuration of the `vector_search_index`. The configuration must be done manually - in the Web UI of MongoDB Atlas. :param top_k: Maximum number of Documents to return. - :param filter_policy: Policy to determine how filters are applied. :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. """ @@ -38,12 +30,8 @@ def __init__( raise ValueError(msg) self.document_store = document_store - self.filters = filters or {} self.top_k = top_k self.search_path = search_path - self.filter_policy = ( - filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) - ) def to_dict(self) -> Dict[str, Any]: """ @@ -54,9 +42,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - filters=self.filters, top_k=self.top_k, - filter_policy=self.filter_policy.value, document_store=self.document_store.to_dict(), ) @@ -73,34 +59,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever": data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( data["init_parameters"]["document_store"] ) - # Pipelines serialized with old versions of the component might not - # have the filter_policy field. - if filter_policy := data["init_parameters"].get("filter_policy"): - data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy) return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run( self, query: str, - filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, ) -> Dict[str, List[Document]]: """ Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided query. :param query: Text query. - :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. :returns: A dictionary with the following keys: - `documents`: List of Documents most similar to the given `query` """ - filters = apply_filter_policy(self.filter_policy, self.filters, filters) top_k = top_k or self.top_k - docs = self.document_store._fulltext_retrieval( - query=query, filters=filters, top_k=top_k, search_path=self.search_path - ) + docs = self.document_store._fulltext_retrieval(query=query, top_k=top_k, search_path=self.search_path) return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 3a4a240b6..080c15736 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -230,7 +230,6 @@ def _fulltext_retrieval( self, query: str, search_path: Union[str, List[str]] = "content", - filters: Optional[Dict[str, Any]] = None, top_k: int = 10, ) -> List[Document]: """ @@ -238,7 +237,6 @@ def _fulltext_retrieval( :param query: The text to search in the document store. :param search_path: Field(s) to search within, e.g., "content" or ["content", "title"]. - :param filters: Optional filters. :param top_k: How many documents to return. :returns: A list of Documents matching the full-text search query. :raises ValueError: If `query` is empty. @@ -248,8 +246,6 @@ def _fulltext_retrieval( msg = "query must not be empty" raise ValueError(msg) - filters = _normalize_filters(filters) if filters else {} - pipeline = [ { "$search": { @@ -260,7 +256,6 @@ def _fulltext_retrieval( }, } }, - {"$match": filters if filters else {}}, {"$limit": top_k}, {"$project": {"_id": 0, "content": 1, "meta": 1, "score": {"$meta": "searchScore"}}}, ] diff --git a/integrations/mongodb_atlas/tests/test_full_text_retriever.py b/integrations/mongodb_atlas/tests/test_full_text_retriever.py index b7b6c8710..41ab5e9c0 100644 --- a/integrations/mongodb_atlas/tests/test_full_text_retriever.py +++ b/integrations/mongodb_atlas/tests/test_full_text_retriever.py @@ -2,7 +2,6 @@ import pytest from haystack.dataclasses import Document -from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever @@ -27,40 +26,9 @@ def test_init_default(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) assert retriever.document_store == mock_store - assert retriever.filters == {} assert retriever.top_k == 10 - assert retriever.filter_policy == FilterPolicy.REPLACE - retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="merge") - assert retriever.filter_policy == FilterPolicy.MERGE - - with pytest.raises(ValueError): - MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="wrong_policy") - - def test_init(self): - mock_store = Mock(spec=MongoDBAtlasDocumentStore) - retriever = MongoDBAtlasFullTextRetriever( - document_store=mock_store, - filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, - top_k=5, - ) - assert retriever.document_store == mock_store - assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} - assert retriever.top_k == 5 - assert retriever.filter_policy == FilterPolicy.REPLACE - - def test_init_filter_policy_merge(self): - mock_store = Mock(spec=MongoDBAtlasDocumentStore) - retriever = MongoDBAtlasFullTextRetriever( - document_store=mock_store, - filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, - top_k=5, - filter_policy=FilterPolicy.MERGE, - ) - assert retriever.document_store == mock_store - assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} - assert retriever.top_k == 5 - assert retriever.filter_policy == FilterPolicy.MERGE + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") @@ -71,7 +39,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a vector_search_index="default", ) - retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) + retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, top_k=5) res = retriever.to_dict() assert res == { "type": "haystack_integrations.components.retrievers.mongodb_atlas.fulltext_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 @@ -89,9 +57,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a "vector_search_index": "default", }, }, - "filters": {"field": "value"}, "top_k": 5, - "filter_policy": "replace", }, } @@ -114,9 +80,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client "vector_search_index": "default", }, }, - "filters": {"field": "value"}, "top_k": 5, - "filter_policy": "replace", }, } @@ -128,9 +92,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client assert document_store.database_name == "haystack_integration_test" assert document_store.collection_name == "test_collection" assert document_store.vector_search_index == "default" - assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 - assert retriever.filter_policy == FilterPolicy.REPLACE def test_run(self): mock_store = Mock(spec=MongoDBAtlasDocumentStore) @@ -140,26 +102,6 @@ def test_run(self): retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, search_path="desc") res = retriever.run(query="text") - mock_store._fulltext_retrieval.assert_called_once_with(query="text", filters={}, top_k=10, search_path="desc") - - assert res == {"documents": [doc]} - - def test_run_merge_policy_filter(self): - mock_store = Mock(spec=MongoDBAtlasDocumentStore) - doc = Document(content="Test doc") - mock_store._fulltext_retrieval.return_value = [doc] - - retriever = MongoDBAtlasFullTextRetriever( - document_store=mock_store, - filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, - filter_policy=FilterPolicy.MERGE, - ) - res = retriever.run(query="text", filters={"field": "meta.some_field", "operator": "==", "value": "Test"}) - mock_store._fulltext_retrieval.assert_called_once_with( - query="text", - filters={"field": "meta.some_field", "operator": "==", "value": "Test"}, - top_k=10, - search_path="content", - ) + mock_store._fulltext_retrieval.assert_called_once_with(query="text", top_k=10, search_path="desc") assert res == {"documents": [doc]} diff --git a/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py new file mode 100644 index 000000000..757e16f46 --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +@pytest.mark.integration +class TestEmbeddingRetrieval: + def test_basic_fulltext_retrieval(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_fulltext_collection", + vector_search_index="default", + ) + query = "crime" + results = document_store._fulltext_retrieval(query=query) + assert len(results) == 1 + + def test_fulltext_retrieval_custom_path(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_fulltext_collection", + vector_search_index="default", + ) + query = "Godfather" + path = "title" + results = document_store._fulltext_retrieval(query=query, search_path=path) + assert len(results) == 1 + + def test_fulltext_retrieval_multi_paths_and_top_k(self): + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_fulltext_collection", + vector_search_index="default", + ) + query = "movie" + paths = ["title", "content"] + results = document_store._fulltext_retrieval(query=query, search_path=paths) + assert len(results) == 2 + + results = document_store._fulltext_retrieval(query=query, search_path=paths, top_k=1) + assert len(results) == 1 + + +""" +[ + { + "title": "The Matrix", + "content": "A hacker discovers that his reality is a simulation in this movie.", + "meta": { + "author": "Wachowskis", + "city": "San Francisco" + } + }, + { + "title": "Inception", + "content": "A thief who steals corporate secrets through the use of dream-sharing technology.", + "meta": { + "author": "Christopher Nolan", + "city": "Los Angeles" + } + }, + { + "title": "Interstellar", + "content": "A team of explorers travel through a wormhole in space in an attempt + to ensure humanity's survival.", + "meta": { + "author": "Christopher Nolan", + "city": "Houston" + } + }, + { + "title": "The Dark Knight", + "content": "When the menace known as the Joker emerges from his mysterious past, + he wreaks havoc on Gotham.", + "meta": { + "author": "Christopher Nolan", + "city": "Gotham" + } + }, + { + "title": "The Godfather Movie", + "content": "The aging patriarch of an organized crime dynasty transfers + control of his empire to his reluctant son.", + "meta": { + "author": "Mario Puzo", + "city": "New York" + } + } +] + +"""