Skip to content

Commit

Permalink
!feat: unify NLTKDocumentSplitter and DocumentSplitter (#8617)
Browse files Browse the repository at this point in the history
* wip: initial import

* wip: refactoring

* wip: refactoring tests

* wip: refactoring tests

* making all NLTKSplitter related tests work

* refactoring

* docstrings

* refactoring and removing NLTKDocumentSplitter

* fixing tests for custom sentence tokenizer

* fixing tests for custom sentence tokenizer

* cleaning up

* adding release notes

* reverting some changes

* cleaning up tests

* fixing serialisation and adding tests

* cleaning up

* wip

* renaming and cleaning

* adding NLTK files

* updating docstring

* adding import to init

* Update haystack/components/preprocessors/document_splitter.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* updating tests

* wip

* adding sentence/period change warning

* fixing LICENSE header

* Update haystack/components/preprocessors/document_splitter.py

Co-authored-by: Stefano Fiorucci <[email protected]>

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
davidsbatista and anakin87 authored Dec 12, 2024
1 parent 6cceaac commit 3f77d3a
Show file tree
Hide file tree
Showing 6 changed files with 635 additions and 53 deletions.
274 changes: 239 additions & 35 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
#
# SPDX-License-Identifier: Apache-2.0

import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from more_itertools import windowed

from haystack import Document, component, logging
from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter, nltk_imports
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.utils import deserialize_callable, serialize_callable

logger = logging.getLogger(__name__)

# Maps the 'split_by' argument to the actual char used to split the Documents.
# 'function' is not in the mapping cause it doesn't split on chars.
_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "sentence": ".", "word": " ", "line": "\n"}
# mapping of split by character, 'function' and 'sentence' don't split by character
_CHARACTER_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"}


