diff --git a/configs/xlmr_stratify_0.1_3layers.json b/configs/xlmr_stratify_0.1_3layers.json index d8425261..3e5e4bed 100644 --- a/configs/xlmr_stratify_0.1_3layers.json +++ b/configs/xlmr_stratify_0.1_3layers.json @@ -37,5 +37,7 @@ "use_auxiliary": true, "ddp_timeout": 3600, "use_subwords": true, - "num_hidden_layers": 3 + "num_hidden_layers": 3, + "custom_punctuation_file": "punctuation_xlmr_unk.txt", + "log_level": "info" } \ No newline at end of file diff --git a/configs/xlmr_stratify_0.1_3layers_nounks.json b/configs/xlmr_stratify_0.1_3layers_nounks.json index 22f78160..f61c50b7 100644 --- a/configs/xlmr_stratify_0.1_3layers_nounks.json +++ b/configs/xlmr_stratify_0.1_3layers_nounks.json @@ -38,5 +38,6 @@ "ddp_timeout": 3600, "use_subwords": true, "num_hidden_layers": 3, - "custom_punctuation_file": "punctuation_xlmr.txt" + "custom_punctuation_file": "punctuation_xlmr.txt", + "log_level": "info" } \ No newline at end of file diff --git a/utils/remove_unks.py b/utils/remove_unks.py index bd8411e8..c7455e20 100644 --- a/utils/remove_unks.py +++ b/utils/remove_unks.py @@ -8,7 +8,7 @@ def get_subword_label_dict(label_args, tokenizer): n_unks = 0 # Map auxiliary characters to token IDs with labels - for i, c in enumerate(label_args.auxiliary_chars): + for i, c in enumerate(Constants.PUNCTUATION_CHARS): token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.AUX_OFFSET + i # TODO: remove UNKs? @@ -33,7 +33,7 @@ def get_subword_label_dict(label_args, tokenizer): tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) -label_dict = get_subword_label_dict(LabelArgs(custom_punctuation_file='punctuation_xlmr.txt'), tokenizer) +label_dict = get_subword_label_dict(LabelArgs(), tokenizer) print(len(label_dict)) def write_punctuation_file(): @@ -42,8 +42,21 @@ def write_punctuation_file(): token_id = tokenizer.convert_tokens_to_ids(char) if token_id != tokenizer.unk_token_id: file.write(char + '\n') + +def write_punctuation_file_unk(): + added_unk = False + with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr_unk.txt"), 'w', encoding='utf-8') as file: + for char in Constants.PUNCTUATION_CHARS: + token_id = tokenizer.convert_tokens_to_ids(char) + if token_id != tokenizer.unk_token_id: + file.write(char + '\n') + elif not added_unk: + print("added unk") + file.write('\n') + added_unk = True write_punctuation_file() +write_punctuation_file_unk() label_args_default = LabelArgs() print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS)) diff --git a/wtpsplit/data/punctuation_xlmr_unk.txt b/wtpsplit/data/punctuation_xlmr_unk.txt new file mode 100644 index 00000000..916398b3 --- /dev/null +++ b/wtpsplit/data/punctuation_xlmr_unk.txt @@ -0,0 +1,109 @@ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +: +; +< += +> +? +@ +[ +\ +] +^ +_ +` +{ +| +} +~ +¡ +£ +¤ +§ + +© +« +¬ +® +° +± +· +» +¿ +÷ +՛ +՝ +՞ +։ +־ +׳ +، +؛ +؟ +۔ +। +॥ +၊ +။ +၌ +၍ +၎ +၏ +፡ +። +፣ +፤ +፥ +។ +៕ +៖ +– +— +‘ +’ +‚ +“ +” +„ +• +′ +‹ +› +€ +↑ +→ +⇌ +∑ +√ +╛ +□ +▬ +☎ +➖ +、 +。 +《 +》 +「 +」 +『 +』 +【 +】 +・ +~ +💘 diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index e13b5a2b..938c5db4 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -238,17 +238,6 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): tokenizer=tokenizer if args.use_subwords else None, ) - if input_ids[0] != tokenizer.cls_token_id: - logger.warn(input_ids) - logger.warn(len(input_ids)) - logger.warn(tokenizer.cls_token_id) - # raise ValueError("CLS token not first token") - if input_ids[-1] != tokenizer.sep_token_id: - logger.warn(input_ids) - logger.warn(len(input_ids)) - logger.warn(tokenizer.sep_token_id) - # raise ValueError("SEP token not last token") - if len(input_ids) > args.block_size: if tokenizer: # always include CLS @@ -264,7 +253,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): # always include SEP if input_ids[-1] != tokenizer.sep_token_id: # also insert PAD token as long as len < block_size - while len(input_ids) <= args.block_size - 1: + while len(input_ids) < args.block_size - 1: input_ids = input_ids + [tokenizer.pad_token_id] labels = labels + [0] input_ids = input_ids + [tokenizer.sep_token_id] @@ -273,34 +262,33 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): start = np.random.randint(0, len(input_ids) - args.block_size) input_ids = input_ids[start : start + args.block_size] labels = labels[start : start + args.block_size] - elif len(input_ids) != args.block_size and args.use_subwords: + if len(input_ids) != args.block_size and tokenizer: del input_ids[-1] del labels[-1] - while len(input_ids) <= args.block_size - 1: + while len(input_ids) < args.block_size - 1: # insert pad token at second-to-last position - logger.debug("second", len(input_ids)) + logger.warn("second", len(input_ids)) input_ids = input_ids + [tokenizer.pad_token_id] labels = labels + [0] input_ids = input_ids + [tokenizer.sep_token_id] labels = labels + [0] + if len(input_ids) != args.block_size: + logger.warn(len(input_ids)) input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long) labels = torch.tensor(labels[: args.block_size], dtype=torch.long) if input_ids[-1] != tokenizer.sep_token_id: logger.warn(input_ids) logger.warn(tokenizer.sep_token_id) - logger.warn(labels) # raise ValueError("SEP token not last token") if input_ids[0] != tokenizer.cls_token_id: logger.warn(input_ids) logger.warn(tokenizer.cls_token_id) - logger.warn(labels) # raise ValueError("CLS token not first token") # FIXME: check this - why does it occur in train split? if (input_ids == tokenizer.cls_token_id).sum() != 1: logger.warn(input_ids) logger.warn(tokenizer.cls_token_id) - logger.warn(labels) # raise ValueError("CLS token not unique") position_ids = torch.arange(len(input_ids), dtype=torch.long) @@ -657,7 +645,7 @@ def maybe_pad(text): num_workers=args.preprocessing_num_workers, include_languages=args.include_languages, shuffle=args.shuffle, - split="train", + split="valid", ) logger.info(f"Train dataset has {len(train_dataset)} examples.") @@ -671,7 +659,6 @@ def maybe_pad(text): logger.info(f"Sample {index} of the training set: {sample}.") if tokenizer: logger.info(tokenizer.decode(sample["input_ids"])) - logger.info() count += 1 # dataset we use is in cached now diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index e6fc8f53..4bd3c38a 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -86,7 +86,7 @@ def __post_init__(self): def get_label_dict(label_args): label_dict = {} - for i, c in enumerate(label_args.auxiliary_chars): + for i, c in enumerate(Constants.PUNCTUATION_CHARS): label_dict[ord(c)] = 1 + Constants.AUX_OFFSET + i for c in label_args.newline_chars: @@ -99,10 +99,11 @@ def get_subword_label_dict(label_args, tokenizer): n_unks = 0 # Map auxiliary characters to token IDs with labels - for i, c in enumerate(label_args.auxiliary_chars): + logger.warn(f"Using {Constants.PUNCTUATION_CHARS} auxiliary characters.") + for i, c in enumerate(Constants.PUNCTUATION_CHARS): token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.AUX_OFFSET + i - logger.info(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}") + logger.warn(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}") if token_id == tokenizer.unk_token_id: n_unks += 1 @@ -112,8 +113,8 @@ def get_subword_label_dict(label_args, tokenizer): for c in label_args.newline_chars: token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.NEWLINE_INDEX - logger.info(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") - logger.info(r"{}".format(tokenizer.decode([token_id]))) + logger.warn(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") + logger.warn(r"{}".format(tokenizer.decode([token_id]))) return label_dict