Skip to content

Commit

Permalink
fixing the overlap bug
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Dec 20, 2024
1 parent d292de6 commit b09154e
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 41 deletions.
93 changes: 80 additions & 13 deletions haystack/components/preprocessors/recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,33 +118,100 @@ def _get_custom_sentence_tokenizer(sentence_splitter_params: Dict[str, Any]):
def _apply_overlap(self, chunks: List[str]) -> List[str]:
"""
Applies an overlap between consecutive chunks if the chunk_overlap attribute is greater than zero.
:param chunks: List of text chunks.
:returns:
The list of chunks with overlap applied.
"""
overlapped_chunks = []
overlapped_chunks: List[str] = []
remaining_words: List[str] = []
remaining_chars: str = ""

for idx, chunk in enumerate(chunks):
if idx == 0:
overlapped_chunks.append(chunk)
continue
overlap_start = max(0, self._chunk_length(chunks[idx - 1]) - self.split_overlap)
if self.split_units == "word":
word_chunks = chunks[idx - 1].split()
overlap = " ".join(word_chunks[overlap_start:])
else:
overlap = chunks[idx - 1][overlap_start:]
if overlap == chunks[idx - 1]:

overlap, prev_chunk = self._get_overlap(overlapped_chunks)

if overlap == prev_chunk:
logger.warning(
"Overlap is the same as the previous chunk. "
"Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter."
)
current_chunk = overlap + chunk

# create new chunk starting with the overlap
if self.split_units == "word":
current_chunk = overlap + " " + chunk
else:
current_chunk = overlap + chunk

# if the new chunk exceeds split_length, trim it and add the trimmed content to the next chunk
if self._chunk_length(current_chunk) > self.split_length:
if self.split_units == "word":
words = current_chunk.split()
current_chunk = " ".join(words[: self.split_length])
remaining_words = words[self.split_length :]
if idx < len(chunks) - 1:
# add remaining words to the beginning of the next chunk
chunks[idx + 1] = " ".join(remaining_words) + " " + chunks[idx + 1]
elif remaining_words:
# if this is the last chunk, and we have remaining words
overlapped_chunks.append(current_chunk)
current_chunk = " ".join(remaining_words)

else: # char-level splitting
text = current_chunk
current_chunk = text[: self.split_length]
remaining_chars = text[self.split_length :]
if idx < len(chunks) - 1:
# add remaining chars to the beginning of the next chunk
chunks[idx + 1] = remaining_chars + chunks[idx + 1]
elif remaining_chars: # if this is the last chunk and we have remaining chars
overlapped_chunks.append(current_chunk)
current_chunk = remaining_chars

# if this is the last chunk, and we have remaining words or characters, add them to the current chunk
if idx == len(chunks) - 1 and (remaining_words or remaining_chars):
overlap, prev_chunk = self._get_overlap(overlapped_chunks)
if remaining_words:
current_chunk = overlap + " " + current_chunk
if remaining_chars:
current_chunk = overlap + current_chunk

overlapped_chunks.append(current_chunk)

# check if the last chunk exceeds split_length and split it
if idx == len(chunks) - 1 and self._chunk_length(current_chunk) > self.split_length:
# split the last chunk and add the first chunk to the list
last_chunk = overlapped_chunks.pop()
if self.split_units == "word":
words = last_chunk.split()
first_chunk = " ".join(words[: self.split_length])
remaining_chunk = " ".join(words[self.split_length :])
else:
first_chunk = last_chunk[: self.split_length]
remaining_chunk = last_chunk[self.split_length :]
overlapped_chunks.append(first_chunk)

# add the remaining chunk with overlap from the previous chunk
if remaining_chunk:
overlap, prev_chunk = self._get_overlap(overlapped_chunks)
if self.split_units == "word":
remaining_chunk = overlap + " " + remaining_chunk
else:
remaining_chunk = overlap + remaining_chunk
overlapped_chunks.append(remaining_chunk)

return overlapped_chunks

def _get_overlap(self, overlapped_chunks):
"""Get the previous overlapped chunk instead of the original chunk."""
prev_chunk = overlapped_chunks[-1]
overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap)
if self.split_units == "word":
word_chunks = prev_chunk.split()
overlap = " ".join(word_chunks[overlap_start:])
else:
overlap = prev_chunk[overlap_start:]
return overlap, prev_chunk

