Skip to content

Commit

Permalink
fixing serialisation and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Dec 11, 2024
1 parent 9d682f5 commit 09e67fa
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 17 deletions.
40 changes: 23 additions & 17 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,25 @@ def __init__( # pylint: disable=too-many-positional-arguments
of curated abbreviations, if available. This is currently supported for English ("en") and German ("de").
"""

self._init_checks(
split_by=split_by,
split_length=split_length,
split_overlap=split_overlap,
splitting_function=splitting_function,
respect_sentence_boundary=respect_sentence_boundary,
)

self.split_by = split_by
self.split_length = split_length
self.split_overlap = split_overlap
self.split_threshold = split_threshold
self.splitting_function = splitting_function
self.respect_sentence_boundary = respect_sentence_boundary
self.language = (language,)
self.use_split_rules = use_split_rules
self.extend_abbreviations = extend_abbreviations

if split_by == "nltk_sentence" or respect_sentence_boundary and split_by == "word":
self._init_checks(
split_by=split_by,
split_length=split_length,
split_overlap=split_overlap,
splitting_function=splitting_function,
respect_sentence_boundary=respect_sentence_boundary,
)

if split_by == "nltk_sentence" or (respect_sentence_boundary and split_by == "word"):
nltk_imports.check()
self.sentence_splitter = SentenceSplitter(
language=language,
Expand All @@ -115,8 +116,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
)
self.language = language

@staticmethod
def _init_checks(
self,
split_by: str,
split_length: int,
split_overlap: int,
Expand Down Expand Up @@ -151,6 +152,7 @@ def _init_checks(
"The 'respect_sentence_boundary' option is only supported for `split_by='word'`. "
"The option `respect_sentence_boundary` will be set to `False`."
)
self.respect_sentence_boundary = False

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
Expand Down Expand Up @@ -197,13 +199,12 @@ def _split_document(self, doc: Document) -> List[Document]:
return self._split_by_character(doc)

def _split_by_nltk_sentence(self, doc: Document) -> List[Document]:
if doc.content is None:
return []

split_docs = []

# whitespace is preserved while splitting text into sentences when using keep_white_spaces=True
# so split_at is set to an empty string
self.split_at = ""
# self.split_at = ""

result = self.sentence_splitter.split_sentences(doc.content)
units = [sentence["sentence"] for sentence in result]

Expand Down Expand Up @@ -362,6 +363,10 @@ def to_dict(self) -> Dict[str, Any]:
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
respect_sentence_boundary=self.respect_sentence_boundary,
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
)
if self.splitting_function:
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
Expand All @@ -380,8 +385,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter":

return default_from_dict(cls, data)

@staticmethod
def _concatenate_sentences_based_on_word_amount(
self, sentences: List[str], split_length: int, split_overlap: int
sentences: List[str], split_length: int, split_overlap: int
) -> Tuple[List[str], List[int], List[int]]:
"""
Groups the sentences into chunks of `split_length` words while respecting sentence boundaries.
Expand All @@ -394,12 +400,12 @@ def _concatenate_sentences_based_on_word_amount(
:param split_overlap: The number of overlapping words in each split.
:returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices.
"""
# Chunk information
# chunk information
chunk_word_count = 0
chunk_starting_page_number = 1
chunk_start_idx = 0
current_chunk: List[str] = []
# Output lists
# output lists
split_start_page_numbers = []
list_of_splits: List[List[str]] = []
split_start_indices = []
Expand Down
47 changes: 47 additions & 0 deletions test/components/preprocessors/test_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,3 +749,50 @@ def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page
assert documents[5].meta["page_number"] == 3
assert documents[5].meta["split_id"] == 5
assert documents[5].meta["split_idx_start"] == text.index(documents[5].content)

def test_respect_sentence_boundary_checks(self):
# this combination triggers the warning
splitter = DocumentSplitter(split_by="sentence", split_length=10, respect_sentence_boundary=True)
assert splitter.respect_sentence_boundary == False

def test_nltk_sentence_serialization(self):
"""Test serialization with NLTK sentence splitting configuration and using non-default values"""
splitter = DocumentSplitter(
split_by="nltk_sentence",
language="de",
use_split_rules=False,
extend_abbreviations=False,
respect_sentence_boundary=False,
)
serialized = splitter.to_dict()
deserialized = DocumentSplitter.from_dict(serialized)

assert deserialized.split_by == "nltk_sentence"
assert hasattr(deserialized, "sentence_splitter")
assert deserialized.language == "de"
assert deserialized.use_split_rules == False
assert deserialized.extend_abbreviations == False
assert deserialized.respect_sentence_boundary == False

def test_nltk_serialization_roundtrip(self):
"""Test complete serialization roundtrip with actual document splitting"""
splitter = DocumentSplitter(
split_by="nltk_sentence",
language="de",
use_split_rules=False,
extend_abbreviations=False,
respect_sentence_boundary=False,
)
serialized = splitter.to_dict()
deserialized_splitter = DocumentSplitter.from_dict(serialized)
assert splitter.split_by == deserialized_splitter.split_by

def test_respect_sentence_boundary_serialization(self):
"""Test serialization with respect_sentence_boundary option"""
splitter = DocumentSplitter(split_by="word", respect_sentence_boundary=True, language="de")
serialized = splitter.to_dict()
deserialized = DocumentSplitter.from_dict(serialized)

assert deserialized.respect_sentence_boundary == True
assert hasattr(deserialized, "sentence_splitter")
assert deserialized.language == "de"

0 comments on commit 09e67fa

Please sign in to comment.