Skip to content

Commit

Permalink
add logging, no unk aux chars
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 27, 2023
1 parent d55d547 commit 469dd13
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 51 deletions.
42 changes: 42 additions & 0 deletions configs/xlmr_stratify_0.1_3layers_nounks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-no_unks",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 2000000,
"save_steps": 100000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr.txt"
}
52 changes: 52 additions & 0 deletions utils/remove_unks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
from transformers import AutoTokenizer
from tokenizers import AddedToken
from wtpsplit.utils import Constants, LabelArgs

def get_subword_label_dict(label_args, tokenizer):
label_dict = {}

n_unks = 0
# Map auxiliary characters to token IDs with labels
for i, c in enumerate(label_args.auxiliary_chars):
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.AUX_OFFSET + i
# TODO: remove UNKs?
print(
f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}"
)
if token_id == tokenizer.unk_token_id:
n_unks += 1

print(f"found {n_unks} UNK tokens in auxiliary characters")

# Map newline characters to token IDs with labels
for c in label_args.newline_chars:
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.NEWLINE_INDEX
print(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
print(r"{}".format(tokenizer.decode([token_id])))

return label_dict


tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})

label_dict = get_subword_label_dict(LabelArgs(custom_punctuation_file='punctuation_xlmr.txt'), tokenizer)
print(len(label_dict))

def write_punctuation_file():
with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr.txt"), 'w', encoding='utf-8') as file:
for char in Constants.PUNCTUATION_CHARS:
token_id = tokenizer.convert_tokens_to_ids(char)
if token_id != tokenizer.unk_token_id:
file.write(char + '\n')

write_punctuation_file()

label_args_default = LabelArgs()
print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS))

label_args_custom = LabelArgs(custom_punctuation_file='punctuation_xlmr.txt')
print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS))
4 changes: 3 additions & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import sys
import logging

import numpy as np
from tqdm.auto import tqdm
Expand All @@ -8,6 +9,7 @@

from wtpsplit.utils import Constants, hash_encode

logger = logging.getLogger(__name__)

class ORTWrapper:
def __init__(self, config, ort_session):
Expand Down Expand Up @@ -222,7 +224,7 @@ def extract(
)["logits"]
if use_subwords:
logits = logits[:, 1:-1, :] # remove CLS and SEP tokens
print(np.max(logits[0, :, 0]))
logger.debug(np.max(logits[0, :, 0]))

for i in range(start, end):
original_idx, start_char_idx, end_char_idx = locs[i]
Expand Down
108 changes: 108 additions & 0 deletions wtpsplit/punctuation_xlmr.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
¡
£
¤
§
©
«
¬
®
°
±
·
»
¿
÷
՛
՝
՞
։
־
׳
،
؛
؟
۔
💘
Loading

0 comments on commit 469dd13

Please sign in to comment.