def _chunk_length(self, text: str) -> int:
"""
Get the length of the chunk in words or characters.
Expand Down
75 changes: 47 additions & 28 deletions test/components/preprocessors/test_recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,40 +354,47 @@ def test_run_split_by_sentence_count_page_breaks_split_unit_char() -> None:


def test_run_split_document_with_overlap_character_unit():
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "], split_unit="char")
text = """A simple sentence1. A bright sentence2. A clever sentence3. A joyful sentence4"""
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=10, separators=["."], split_unit="char")
text = """A simple sentence1. A bright sentence2. A clever sentence3"""

doc = Document(content=text)
doc_chunks = splitter.run([doc])
doc_chunks = doc_chunks["documents"]

assert len(doc_chunks) == 4

assert len(doc_chunks) == 5
assert doc_chunks[0].content == "A simple sentence1."
assert doc_chunks[0].meta["split_id"] == 0
assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content)
assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 11)}]
assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 10)}]

assert doc_chunks[1].content == " sentence1. A bright sentence2."
assert doc_chunks[1].content == "sentence1. A bright "
assert doc_chunks[1].meta["split_id"] == 1
assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content)
assert doc_chunks[1].meta["_split_overlap"] == [
{"doc_id": doc_chunks[0].id, "range": (8, 19)},
{"doc_id": doc_chunks[2].id, "range": (0, 11)},
{"doc_id": doc_chunks[0].id, "range": (9, 19)},
{"doc_id": doc_chunks[2].id, "range": (0, 10)},
]

assert doc_chunks[2].content == " sentence2. A clever sentence3."
assert doc_chunks[2].content == " A bright sentence2."
assert doc_chunks[2].meta["split_id"] == 2
assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content)
assert doc_chunks[2].meta["_split_overlap"] == [
{"doc_id": doc_chunks[1].id, "range": (20, 31)},
{"doc_id": doc_chunks[3].id, "range": (0, 11)},
{"doc_id": doc_chunks[1].id, "range": (10, 20)},
{"doc_id": doc_chunks[3].id, "range": (0, 10)},
]

assert doc_chunks[3].content == " sentence3. A joyful sentence4"
assert doc_chunks[3].content == "sentence2. A clever "
assert doc_chunks[3].meta["split_id"] == 3
assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content)
assert doc_chunks[3].meta["_split_overlap"] == [{"doc_id": doc_chunks[2].id, "range": (20, 31)}]
assert doc_chunks[3].meta["_split_overlap"] == [
{"doc_id": doc_chunks[2].id, "range": (10, 20)},
{"doc_id": doc_chunks[4].id, "range": (0, 10)},
]

assert doc_chunks[4].content == " A clever sentence3"
assert doc_chunks[4].meta["split_id"] == 4
assert doc_chunks[4].meta["split_idx_start"] == text.index(doc_chunks[4].content)
assert doc_chunks[4].meta["_split_overlap"] == [{"doc_id": doc_chunks[3].id, "range": (10, 20)}]


def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking():
Expand Down Expand Up @@ -421,31 +428,38 @@ def test_run_fallback_to_word_chunking_by_default_length_too_short():

def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit():
"""Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap"""
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"], split_unit="char")
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=10, separators=["sentence"], split_unit="char")
text = "This is sentence one. This is sentence two. This is sentence three."

doc = Document(content=text)
doc_chunks = splitter.run([doc])["documents"]

assert len(doc_chunks) == 3

assert len(doc_chunks) == 4
assert doc_chunks[0].content == "This is sentence one. "
assert doc_chunks[0].meta["split_id"] == 0
assert doc_chunks[0].meta["split_idx_start"] == text.index(doc_chunks[0].content)
assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 5)}]
assert doc_chunks[0].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (0, 10)}]

assert doc_chunks[1].content == "one. This is sentence two. "
assert doc_chunks[1].content == "ence one. This is sentenc"
assert doc_chunks[1].meta["split_id"] == 1
assert doc_chunks[1].meta["split_idx_start"] == text.index(doc_chunks[1].content)
assert doc_chunks[1].meta["_split_overlap"] == [
{"doc_id": doc_chunks[0].id, "range": (17, 22)},
{"doc_id": doc_chunks[2].id, "range": (0, 5)},
{"doc_id": doc_chunks[0].id, "range": (12, 22)},
{"doc_id": doc_chunks[2].id, "range": (0, 10)},
]

assert doc_chunks[2].content == "two. This is sentence three."
assert doc_chunks[2].content == "is sentence two. This is "
assert doc_chunks[2].meta["split_id"] == 2
assert doc_chunks[2].meta["split_idx_start"] == text.index(doc_chunks[2].content)
assert doc_chunks[2].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (22, 27)}]
assert doc_chunks[2].meta["_split_overlap"] == [
{"doc_id": doc_chunks[1].id, "range": (15, 25)},
{"doc_id": doc_chunks[3].id, "range": (0, 10)},
]

assert doc_chunks[3].content == ". This is sentence three."
assert doc_chunks[3].meta["split_id"] == 3
assert doc_chunks[3].meta["split_idx_start"] == text.index(doc_chunks[3].content)
assert doc_chunks[3].meta["_split_overlap"] == [{"doc_id": doc_chunks[2].id, "range": (15, 25)}]


def test_run_split_by_dot_count_page_breaks_word_unit() -> None:
Expand Down Expand Up @@ -673,15 +687,20 @@ def test_run_custom_sentence_tokenizer_document_and_overlap_word_unit_no_overlap
assert chunks[2].content == " This is sentence three."


def test_run_custom_sentence_tokenizer_document_and_overlap_word_unit_overlap_2_words():
splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=2, separators=["."], split_unit="word")
def test_run_custom_split_by_dot_and_overlap_1_word_unit():
splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=1, separators=["."], split_unit="word")
text = "This is sentence one. This is sentence two. This is sentence three. This is sentence four."
chunks = splitter.run([Document(content=text)])["documents"]
assert len(chunks) == 4
assert len(chunks) == 5
assert chunks[0].content == "This is sentence one."
assert chunks[1].content == "sentence one. This is sentence two."
assert chunks[2].content == "sentence two. This is sentence three."
assert chunks[3].content == "sentence three. This is sentence four."
assert chunks[1].content == "one. This is sentence"
assert chunks[2].content == "sentence two. This is"
assert chunks[3].content == "is sentence three. This"
assert chunks[4].content == "This is sentence four."


def test_run_custom_split_by_dot_and_overlap_3_char_unit():
pass


def test_run_serialization_in_pipeline():
Expand Down

0 comments on commit b09154e

Please sign in to comment.