Skip to content

Commit

Permalink
feat: PreProcessor split by token (tiktoken & Hugging Face) (#5276)
Browse files Browse the repository at this point in the history
* #4983 implemented split by token for tiktoken tokenizer

* #4983 added unit test for tiktoken splitting

* #4983 implemented and added a test for splitting documents with HuggingFace tokenizer

* #4983 added support for passing HF model names (instead of objects) and added an example to the HF token splitting test

* mocked HTTP model loading in unit tests, fixed pylint error

* fix lossy tokenizers splitting, use LazyImport, ignore UnicodeEncodeError for tiktoken

* reno

* rename reno file

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
Co-authored-by: ZanSara <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2023
1 parent e04a1f1 commit a492771
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 32 deletions.
14 changes: 10 additions & 4 deletions haystack/nodes/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from abc import abstractmethod

from transformers import PreTrainedTokenizerBase

from haystack.nodes.base import BaseComponent
from haystack.schema import Document

Expand All @@ -17,10 +19,11 @@ def process(
clean_header_footer: Optional[bool] = False,
clean_empty_lines: Optional[bool] = True,
remove_substrings: Optional[List[str]] = None,
split_by: Literal["word", "sentence", "passage", None] = "word",
split_by: Literal["token", "word", "sentence", "passage", None] = "word",
split_length: Optional[int] = 1000,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = True,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""
Expand All @@ -44,10 +47,11 @@ def clean(
def split(
self,
document: Union[dict, Document],
split_by: Literal["word", "sentence", "passage", None],
split_by: Literal["token", "word", "sentence", "passage", None],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
) -> List[Document]:
raise NotImplementedError

Expand All @@ -57,10 +61,11 @@ def run( # type: ignore
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
):
processed_documents = self.process(
Expand All @@ -83,10 +88,11 @@ def run_batch( # type: ignore
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
):
return self.run(
Expand Down
112 changes: 84 additions & 28 deletions haystack/nodes/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any

import logging
import re
Expand All @@ -17,14 +17,18 @@
from haystack.schema import Document
from haystack.lazy_imports import LazyImport

with LazyImport("Run 'pip install transformers'") as transformers_import:
from transformers import PreTrainedTokenizerBase
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken

logger = logging.getLogger(__name__)

with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
import nltk


iso639_to_nltk = {
"ru": "russian",
"sl": "slovene",
Expand Down Expand Up @@ -55,11 +59,12 @@ def __init__(
clean_header_footer: bool = False,
clean_empty_lines: bool = True,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = "word",
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_respect_sentence_boundary: bool = True,
tokenizer_model_folder: Optional[Union[str, Path]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
language: str = "en",
id_hash_keys: Optional[List[str]] = None,
progress_bar: bool = True,
Expand All @@ -86,6 +91,9 @@ def __init__(
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
to True, the individual split will always have complete sentences &
the number of words will be <= split_length.
:param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken"
(for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased').
HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object.
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
Expand Down Expand Up @@ -124,6 +132,7 @@ def __init__(
self.split_length = split_length
self.split_overlap = split_overlap
self.split_respect_sentence_boundary = split_respect_sentence_boundary
self.tokenizer = tokenizer
self.language = language
self.tokenizer_model_folder = tokenizer_model_folder
self.print_log: Set[str] = set()
Expand All @@ -139,10 +148,11 @@ def process(
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""
Expand All @@ -167,6 +177,7 @@ def process(
"split_length": split_length,
"split_overlap": split_overlap,
"split_respect_sentence_boundary": split_respect_sentence_boundary,
"tokenizer": tokenizer,
}

if id_hash_keys is None:
Expand Down Expand Up @@ -219,10 +230,11 @@ def _process_single(
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
if remove_substrings is None:
Expand All @@ -243,6 +255,8 @@ def _process_single(
split_overlap = self.split_overlap
if split_respect_sentence_boundary is None:
split_respect_sentence_boundary = self.split_respect_sentence_boundary
if tokenizer is None:
tokenizer = self.tokenizer

cleaned_document = self.clean(
document=document,
Expand All @@ -258,6 +272,7 @@ def _process_single(
split_length=split_length,
split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary,
tokenizer=tokenizer,
id_hash_keys=id_hash_keys,
)

Expand Down Expand Up @@ -332,10 +347,11 @@ def clean(
def split(
self,
document: Union[dict, Document],
split_by: Optional[Literal["word", "sentence", "passage"]],
split_by: Optional[Literal["token", "word", "sentence", "passage"]],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
Expand All @@ -359,8 +375,10 @@ def split(
if not split_length:
raise Exception("split_length needs be set when using split_by.")

if split_respect_sentence_boundary and split_by != "word":
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.")
if split_respect_sentence_boundary and split_by not in ["word", "token"]:
raise NotImplementedError(
"'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'."
)

if type(document.content) is not str:
logger.error("Document content is not of type str. Nothing to split.")
Expand All @@ -369,13 +387,17 @@ def split(
text = document.content
headlines = document.meta["headlines"] if "headlines" in document.meta else []

if split_respect_sentence_boundary and split_by == "word":
text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap
if split_respect_sentence_boundary and split_by in ["word", "token"]:

def split_function(text):
return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split()

text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function
)
else:
# create individual "elements" of passage, sentence, or word
elements, split_at = self._split_into_units(text=text, split_by=split_by)
elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer)

# concatenate individual elements based on split_length & split_stride
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
Expand Down Expand Up @@ -467,47 +489,47 @@ def _remove_substring(text: str, substring: str, headlines: List[Dict]) -> Tuple
cleaned_text = text.replace(substring, "")
return cleaned_text, headlines

def _split_by_word_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int
def _split_into_units_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[int], List[int]]:
"""
Splits the text into parts of split_length words while respecting sentence boundaries.
"""
sentences = self._split_sentences(text)

word_count_slice = 0
unit_count_slice = 0
cur_page = 1
cur_start_idx = 0
splits_pages = []
list_splits = []
splits_start_idxs = []
current_slice: List[str] = []
for sen in sentences:
word_count_sen = len(sen.split())
unit_count_sen = len(split_function(sen))

if word_count_sen > split_length:
if unit_count_sen > split_length:
long_sentence_message = (
"We found one or more sentences whose word count is higher than the split length."
"We found one or more sentences whose split count is higher than the split length."
)
if long_sentence_message not in self.print_log:
self.print_log.add(long_sentence_message)
logger.warning(long_sentence_message)

if word_count_slice + word_count_sen > split_length:
if unit_count_slice + unit_count_sen > split_length:
# Number of words exceeds split_length -> save current slice and start a new one
if current_slice:
list_splits.append(current_slice)
splits_pages.append(cur_page)
splits_start_idxs.append(cur_start_idx)

if split_overlap:
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap
processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap, split_function
)
else:
processed_sents = current_slice
current_slice = []
word_count_slice = 0
unit_count_slice = 0

cur_start_idx += len("".join(processed_sents))

Expand All @@ -522,7 +544,7 @@ def _split_by_word_respecting_sent_boundary(
cur_page += num_page_breaks

current_slice.append(sen)
word_count_slice += word_count_sen
unit_count_slice += unit_count_sen

if current_slice:
list_splits.append(current_slice)
Expand All @@ -539,7 +561,7 @@ def _split_by_word_respecting_sent_boundary(

@staticmethod
def _get_overlap_from_slice(
current_slice: List[str], split_length: int, split_overlap: int
current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[str], int]:
"""
Returns a tuple with the following elements:
Expand All @@ -553,7 +575,7 @@ def _get_overlap_from_slice(
current_slice_copy = deepcopy(current_slice)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for idx, s in reversed(list(enumerate(current_slice))[1:]):
sen_len = len(s.split())
sen_len = len(split_function(s))
if word_count_overlap < split_overlap and sen_len < split_length:
overlap.append(s)
word_count_overlap += sen_len
Expand All @@ -566,7 +588,7 @@ def _get_overlap_from_slice(

return processed_sents, next_slice, word_count_slice

def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]:
if split_by == "passage":
elements = text.split("\n\n")
split_at = "\n\n"
Expand All @@ -576,8 +598,13 @@ def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
elif split_by == "word":
elements = text.split(" ")
split_at = " "
elif split_by == "token":
elements = self._split_tokens(text, tokenizer)
split_at = ""
else:
raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.")
raise NotImplementedError(
"PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options."
)

return elements, split_at

Expand Down Expand Up @@ -823,6 +850,35 @@ def _split_sentences(self, text: str) -> List[str]:
sentences = sentence_tokenizer.tokenize(text)
return sentences

def _split_tokens(self, text: str, tokenizer: Any) -> List[str]:
if tokenizer == "tiktoken":
tiktoken_import.check()
enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless
integer_tokens = enc.encode(text, disallowed_special=())
elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens]
return elements
if isinstance(tokenizer, str):
transformers_import.check()
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. "
f"Please make sure that the tokenizer is correct and exists."
)
if isinstance(tokenizer, PreTrainedTokenizerBase):
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False)
elements = []
for i in range(l := len(encoded.offset_mapping)):
start_current = encoded.offset_mapping[i][0]
start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text)
elements.append(text[start_current:start_next])
return elements
raise ValueError(
f"Unsupported tokenizer specification {tokenizer}. "
f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)."
)

def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
# Try to load a custom model from 'tokenizer_model_path'
if self.tokenizer_model_folder is not None:
Expand Down
2 changes: 2 additions & 0 deletions releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
features:
- Add `split_length` by token in PreProcessor
Loading

0 comments on commit a492771

Please sign in to comment.