diff --git a/.github/workflows/mongodb_atlas.yml b/.github/workflows/mongodb_atlas.yml index 3d1ad5101..3fd2a43ac 100644 --- a/.github/workflows/mongodb_atlas.yml +++ b/.github/workflows/mongodb_atlas.yml @@ -4,11 +4,11 @@ name: Test / mongodb_atlas on: schedule: - - cron: "0 0 * * *" + - cron: '0 0 * * *' pull_request: paths: - - "integrations/mongodb_atlas/**" - - ".github/workflows/mongodb_atlas.yml" + - 'integrations/mongodb_atlas/**' + - '.github/workflows/mongodb_atlas.yml' defaults: run: @@ -19,9 +19,10 @@ concurrency: cancel-in-progress: true env: - PYTHONUNBUFFERED: "1" - FORCE_COLOR: "1" + PYTHONUNBUFFERED: '1' + FORCE_COLOR: '1' MONGO_CONNECTION_STRING: ${{ secrets.MONGO_CONNECTION_STRING }} + MONGO_CONNECTION_STRING_2: ${{ secrets.MONGO_CONNECTION_STRING_2 }} jobs: run: @@ -31,7 +32,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.9", "3.10", "3.11"] + python-version: ['3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v4 diff --git a/integrations/mongodb_atlas/examples/example.py b/integrations/mongodb_atlas/examples/embedding_retrieval.py similarity index 80% rename from integrations/mongodb_atlas/examples/example.py rename to integrations/mongodb_atlas/examples/embedding_retrieval.py index 54fd569ce..d8a71c343 100644 --- a/integrations/mongodb_atlas/examples/example.py +++ b/integrations/mongodb_atlas/examples/embedding_retrieval.py @@ -19,6 +19,8 @@ # To use the MongoDBAtlasDocumentStore, you must have a running MongoDB Atlas database. # For details, see https://www.mongodb.com/docs/atlas/getting-started/ +# NOTE: you need to create manually the vector search index and the full text search +# index in your MongoDB Atlas database. # Once your database is set, set the environment variable `MONGO_CONNECTION_STRING` # with the connection string to your MongoDB Atlas database. @@ -29,12 +31,17 @@ database_name="haystack_test", collection_name="test_collection", vector_search_index="test_vector_search_index", + full_text_search_index="test_full_text_search_index", ) +# This is to avoid duplicates in the collection +print(f"Cleaning up collection {document_store.collection_name}") +document_store.collection.delete_many({}) + # Create the indexing Pipeline and index some documents file_paths = glob.glob("neural-search-pills/pills/*.md") - +print("Creating indexing pipeline") indexing = Pipeline() indexing.add_component("converter", MarkdownToDocument()) indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) @@ -44,17 +51,20 @@ indexing.connect("splitter", "embedder") indexing.connect("embedder", "writer") +print(f"Running indexing pipeline with {len(file_paths)} files") indexing.run({"converter": {"sources": file_paths}}) - -# Create the querying Pipeline and try a query +print("Creating querying pipeline") querying = Pipeline() querying.add_component("embedder", SentenceTransformersTextEmbedder()) querying.add_component("retriever", MongoDBAtlasEmbeddingRetriever(document_store=document_store, top_k=3)) querying.connect("embedder", "retriever") +query = "What is a cross-encoder?" +print(f"Running querying pipeline with query: '{query}'") results = querying.run({"embedder": {"text": "What is a cross-encoder?"}}) +print(f"Results: {results}") for doc in results["retriever"]["documents"]: print(doc) print("-" * 10) diff --git a/integrations/mongodb_atlas/examples/hybrid_retrieval.py b/integrations/mongodb_atlas/examples/hybrid_retrieval.py new file mode 100644 index 000000000..a165edf12 --- /dev/null +++ b/integrations/mongodb_atlas/examples/hybrid_retrieval.py @@ -0,0 +1,80 @@ +# Install required packages for this example, including mongodb-atlas-haystack and other libraries needed +# for Markdown conversion and embeddings generation. Use the following command: +# +# pip install mongodb-atlas-haystack markdown-it-py mdit_plain "sentence-transformers>=2.2.0" +# +# Download some Markdown files to index. +# git clone https://github.com/anakin87/neural-search-pills + +import glob + +from haystack import Pipeline +from haystack.components.converters import MarkdownToDocument +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.joiners import DocumentJoiner +from haystack.components.preprocessors import DocumentSplitter +from haystack.components.writers import DocumentWriter + +from haystack_integrations.components.retrievers.mongodb_atlas import ( + MongoDBAtlasEmbeddingRetriever, + MongoDBAtlasFullTextRetriever, +) +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + +# To use the MongoDBAtlasDocumentStore, you must have a running MongoDB Atlas database. +# For details, see https://www.mongodb.com/docs/atlas/getting-started/ +# NOTE: you need to create manually the vector search index and the full text search +# index in your MongoDB Atlas database. + +# Once your database is set, set the environment variable `MONGO_CONNECTION_STRING` +# with the connection string to your MongoDB Atlas database. +# format: "mongodb+srv://{mongo_atlas_username}:{mongo_atlas_password}@{mongo_atlas_host}/?{mongo_atlas_params_string}". + +# Initialize the document store +document_store = MongoDBAtlasDocumentStore( + database_name="haystack_test", + collection_name="test_collection", + vector_search_index="test_vector_search_index", + full_text_search_index="test_full_text_search_index", +) + +file_paths = glob.glob("neural-search-pills/pills/*.md") + +# This is to avoid duplicates in the collection +print(f"Cleaning up collection {document_store.collection_name}") +document_store.collection.delete_many({}) + +print("Creating indexing pipeline") +indexing = Pipeline() +indexing.add_component("converter", MarkdownToDocument()) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) +indexing.add_component("document_embedder", SentenceTransformersDocumentEmbedder()) +indexing.add_component("writer", DocumentWriter(document_store)) +indexing.connect("converter", "splitter") +indexing.connect("splitter", "document_embedder") +indexing.connect("document_embedder", "writer") + +print(f"Running indexing pipeline with {len(file_paths)} files") +indexing.run({"converter": {"sources": file_paths}}) + +print("Creating querying pipeline") +querying = Pipeline() +querying.add_component("text_embedder", SentenceTransformersTextEmbedder()) +querying.add_component("embedding_retriever", MongoDBAtlasEmbeddingRetriever(document_store=document_store, top_k=3)) +querying.add_component("full_text_retriever", MongoDBAtlasFullTextRetriever(document_store=document_store, top_k=3)) +querying.add_component( + "joiner", + DocumentJoiner(join_mode="reciprocal_rank_fusion", top_k=3), +) +querying.connect("text_embedder", "embedding_retriever") +querying.connect("embedding_retriever", "joiner") +querying.connect("full_text_retriever", "joiner") + +query = "cross-encoder" +print(f"Running querying pipeline with query '{query}'") +results = querying.run({"text_embedder": {"text": query}, "full_text_retriever": {"query": query}}) + +print(f"Results: {results}") +for doc in results["joiner"]["documents"]: + print(doc) + print("-" * 10) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py index fed0a4c28..bbeec63d1 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/__init__.py @@ -1,3 +1,4 @@ from haystack_integrations.components.retrievers.mongodb_atlas.embedding_retriever import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever import MongoDBAtlasFullTextRetriever -__all__ = ["MongoDBAtlasEmbeddingRetriever"] +__all__ = ["MongoDBAtlasEmbeddingRetriever", "MongoDBAtlasFullTextRetriever"] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 3345f4f7c..4579a85bc 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -28,7 +28,8 @@ class MongoDBAtlasEmbeddingRetriever: store = MongoDBAtlasDocumentStore(database_name="haystack_integration_test", collection_name="test_embeddings_collection", - vector_search_index="cosine_index") + vector_search_index="cosine_index", + full_text_search_index="full_text_index") retriever = MongoDBAtlasEmbeddingRetriever(document_store=store) results = retriever.run(query_embedding=np.random.random(768).tolist()) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py new file mode 100644 index 000000000..63348c6f3 --- /dev/null +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/full_text_retriever.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +@component +class MongoDBAtlasFullTextRetriever: + """ + Retrieves documents from the MongoDBAtlasDocumentStore by full-text search. + + The full-text search is dependent on the full_text_search_index used in the MongoDBAtlasDocumentStore. + See MongoDBAtlasDocumentStore for more information. + + Usage example: + ```python + from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasFullTextRetriever + + store = MongoDBAtlasDocumentStore(database_name="your_existing_db", + collection_name="your_existing_collection", + vector_search_index="your_existing_index", + full_text_search_index="your_existing_index") + retriever = MongoDBAtlasFullTextRetriever(document_store=store) + + results = retriever.run(query="Lorem ipsum") + print(results["documents"]) + ``` + + The example above retrieves the 10 most similar documents to the query "Lorem ipsum" from the + MongoDBAtlasDocumentStore. + """ + + def __init__( + self, + *, + document_store: MongoDBAtlasDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + :param document_store: An instance of MongoDBAtlasDocumentStore. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `full_text_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. + :param top_k: Maximum number of Documents to return. + :param filter_policy: Policy to determine how filters are applied. + + :raises ValueError: If `document_store` is not an instance of MongoDBAtlasDocumentStore. + """ + + if not isinstance(document_store, MongoDBAtlasDocumentStore): + msg = "document_store must be an instance of MongoDBAtlasDocumentStore" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + filter_policy=self.filter_policy.value, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MongoDBAtlasFullTextRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = MongoDBAtlasDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: Union[str, List[str]], + fuzzy: Optional[Dict[str, int]] = None, + match_criteria: Optional[Literal["any", "all"]] = None, + score: Optional[Dict[str, Dict]] = None, + synonyms: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> Dict[str, List[Document]]: + """ + Retrieve documents from the MongoDBAtlasDocumentStore by full-text search. + + :param query: The query string or a list of query strings to search for. + If the query contains multiple terms, Atlas Search evaluates each term separately for matches. + :param fuzzy: Enables finding strings similar to the search term(s). + Note, `fuzzy` cannot be used with `synonyms`. Configurable options include `maxEdits`, `prefixLength`, + and `maxExpansions`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param match_criteria: Defines how terms in the query are matched. Supported options are `"any"` and `"all"`. + For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param score: Specifies the scoring method for matching results. Supported options include `boost`, `constant`, + and `function`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param synonyms: The name of the synonym mapping definition in the index. This value cannot be an empty string. + Note, `synonyms` can not be used with `fuzzy`. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: Maximum number of Documents to return. Overrides the value specified at initialization. + :returns: A dictionary with the following keys: + - `documents`: List of Documents most similar to the given `query` + """ + filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k + + docs = self.document_store._fulltext_retrieval( + query=query, + fuzzy=fuzzy, + match_criteria=match_criteria, + score=score, + synonyms=synonyms, + filters=filters, + top_k=top_k, + ) + + return {"documents": docs} diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 79caa15f8..f13924185 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from haystack import default_from_dict, default_to_dict from haystack.dataclasses.document import Document @@ -37,8 +37,10 @@ class MongoDBAtlasDocumentStore: Python driver. Creating databases and collections is beyond the scope of MongoDBAtlasDocumentStore. The primary purpose of this document store is to read and write documents to an existing collection. - The last parameter users needs to provide is a `vector_search_index` - used for vector search operations. This index - can support a chosen metric (i.e. cosine, dot product, or euclidean) and can be created in the Atlas web UI. + Users must provide both a `vector_search_index` for vector search operations and a `full_text_search_index` + for full-text search operations. The `vector_search_index` supports a chosen metric + (e.g., cosine, dot product, or Euclidean), while the `full_text_search_index` enables efficient text-based searches. + Both indexes can be created through the Atlas web UI. For more details on MongoDB Atlas, see the official MongoDB Atlas [documentation](https://www.mongodb.com/docs/atlas/getting-started/). @@ -49,7 +51,8 @@ class MongoDBAtlasDocumentStore: store = MongoDBAtlasDocumentStore(database_name="your_existing_db", collection_name="your_existing_collection", - vector_search_index="your_existing_index") + vector_search_index="your_existing_index", + full_text_search_index="your_existing_index") print(store.count_documents()) ``` """ @@ -61,6 +64,7 @@ def __init__( database_name: str, collection_name: str, vector_search_index: str, + full_text_search_index: str, ): """ Creates a new MongoDBAtlasDocumentStore instance. @@ -76,6 +80,10 @@ def __init__( Create a vector_search_index in the Atlas web UI and specify the init params of MongoDBAtlasDocumentStore. \ For more details refer to MongoDB Atlas [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#std-label-avs-create-index). + :param full_text_search_index: The name of the search index to use for full-text search operations. + Create a full_text_search_index in the Atlas web UI and specify the init params of + MongoDBAtlasDocumentStore. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/create-index/). :raises ValueError: If the collection name contains invalid characters. """ @@ -88,6 +96,7 @@ def __init__( self.database_name = database_name self.collection_name = collection_name self.vector_search_index = vector_search_index + self.full_text_search_index = full_text_search_index self._connection: Optional[MongoClient] = None self._collection: Optional[Collection] = None @@ -124,6 +133,7 @@ def to_dict(self) -> Dict[str, Any]: database_name=self.database_name, collection_name=self.collection_name, vector_search_index=self.vector_search_index, + full_text_search_index=self.full_text_search_index, ) @classmethod @@ -285,6 +295,106 @@ def _embedding_retrieval( documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] return documents + def _fulltext_retrieval( + self, + query: Union[str, List[str]], + fuzzy: Optional[Dict[str, int]] = None, + match_criteria: Optional[Literal["any", "all"]] = None, + score: Optional[Dict[str, Dict]] = None, + synonyms: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + ) -> List[Document]: + """ + Retrieve documents similar to the provided `query` using a full-text search. + + :param query: The query string or a list of query strings to search for. + If the query contains multiple terms, Atlas Search evaluates each term separately for matches. + :param fuzzy: Enables finding strings similar to the search term(s). + Note, `fuzzy` cannot be used with `synonyms`. Configurable options include `maxEdits`, `prefixLength`, + and `maxExpansions`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param match_criteria: Defines how terms in the query are matched. Supported options are `"any"` and `"all"`. + For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param score: Specifies the scoring method for matching results. Supported options include `boost`, `constant`, + and `function`. For more details refer to MongoDB Atlas + [documentation](https://www.mongodb.com/docs/atlas/atlas-search/text/#fields). + :param synonyms: The name of the synonym mapping definition in the index. This value cannot be an empty string. + Note, `synonyms` can not be used with `fuzzy`. + :param filters: Optional filters. + :param top_k: How many documents to return. + :returns: A list of Documents that are most similar to the given `query` + :raises ValueError: If `query` or `synonyms` is empty. + :raises ValueError: If `synonyms` and `fuzzy` are used together. + :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails. + """ + # Validate user input according to MongoDB Atlas Search requirements + if not query: + msg = "Argument query must not be empty." + raise ValueError(msg) + + if isinstance(synonyms, str) and not synonyms: + msg = "Argument synonyms cannot be an empty string." + raise ValueError(msg) + + if synonyms and fuzzy: + msg = "Cannot use both synonyms and fuzzy search together." + raise ValueError(msg) + + if synonyms and not match_criteria: + logger.warning( + "Specify matchCriteria when using synonyms. " + "Atlas Search matches terms in exact order by default, which may change in future versions." + ) + + filters = _normalize_filters(filters) if filters else {} + + # Build the text search options + text_search: Dict[str, Any] = {"path": "content", "query": query} + if match_criteria: + text_search["matchCriteria"] = match_criteria + if synonyms: + text_search["synonyms"] = synonyms + if fuzzy: + text_search["fuzzy"] = fuzzy + if score: + text_search["score"] = score + + # Define the pipeline for MongoDB aggregation + pipeline = [ + { + "$search": { + "index": self.full_text_search_index, + "compound": {"must": [{"text": text_search}]}, + } + }, + # TODO: Use compound filter. See: (https://www.mongodb.com/docs/atlas/atlas-search/performance/query-performance/#avoid--match-after--search) + {"$match": filters}, + {"$limit": top_k}, + { + "$project": { + "_id": 0, + "content": 1, + "dataframe": 1, + "blob": 1, + "meta": 1, + "embedding": 1, + "score": {"$meta": "searchScore"}, + } + }, + ] + + try: + documents = list(self.collection.aggregate(pipeline)) + except Exception as e: + error_msg = f"Failed to retrieve documents from MongoDB Atlas: {e}" + if filters: + error_msg += "\nEnsure fields in filters are included in the `full_text_search_index` configuration." + raise DocumentStoreError(error_msg) from e + + return [self._mongo_doc_to_haystack_doc(doc) for doc in documents] + def _mongo_doc_to_haystack_doc(self, mongo_doc: Dict[str, Any]) -> Document: """ Converts the dictionary coming out of MongoDB into a Haystack document diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 6d34b1ca0..6c0ac191e 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -25,6 +25,7 @@ def test_init_is_lazy(_mock_client): database_name="database_name", collection_name="collection_name", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) _mock_client.assert_not_called() @@ -53,6 +54,7 @@ def document_store(self): database_name=database_name, collection_name=collection_name, vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) yield store database[collection_name].drop() @@ -92,6 +94,7 @@ def test_to_dict(self, document_store): }, "database_name": "haystack_integration_test", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, } @@ -110,6 +113,7 @@ def test_from_dict(self): "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, } ) @@ -117,6 +121,7 @@ def test_from_dict(self): assert docstore.database_name == "haystack_integration_test" assert docstore.collection_name == "test_embeddings_collection" assert docstore.vector_search_index == "cosine_index" + assert docstore.full_text_search_index == "full_text_index" def test_complex_filter(self, document_store, filterable_docs): document_store.write_documents(filterable_docs) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 143f6e106..306e59a98 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -21,11 +21,12 @@ def test_embedding_retrieval_cosine_similarity(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) assert len(results) == 2 - assert results[0].content == "Document A" + assert results[0].content == "Document C" assert results[1].content == "Document B" assert results[0].score > results[1].score @@ -34,6 +35,7 @@ def test_embedding_retrieval_dot_product(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="dotProduct_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) @@ -47,6 +49,7 @@ def test_embedding_retrieval_euclidean(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="euclidean_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters={}) @@ -60,6 +63,7 @@ def test_empty_query_embedding(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding: List[float] = [] with pytest.raises(ValueError): @@ -70,6 +74,7 @@ def test_query_embedding_wrong_dimension(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 4 with pytest.raises(DocumentStoreError): @@ -98,6 +103,7 @@ def test_embedding_retrieval_with_filters(self): database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) query_embedding = [0.1] * 768 filters = {"field": "content", "operator": "!=", "value": "Document A"} diff --git a/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py new file mode 100644 index 000000000..aa0132f2c --- /dev/null +++ b/integrations/mongodb_atlas/tests/test_fulltext_retrieval.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from time import sleep +from typing import List, Union +from unittest.mock import MagicMock + +import pytest +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore + + +def get_document_store(): + return MongoDBAtlasDocumentStore( + mongo_connection_string=Secret.from_env_var("MONGO_CONNECTION_STRING_2"), + database_name="haystack_test", + collection_name="test_collection", + vector_search_index="cosine_index", + full_text_search_index="full_text_index", + ) + + +@pytest.mark.skipif( + "MONGO_CONNECTION_STRING_2" not in os.environ, + reason="No MongoDB Atlas connection string provided", +) +@pytest.mark.integration +class TestFullTextRetrieval: + @pytest.fixture(scope="class") + def document_store(self) -> MongoDBAtlasDocumentStore: + return get_document_store() + + @pytest.fixture(autouse=True, scope="class") + def setup_teardown(self, document_store): + document_store.collection.delete_many({}) + document_store.write_documents( + [ + Document(content="The quick brown fox chased the dog", meta={"meta_field": "right_value"}), + Document(content="The fox was brown", meta={"meta_field": "right_value"}), + Document(content="The lazy dog"), + Document(content="fox fox fox"), + ] + ) + + # Wait for documents to be indexed + sleep(5) + + yield + + def test_pipeline_correctly_passes_parameters(self): + document_store = get_document_store() + mock_collection = MagicMock() + document_store._collection = mock_collection + mock_collection.aggregate.return_value = [] + document_store._fulltext_retrieval( + query=["spam", "eggs"], + fuzzy={"maxEdits": 1}, + match_criteria="any", + score={"boost": {"value": 3}}, + filters={"field": "meta.meta_field", "operator": "==", "value": "right_value"}, + top_k=5, + ) + + # Assert aggregate was called with the correct pipeline + assert mock_collection.aggregate.called + actual_pipeline = mock_collection.aggregate.call_args[0][0] + expected_pipeline = [ + { + "$search": { + "compound": { + "must": [ + { + "text": { + "fuzzy": {"maxEdits": 1}, + "matchCriteria": "any", + "path": "content", + "query": ["spam", "eggs"], + "score": {"boost": {"value": 3}}, + } + } + ] + }, + "index": "full_text_index", + } + }, + {"$match": {"meta.meta_field": {"$eq": "right_value"}}}, + {"$limit": 5}, + { + "$project": { + "_id": 0, + "blob": 1, + "content": 1, + "dataframe": 1, + "embedding": 1, + "meta": 1, + "score": {"$meta": "searchScore"}, + } + }, + ] + + assert actual_pipeline == expected_pipeline + + def test_query_retrieval(self, document_store: MongoDBAtlasDocumentStore): + results = document_store._fulltext_retrieval(query="fox", top_k=2) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].score >= results[1].score + + def test_fuzzy_retrieval(self, document_store: MongoDBAtlasDocumentStore): + results = document_store._fulltext_retrieval(query="fax", fuzzy={"maxEdits": 1}, top_k=2) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].score >= results[1].score + + def test_filters_retrieval(self, document_store: MongoDBAtlasDocumentStore): + filters = {"field": "meta.meta_field", "operator": "==", "value": "right_value"} + + results = document_store._fulltext_retrieval(query="fox", top_k=3, filters=filters) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert doc.meta["meta_field"] == "right_value" + + def test_synonyms_retrieval(self, document_store: MongoDBAtlasDocumentStore): + results = document_store._fulltext_retrieval(query="reynard", synonyms="synonym_mapping", top_k=2) + assert len(results) == 2 + for doc in results: + assert "fox" in doc.content + assert results[0].score >= results[1].score + + @pytest.mark.parametrize("query", ["", []]) + def test_empty_query_raises_value_error(self, query: Union[str, List], document_store: MongoDBAtlasDocumentStore): + with pytest.raises(ValueError): + document_store._fulltext_retrieval(query=query) + + def test_empty_synonyms_raises_value_error(self, document_store: MongoDBAtlasDocumentStore): + with pytest.raises(ValueError): + document_store._fulltext_retrieval(query="fox", synonyms="") + + def test_synonyms_and_fuzzy_raises_value_error(self, document_store: MongoDBAtlasDocumentStore): + with pytest.raises(ValueError): + document_store._fulltext_retrieval(query="fox", synonyms="wolf", fuzzy={"maxEdits": 1}) diff --git a/integrations/mongodb_atlas/tests/test_retriever.py b/integrations/mongodb_atlas/tests/test_retriever.py index 832256ccd..26079d145 100644 --- a/integrations/mongodb_atlas/tests/test_retriever.py +++ b/integrations/mongodb_atlas/tests/test_retriever.py @@ -8,11 +8,14 @@ from haystack.document_stores.types import FilterPolicy from haystack.utils.auth import EnvVarSecret -from haystack_integrations.components.retrievers.mongodb_atlas import MongoDBAtlasEmbeddingRetriever +from haystack_integrations.components.retrievers.mongodb_atlas import ( + MongoDBAtlasEmbeddingRetriever, + MongoDBAtlasFullTextRetriever, +) from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore -class TestRetriever: +class TestEmbeddingRetriever: @pytest.fixture def mock_client(self): with patch( @@ -72,6 +75,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a database_name="haystack_integration_test", collection_name="test_embeddings_collection", vector_search_index="cosine_index", + full_text_search_index="full_text_index", ) retriever = MongoDBAtlasEmbeddingRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) @@ -90,6 +94,7 @@ def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client a "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, }, "filters": {"field": "value"}, @@ -115,6 +120,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, }, "filters": {"field": "value"}, @@ -131,6 +137,7 @@ def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client assert document_store.database_name == "haystack_integration_test" assert document_store.collection_name == "test_embeddings_collection" assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.REPLACE @@ -152,6 +159,7 @@ def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears u "database_name": "haystack_integration_test", "collection_name": "test_embeddings_collection", "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", }, }, "filters": {"field": "value"}, @@ -167,6 +175,7 @@ def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears u assert document_store.database_name == "haystack_integration_test" assert document_store.collection_name == "test_embeddings_collection" assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" assert retriever.filters == {"field": "value"} assert retriever.top_k == 5 assert retriever.filter_policy == FilterPolicy.REPLACE # defaults to REPLACE @@ -204,3 +213,209 @@ def test_run_merge_policy_filter(self): ) assert res == {"documents": [doc]} + + +class TestFullTextRetriever: + @pytest.fixture + def mock_client(self): + with patch( + "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoClient" + ) as mock_mongo_client: + mock_connection = MagicMock() + mock_database = MagicMock() + mock_collection_names = MagicMock(return_value=["test_full_text_collection"]) + mock_database.list_collection_names = mock_collection_names + mock_connection.__getitem__.return_value = mock_database + mock_mongo_client.return_value = mock_connection + yield mock_mongo_client + + def test_init_default(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) + assert retriever.document_store == mock_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.filter_policy == FilterPolicy.REPLACE + + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="merge") + assert retriever.filter_policy == FilterPolicy.MERGE + + with pytest.raises(ValueError): + MongoDBAtlasFullTextRetriever(document_store=mock_store, filter_policy="wrong_policy") + + def test_init(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + top_k=5, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_init_filter_policy_merge(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + top_k=5, + filter_policy=FilterPolicy.MERGE, + ) + assert retriever.document_store == mock_store + assert retriever.filters == {"field": "meta.some_field", "operator": "==", "value": "SomeValue"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.MERGE + + def test_to_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_full_text_collection", + vector_search_index="cosine_index", + full_text_search_index="full_text_index", + ) + + retriever = MongoDBAtlasFullTextRetriever(document_store=document_store, filters={"field": "value"}, top_k=5) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_full_text_collection", + "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + def test_from_dict(self, mock_client, monkeypatch): # noqa: ARG002 mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + data = { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_full_text_collection", + "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "filter_policy": "replace", + }, + } + + retriever = MongoDBAtlasFullTextRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_full_text_collection" + assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_from_dict_no_filter_policy(self, monkeypatch): # mock_client appears unused but is required + monkeypatch.setenv("MONGO_CONNECTION_STRING", "test_conn_str") + + data = { + "type": "haystack_integrations.components.retrievers.mongodb_atlas.full_text_retriever.MongoDBAtlasFullTextRetriever", # noqa: E501 + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.mongodb_atlas.document_store.MongoDBAtlasDocumentStore", # noqa: E501 + "init_parameters": { + "mongo_connection_string": { + "env_vars": ["MONGO_CONNECTION_STRING"], + "strict": True, + "type": "env_var", + }, + "database_name": "haystack_integration_test", + "collection_name": "test_full_text_collection", + "vector_search_index": "cosine_index", + "full_text_search_index": "full_text_index", + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + }, + } + + retriever = MongoDBAtlasFullTextRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, MongoDBAtlasDocumentStore) + assert isinstance(document_store.mongo_connection_string, EnvVarSecret) + assert document_store.database_name == "haystack_integration_test" + assert document_store.collection_name == "test_full_text_collection" + assert document_store.vector_search_index == "cosine_index" + assert document_store.full_text_search_index == "full_text_index" + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.filter_policy == FilterPolicy.REPLACE + + def test_run(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Lorem ipsum") + mock_store._fulltext_retrieval.return_value = [doc] + + retriever = MongoDBAtlasFullTextRetriever(document_store=mock_store) + res = retriever.run(query="Lorem ipsum") + + mock_store._fulltext_retrieval.assert_called_once_with( + query="Lorem ipsum", fuzzy=None, match_criteria=None, score=None, synonyms=None, filters={}, top_k=10 + ) + + assert res == {"documents": [doc]} + + def test_run_merge_policy_filter(self): + mock_store = Mock(spec=MongoDBAtlasDocumentStore) + doc = Document(content="Lorem ipsum") + mock_store._fulltext_retrieval.return_value = [doc] + + retriever = MongoDBAtlasFullTextRetriever( + document_store=mock_store, + filters={"field": "meta.some_field", "operator": "==", "value": "SomeValue"}, + filter_policy=FilterPolicy.MERGE, + ) + res = retriever.run( + query="Lorem ipsum", filters={"field": "meta.some_field", "operator": "==", "value": "Test"} + ) + # as the both init and run filters are filtering the same field, the run filter takes precedence + mock_store._fulltext_retrieval.assert_called_once_with( + query="Lorem ipsum", + fuzzy=None, + match_criteria=None, + score=None, + synonyms=None, + filters={"field": "meta.some_field", "operator": "==", "value": "Test"}, + top_k=10, + ) + + assert res == {"documents": [doc]}