Skip to content

Commit

Permalink
fix empty string behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 14, 2024
1 parent 16f1a2c commit f09190d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 34 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="wtpsplit",
version="2.1.1",
version="2.1.2",
packages=find_packages(),
description="Universal Robust, Efficient and Adaptable Sentence Segmentation",
author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer",
Expand Down
15 changes: 14 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def test_split_paragraphs():

assert paragraph1.startswith("Text segmentation is")
assert paragraph2.startswith("Daniel Wroughton Craig CMG (born 2 March 1968) is")

def test_split_empty_strings():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

text = " "
splits = sat.split(text)
assert splits == [" "]
text = " \n"
splits = sat.split(text)
assert splits == [" ", ""]
text = ""
splits = sat.split(text)
assert splits == []


def test_split_ort_wtp():
Expand Down Expand Up @@ -229,4 +242,4 @@ def test_split_threshold_wtp():

splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=-1e-3)
# space might still be included in a character split
assert splits[:3] == list("Thi")
assert splits[:3] == list("Thi")
85 changes: 53 additions & 32 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs

__version__ = "2.1.1"
__version__ = "2.1.2"

warnings.simplefilter("default", DeprecationWarning) # show by default
warnings.simplefilter("ignore", category=FutureWarning) # for tranformers
Expand Down Expand Up @@ -233,20 +233,31 @@ def _predict_proba(

input_texts.append(input_text)

outer_batch_logits = extract(
input_texts,
self.model,
lang_code=lang_code,
stride=stride,
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
verbose=verbose,
)[0]
empty_string_indices = [i for i, text in enumerate(input_texts) if not text.strip()]
# remove empty strings from input_texts
input_texts = [text for text in input_texts if text.strip()]

if input_texts:
outer_batch_logits = extract(
input_texts,
self.model,
lang_code=lang_code,
stride=stride,
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
verbose=verbose,
)[0]
else:
outer_batch_logits = []

def newline_probability_fn(logits):
return sigmoid(logits[:, Constants.NEWLINE_INDEX])

# add back empty strings
for i in empty_string_indices:
outer_batch_logits.insert(i, np.ones([1, 1]) * -np.inf)

for i, (text, logits) in enumerate(zip(outer_batch_texts, outer_batch_logits)):
if style is not None:
sentence_probs = clf.predict_proba(logits)[:, 1]
Expand Down Expand Up @@ -635,28 +646,38 @@ def newline_probability_fn(logits):

input_texts.append(input_text)

outer_batch_logits, _, tokenizer, tokenizer_output = extract(
input_texts,
self.model,
stride=stride,
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
verbose=verbose,
tokenizer=self.tokenizer,
)

# convert token probabilities to character probabilities for the entire array
outer_batch_logits = [
token_to_char_probs(
input_texts[i],
tokenizer_output["input_ids"][i],
outer_batch_logits[i],
tokenizer,
tokenizer_output["offset_mapping"][i],
empty_string_indices = [i for i, text in enumerate(input_texts) if not text.strip()]
# remove empty strings from input_texts
input_texts = [text for text in input_texts if text.strip()]
if input_texts:
outer_batch_logits, _, tokenizer, tokenizer_output = extract(
input_texts,
self.model,
stride=stride,
max_block_size=block_size,
batch_size=batch_size,
pad_last_batch=pad_last_batch,
verbose=verbose,
tokenizer=self.tokenizer,
)
for i in range(len(input_texts))
]

# convert token probabilities to character probabilities for the entire array
outer_batch_logits = [
token_to_char_probs(
input_texts[i],
tokenizer_output["input_ids"][i],
outer_batch_logits[i],
tokenizer,
tokenizer_output["offset_mapping"][i],
)
for i in range(len(input_texts))
]
else:
outer_batch_logits = []

# add back empty strings
for i in empty_string_indices:
outer_batch_logits.insert(i, np.ones([1, 1]) * -np.inf)

for i, (text, logits) in enumerate(zip(outer_batch_texts, outer_batch_logits)):
sentence_probs = newline_probs = newline_probability_fn(logits)
Expand Down

0 comments on commit f09190d

Please sign in to comment.