diff --git a/configs/xlmr_stratify_0.1_3layers_nounks.json b/configs/xlmr_stratify_0.1_3layers_nounks.json new file mode 100644 index 00000000..22f78160 --- /dev/null +++ b/configs/xlmr_stratify_0.1_3layers_nounks.json @@ -0,0 +1,42 @@ +{ + "model_name_or_path": "xlm-roberta-base", + "output_dir": "xlmr-normal-no_unks", + "train_text_path": "data/sentence/train.parquet", + "valid_text_path": "data/sentence/valid.parquet", + "block_size": 512, + "use_bert": true, + "do_train": true, + "do_eval": true, + "evaluation_strategy": "steps", + "per_device_train_batch_size": 32, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 2, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 4, + "preprocessing_num_workers": 32, + "learning_rate": 1e-4, + "save_strategy": "steps", + "fp16": false, + "max_steps": 2000000, + "save_steps": 100000, + "eval_steps": 5000, + "logging_steps": 50, + "report_to": "wandb", + "is_decoder": false, + "remove_unused_columns": false, + "lookahead": null, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": true, + "warmup_steps": 5000, + "adapter_warmup_steps": 0, + "adapter_lr_multiplier": 1, + "ngram_order": 1, + "non_punctuation_sample_ratio": 0.1, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "num_hidden_layers": 3, + "custom_punctuation_file": "punctuation_xlmr.txt" +} \ No newline at end of file diff --git a/utils/remove_unks.py b/utils/remove_unks.py new file mode 100644 index 00000000..bd8411e8 --- /dev/null +++ b/utils/remove_unks.py @@ -0,0 +1,52 @@ +import os +from transformers import AutoTokenizer +from tokenizers import AddedToken +from wtpsplit.utils import Constants, LabelArgs + +def get_subword_label_dict(label_args, tokenizer): + label_dict = {} + + n_unks = 0 + # Map auxiliary characters to token IDs with labels + for i, c in enumerate(label_args.auxiliary_chars): + token_id = tokenizer.convert_tokens_to_ids(c) + label_dict[token_id] = 1 + Constants.AUX_OFFSET + i + # TODO: remove UNKs? + print( + f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}" + ) + if token_id == tokenizer.unk_token_id: + n_unks += 1 + + print(f"found {n_unks} UNK tokens in auxiliary characters") + + # Map newline characters to token IDs with labels + for c in label_args.newline_chars: + token_id = tokenizer.convert_tokens_to_ids(c) + label_dict[token_id] = 1 + Constants.NEWLINE_INDEX + print(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") + print(r"{}".format(tokenizer.decode([token_id]))) + + return label_dict + + +tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") +tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) + +label_dict = get_subword_label_dict(LabelArgs(custom_punctuation_file='punctuation_xlmr.txt'), tokenizer) +print(len(label_dict)) + +def write_punctuation_file(): + with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr.txt"), 'w', encoding='utf-8') as file: + for char in Constants.PUNCTUATION_CHARS: + token_id = tokenizer.convert_tokens_to_ids(char) + if token_id != tokenizer.unk_token_id: + file.write(char + '\n') + +write_punctuation_file() + +label_args_default = LabelArgs() +print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS)) + +label_args_custom = LabelArgs(custom_punctuation_file='punctuation_xlmr.txt') +print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS)) \ No newline at end of file diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 662ffb82..6e061983 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -1,5 +1,6 @@ import math import sys +import logging import numpy as np from tqdm.auto import tqdm @@ -8,6 +9,7 @@ from wtpsplit.utils import Constants, hash_encode +logger = logging.getLogger(__name__) class ORTWrapper: def __init__(self, config, ort_session): @@ -222,7 +224,7 @@ def extract( )["logits"] if use_subwords: logits = logits[:, 1:-1, :] # remove CLS and SEP tokens - print(np.max(logits[0, :, 0])) + logger.debug(np.max(logits[0, :, 0])) for i in range(start, end): original_idx, start_char_idx, end_char_idx = locs[i] diff --git a/wtpsplit/punctuation_xlmr.txt b/wtpsplit/punctuation_xlmr.txt new file mode 100644 index 00000000..93cf26ab --- /dev/null +++ b/wtpsplit/punctuation_xlmr.txt @@ -0,0 +1,108 @@ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +: +; +< += +> +? +@ +[ +\ +] +^ +_ +` +{ +| +} +~ +¡ +£ +¤ +§ +© +« +¬ +® +° +± +· +» +¿ +÷ +՛ +՝ +՞ +։ +־ +׳ +، +؛ +؟ +۔ +। +॥ +၊ +။ +၌ +၍ +၎ +၏ +፡ +። +፣ +፤ +፥ +។ +៕ +៖ +– +— +‘ +’ +‚ +“ +” +„ +• +′ +‹ +› +€ +↑ +→ +⇌ +∑ +√ +╛ +□ +▬ +☎ +➖ +、 +。 +《 +》 +「 +」 +『 +』 +【 +】 +・ +~ +💘 diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 6c25ea9d..e13b5a2b 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -1,3 +1,4 @@ +import logging import math import os import sys @@ -18,6 +19,8 @@ from transformers import HfArgumentParser, TrainingArguments, AutoTokenizer, set_seed from torchinfo import summary from tokenizers import AddedToken +import transformers +import datasets from wtpsplit.models import ( BertCharConfig, @@ -31,13 +34,32 @@ from wtpsplit.train.trainer import Trainer from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict -# TODO: set logger (see ScaLearn?) -# TODO: double-check checkpointing and saving (also to txt) +logger = logging.getLogger(__name__) + -os.environ["PJRT_DEVICE"] = "None" +# TODO: double-check checkpointing and saving (also to txt) +# os.environ["PJRT_DEVICE"] = "None" +def setup_logging(training_args: transformers.TrainingArguments) -> None: + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") class Model(nn.Module): def __init__( self, @@ -216,16 +238,16 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): tokenizer=tokenizer if args.use_subwords else None, ) - # if input_ids[0] != tokenizer.cls_token_id: - # print(input_ids) - # print(len(input_ids)) - # print(tokenizer.cls_token_id) - # raise ValueError("CLS token not first token") - # if input_ids[-1] != tokenizer.sep_token_id: - # print(input_ids) - # print(len(input_ids)) - # print(tokenizer.sep_token_id) - # raise ValueError("SEP token not last token") + if input_ids[0] != tokenizer.cls_token_id: + logger.warn(input_ids) + logger.warn(len(input_ids)) + logger.warn(tokenizer.cls_token_id) + # raise ValueError("CLS token not first token") + if input_ids[-1] != tokenizer.sep_token_id: + logger.warn(input_ids) + logger.warn(len(input_ids)) + logger.warn(tokenizer.sep_token_id) + # raise ValueError("SEP token not last token") if len(input_ids) > args.block_size: if tokenizer: @@ -243,7 +265,6 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): if input_ids[-1] != tokenizer.sep_token_id: # also insert PAD token as long as len < block_size while len(input_ids) <= args.block_size - 1: - print("first", len(input_ids)) input_ids = input_ids + [tokenizer.pad_token_id] labels = labels + [0] input_ids = input_ids + [tokenizer.sep_token_id] @@ -257,7 +278,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): del labels[-1] while len(input_ids) <= args.block_size - 1: # insert pad token at second-to-last position - print("second", len(input_ids)) + logger.debug("second", len(input_ids)) input_ids = input_ids + [tokenizer.pad_token_id] labels = labels + [0] input_ids = input_ids + [tokenizer.sep_token_id] @@ -265,22 +286,22 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long) labels = torch.tensor(labels[: args.block_size], dtype=torch.long) - # if input_ids[-1] != tokenizer.sep_token_id: - # print(input_ids) - # print(tokenizer.sep_token_id) - # print(labels) - # raise ValueError("SEP token not last token") - # if input_ids[0] != tokenizer.cls_token_id: - # print(input_ids) - # print(tokenizer.cls_token_id) - # print(labels) - # raise ValueError("CLS token not first token") - # TODO: check this - why does it occur in train split? - # if (input_ids == tokenizer.cls_token_id).sum() != 1: - # print(input_ids) - # print(tokenizer.cls_token_id) - # print(labels) - # raise ValueError("CLS token not unique") + if input_ids[-1] != tokenizer.sep_token_id: + logger.warn(input_ids) + logger.warn(tokenizer.sep_token_id) + logger.warn(labels) + # raise ValueError("SEP token not last token") + if input_ids[0] != tokenizer.cls_token_id: + logger.warn(input_ids) + logger.warn(tokenizer.cls_token_id) + logger.warn(labels) + # raise ValueError("CLS token not first token") + # FIXME: check this - why does it occur in train split? + if (input_ids == tokenizer.cls_token_id).sum() != 1: + logger.warn(input_ids) + logger.warn(tokenizer.cls_token_id) + logger.warn(labels) + # raise ValueError("CLS token not unique") position_ids = torch.arange(len(input_ids), dtype=torch.long) label_weights = torch.ones(args.block_size, dtype=torch.float32) @@ -318,7 +339,8 @@ def main(): else: (args, training_args, label_args) = parser.parse_args_into_dataclasses() wandb_name = None - + + setup_logging(training_args) set_seed(training_args.seed) num_labels = Constants.AUX_OFFSET + ((1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training else 0) @@ -342,7 +364,6 @@ def main(): backbone.config.base_model = args.model_name_or_path tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) # needed since we create labels in collate_fn based on tokens - # TODO: problematic for tokens! tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) else: @@ -383,7 +404,7 @@ def main(): ) with training_args.main_process_first(): - print(summary(model, depth=4)) + logger.info(summary(model, depth=4)) # backbone.push_to_hub("markus583/xlm-token-untrained", private=True) def prepare_dataset( @@ -395,6 +416,7 @@ def prepare_dataset( with training_args.main_process_first(): dlconf = DownloadConfig(cache_dir="/home/Markus/.cache/huggingface/datasets") dataset = load_dataset("markus583/mC4-TEST", split=split, download_config=dlconf) + logger.info(f"Loaded {split} dataset.") # optional: delete downloaded dataset, it is stored in cache_dir now (but we delete it later) # ~40GB on disk # os.system("rm -rf /home/Markus/.cache/huggingface/datasets") @@ -406,9 +428,11 @@ def prepare_dataset( lambda example: example["lang"] in include_languages, num_proc=args.preprocessing_num_workers, ) + logger.info(f"Filtered to {len(dataset)} examples.") if shuffle: dataset = dataset.shuffle(seed=42) + logger.info(f"Shuffled dataset.") # very likely not relevant / used only for the compound part if args.ignore_non_hyphen: @@ -417,6 +441,7 @@ def prepare_dataset( lambda sample: any(c in sample[args.text_column] for c in label_args.hyphen_chars), num_proc=args.preprocessing_num_workers, ) + logger.info(f"Filtered to {len(dataset)} examples.") # "punctuation-specific sampling" in the paper if args.non_punctuation_sample_ratio is not None: @@ -596,13 +621,14 @@ def maybe_pad(text): # this is no longer used and would cause an error otherwise with training_args.main_process_first(): dataset = dataset.remove_columns([args.text_column]) - + logger.info(f"Tokenized {split} dataset.") + if split == "train": with training_args.main_process_first(): for root, dirs, files in os.walk(os.environ.get("HF_DATASETS_CACHE")): for file in files: if file.startswith("m_c4-test-train"): - print(f"Removing {os.path.join(root, file)}") + logger.info(f"Removing {os.path.join(root, file)}") os.remove(os.path.join(root, file)) if not args.one_sample_per_line: @@ -616,6 +642,7 @@ def maybe_pad(text): if args.text_column == "text" else [], ) + logger.info(f"Grouped {split} dataset.") return dataset @@ -625,12 +652,14 @@ def maybe_pad(text): shuffle=False, split="valid", ) + logger.info(f"Valid dataset has {len(valid_dataset)} examples.") train_dataset = prepare_dataset( num_workers=args.preprocessing_num_workers, include_languages=args.include_languages, shuffle=args.shuffle, split="train", ) + logger.info(f"Train dataset has {len(train_dataset)} examples.") # print some samples from the dataset count = 0 @@ -639,10 +668,10 @@ def maybe_pad(text): sample = train_dataset[index] if sample.get('lang') == "de": - print(f"Sample {index} of the training set: {sample}.") + logger.info(f"Sample {index} of the training set: {sample}.") if tokenizer: - print(tokenizer.decode(sample["input_ids"])) - print() + logger.info(tokenizer.decode(sample["input_ids"])) + logger.info() count += 1 # dataset we use is in cached now @@ -700,6 +729,7 @@ def compute_metrics(trainer): wandb.save(os.path.abspath(file), policy="now") label_dict = get_subword_label_dict(label_args, tokenizer) if args.use_subwords else get_label_dict(label_args) + logger.info(f"Label dict has {len(label_dict)} entries.") # needed in the trainer training_args.adapter_warmup_steps = args.adapter_warmup_steps diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index a731c376..e6fc8f53 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -5,6 +5,7 @@ from cached_property import cached_property from pathlib import Path from typing import List +import logging import numpy as np import pandas as pd @@ -12,10 +13,17 @@ # same as in CANINE PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223] +logger = logging.getLogger(__name__) class ConstantsClass: NEWLINE_INDEX = 0 AUX_OFFSET = 1 + DEFAULT_PUNCTUATION_FILE = "punctuation.txt" + _PUNCTUATION_FILE = "punctuation.txt" + + @classmethod + def set_punctuation_file(cls, file_name): + cls._PUNCTUATION_FILE = file_name @cached_property def ROOT_DIR(self): @@ -25,16 +33,19 @@ def ROOT_DIR(self): def CACHE_DIR(self): CACHE_DIR = self.ROOT_DIR / ".cache" CACHE_DIR.mkdir(exist_ok=True) - return CACHE_DIR @cached_property def LANGINFO(self): return pd.read_csv(os.path.join(self.ROOT_DIR, "data", "language_info.csv"), index_col=0) - @cached_property + @property def PUNCTUATION_CHARS(self): - return [x.strip() for x in open(os.path.join(self.ROOT_DIR, "data", "punctuation.txt")).readlines()] + punctuation_path = os.path.join(self.ROOT_DIR, "data", self._PUNCTUATION_FILE) + if os.path.exists(punctuation_path): + return [x.strip() for x in open(punctuation_path).readlines()] + else: + raise FileNotFoundError(f"The file {punctuation_path} does not exist.") @cached_property def PUNCTUATION_MAP(self): @@ -42,13 +53,12 @@ def PUNCTUATION_MAP(self): @cached_property def LANG_CODE_TO_INDEX(self): - return {lang: i for i, lang in enumerate(Constants.LANGINFO.index)} + return {lang: i for i, lang in enumerate(self.LANGINFO.index)} @cached_property def SEPARATORS(self): return {lang: ("" if row["no_whitespace"] else " ") for lang, row in Constants.LANGINFO.iterrows()} - Constants = ConstantsClass() @@ -60,9 +70,17 @@ class LabelArgs: newline_whitespace_prob: float = 0.99 hyphen_smooth_prob: float = 0.9 newline_chars: List[str] = field(default_factory=lambda: ["\n"]) - auxiliary_chars: List[str] = field(default_factory=lambda: Constants.PUNCTUATION_CHARS.copy()) + auxiliary_chars: List[str] = field(default_factory=lambda: []) hyphen_chars: List[str] = field(default_factory=lambda: ["-", "‐"]) use_auxiliary: bool = False + custom_punctuation_file: str = None + + def __post_init__(self): + if self.custom_punctuation_file: + Constants.set_punctuation_file(self.custom_punctuation_file) + else: + Constants.set_punctuation_file("punctuation.txt") + self.auxiliary_chars = Constants.DEFAULT_PUNCTUATION_FILE def get_label_dict(label_args): @@ -84,19 +102,18 @@ def get_subword_label_dict(label_args, tokenizer): for i, c in enumerate(label_args.auxiliary_chars): token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.AUX_OFFSET + i - # TODO: remove UNKs? - print(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}") + logger.info(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}") if token_id == tokenizer.unk_token_id: n_unks += 1 - print(f"found {n_unks} UNK tokens in auxiliary characters") + logger.warn(f"found {n_unks} UNK tokens in auxiliary characters") # Map newline characters to token IDs with labels for c in label_args.newline_chars: token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.NEWLINE_INDEX - print(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") - print(r"{}".format(tokenizer.decode([token_id]))) + logger.info(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") + logger.info(r"{}".format(tokenizer.decode([token_id]))) return label_dict @@ -289,3 +306,4 @@ def reconstruct_sentences(text, partial_sentences): fixed_sentences.append(text[i:]) return fixed_sentences +