Skip to content

Commit

Permalink
fix model saving during training
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 2, 2024
1 parent f9dbd9f commit fbc2c5d
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 116 deletions.
43 changes: 43 additions & 0 deletions configs/xlmr_stratify_0.1_3layers_100k.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-100k",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 100000,
"save_steps": 50000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "info"
}
114 changes: 4 additions & 110 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from datasets import load_dataset
from datasets.download import DownloadConfig
from tokenizers import AddedToken
from torch import nn
from torchinfo import summary
from tqdm.auto import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
Expand All @@ -33,8 +32,8 @@
)
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.trainer import Trainer
from wtpsplit.train.utils import cleanup_cache_files
from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict
from wtpsplit.train.utils import Model, cleanup_cache_files

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,111 +65,6 @@ def setup_logging(training_args: transformers.TrainingArguments) -> None:
# logger.info(f"Training/evaluation parameters {training_args}")


class Model(nn.Module):
def __init__(
self,
backbone,
loss_margin=0.5,
use_loss_weights=False,
do_sentence_training=True,
do_auxiliary_training=False,
):
super().__init__()
self.backbone = backbone
self.config = self.backbone.config

assert loss_margin <= 0.5

self.loss_margin = loss_margin
self.use_loss_weights = use_loss_weights
self.do_sentence_training = do_sentence_training
self.do_auxiliary_training = do_auxiliary_training

@property
def device(self):
return self.backbone.device

def forward(
self,
input_ids,
language_ids=None,
attention_mask=None,
position_ids=None,
labels=None,
label_weights=None,
**kwargs,
):
if position_ids is not None:
reduced_attention_mask = (input_ids != 0).to(torch.long)
else:
# XXX: 1 is pad token id
reduced_attention_mask = (input_ids != 1).to(torch.long)

