Skip to content

Commit

Permalink
make punctuation chars customizable
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Feb 5, 2021
1 parent 3edc024 commit 3fe510a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def store_code(run):
labeler = Labeler(
[
SpacySentenceTokenizer(
hparams.spacy_model, lower_start_prob=0.7, remove_end_punct_prob=0.7
hparams.spacy_model, lower_start_prob=0.7, remove_end_punct_prob=0.7, punctuation=".?!"
),
SpacyWordTokenizer(hparams.spacy_model),
WhitespaceTokenizer(),
Expand Down
8 changes: 5 additions & 3 deletions train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
dataset,
remove_end_punct_prob,
lower_start_prob,
punctuation,
lengths=[2, 3, 4],
seed=1234,
):
Expand All @@ -54,7 +55,7 @@ def __init__(
sentence = dataset[i]

if gen.random() < remove_end_punct_prob:
sentence = remove_last_punct(sentence)
sentence = remove_last_punct(sentence, punctuation)

if gen.random() < lower_start_prob:
sentence = sentence[0].lower() + sentence[1:]
Expand Down Expand Up @@ -134,7 +135,8 @@ def split(self, texts):
@click.option("--subtitle_path", help="Path to the OPUS OpenSubtitles raw text.")
@click.option("--spacy_model", help="Name of the spacy model to compare against.")
@click.option("--nnsplit_path", help="Path to the .onnx NNSplit model to use.")
def evaluate(subtitle_path, spacy_model, nnsplit_path):
@click.option("--punctuation", help="Which characters to consider punctuation.", default=".?!")
def evaluate(subtitle_path, spacy_model, nnsplit_path, punctuation):
# nnsplit must be installed to evaluate
from nnsplit import NNSplit

Expand Down Expand Up @@ -164,7 +166,7 @@ def evaluate(subtitle_path, spacy_model, nnsplit_path):

for eval_name, (remove_punct_prob, lower_start_prob) in eval_setups.items():
result[eval_name] = {}
evaluator = Evaluator(dataset, remove_punct_prob, lower_start_prob)
evaluator = Evaluator(dataset, remove_punct_prob, lower_start_prob, punctuation)

for target_name, interface in targets.items():
correct = evaluator.evaluate(interface.split)
Expand Down
10 changes: 6 additions & 4 deletions train/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def tokenize(self, text: str) -> List[str]:
pass


def remove_last_punct(text: str) -> str:
def remove_last_punct(text: str, punctuation) -> str:
for i in range(len(text))[::-1]:
if text[i] in string.punctuation:
if text[i] in punctuation:
return text[:i] + text[i + 1 :]
elif not text[i].isspace():
return text
Expand All @@ -84,13 +84,15 @@ def __init__(
model_name: str,
lower_start_prob: Fraction,
remove_end_punct_prob: Fraction,
punctuation: str,
):
super().__init__()
self.nlp = get_model(model_name)
self.nlp.add_pipe("sentencizer")

self.lower_start_prob = lower_start_prob
self.remove_end_punct_prob = remove_end_punct_prob
self.punctuation = punctuation

def tokenize(self, text: str) -> List[str]:
out_sentences = []
Expand All @@ -106,7 +108,7 @@ def tokenize(self, text: str) -> List[str]:

if end_sentence and not text.isspace():
if self.training and random.random() < self.remove_end_punct_prob:
current_sentence = remove_last_punct(current_sentence)
current_sentence = remove_last_punct(current_sentence, self.punctuation)

out_sentences.append(current_sentence)

Expand Down Expand Up @@ -332,7 +334,7 @@ def visualize(self, text):
labeler = Labeler(
[
SpacySentenceTokenizer(
"de_core_news_sm", lower_start_prob=0.7, remove_end_punct_prob=0.7
"de_core_news_sm", lower_start_prob=0.7, remove_end_punct_prob=0.7, punctuation=".?!"
),
SpacyWordTokenizer("de_core_news_sm"),
WhitespaceTokenizer(),
Expand Down

0 comments on commit 3fe510a

Please sign in to comment.