diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py index a1fdd991a0..0caf63ab29 100644 --- a/haystack/components/preprocessors/recursive_splitter.py +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -118,33 +118,100 @@ def _get_custom_sentence_tokenizer(sentence_splitter_params: Dict[str, Any]): def _apply_overlap(self, chunks: List[str]) -> List[str]: """ Applies an overlap between consecutive chunks if the chunk_overlap attribute is greater than zero. - - :param chunks: List of text chunks. - :returns: - The list of chunks with overlap applied. """ - overlapped_chunks = [] + overlapped_chunks: List[str] = [] + remaining_words: List[str] = [] + remaining_chars: str = "" for idx, chunk in enumerate(chunks): if idx == 0: overlapped_chunks.append(chunk) continue - overlap_start = max(0, self._chunk_length(chunks[idx - 1]) - self.split_overlap) - if self.split_units == "word": - word_chunks = chunks[idx - 1].split() - overlap = " ".join(word_chunks[overlap_start:]) - else: - overlap = chunks[idx - 1][overlap_start:] - if overlap == chunks[idx - 1]: + + overlap, prev_chunk = self._get_overlap(overlapped_chunks) + + if overlap == prev_chunk: logger.warning( "Overlap is the same as the previous chunk. " "Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter." ) - current_chunk = overlap + chunk + + # create new chunk starting with the overlap + if self.split_units == "word": + current_chunk = overlap + " " + chunk + else: + current_chunk = overlap + chunk + + # if the new chunk exceeds split_length, trim it and add the trimmed content to the next chunk + if self._chunk_length(current_chunk) > self.split_length: + if self.split_units == "word": + words = current_chunk.split() + current_chunk = " ".join(words[: self.split_length]) + remaining_words = words[self.split_length :] + if idx < len(chunks) - 1: + # add remaining words to the beginning of the next chunk + chunks[idx + 1] = " ".join(remaining_words) + " " + chunks[idx + 1] + elif remaining_words: + # if this is the last chunk, and we have remaining words + overlapped_chunks.append(current_chunk) + current_chunk = " ".join(remaining_words) + + else: # char-level splitting + text = current_chunk + current_chunk = text[: self.split_length] + remaining_chars = text[self.split_length :] + if idx < len(chunks) - 1: + # add remaining chars to the beginning of the next chunk + chunks[idx + 1] = remaining_chars + chunks[idx + 1] + elif remaining_chars: # if this is the last chunk and we have remaining chars + overlapped_chunks.append(current_chunk) + current_chunk = remaining_chars + + # if this is the last chunk, and we have remaining words or characters, add them to the current chunk + if idx == len(chunks) - 1 and (remaining_words or remaining_chars): + overlap, prev_chunk = self._get_overlap(overlapped_chunks) + if remaining_words: + current_chunk = overlap + " " + current_chunk + if remaining_chars: + current_chunk = overlap + current_chunk + overlapped_chunks.append(current_chunk) + # check if the last chunk exceeds split_length and split it + if idx == len(chunks) - 1 and self._chunk_length(current_chunk) > self.split_length: + # split the last chunk and add the first chunk to the list + last_chunk = overlapped_chunks.pop() + if self.split_units == "word": + words = last_chunk.split() + first_chunk = " ".join(words[: self.split_length]) + remaining_chunk = " ".join(words[self.split_length :]) + else: + first_chunk = last_chunk[: self.split_length] + remaining_chunk = last_chunk[self.split_length :] + overlapped_chunks.append(first_chunk) + + # add the remaining chunk with overlap from the previous chunk + if remaining_chunk: + overlap, prev_chunk = self._get_overlap(overlapped_chunks) + if self.split_units == "word": + remaining_chunk = overlap + " " + remaining_chunk + else: + remaining_chunk = overlap + remaining_chunk + overlapped_chunks.append(remaining_chunk) + return overlapped_chunks + def _get_overlap(self, overlapped_chunks): + """Get the previous overlapped chunk instead of the original chunk.""" + prev_chunk = overlapped_chunks[-1] + overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap) + if self.split_units == "word": + word_chunks = prev_chunk.split() + overlap = " ".join(word_chunks[overlap_start:]) + else: + overlap = prev_chunk[overlap_start:] + return overlap, prev_chunk + def _chunk_length(self, text: str) -> int: """ Get the length of the chunk in words or characters. diff --git a/test/components/preprocessors/test_recursive_splitter.py b/test/components/preprocessors/test_recursive_splitter.py index bc0049b6a5..0413d5538f 100644 --- a/test/components/preprocessors/test_recursive_splitter.py +++ b/test/components/preprocessors/test_recursive_splitter.py @@ -354,40 +354,47 @@ def test_run_split_by_sentence_count_page_breaks_split_unit_char() -> None: def test_run_split_document_with_overlap_character_unit(): - splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "], split_unit="char") - text = """A simple sentence1. A bright sentence2. A clever sentence3. A joyful sentence4""" + splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=10, separators=["."], split_unit="char") + text = """A simple sentence1. A bright sentence2. A clever sentence3""" doc = Document(content=text) doc_chunks = splitter.run([doc]) doc_chunks = doc_chunks["documents"] - assert len(doc_chunks) == 4 - + assert len(doc_chunks) == 5 assert doc_chunks[0].content == "A simple sentence1." assert doc_chunks[0].meta["split_id"] == 0 assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content) - assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 11)}] + assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 10)}] - assert doc_chunks[1].content == " sentence1. A bright sentence2." + assert doc_chunks[1].content == "sentence1. A bright " assert doc_chunks[1].meta["split_id"] == 1 assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content) assert doc_chunks[1].meta["_split_overlap"] == [ - {"doc_id": doc_chunks[0].id, "range": (8, 19)}, - {"doc_id": doc_chunks[2].id, "range": (0, 11)}, + {"doc_id": doc_chunks[0].id, "range": (9, 19)}, + {"doc_id": doc_chunks[2].id, "range": (0, 10)}, ] - assert doc_chunks[2].content == " sentence2. A clever sentence3." + assert doc_chunks[2].content == " A bright sentence2." assert doc_chunks[2].meta["split_id"] == 2 assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content) assert doc_chunks[2].meta["_split_overlap"] == [ - {"doc_id": doc_chunks[1].id, "range": (20, 31)}, - {"doc_id": doc_chunks[3].id, "range": (0, 11)}, + {"doc_id": doc_chunks[1].id, "range": (10, 20)}, + {"doc_id": doc_chunks[3].id, "range": (0, 10)}, ] - assert doc_chunks[3].content == " sentence3. A joyful sentence4" + assert doc_chunks[3].content == "sentence2. A clever " assert doc_chunks[3].meta["split_id"] == 3 assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content) - assert doc_chunks[3].meta["_split_overlap"] == [{"doc_id": doc_chunks[2].id, "range": (20, 31)}] + assert doc_chunks[3].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[2].id, "range": (10, 20)}, + {"doc_id": doc_chunks[4].id, "range": (0, 10)}, + ] + + assert doc_chunks[4].content == " A clever sentence3" + assert doc_chunks[4].meta["split_id"] == 4 + assert doc_chunks[4].meta["split_idx_start"] == text.index(doc_chunks[4].content) + assert doc_chunks[4].meta["_split_overlap"] == [{"doc_id": doc_chunks[3].id, "range": (10, 20)}] def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking(): @@ -421,31 +428,38 @@ def test_run_fallback_to_word_chunking_by_default_length_too_short(): def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit(): """Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap""" - splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"], split_unit="char") + splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=10, separators=["sentence"], split_unit="char") text = "This is sentence one. This is sentence two. This is sentence three." doc = Document(content=text) doc_chunks = splitter.run([doc])["documents"] - assert len(doc_chunks) == 3 - + assert len(doc_chunks) == 4 assert doc_chunks[0].content == "This is sentence one. " assert doc_chunks[0].meta["split_id"] == 0 assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content) - assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 5)}] + assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 10)}] - assert doc_chunks[1].content == "one. This is sentence two. " + assert doc_chunks[1].content == "ence one. This is sentenc" assert doc_chunks[1].meta["split_id"] == 1 assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content) assert doc_chunks[1].meta["_split_overlap"] == [ - {"doc_id": doc_chunks[0].id, "range": (17, 22)}, - {"doc_id": doc_chunks[2].id, "range": (0, 5)}, + {"doc_id": doc_chunks[0].id, "range": (12, 22)}, + {"doc_id": doc_chunks[2].id, "range": (0, 10)}, ] - assert doc_chunks[2].content == "two. This is sentence three." + assert doc_chunks[2].content == "is sentence two. This is " assert doc_chunks[2].meta["split_id"] == 2 assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content) - assert doc_chunks[2].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (22, 27)}] + assert doc_chunks[2].meta["_split_overlap"] == [ + {"doc_id": doc_chunks[1].id, "range": (15, 25)}, + {"doc_id": doc_chunks[3].id, "range": (0, 10)}, + ] + + assert doc_chunks[3].content == ". This is sentence three." + assert doc_chunks[3].meta["split_id"] == 3 + assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content) + assert doc_chunks[3].meta["_split_overlap"] == [{"doc_id": doc_chunks[2].id, "range": (15, 25)}] def test_run_split_by_dot_count_page_breaks_word_unit() -> None: @@ -673,15 +687,20 @@ def test_run_custom_sentence_tokenizer_document_and_overlap_word_unit_no_overlap assert chunks[2].content == " This is sentence three." -def test_run_custom_sentence_tokenizer_document_and_overlap_word_unit_overlap_2_words(): - splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=2, separators=["."], split_unit="word") +def test_run_custom_split_by_dot_and_overlap_1_word_unit(): + splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=1, separators=["."], split_unit="word") text = "This is sentence one. This is sentence two. This is sentence three. This is sentence four." chunks = splitter.run([Document(content=text)])["documents"] - assert len(chunks) == 4 + assert len(chunks) == 5 assert chunks[0].content == "This is sentence one." - assert chunks[1].content == "sentence one. This is sentence two." - assert chunks[2].content == "sentence two. This is sentence three." - assert chunks[3].content == "sentence three. This is sentence four." + assert chunks[1].content == "one. This is sentence" + assert chunks[2].content == "sentence two. This is" + assert chunks[3].content == "is sentence three. This" + assert chunks[4].content == "This is sentence four." + + +def test_run_custom_split_by_dot_and_overlap_3_char_unit(): + pass def test_run_serialization_in_pipeline():