diff --git a/README.md b/README.md index 6afc03e8..fb178c00 100644 --- a/README.md +++ b/README.md @@ -36,15 +36,15 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b The latest version of the package contains the following experiments: -| Name | Type | Expected experiment end date | Dependencies | | --------------------------- | -------------------------- | ---------------------------- | ------------ | | [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None | | [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None | | [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref | | [`Tool`][4] | Tool dataclass | November 2024 | jsonschema | -| [`ChatMessageWriter`][5] | Memory Component | November 2024 | None | +| [`ChatMessageWriter`][5] | Memory Component | November 2024 | None | | [`ChatMessageRetriever`][6] | Memory Component | November 2024 | None | | [`InMemoryChatMessageStore`][7] | Memory Store | November 2024 | None | +| [`Auto-Merge Retriever`][8] | Retrieval Technique | November 2024 | None | [1]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/evaluation/harness [2]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openai @@ -53,6 +53,8 @@ The latest version of the package contains the following experiments: [5]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/writers/chat_message_writer.py [6]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/retrievers/chat_message_retriever.py [7]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/chat_message_stores/in_memory.py +[8]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/retrievers/auto_merge_retriever.py + ## Usage @@ -111,11 +113,11 @@ from haystack.core.pipeline import Pipeline as HaystackPipeline class Pipeline(HaystackPipeline): - # Any new experimental method that doesn't exist in the original class - def run_async(self, inputs) -> Dict[str, Dict[str, Any]]: - ... + # Any new experimental method that doesn't exist in the original class + def run_async(self, inputs) -> Dict[str, Dict[str, Any]]: + ... - # Existing methods with breaking changes to their signature, like adding a new mandatory param + # Existing methods with breaking changes to their signature, like adding a new mandatory param def to_dict(new_param: str) -> Dict[str, Any]: # do something with the new parameter print(new_param) diff --git a/docs/pydoc/config/auto_merging_retriever.yml b/docs/pydoc/config/auto_merging_retriever.yml new file mode 100644 index 00000000..dca389ec --- /dev/null +++ b/docs/pydoc/config/auto_merging_retriever.yml @@ -0,0 +1,27 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../../../] + modules: ["haystack_experimental.components.retrievers.auto_merging_retriever"] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer + excerpt: Auto Merging Retriever for Haystack. + category_slug: experiments-api + title: Auto Merge Retriever + slug: auto-merge-retriever + order: 10 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: auto_merging_retriever.md diff --git a/docs/pydoc/config/hierarchical_document_splitter.yml b/docs/pydoc/config/hierarchical_document_splitter.yml new file mode 100644 index 00000000..9522c631 --- /dev/null +++ b/docs/pydoc/config/hierarchical_document_splitter.yml @@ -0,0 +1,27 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../../../] + modules: ["haystack_experimental.components.splitters.hierarchical_doc_splitter"] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer + excerpt: Hierarchical Document Splitter for Haystack. + category_slug: experiments-api + title: Split documents into hierarchical chunks. + slug: hierarchical-document-splitter + order: 70 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: data_classess_api.md diff --git a/haystack_experimental/components/retrievers/__init__.py b/haystack_experimental/components/retrievers/__init__.py index 9bcad180..c7dbd7a5 100644 --- a/haystack_experimental/components/retrievers/__init__.py +++ b/haystack_experimental/components/retrievers/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from haystack_experimental.components.retrievers.auto_merging_retriever import AutoMergingRetriever from haystack_experimental.components.retrievers.chat_message_retriever import ChatMessageRetriever -_all_ = ["ChatMessageRetriever"] +_all_ = ["AutoMergingRetriever", "ChatMessageRetriever"] diff --git a/haystack_experimental/components/retrievers/auto_merging_retriever.py b/haystack_experimental/components/retrievers/auto_merging_retriever.py new file mode 100644 index 00000000..b177974c --- /dev/null +++ b/haystack_experimental/components/retrievers/auto_merging_retriever.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from typing import Any, Dict, List + +from haystack import Document, component, default_to_dict +from haystack.core.serialization import default_from_dict +from haystack.document_stores.types import DocumentStore +from haystack.utils import deserialize_document_store_in_init_parameters + + +@component +class AutoMergingRetriever: + """ + A retriever which returns parent documents of the matched leaf nodes documents, based on a threshold setting. + + The AutoMergingRetriever assumes you have a hierarchical tree structure of documents, where the leaf nodes + are indexed in a document store. See the HierarchicalDocumentSplitter for more information on how to create + such a structure. During retrieval, if the number of matched leaf documents below the same parent is + higher than a defined threshold, the retriever will return the parent document instead of the individual leaf + documents. + + The rational is, given that a paragraph is split into multiple chunks represented as leaf documents, and if for + a given query, multiple chunks are matched, the whole paragraph might be more informative than the individual + chunks alone. + + Currently the AutoMergingRetriever can only be used by the following DocumentStores: + - [ElasticSearch](https://haystack.deepset.ai/docs/latest/documentstore/elasticsearch) + - [OpenSearch](https://haystack.deepset.ai/docs/latest/documentstore/opensearch) + - [PGVector](https://haystack.deepset.ai/docs/latest/documentstore/pgvector) + - [Qdrant](https://haystack.deepset.ai/docs/latest/documentstore/qdrant) + + ```python + from haystack import Document + from haystack_experimental.components.splitters import HierarchicalDocumentSplitter + from haystack_experimental.components.retrievers.auto_merging_retriever import AutoMergingRetriever + from haystack.document_stores.in_memory import InMemoryDocumentStore + + # create a hierarchical document structure with 2 levels, where the parent document has 3 children + text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing." + original_document = Document(content=text) + builder = HierarchicalDocumentSplitter(block_sizes=[10, 3], split_overlap=0, split_by="word") + docs = builder.run([original_document])["documents"] + + # store level-1 parent documents and initialize the retriever + doc_store_parents = InMemoryDocumentStore() + for doc in docs["documents"]: + if doc.meta["children_ids"] and doc.meta["level"] == 1: + doc_store_parents.write_documents([doc]) + retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5) + + # assume we retrieved 2 leaf docs from the same parent, the parent document should be returned, + # since it has 3 children and the threshold=0.5, and we retrieved 2 children (2/3 > 0.66(6)) + leaf_docs = [doc for doc in docs["documents"] if not doc.meta["children_ids"]] + docs = retriever.run(leaf_docs[4:6]) + >> {'documents': [Document(id=538..), + >> content: 'warm glow over the trees. Birds began to sing.', + >> meta: {'block_size': 10, 'parent_id': '835..', 'children_ids': ['c17...', '3ff...', '352...'], 'level': 1, 'source_id': '835...', + >> 'page_number': 1, 'split_id': 1, 'split_idx_start': 45})]} + ``` + """ # noqa: E501 + + def __init__(self, document_store: DocumentStore, threshold: float = 0.5): + """ + Initialize the AutoMergingRetriever. + + :param document_store: DocumentStore from which to retrieve the parent documents + :param threshold: Threshold to decide whether the parent instead of the individual documents is returned + """ + + if not 0 < threshold < 1: + raise ValueError("The threshold parameter must be between 0 and 1.") + + self.document_store = document_store + self.threshold = threshold + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + docstore = self.document_store.to_dict() + return default_to_dict(self, document_store=docstore, threshold=self.threshold) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AutoMergingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary with serialized data. + :returns: + An instance of the component. + """ + data = deserialize_document_store_in_init_parameters(data) + return default_from_dict(cls, data) + + @staticmethod + def _check_valid_documents(matched_leaf_documents: List[Document]): + # check if the matched leaf documents have the required meta fields + if not all(doc.meta.get("__parent_id") for doc in matched_leaf_documents): + raise ValueError("The matched leaf documents do not have the required meta field '__parent_id'") + + if not all(doc.meta.get("__level") for doc in matched_leaf_documents): + raise ValueError("The matched leaf documents do not have the required meta field '__level'") + + if not all(doc.meta.get("__block_size") for doc in matched_leaf_documents): + raise ValueError("The matched leaf documents do not have the required meta field '__block_size'") + + @component.output_types(documents=List[Document]) + def run(self, matched_leaf_documents: List[Document]): + """ + Run the AutoMergingRetriever. + + Groups the matched leaf documents by their parent documents and returns the parent documents if the number of + matched leaf documents below the same parent is higher than the defined threshold. Otherwise, returns the + matched leaf documents. + + :param matched_leaf_documents: List of leaf documents that were matched by a retriever + :returns: + List of parent documents or matched leaf documents based on the threshold value + """ + + docs_to_return = [] + + # group the matched leaf documents by their parent documents + parent_documents: Dict[str, List[Document]] = defaultdict(list) + for doc in matched_leaf_documents: + parent_documents[doc.meta["__parent_id"]].append(doc) + + # find total number of children for each parent document + for doc_id, retrieved_child_docs in parent_documents.items(): + parent_doc = self.document_store.filter_documents({"field": "id", "operator": "==", "value": doc_id}) + if len(parent_doc) == 0: + raise ValueError(f"Parent document with id {doc_id} not found in the document store.") + if len(parent_doc) > 1: + raise ValueError(f"Multiple parent documents found with id {doc_id} in the document store.") + if not parent_doc[0].meta.get("__children_ids"): + raise ValueError(f"Parent document with id {doc_id} does not have any children.") + parent_children_count = len(parent_doc[0].meta["__children_ids"]) + + # return either the parent document or the matched leaf documents based on the threshold value + score = len(retrieved_child_docs) / parent_children_count + if score >= self.threshold: + # return the parent document + docs_to_return.append(parent_doc[0]) + else: + # return all the matched leaf documents which are child of this parent document + leafs_ids = {doc.id for doc in retrieved_child_docs} + docs_to_return.extend([doc for doc in matched_leaf_documents if doc.id in leafs_ids]) + + return {"documents": docs_to_return} diff --git a/haystack_experimental/components/splitters/__init__.py b/haystack_experimental/components/splitters/__init__.py new file mode 100644 index 00000000..5159c944 --- /dev/null +++ b/haystack_experimental/components/splitters/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.splitters.hierarchical_doc_splitter import HierarchicalDocumentSplitter + +_all_ = ["HierarchicalDocumentSplitter"] diff --git a/haystack_experimental/components/splitters/hierarchical_doc_splitter.py b/haystack_experimental/components/splitters/hierarchical_doc_splitter.py new file mode 100644 index 00000000..2efdd5f2 --- /dev/null +++ b/haystack_experimental/components/splitters/hierarchical_doc_splitter.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Literal, Set + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.components.preprocessors import DocumentSplitter + + +@component +class HierarchicalDocumentSplitter: + """ + Splits a documents into different block sizes building a hierarchical tree structure of blocks of different sizes. + + The root node of the tree is the original document, the leaf nodes are the smallest blocks. The blocks in between + are connected such that the smaller blocks are children of the parent-larger blocks. + + ## Usage example + ```python + from haystack import Document + from haystack.components.builders import HierarchicalDocumentBuilder + + doc = Document(content="This is a simple test document") + builder = HierarchicalDocumentBuilder(block_sizes=[3, 2], split_overlap=0, split_by="word") + builder.run([doc]) + >> {'documents': [Document(id=3f7..., content: 'This is a simple test document', meta: {'block_size': 0, 'parent_id': None, 'children_ids': ['5ff..', '8dc..'], 'level': 0}), + >> Document(id=5ff.., content: 'This is a ', meta: {'block_size': 3, 'parent_id': '3f7..', 'children_ids': ['f19..', '52c..'], 'level': 1, 'source_id': '3f7..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}), + >> Document(id=8dc.., content: 'simple test document', meta: {'block_size': 3, 'parent_id': '3f7..', 'children_ids': ['39d..', 'e23..'], 'level': 1, 'source_id': '3f7..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 10}), + >> Document(id=f19.., content: 'This is ', meta: {'block_size': 2, 'parent_id': '5ff..', 'children_ids': [], 'level': 2, 'source_id': '5ff..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}), + >> Document(id=52c.., content: 'a ', meta: {'block_size': 2, 'parent_id': '5ff..', 'children_ids': [], 'level': 2, 'source_id': '5ff..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 8}), + >> Document(id=39d.., content: 'simple test ', meta: {'block_size': 2, 'parent_id': '8dc..', 'children_ids': [], 'level': 2, 'source_id': '8dc..', 'page_number': 1, 'split_id': 0, 'split_idx_start': 0}), + >> Document(id=e23.., content: 'document', meta: {'block_size': 2, 'parent_id': '8dc..', 'children_ids': [], 'level': 2, 'source_id': '8dc..', 'page_number': 1, 'split_id': 1, 'split_idx_start': 12})]} + ``` + """ # noqa: E501 + + def __init__( + self, + block_sizes: Set[int], + split_overlap: int = 0, + split_by: Literal["word", "sentence", "page", "passage"] = "word", + ): + """ + Initialize HierarchicalDocumentBuilder. + + :param block_sizes: Set of block sizes to split the document into. The blocks are split in descending order. + :param split_overlap: The number of overlapping units for each split. + :param split_by: The unit for splitting your documents. + """ + + self.block_sizes = sorted(set(block_sizes), reverse=True) + self.splitters: Dict[int, DocumentSplitter] = {} + self.split_overlap = split_overlap + self.split_by = split_by + self._build_block_sizes() + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Builds a hierarchical document structure for each document in a list of documents. + + :param documents: List of Documents to split into hierarchical blocks. + :returns: List of HierarchicalDocument + """ + hierarchical_docs = [] + for doc in documents: + hierarchical_docs.extend(self.build_hierarchy_from_doc(doc)) + return {"documents": hierarchical_docs} + + def _build_block_sizes(self): + for block_size in self.block_sizes: + self.splitters[block_size] = DocumentSplitter( + split_length=block_size, split_overlap=self.split_overlap, split_by=self.split_by + ) + + @staticmethod + def _add_meta_data(document: Document): + document.meta["__block_size"] = 0 + document.meta["__parent_id"] = None + document.meta["__children_ids"] = [] + document.meta["__level"] = 0 + return document + + def build_hierarchy_from_doc(self, document: Document) -> List[Document]: + """ + Build a hierarchical tree document structure from a single document. + + Given a document, this function splits the document into hierarchical blocks of different sizes represented + as HierarchicalDocument objects. + + :param document: Document to split into hierarchical blocks. + :returns: + List of HierarchicalDocument + """ + + root = self._add_meta_data(document) + current_level_nodes = [root] + all_docs = [] + + for block in self.block_sizes: + next_level_nodes = [] + for doc in current_level_nodes: + splitted_docs = self.splitters[block].run([doc]) + child_docs = splitted_docs["documents"] + # if it's only one document skip + if len(child_docs) == 1: + next_level_nodes.append(doc) + continue + for child_doc in child_docs: + child_doc = self._add_meta_data(child_doc) + child_doc.meta["__level"] = doc.meta["__level"] + 1 + child_doc.meta["__block_size"] = block + child_doc.meta["__parent_id"] = doc.id + all_docs.append(child_doc) + doc.meta["__children_ids"].append(child_doc.id) + next_level_nodes.append(child_doc) + current_level_nodes = next_level_nodes + + return [root] + all_docs + + def to_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the component. + + :returns: + Serialized dictionary representation of the component. + """ + return default_to_dict( + self, block_sizes=self.block_sizes, split_overlap=self.split_overlap, split_by=self.split_by + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HierarchicalDocumentSplitter": + """ + Deserialize this component from a dictionary. + + :param data: + The dictionary to deserialize and create the component. + + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) diff --git a/test/components/retrievers/test_auto_merging_retriever.py b/test/components/retrievers/test_auto_merging_retriever.py new file mode 100644 index 00000000..24e20c48 --- /dev/null +++ b/test/components/retrievers/test_auto_merging_retriever.py @@ -0,0 +1,115 @@ +import pytest + +from haystack import Document, Pipeline +from haystack.components.retrievers import InMemoryBM25Retriever +from haystack_experimental.components.splitters import HierarchicalDocumentSplitter +from haystack_experimental.components.retrievers.auto_merging_retriever import AutoMergingRetriever +from haystack.document_stores.in_memory import InMemoryDocumentStore + + +class TestAutoMergingRetriever: + def test_init_default(self): + retriever = AutoMergingRetriever(InMemoryDocumentStore()) + assert retriever.threshold == 0.5 + + def test_init_with_parameters(self): + retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7) + assert retriever.threshold == 0.7 + + def test_init_with_invalid_threshold(self): + with pytest.raises(ValueError): + AutoMergingRetriever(InMemoryDocumentStore(), threshold=-2) + + def test_to_dict(self): + retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7) + expected = retriever.to_dict() + assert expected['type'] == 'haystack_experimental.components.retrievers.auto_merging_retriever.AutoMergingRetriever' + assert expected['init_parameters']['threshold'] == 0.7 + assert expected['init_parameters']['document_store']['type'] == 'haystack.document_stores.in_memory.document_store.InMemoryDocumentStore' + + def test_from_dict(self): + data = { + 'type': 'haystack_experimental.components.retrievers.auto_merging_retriever.AutoMergingRetriever', + 'init_parameters': { + 'document_store': { + 'type': 'haystack.document_stores.in_memory.document_store.InMemoryDocumentStore', + 'init_parameters': { + 'bm25_tokenization_regex': '(?u)\\b\\w\\w+\\b', + 'bm25_algorithm': 'BM25L', + 'bm25_parameters': {}, + 'embedding_similarity_function': 'dot_product', + 'index': '6b122bb4-211b-465e-804d-77c5857bf4c5'}}, + 'threshold': 0.7}} + retriever = AutoMergingRetriever.from_dict(data) + assert retriever.threshold == 0.7 + + def test_run_return_parent_document(self): + text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing." + + docs = [Document(content=text)] + builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word") + docs = builder.run(docs) + + # store level-1 parent documents and initialize the retriever + doc_store_parents = InMemoryDocumentStore() + for doc in docs["documents"]: + if doc.meta["__children_ids"] and doc.meta["__level"] == 1: + doc_store_parents.write_documents([doc]) + retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5) + + # assume we retrieved 2 leaf docs from the same parent, the parent document should be returned, + # since it has 3 children and the threshold=0.5, and we retrieved 2 children (2/3 > 0.66(6)) + leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]] + docs = retriever.run(leaf_docs[4:6]) + assert len(docs["documents"]) == 1 + assert docs["documents"][0].content == "warm glow over the trees. Birds began to sing." + assert len(docs["documents"][0].meta["__children_ids"]) == 3 + + def test_run_return_leafs_document(self): + docs = [Document(content="The monarch of the wild blue yonder rises from the eastern side of the horizon.")] + builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word") + docs = builder.run(docs) + + doc_store_parents = InMemoryDocumentStore() + for doc in docs["documents"]: + if doc.meta["__level"] == 1: + doc_store_parents.write_documents([doc]) + + leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]] + retriever = AutoMergingRetriever(doc_store_parents, threshold=0.6) + result = retriever.run([leaf_docs[4]]) + + assert len(result['documents']) == 1 + assert result['documents'][0].content == 'eastern side of ' + assert result['documents'][0].meta["__parent_id"] == docs["documents"][2].id + + def test_run_return_leafs_document_different_parents(self): + docs = [Document(content="The monarch of the wild blue yonder rises from the eastern side of the horizon.")] + builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word") + docs = builder.run(docs) + + doc_store_parents = InMemoryDocumentStore() + for doc in docs["documents"]: + if doc.meta["__level"] == 1: + doc_store_parents.write_documents([doc]) + + leaf_docs = [doc for doc in docs["documents"] if not doc.meta["__children_ids"]] + retriever = AutoMergingRetriever(doc_store_parents, threshold=0.6) + result = retriever.run([leaf_docs[4], leaf_docs[3]]) + + assert len(result['documents']) == 2 + assert result['documents'][0].meta["__parent_id"] != result['documents'][1].meta["__parent_id"] + + def test_serialization_deserialization_pipeline(self): + pipeline = Pipeline() + doc_store_parents = InMemoryDocumentStore() + bm_25_retriever = InMemoryBM25Retriever(doc_store_parents) + auto_merging_retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5) + + pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever) + pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever) + pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.matched_leaf_documents") + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline diff --git a/test/components/splitters/test_hierarchical_doc_splitter.py b/test/components/splitters/test_hierarchical_doc_splitter.py new file mode 100644 index 00000000..b5fbc83b --- /dev/null +++ b/test/components/splitters/test_hierarchical_doc_splitter.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack import Document, Pipeline +from haystack_experimental.components.splitters import HierarchicalDocumentSplitter +from haystack.components.writers import DocumentWriter +from haystack.document_stores.in_memory import InMemoryDocumentStore + + +class TestHierarchicalDocumentSplitter: + def test_init_with_default_params(self): + builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}) + assert builder.block_sizes == [300, 200, 100] + assert builder.split_overlap == 0 + assert builder.split_by == "word" + + def test_init_with_custom_params(self): + builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word") + assert builder.block_sizes == [300, 200, 100] + assert builder.split_overlap == 25 + assert builder.split_by == "word" + + def test_to_dict(self): + builder = HierarchicalDocumentSplitter(block_sizes={100, 200, 300}, split_overlap=25, split_by="word") + expected = builder.to_dict() + assert expected == { + "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", + "init_parameters": {"block_sizes": [300, 200, 100], "split_overlap": 25, "split_by": "word"}, + } + + def test_from_dict(self): + data = { + "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", + "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, + } + + builder = HierarchicalDocumentSplitter.from_dict(data) + assert builder.block_sizes == [10, 5, 2] + assert builder.split_overlap == 0 + assert builder.split_by == "word" + + def test_run(self): + builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word") + text = "one two three four five six seven eight nine ten" + doc = Document(content=text) + output = builder.run([doc]) + docs = output["documents"] + builder.run([doc]) + + assert len(docs) == 9 + assert docs[0].content == "one two three four five six seven eight nine ten" + + # level 1 - root node + assert docs[0].meta["__level"] == 0 + assert len(docs[0].meta["__children_ids"]) == 2 + + # level 2 -left branch + assert docs[1].meta["__parent_id"] == docs[0].id + assert docs[1].meta["__level"] == 1 + assert len(docs[1].meta["__children_ids"]) == 3 + + # level 2 - right branch + assert docs[2].meta["__parent_id"] == docs[0].id + assert docs[2].meta["__level"] == 1 + assert len(docs[2].meta["__children_ids"]) == 3 + + # level 3 - left branch - leaf nodes + assert docs[3].meta["__parent_id"] == docs[1].id + assert docs[4].meta["__parent_id"] == docs[1].id + assert docs[5].meta["__parent_id"] == docs[1].id + assert docs[3].meta["__level"] == 2 + assert docs[4].meta["__level"] == 2 + assert docs[5].meta["__level"] == 2 + assert len(docs[3].meta["__children_ids"]) == 0 + assert len(docs[4].meta["__children_ids"]) == 0 + assert len(docs[5].meta["__children_ids"]) == 0 + + # level 3 - right branch - leaf nodes + assert docs[6].meta["__parent_id"] == docs[2].id + assert docs[7].meta["__parent_id"] == docs[2].id + assert docs[8].meta["__parent_id"] == docs[2].id + assert docs[6].meta["__level"] == 2 + assert docs[7].meta["__level"] == 2 + assert docs[8].meta["__level"] == 2 + assert len(docs[6].meta["__children_ids"]) == 0 + assert len(docs[7].meta["__children_ids"]) == 0 + assert len(docs[8].meta["__children_ids"]) == 0 + + def test_to_dict_in_pipeline(self): + pipeline = Pipeline() + hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}) + doc_store = InMemoryDocumentStore() + doc_writer = DocumentWriter(document_store=doc_store) + pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder) + pipeline.add_component(name="doc_writer", instance=doc_writer) + pipeline.connect("hierarchical_doc_splitter", "doc_writer") + expected = pipeline.to_dict() + + assert expected.keys() == {"metadata", "max_loops_allowed", "components", "connections"} + assert expected["components"].keys() == {"hierarchical_doc_splitter", "doc_writer"} + assert expected["components"]["hierarchical_doc_splitter"] == { + "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", + "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, + } + + def test_from_dict_in_pipeline(self): + data = { + "metadata": {}, + "max_loops_allowed": 100, + "components": { + "hierarchical_doc_splitter": { + "type": "haystack_experimental.components.splitters.hierarchical_doc_splitter.HierarchicalDocumentSplitter", + "init_parameters": {"block_sizes": [10, 5, 2], "split_overlap": 0, "split_by": "word"}, + }, + "doc_writer": { + "type": "haystack.components.writers.document_writer.DocumentWriter", + "init_parameters": { + "document_store": { + "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", + "init_parameters": { + "bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b", + "bm25_algorithm": "BM25L", + "bm25_parameters": {}, + "embedding_similarity_function": "dot_product", + "index": "f32ad5bf-43cb-4035-9823-1de1ae9853c1", + }, + }, + "policy": "NONE", + }, + }, + }, + "connections": [{"sender": "hierarchical_doc_splitter.documents", "receiver": "doc_writer.documents"}], + } + + assert Pipeline.from_dict(data) + + @pytest.mark.integration + def test_example_in_pipeline(self): + pipeline = Pipeline() + hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word") + doc_store = InMemoryDocumentStore() + doc_writer = DocumentWriter(document_store=doc_store) + + pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder) + pipeline.add_component(name="doc_writer", instance=doc_writer) + pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer") + + text = "one two three four five six seven eight nine ten" + doc = Document(content=text) + docs = pipeline.run({"hierarchical_doc_splitter": {"documents": [doc]}}) + + assert docs["doc_writer"]["documents_written"] == 9 + assert len(doc_store.storage.values()) == 9 + + def test_serialization_deserialization_pipeline(self): + pipeline = Pipeline() + hierarchical_doc_builder = HierarchicalDocumentSplitter(block_sizes={10, 5, 2}, split_overlap=0, split_by="word") + doc_store = InMemoryDocumentStore() + doc_writer = DocumentWriter(document_store=doc_store) + + pipeline.add_component(name="hierarchical_doc_splitter", instance=hierarchical_doc_builder) + pipeline.add_component(name="doc_writer", instance=doc_writer) + pipeline.connect("hierarchical_doc_splitter.documents", "doc_writer") + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline