Skip to content

Commit

Permalink
customising overlap function for word and adding a few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Dec 17, 2024
1 parent a418f73 commit 71ce15b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
46 changes: 29 additions & 17 deletions haystack/components/preprocessors/recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self,
split_length: int = 200,
split_overlap: int = 0,
split_units: Literal["words", "char"] = "char",
split_units: Literal["word", "char"] = "char",
separators: Optional[List[str]] = None,
sentence_splitter_params: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -114,8 +114,12 @@ def _apply_overlap(self, chunks: List[str]) -> List[str]:
if idx == 0:
overlapped_chunks.append(chunk)
continue
overlap_start = max(0, len(chunks[idx - 1]) - self.split_overlap)
overlap = chunks[idx - 1][overlap_start:]
overlap_start = max(0, self._chunk_length(chunks[idx - 1]) - self.split_overlap)
if self.split_units == "word":
word_chunks = chunks[idx - 1].split()
overlap = " ".join(word_chunks[overlap_start:])
else:
overlap = chunks[idx - 1][overlap_start:]
if overlap == chunks[idx - 1]:
logger.warning(
"Overlap is the same as the previous chunk. "
Expand All @@ -134,7 +138,7 @@ def _chunk_length(self, text: str) -> int:
:returns:
The length of the chunk in words or characters.
"""
if self.split_units == "words":
if self.split_units == "word":
return len(text.split())
else:
return len(text)
Expand Down Expand Up @@ -214,36 +218,44 @@ def _chunk_text(self, text: str) -> List[str]:
if chunks:
return chunks

# if no separator worked, fall back to character-level chunking
# if no separator worked, fall back to character- or word-level chunking
return [
text[i : i + self.split_length]
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]

def _add_overlap_info(self, curr_pos, new_doc, new_docs):
prev_doc = new_docs[-1]
overlap_length = self._chunk_length(prev_doc.content) - (curr_pos - prev_doc.meta["split_idx_start"]) # type: ignore
if overlap_length > 0:
prev_doc.meta["_split_overlap"].append({"doc_id": new_doc.id, "range": (0, overlap_length)})
new_doc.meta["_split_overlap"].append(
{
"doc_id": prev_doc.id,
"range": (
self._chunk_length(prev_doc.content) - overlap_length,
self._chunk_length(prev_doc.content), # type: ignore
),
}
)

def _run_one(self, doc: Document) -> List[Document]:
new_docs: List[Document] = []
chunks = self._chunk_text(doc.content) # type: ignore # the caller already check for a non-empty doc.content
chunks = chunks[:-1] if len(chunks[-1]) == 0 else chunks # remove last empty chunk
chunks = chunks[:-1] if len(chunks[-1]) == 0 else chunks # remove last empty chunk if it exists
current_position = 0
current_page = 1

new_docs: List[Document] = []

for split_nr, chunk in enumerate(chunks):
new_doc = Document(content=chunk, meta=deepcopy(doc.meta))
new_doc.meta["split_id"] = split_nr
new_doc.meta["split_idx_start"] = current_position
new_doc.meta["_split_overlap"] = [] if self.split_overlap > 0 else None

# add overlap information to the previous and current doc
if split_nr > 0 and self.split_overlap > 0:
previous_doc = new_docs[-1]
overlap_length = len(previous_doc.content) - (current_position - previous_doc.meta["split_idx_start"]) # type: ignore
if overlap_length > 0:
previous_doc.meta["_split_overlap"].append({"doc_id": new_doc.id, "range": (0, overlap_length)})
new_doc.meta["_split_overlap"].append(
{
"doc_id": previous_doc.id,
"range": (len(previous_doc.content) - overlap_length, len(previous_doc.content)), # type: ignore
}
)
self._add_overlap_info(current_position, new_doc, new_docs)

# count page breaks in the chunk
current_page += chunk.count("\f")
Expand Down
23 changes: 22 additions & 1 deletion test/components/preprocessors/test_recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,28 @@ def test_recursive_splitter_custom_sentence_tokenizer_document_and_overlap():
assert doc_chunks[2].meta["_split_overlap"] == [{"doc_id": doc_chunks[1].id, "range": (22, 27)}]


def test_run_split_document_with_overlap():
def test_recursive_splitter_custom_sentence_tokenizer_document_and_overlap_word_unit_no_overlap():
splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=0, separators=["."], split_units="word")
text = "This is sentence one. This is sentence two. This is sentence three."
chunks = splitter.run([Document(content=text)])["documents"]
assert len(chunks) == 3
assert chunks[0].content == "This is sentence one."
assert chunks[1].content == " This is sentence two."
assert chunks[2].content == " This is sentence three."


def test_recursive_splitter_custom_sentence_tokenizer_document_and_overlap_word_unit_overlap_2_words():
splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=2, separators=["."], split_units="word")
text = "This is sentence one. This is sentence two. This is sentence three. This is sentence four."
chunks = splitter.run([Document(content=text)])["documents"]
assert len(chunks) == 4
assert chunks[0].content == "This is sentence one."
assert chunks[1].content == "sentence one. This is sentence two."
assert chunks[2].content == "sentence two. This is sentence three."
assert chunks[3].content == "sentence three. This is sentence four."


def test_run_split_document_with_overlap_character_unit():
splitter = RecursiveDocumentSplitter(split_length=20, split_overlap=11, separators=[".", " "])
text = """A simple sentence1. A bright sentence2. A clever sentence3. A joyful sentence4"""

Expand Down

0 comments on commit 71ce15b

Please sign in to comment.