Skip to content

Commit

Permalink
add basic Case Corruption
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 10, 2024
1 parent c931647 commit 9eaf51a
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 2 deletions.
45 changes: 45 additions & 0 deletions configs/xlmr_stratify_0.1_3layers_p_v2_0.5Case.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-p-v2-Case0.5",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 200000,
"save_steps": 100000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"aux_training_weight": 1.0,
"case_corruption_prob": 0.5,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning"
}
45 changes: 45 additions & 0 deletions configs/xlmr_stratify_0.1_3layers_p_v2_0.9Case.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-p-v2-Case0.9",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 200000,
"save_steps": 100000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"aux_training_weight": 1.0,
"case_corruption_prob": 0.9,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning"
}
80 changes: 78 additions & 2 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class LabelArgs:
custom_punctuation_file: str = None
retain_first_consecutive_punctuation: bool = True
non_whitespace_remove_spaces: bool = True
case_corruption_prob: float = 0.5

def __post_init__(self):
if self.custom_punctuation_file:
Expand Down Expand Up @@ -219,7 +220,12 @@ def corrupt(
labels.insert(last_index_in_block, 0)
else:
del block_ids[i + 1]
if tokenizer and separator == "" and label_args.non_whitespace_remove_spaces and i + 1 < len(input_ids):
if (
tokenizer
and separator == ""
and label_args.non_whitespace_remove_spaces
and i + 1 < len(input_ids)
):
# tokenizer.decode() retains the space that leaks the information
# so we need to get the position within the tokenized text and then remove the space
# (so there is no more space when fed into the tokenizer call)
Expand Down Expand Up @@ -251,12 +257,47 @@ def corrupt(
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
if random.random() < label_args.case_corruption_prob and i + 1 < len(input_ids):
if not tokenizer:
raise NotImplementedError()
# corrupt case
token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
insert_ = False
if token.startswith("▁"):
insert_ = True
token = token[1:]
if token.istitle():
token = token.lower()
# re-tokenize
# token_ids = tokenizer.convert_tokens_to_ids(token if not insert_ else "▁" + token)
token_ids = tokenizer(token if not insert_ else "▁" + token, add_special_tokens=False)["input_ids"]
if len(token_ids) == 0 or input_ids[i + 1] == tokenizer.unk_token_id:
# UNK or whitespace token, remove it
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
else:
if token_ids[0] == tokenizer.convert_tokens_to_ids("▁"):
token_ids = token_ids[1:]
elif len(token_ids) > 1:
# replace the token with the remaining token
input_ids[i + 1] = token_ids[0]
for token_id in token_ids[1:]:
input_ids.insert(i + 2, token_id)
labels.insert(i + 2, 0)
block_ids.insert(i + 2, block_ids[i + 1])
elif len(token_ids) == 1:
input_ids[i + 1] = token_ids[0]
else:
print(token, token_ids, input_ids[i + 1], tokenizer.decode(input_ids[i + 1]))


elif label_args.use_auxiliary and labels[i] > Constants.AUX_OFFSET: # auxiliary
if pack_samples:
raise NotImplementedError()

if random.random() < label_args.auxiliary_remove_prob:
if random.random() < label_args.auxiliary_remove_prob:
removed_aux_char = False
if label_args.retain_first_consecutive_punctuation:
# remove only if the next token is not a newline
# this retains the current auxiliary character, even though we decided to remove it
Expand All @@ -265,12 +306,47 @@ def corrupt(
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
removed_aux_char = True
else:
# in case of something like ".\n", this removes the "." and the \n label (=1)
# so the newline in the text is kept, but the label is removed!
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
removed_aux_char = True
if random.random() < label_args.case_corruption_prob and removed_aux_char and i + 1 < len(input_ids):
if not tokenizer:
raise NotImplementedError()
# corrupt case
token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
insert_ = False
if token.startswith("▁"):
insert_ = True
token = token[1:]
if token.istitle():
token = token.lower()
# re-tokenize
# token_ids = tokenizer.convert_tokens_to_ids(token if not insert_ else "▁" + token)
token_ids = tokenizer(token if not insert_ else "▁" + token, add_special_tokens=False)["input_ids"]
if len(token_ids) == 0 or input_ids[i + 1] == tokenizer.unk_token_id:
# UNK or whitespace token, remove it
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
else:
if token_ids[0] == tokenizer.convert_tokens_to_ids("▁"):
token_ids = token_ids[1:]
elif len(token_ids) > 1:
# replace the token with the remaining token
input_ids[i + 1] = token_ids[0]
for token_id in token_ids[1:]:
input_ids.insert(i + 2, token_id)
labels.insert(i + 2, 0)
block_ids.insert(i + 2, block_ids[i + 1])
elif len(token_ids) == 1:
input_ids[i + 1] = token_ids[0]
else:
print(token, token_ids, input_ids[i + 1], tokenizer.decode(input_ids[i + 1]))

try:
i = i + 1 + next(index for index, label in enumerate(labels[i + 1 :]) if label != 0)
Expand Down

0 comments on commit 9eaf51a

Please sign in to comment.