From fbc2c5d535704f4521881c5e068f76f365c8b746 Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 2 Jan 2024 07:43:15 +0000 Subject: [PATCH] fix model saving during training --- configs/xlmr_stratify_0.1_3layers_100k.json | 43 +++++++ wtpsplit/train/train.py | 114 +------------------ wtpsplit/train/trainer.py | 39 +++++++ wtpsplit/train/utils.py | 117 +++++++++++++++++++- 4 files changed, 197 insertions(+), 116 deletions(-) create mode 100644 configs/xlmr_stratify_0.1_3layers_100k.json diff --git a/configs/xlmr_stratify_0.1_3layers_100k.json b/configs/xlmr_stratify_0.1_3layers_100k.json new file mode 100644 index 00000000..4ad3fb2e --- /dev/null +++ b/configs/xlmr_stratify_0.1_3layers_100k.json @@ -0,0 +1,43 @@ +{ + "model_name_or_path": "xlm-roberta-base", + "output_dir": "xlmr-normal-100k", + "train_text_path": "data/sentence/train.parquet", + "valid_text_path": "data/sentence/valid.parquet", + "block_size": 512, + "use_bert": true, + "do_train": true, + "do_eval": true, + "evaluation_strategy": "steps", + "per_device_train_batch_size": 32, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 2, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 4, + "preprocessing_num_workers": 32, + "learning_rate": 1e-4, + "save_strategy": "steps", + "fp16": false, + "max_steps": 100000, + "save_steps": 50000, + "eval_steps": 5000, + "logging_steps": 50, + "report_to": "wandb", + "is_decoder": false, + "remove_unused_columns": false, + "lookahead": null, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": true, + "warmup_steps": 5000, + "adapter_warmup_steps": 0, + "adapter_lr_multiplier": 1, + "ngram_order": 1, + "non_punctuation_sample_ratio": 0.1, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "num_hidden_layers": 3, + "custom_punctuation_file": "punctuation_xlmr_unk.txt", + "log_level": "info" +} \ No newline at end of file diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index d3cea66e..d2eb838e 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -17,7 +17,6 @@ from datasets import load_dataset from datasets.download import DownloadConfig from tokenizers import AddedToken -from torch import nn from torchinfo import summary from tqdm.auto import tqdm from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed @@ -33,8 +32,8 @@ ) from wtpsplit.train.evaluate import evaluate_sentence from wtpsplit.train.trainer import Trainer -from wtpsplit.train.utils import cleanup_cache_files from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict +from wtpsplit.train.utils import Model, cleanup_cache_files logger = logging.getLogger(__name__) @@ -66,111 +65,6 @@ def setup_logging(training_args: transformers.TrainingArguments) -> None: # logger.info(f"Training/evaluation parameters {training_args}") -class Model(nn.Module): - def __init__( - self, - backbone, - loss_margin=0.5, - use_loss_weights=False, - do_sentence_training=True, - do_auxiliary_training=False, - ): - super().__init__() - self.backbone = backbone - self.config = self.backbone.config - - assert loss_margin <= 0.5 - - self.loss_margin = loss_margin - self.use_loss_weights = use_loss_weights - self.do_sentence_training = do_sentence_training - self.do_auxiliary_training = do_auxiliary_training - - @property - def device(self): - return self.backbone.device - - def forward( - self, - input_ids, - language_ids=None, - attention_mask=None, - position_ids=None, - labels=None, - label_weights=None, - **kwargs, - ): - if position_ids is not None: - reduced_attention_mask = (input_ids != 0).to(torch.long) - else: - # XXX: 1 is pad token id - reduced_attention_mask = (input_ids != 1).to(torch.long) - - output = dict( - self.backbone.forward( - input_ids=input_ids, - language_ids=language_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **kwargs, - ) - ) - logits = output["logits"] - - if labels is not None: - loss_fn = nn.BCEWithLogitsLoss(reduction="none") - - losses = [] - - # main (newline prediction) objective - if self.do_sentence_training: - # label smoothing - sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to( - logits.dtype - ).view(-1) * self.loss_margin * 2 - sentence_logits = logits[:, :, Constants.NEWLINE_INDEX].view(-1) - - losses.append( - ( - loss_fn( - sentence_logits, - sentence_labels, - ) - * (label_weights.view(-1) if label_weights is not None and self.use_loss_weights else 1) - * reduced_attention_mask.view(-1) - ).sum() - / reduced_attention_mask.sum() - ) - - # auxiliary (punctuation prediction) objective - if self.do_auxiliary_training: - loss_fn = nn.CrossEntropyLoss() - - # exclude newline and no labels - aux_labels = torch.where( - (labels == 0) | (labels == Constants.NEWLINE_INDEX + 1), - 0, - labels - Constants.AUX_OFFSET, - ) - # exclude reduced_attention_mask tokens from labels - aux_labels = torch.where( - reduced_attention_mask == 1, - aux_labels, - loss_fn.ignore_index, - ) - - losses.append( - loss_fn( - logits[:, :, Constants.AUX_OFFSET :].view(-1, self.config.num_labels - Constants.AUX_OFFSET), - aux_labels.view(-1), - ) - ) - - loss = torch.stack(losses).sum() - - output["loss"] = loss - - return output @dataclass @@ -347,7 +241,7 @@ def main(): num_labels = Constants.AUX_OFFSET + ((1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training else 0) if args.use_subwords: if args.from_scratch: - config = SubwordXLMConfig.from_pretrained( + config = SubwordXLMConfig( args.model_name_or_path, num_hidden_layers=args.num_hidden_layers, num_labels=num_labels, @@ -408,7 +302,7 @@ def main(): do_auxiliary_training=args.do_auxiliary_training, ) - with training_args.main_process_first(): + if training_args.local_rank == 0: logger.info(summary(model, depth=4)) # backbone.push_to_hub("markus583/xlm-token-untrained", private=True) @@ -738,7 +632,7 @@ def compute_metrics(trainer): # because that would remove the cache files of the other dataset! cleanup_cache_files([train_dataset, valid_dataset]) logger.warning("Cleaned up cache files.") - time.sleep(20) + time.sleep(10) trainer = Trainer( model, diff --git a/wtpsplit/train/trainer.py b/wtpsplit/train/trainer.py index adf8fad8..dfce7909 100644 --- a/wtpsplit/train/trainer.py +++ b/wtpsplit/train/trainer.py @@ -1,3 +1,4 @@ +import os from typing import Dict import numpy as np @@ -5,8 +6,11 @@ import transformers from torch import nn from torch.optim.lr_scheduler import LambdaLR +from transformers import PreTrainedModel from transformers.trainer import ( ALL_LAYERNORM_LAYERS, + TRAINING_ARGS_NAME, + WEIGHTS_NAME, DataLoader, EvalLoopOutput, IterableDatasetShard, @@ -25,6 +29,9 @@ nested_numpify, nested_truncate, ) +from transformers.modeling_utils import unwrap_model + +from wtpsplit.train.utils import Model if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm # noqa: F401 @@ -408,3 +415,35 @@ def evaluation_loop( metrics=metrics, num_samples=num_samples, ) + + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") + + if xm.is_master_ordinal(): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + xm.rendezvous("saving_checkpoint") + if isinstance(self.model, Model): + actual_model = self.model.backbone + else: + actual_model = self.model + if not isinstance(actual_model, PreTrainedModel): + if isinstance(unwrap_model(actual_model), PreTrainedModel): + unwrap_model(actual_model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=actual_model.state_dict(), + save_function=xm.save, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = actual_model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + actual_model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + if self.tokenizer is not None and self.args.should_save: + self.tokenizer.save_pretrained(output_dir) diff --git a/wtpsplit/train/utils.py b/wtpsplit/train/utils.py index 9d8a7578..ea5c9ec3 100644 --- a/wtpsplit/train/utils.py +++ b/wtpsplit/train/utils.py @@ -1,10 +1,119 @@ import logging import os -import time +import torch +import torch.nn as nn +from wtpsplit.utils import Constants logger = logging.getLogger(__name__) +class Model(nn.Module): + def __init__( + self, + backbone, + loss_margin=0.5, + use_loss_weights=False, + do_sentence_training=True, + do_auxiliary_training=False, + ): + super().__init__() + self.backbone = backbone + self.config = self.backbone.config + + assert loss_margin <= 0.5 + + self.loss_margin = loss_margin + self.use_loss_weights = use_loss_weights + self.do_sentence_training = do_sentence_training + self.do_auxiliary_training = do_auxiliary_training + + @property + def device(self): + return self.backbone.device + + def forward( + self, + input_ids, + language_ids=None, + attention_mask=None, + position_ids=None, + labels=None, + label_weights=None, + **kwargs, + ): + if position_ids is not None: + reduced_attention_mask = (input_ids != 0).to(torch.long) + else: + # XXX: 1 is pad token id + reduced_attention_mask = (input_ids != 1).to(torch.long) + + output = dict( + self.backbone.forward( + input_ids=input_ids, + language_ids=language_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + ) + logits = output["logits"] + + if labels is not None: + loss_fn = nn.BCEWithLogitsLoss(reduction="none") + + losses = [] + + # main (newline prediction) objective + if self.do_sentence_training: + # label smoothing + sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to( + logits.dtype + ).view(-1) * self.loss_margin * 2 + sentence_logits = logits[:, :, Constants.NEWLINE_INDEX].view(-1) + + losses.append( + ( + loss_fn( + sentence_logits, + sentence_labels, + ) + * (label_weights.view(-1) if label_weights is not None and self.use_loss_weights else 1) + * reduced_attention_mask.view(-1) + ).sum() + / reduced_attention_mask.sum() + ) + + # auxiliary (punctuation prediction) objective + if self.do_auxiliary_training: + loss_fn = nn.CrossEntropyLoss() + + # exclude newline and no labels + aux_labels = torch.where( + (labels == 0) | (labels == Constants.NEWLINE_INDEX + 1), + 0, + labels - Constants.AUX_OFFSET, + ) + # exclude reduced_attention_mask tokens from labels + aux_labels = torch.where( + reduced_attention_mask == 1, + aux_labels, + loss_fn.ignore_index, + ) + + losses.append( + loss_fn( + logits[:, :, Constants.AUX_OFFSET :].view(-1, self.config.num_labels - Constants.AUX_OFFSET), + aux_labels.view(-1), + ) + ) + + loss = torch.stack(losses).sum() + + output["loss"] = loss + + return output + + def cleanup_cache_files(datasets) -> int: """Clean up all cache files in the dataset cache directory, except those currently used by any of the provided datasets. @@ -41,10 +150,6 @@ def cleanup_cache_files(datasets) -> int: for file_path in files_to_remove: logger.warning(f"Removing {file_path}") - try: - os.remove(file_path) - except Exception as e: - logger.warning(f"Error while trying to remove {file_path}: {e}") - time.sleep(0.5) + os.remove(file_path) return len(files_to_remove)