output = dict(
self.backbone.forward(
input_ids=input_ids,
language_ids=language_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
)
logits = output["logits"]

if labels is not None:
loss_fn = nn.BCEWithLogitsLoss(reduction="none")

losses = []

# main (newline prediction) objective
if self.do_sentence_training:
# label smoothing
sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to(
logits.dtype
).view(-1) * self.loss_margin * 2
sentence_logits = logits[:, :, Constants.NEWLINE_INDEX].view(-1)

losses.append(
(
loss_fn(
sentence_logits,
sentence_labels,
)
* (label_weights.view(-1) if label_weights is not None and self.use_loss_weights else 1)
* reduced_attention_mask.view(-1)
).sum()
/ reduced_attention_mask.sum()
)

# auxiliary (punctuation prediction) objective
if self.do_auxiliary_training:
loss_fn = nn.CrossEntropyLoss()

# exclude newline and no labels
aux_labels = torch.where(
(labels == 0) | (labels == Constants.NEWLINE_INDEX + 1),
0,
labels - Constants.AUX_OFFSET,
)
# exclude reduced_attention_mask tokens from labels
aux_labels = torch.where(
reduced_attention_mask == 1,
aux_labels,
loss_fn.ignore_index,
)

losses.append(
loss_fn(
logits[:, :, Constants.AUX_OFFSET :].view(-1, self.config.num_labels - Constants.AUX_OFFSET),
aux_labels.view(-1),
)
)

loss = torch.stack(losses).sum()

output["loss"] = loss

return output


@dataclass
Expand Down Expand Up @@ -347,7 +241,7 @@ def main():
num_labels = Constants.AUX_OFFSET + ((1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training else 0)
if args.use_subwords:
if args.from_scratch:
config = SubwordXLMConfig.from_pretrained(
config = SubwordXLMConfig(
args.model_name_or_path,
num_hidden_layers=args.num_hidden_layers,
num_labels=num_labels,
Expand Down Expand Up @@ -408,7 +302,7 @@ def main():
do_auxiliary_training=args.do_auxiliary_training,
)

with training_args.main_process_first():
if training_args.local_rank == 0:
logger.info(summary(model, depth=4))
# backbone.push_to_hub("markus583/xlm-token-untrained", private=True)

Expand Down Expand Up @@ -738,7 +632,7 @@ def compute_metrics(trainer):
# because that would remove the cache files of the other dataset!
cleanup_cache_files([train_dataset, valid_dataset])
logger.warning("Cleaned up cache files.")
time.sleep(20)
time.sleep(10)

trainer = Trainer(
model,
Expand Down
39 changes: 39 additions & 0 deletions wtpsplit/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
from typing import Dict

import numpy as np
import torch
import transformers
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from transformers import PreTrainedModel
from transformers.trainer import (
ALL_LAYERNORM_LAYERS,
TRAINING_ARGS_NAME,
WEIGHTS_NAME,
DataLoader,
EvalLoopOutput,
IterableDatasetShard,
Expand All @@ -25,6 +29,9 @@
nested_numpify,
nested_truncate,
)
from transformers.modeling_utils import unwrap_model

from wtpsplit.train.utils import Model

if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm # noqa: F401
Expand Down Expand Up @@ -408,3 +415,35 @@ def evaluation_loop(
metrics=metrics,
num_samples=num_samples,
)

def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info(f"Saving model checkpoint to {output_dir}")

if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if isinstance(self.model, Model):
actual_model = self.model.backbone
else:
actual_model = self.model
if not isinstance(actual_model, PreTrainedModel):
if isinstance(unwrap_model(actual_model), PreTrainedModel):
unwrap_model(actual_model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=actual_model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = actual_model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
actual_model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
117 changes: 111 additions & 6 deletions wtpsplit/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,119 @@
import logging
import os
import time
import torch
import torch.nn as nn
from wtpsplit.utils import Constants

logger = logging.getLogger(__name__)


class Model(nn.Module):
def __init__(
self,
backbone,
loss_margin=0.5,
use_loss_weights=False,
do_sentence_training=True,
do_auxiliary_training=False,
):
super().__init__()
self.backbone = backbone
self.config = self.backbone.config

assert loss_margin <= 0.5

self.loss_margin = loss_margin
self.use_loss_weights = use_loss_weights
self.do_sentence_training = do_sentence_training
self.do_auxiliary_training = do_auxiliary_training

@property
def device(self):
return self.backbone.device

def forward(
self,
input_ids,
language_ids=None,
attention_mask=None,
position_ids=None,
labels=None,
label_weights=None,
**kwargs,
):
if position_ids is not None:
reduced_attention_mask = (input_ids != 0).to(torch.long)
else:
# XXX: 1 is pad token id
reduced_attention_mask = (input_ids != 1).to(torch.long)

output = dict(
self.backbone.forward(
input_ids=input_ids,
language_ids=language_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
)
logits = output["logits"]

if labels is not None:
loss_fn = nn.BCEWithLogitsLoss(reduction="none")

losses = []

# main (newline prediction) objective
if self.do_sentence_training:
# label smoothing
sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to(
logits.dtype
).view(-1) * self.loss_margin * 2
sentence_logits = logits[:, :, Constants.NEWLINE_INDEX].view(-1)

losses.append(
(
loss_fn(
sentence_logits,
sentence_labels,
)
* (label_weights.view(-1) if label_weights is not None and self.use_loss_weights else 1)
* reduced_attention_mask.view(-1)
).sum()
/ reduced_attention_mask.sum()
)

# auxiliary (punctuation prediction) objective
if self.do_auxiliary_training:
loss_fn = nn.CrossEntropyLoss()

# exclude newline and no labels
aux_labels = torch.where(
(labels == 0) | (labels == Constants.NEWLINE_INDEX + 1),
0,
labels - Constants.AUX_OFFSET,
)
# exclude reduced_attention_mask tokens from labels
aux_labels = torch.where(
reduced_attention_mask == 1,
aux_labels,
loss_fn.ignore_index,
)

losses.append(
loss_fn(
logits[:, :, Constants.AUX_OFFSET :].view(-1, self.config.num_labels - Constants.AUX_OFFSET),
aux_labels.view(-1),
)
)

loss = torch.stack(losses).sum()

output["loss"] = loss

return output


def cleanup_cache_files(datasets) -> int:
"""Clean up all cache files in the dataset cache directory, except those currently used by any of the provided datasets.
Expand Down Expand Up @@ -41,10 +150,6 @@ def cleanup_cache_files(datasets) -> int:

for file_path in files_to_remove:
logger.warning(f"Removing {file_path}")
try:
os.remove(file_path)
except Exception as e:
logger.warning(f"Error while trying to remove {file_path}: {e}")
time.sleep(0.5)
os.remove(file_path)

return len(files_to_remove)

0 comments on commit fbc2c5d

Please sign in to comment.