@component
class DocumentSplitter:
"""
Splits long documents into smaller chunks.
This is a common preprocessing step during indexing.
It helps Embedders create meaningful semantic representations
This is a common preprocessing step during indexing. It helps Embedders create meaningful semantic representations
and prevents exceeding language model context limits.
The DocumentSplitter is compatible with the following DocumentStores:
Expand Down Expand Up @@ -54,40 +54,115 @@ class DocumentSplitter:

def __init__( # pylint: disable=too-many-positional-arguments
self,
split_by: Literal["function", "page", "passage", "sentence", "word", "line"] = "word",
split_by: Literal["function", "page", "passage", "period", "word", "line", "sentence"] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_threshold: int = 0,
splitting_function: Optional[Callable[[str], List[str]]] = None,
respect_sentence_boundary: bool = False,
language: Language = "en",
use_split_rules: bool = True,
extend_abbreviations: bool = True,
):
"""
Initialize DocumentSplitter.
:param split_by: The unit for splitting your documents. Choose from `word` for splitting by spaces (" "),
`sentence` for splitting by periods ("."), `page` for splitting by form feed ("\\f"),
`passage` for splitting by double line breaks ("\\n\\n") or `line` for splitting each line ("\\n").
:param split_by: The unit for splitting your documents. Choose from:
- `word` for splitting by spaces (" ")
- `period` for splitting by periods (".")
- `page` for splitting by form feed ("\\f")
- `passage` for splitting by double line breaks ("\\n\\n")
- `line` for splitting each line ("\\n")
- `sentence` for splitting by NLTK sentence tokenizer
:param split_length: The maximum number of units in each split.
:param split_overlap: The number of overlapping units for each split.
:param split_threshold: The minimum number of units per split. If a split has fewer units
than the threshold, it's attached to the previous split.
:param splitting_function: Necessary when `split_by` is set to "function".
This is a function which must accept a single `str` as input and return a `list` of `str` as output,
representing the chunks after splitting.
:param respect_sentence_boundary: Choose whether to respect sentence boundaries when splitting by "word".
If True, uses NLTK to detect sentence boundaries, ensuring splits occur only between sentences.
:param language: Choose the language for the NLTK tokenizer. The default is English ("en").
:param use_split_rules: Choose whether to use additional split rules when splitting by `sentence`.
:param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list
of curated abbreviations, if available. This is currently supported for English ("en") and German ("de").
"""

self.split_by = split_by
if split_by not in ["function", "page", "passage", "sentence", "word", "line"]:
raise ValueError("split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'.")
self.split_length = split_length
self.split_overlap = split_overlap
self.split_threshold = split_threshold
self.splitting_function = splitting_function
self.respect_sentence_boundary = respect_sentence_boundary
self.language = language
self.use_split_rules = use_split_rules
self.extend_abbreviations = extend_abbreviations

self._init_checks(
split_by=split_by,
split_length=split_length,
split_overlap=split_overlap,
splitting_function=splitting_function,
respect_sentence_boundary=respect_sentence_boundary,
)

if split_by == "sentence" or (respect_sentence_boundary and split_by == "word"):
nltk_imports.check()
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
extend_abbreviations=extend_abbreviations,
keep_white_spaces=True,
)

if split_by == "sentence":
# ToDo: remove this warning in the next major release
msg = (
"The `split_by='sentence'` no longer splits by '.' and now relies on custom sentence tokenizer "
"based on NLTK. To achieve the previous behaviour use `split_by='period'."
)
warnings.warn(msg)

def _init_checks(
self,
*,
split_by: str,
split_length: int,
split_overlap: int,
splitting_function: Optional[Callable],
respect_sentence_boundary: bool,
) -> None:
"""
Validates initialization parameters for DocumentSplitter.
:param split_by: The unit for splitting documents
:param split_length: The maximum number of units in each split
:param split_overlap: The number of overlapping units for each split
:param splitting_function: Custom function for splitting when split_by="function"
:param respect_sentence_boundary: Whether to respect sentence boundaries when splitting
:raises ValueError: If any parameter is invalid
"""
valid_split_by = ["function", "page", "passage", "period", "word", "line", "sentence"]
if split_by not in valid_split_by:
raise ValueError(f"split_by must be one of {', '.join(valid_split_by)}.")

if split_by == "function" and splitting_function is None:
raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.")

if split_length <= 0:
raise ValueError("split_length must be greater than 0.")
self.split_length = split_length

if split_overlap < 0:
raise ValueError("split_overlap must be greater than or equal to 0.")
self.split_overlap = split_overlap
self.split_threshold = split_threshold
self.splitting_function = splitting_function

if respect_sentence_boundary and split_by != "word":
logger.warning(
"The 'respect_sentence_boundary' option is only supported for `split_by='word'`. "
"The option `respect_sentence_boundary` will be set to `False`."
)
self.respect_sentence_boundary = False

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
Expand All @@ -98,7 +173,6 @@ def run(self, documents: List[Document]):
and an overlap of `split_overlap`.
:param documents: The documents to split.
:returns: A dictionary with the following key:
- `documents`: List of documents with the split texts. Each document includes:
- A metadata field `source_id` to track the original document.
Expand All @@ -121,39 +195,69 @@ def run(self, documents: List[Document]):
if doc.content == "":
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
continue
split_docs += self._split(doc)

split_docs += self._split_document(doc)
return {"documents": split_docs}

def _split(self, to_split: Document) -> List[Document]:
# We already check this before calling _split but
# we need to make linters happy
if to_split.content is None:
return []
def _split_document(self, doc: Document) -> List[Document]:
if self.split_by == "sentence" or self.respect_sentence_boundary:
return self._split_by_nltk_sentence(doc)

if self.split_by == "function" and self.splitting_function is not None:
splits = self.splitting_function(to_split.content)
docs: List[Document] = []
for s in splits:
meta = deepcopy(to_split.meta)
meta["source_id"] = to_split.id
docs.append(Document(content=s, meta=meta))
return docs

split_at = _SPLIT_BY_MAPPING[self.split_by]
units = to_split.content.split(split_at)
return self._split_by_function(doc)

return self._split_by_character(doc)

def _split_by_nltk_sentence(self, doc: Document) -> List[Document]:
split_docs = []

result = self.sentence_splitter.split_sentences(doc.content) # type: ignore # None check is done in run()
units = [sentence["sentence"] for sentence in result]

if self.respect_sentence_boundary:
text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount(
sentences=units, split_length=self.split_length, split_overlap=self.split_overlap
)
else:
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
elements=units,
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
)
metadata = deepcopy(doc.meta)
metadata["source_id"] = doc.id
split_docs += self._create_docs_from_splits(
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
)

return split_docs

def _split_by_character(self, doc) -> List[Document]:
split_at = _CHARACTER_SPLIT_BY_MAPPING[self.split_by]
units = doc.content.split(split_at)
# Add the delimiter back to all units except the last one
for i in range(len(units) - 1):
units[i] += split_at

text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
units, self.split_length, self.split_overlap, self.split_threshold
)
metadata = deepcopy(to_split.meta)
metadata["source_id"] = to_split.id
metadata = deepcopy(doc.meta)
metadata["source_id"] = doc.id
return self._create_docs_from_splits(
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
)

def _split_by_function(self, doc) -> List[Document]:
# the check for None is done already in the run method
splits = self.splitting_function(doc.content) # type: ignore
docs: List[Document] = []
for s in splits:
meta = deepcopy(doc.meta)
meta["source_id"] = doc.id
docs.append(Document(content=s, meta=meta))
return docs

def _concatenate_units(
self, elements: List[str], split_length: int, split_overlap: int, split_threshold: int
) -> Tuple[List[str], List[int], List[int]]:
Expand Down Expand Up @@ -265,6 +369,10 @@ def to_dict(self) -> Dict[str, Any]:
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
respect_sentence_boundary=self.respect_sentence_boundary,
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
)
if self.splitting_function:
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
Expand All @@ -282,3 +390,99 @@ def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter":
init_params["splitting_function"] = deserialize_callable(splitting_function)

return default_from_dict(cls, data)

@staticmethod
def _concatenate_sentences_based_on_word_amount(
sentences: List[str], split_length: int, split_overlap: int
) -> Tuple[List[str], List[int], List[int]]:
"""
Groups the sentences into chunks of `split_length` words while respecting sentence boundaries.
This function is only used when splitting by `word` and `respect_sentence_boundary` is set to `True`, i.e.:
with NLTK sentence tokenizer.
:param sentences: The list of sentences to split.
:param split_length: The maximum number of words in each split.
:param split_overlap: The number of overlapping words in each split.
:returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices.
"""
# chunk information
chunk_word_count = 0
chunk_starting_page_number = 1
chunk_start_idx = 0
current_chunk: List[str] = []
# output lists
split_start_page_numbers = []
list_of_splits: List[List[str]] = []
split_start_indices = []

for sentence_idx, sentence in enumerate(sentences):
current_chunk.append(sentence)
chunk_word_count += len(sentence.split())
next_sentence_word_count = (
len(sentences[sentence_idx + 1].split()) if sentence_idx < len(sentences) - 1 else 0
)

# Number of words in the current chunk plus the next sentence is larger than the split_length,
# or we reached the last sentence
if (chunk_word_count + next_sentence_word_count) > split_length or sentence_idx == len(sentences) - 1:
# Save current chunk and start a new one
list_of_splits.append(current_chunk)
split_start_page_numbers.append(chunk_starting_page_number)
split_start_indices.append(chunk_start_idx)

# Get the number of sentences that overlap with the next chunk
num_sentences_to_keep = DocumentSplitter._number_of_sentences_to_keep(
sentences=current_chunk, split_length=split_length, split_overlap=split_overlap
)
# Set up information for the new chunk
if num_sentences_to_keep > 0:
# Processed sentences are the ones that are not overlapping with the next chunk
processed_sentences = current_chunk[:-num_sentences_to_keep]
chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences)
chunk_start_idx += len("".join(processed_sentences))
# Next chunk starts with the sentences that were overlapping with the previous chunk
current_chunk = current_chunk[-num_sentences_to_keep:]
chunk_word_count = sum(len(s.split()) for s in current_chunk)
else:
# Here processed_sentences is the same as current_chunk since there is no overlap
chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk)
chunk_start_idx += len("".join(current_chunk))
current_chunk = []
chunk_word_count = 0

# Concatenate the sentences together within each split
text_splits = []
for split in list_of_splits:
text = "".join(split)
if len(text) > 0:
text_splits.append(text)

return text_splits, split_start_page_numbers, split_start_indices

@staticmethod
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int:
"""
Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`.
:param sentences: The list of sentences to split.
:param split_length: The maximum number of words in each split.
:param split_overlap: The number of overlapping words in each split.
:returns: The number of sentences to keep in the next chunk.
"""
# If the split_overlap is 0, we don't need to keep any sentences
if split_overlap == 0:
return 0

num_sentences_to_keep = 0
num_words = 0
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for sent in reversed(sentences[1:]):
num_words += len(sent.split())
# If the number of words is larger than the split_length then don't add any more sentences
if num_words > split_length:
break
num_sentences_to_keep += 1
if num_words > split_overlap:
break
return num_sentences_to_keep
Loading

0 comments on commit 3f77d3a

Please sign in to comment.