From c6361d9d80508a9898a19435d8336c5ca1d86482 Mon Sep 17 00:00:00 2001 From: markus583 Date: Thu, 7 Mar 2024 08:56:46 +0000 Subject: [PATCH] update non-parallel adp training --- wtpsplit/train/adaptertrainer.py | 33 ++--- wtpsplit/train/train_adapter.py | 210 ++++++++++++++++++++----------- 2 files changed, 156 insertions(+), 87 deletions(-) diff --git a/wtpsplit/train/adaptertrainer.py b/wtpsplit/train/adaptertrainer.py index 5303ee48..5a735471 100644 --- a/wtpsplit/train/adaptertrainer.py +++ b/wtpsplit/train/adaptertrainer.py @@ -806,22 +806,22 @@ def evaluation_loop( """ args = self.args - if not self.skip_eval_loss: - prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only - - # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init( - self, num_training_steps=0, resume_from_checkpoint=None, inference=True - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - - model = self._wrap_model(self.model, training=False, dataloader=dataloader) + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if not self.skip_eval_loss: # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: @@ -832,7 +832,7 @@ def evaluation_loop( batch_size = self.args.eval_batch_size - logger.info(f"***** Running {description} *****") + logger.warning(f"***** Running {description} *****") if has_length(dataloader): logger.warning(f" Num examples = {self.num_examples(dataloader)}") else: @@ -983,6 +983,7 @@ def evaluation_loop( if all_inputs is not None: all_inputs = nested_truncate(all_inputs, num_samples) else: + xm.rendezvous("eval_metrics") all_losses, all_preds, all_labels, all_inputs, num_samples = None, None, None, None, 0 # Metrics! diff --git a/wtpsplit/train/train_adapter.py b/wtpsplit/train/train_adapter.py index 3e5959cf..efb5b762 100644 --- a/wtpsplit/train/train_adapter.py +++ b/wtpsplit/train/train_adapter.py @@ -1,29 +1,32 @@ -from dataclasses import dataclass +import copy import logging -import sys +import math import os -import copy +import random +import sys +from collections import Counter +from dataclasses import dataclass +from functools import partial +from glob import glob from typing import List -from adapters import AdapterArguments -from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed -from wtpsplit.train.evaluate import evaluate_sentence -from wtpsplit.train.adaptertrainer import AdapterTrainer -from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict -from wtpsplit.train.utils import Model -from wtpsplit.train.train import setup_logging, collate_fn -from wtpsplit.models import SubwordXLMForTokenClassification, SubwordXLMConfig -from tokenizers import AddedToken -import adapters import datasets import numpy as np -import math -from collections import Counter import torch -import random +from tokenizers import AddedToken +from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed + +import adapters import wandb -from glob import glob -from functools import partial +from adapters import AdapterArguments +from wtpsplit.evaluation.intrinsic import corrupt +from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification +from wtpsplit.train.adaptertrainer import AdapterTrainer +from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise +from wtpsplit.train.train import collate_fn, setup_logging +from wtpsplit.train.utils import Model +from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -58,6 +61,15 @@ class Args: use_subwords: bool = False freeze_classifier: bool = False clf_from_scratch: bool = False + unfreeze_ln: bool = False + do_process: bool = False + meta_clf: bool = False + wandb_project: str = "sentence" + # corruption + do_lowercase: bool = False + do_remove_punct: bool = False + eval_pairwise: bool = False + skip_eval_loss: bool = False def main(): @@ -73,21 +85,15 @@ def main(): set_seed(training_args.seed) num_labels = Constants.AUX_OFFSET + ( - (1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training or label_args.use_auxiliary else 0 + (1 + len(Constants.PUNCTUATION_CHARS)) + if (label_args.use_auxiliary or args.do_auxiliary_training or args.meta_clf) + else 0 ) config = SubwordXLMConfig.from_pretrained( args.model_name_or_path, num_labels=num_labels, ) - # since we pre-tokenize, running multiple epochs would iterate over data in same order - # hence, we duplicate & shuffle train data sentences in prepare_dataset - # and set num_train_epochs to 1 --> simulate multiple epochs, each with different sentence order - num_train_epochs = training_args.num_train_epochs - - training_args.num_train_epochs = 1 - training_args.evaluation_strategy = "steps" - def prepare_dataset( data, num_workers=1, @@ -95,46 +101,36 @@ def prepare_dataset( dataset_name="ud", shuffle=False, split="train", + do_lowercase=False, + do_remove_punct=False, ): # maybe we use more than 1 lang later at once. with training_args.main_process_first(): + # maybe we use more than 1 lang later at once. for lang in include_languages: if split == "train": dataset = data[lang]["sentence"][dataset_name]["meta"]["train_data"] elif split == "valid": dataset = data[lang]["sentence"][dataset_name]["data"] - data_list = [] if dataset is None: return None - for sample in dataset: - ends_with_punctuation = sample.endswith(tuple(Constants.PUNCTUATION_CHARS)) - data_list.append( + dataset = datasets.Dataset.from_list( + [ { - args.text_column: sample + "\n" if len(sample) > 0 and sample[-1] != "\n" else sample, + args.text_column: corrupt(sample, do_lowercase, do_remove_punct) + "\n" + if sample and sample[-1] != "\n" + else corrupt(sample, do_lowercase, do_remove_punct), "lang": lang, - "ends_with_punctuation": ends_with_punctuation, + "ends_with_punctuation": sample.endswith(tuple(Constants.PUNCTUATION_CHARS)), } - ) - dataset = datasets.Dataset.from_list(data_list) - with training_args.main_process_first(): - logger.warning(f"Loaded {len(dataset)} examples for {lang} {dataset_name} {split} dataset.") - - if include_languages is not None: - include_languages = set(include_languages) - - dataset = dataset.filter( - lambda example: example["lang"] in include_languages, - num_proc=args.preprocessing_num_workers, - ) + for sample in dataset + ] + ) with training_args.main_process_first(): - logger.warning(f"Filtered to {len(dataset)} examples.") + logger.warning(f"Loaded {len(dataset)} examples for {lang} {dataset_name} {split} dataset.") if shuffle: - # create n_epochs copies of the dataset and shuffle them individually - dataset = datasets.concatenate_datasets([dataset.shuffle(seed=i) for i in range(num_train_epochs)]) - - with training_args.main_process_first(): - logger.warning(f"Shuffled dataset to {len(dataset)} examples.") + dataset = dataset.shuffle(seed=42) # very likely not relevant / used only for the compound part if args.ignore_non_hyphen: @@ -347,20 +343,21 @@ def maybe_pad(text): # 1 wandb run for all language-dataset combinations if "wandb" in training_args.report_to and training_args.process_index == 0: - wandb.init(name=wandb_name, project="sentence-peft") + wandb.init(name=wandb_name, project=args.wandb_project, group=wandb_name) wandb.config.update(args) wandb.config.update(training_args) wandb.config.update(label_args) + wandb.config.update(adapter_args) for file in glob(os.path.join(os.path.dirname(__file__), "*.py")): wandb.save(os.path.abspath(file), policy="now") - for lang in data.keys(): + for lang in tqdm(data.keys(), desc="Language"): if lang in args.include_languages: for dataset_name in data[lang]["sentence"].keys(): # do model stuff here; otherwise, head params would be overwritten every time backbone = SubwordXLMForTokenClassification.from_pretrained( - args.model_name_or_path, config=config, ignore_mismatched_sizes=True + args.model_name_or_path, config=copy.deepcopy(config), ignore_mismatched_sizes=True ) backbone.config.base_model = args.base_model @@ -397,6 +394,8 @@ def maybe_pad(text): dataset_name=dataset_name, shuffle=False, split="valid", + do_lowercase=args.do_lowercase, + do_remove_punct=args.do_remove_punct, ) logger.warning(f"Valid ds for {lang} {dataset_name} has {len(valid_dataset)} examples.") @@ -407,22 +406,14 @@ def maybe_pad(text): dataset_name=dataset_name, shuffle=args.shuffle, split="train", + do_lowercase=args.do_lowercase, + do_remove_punct=args.do_remove_punct, ) if train_dataset is None or valid_dataset is None: logger.warning(f"Skipping {lang} {dataset_name} due to missing data.") continue logger.warning(f"Train ds for {lang} {dataset_name} has {len(train_dataset)} examples.") - # eval every actual epoch, based on steps - training_args.eval_steps = ( - len(train_dataset) - // ( - training_args.per_device_train_batch_size - * training_args.gradient_accumulation_steps - * num_train_epochs - ) - ) + 1 - # print some samples from the dataset count = 0 while count < 1: @@ -446,15 +437,44 @@ def compute_metrics(trainer): eval_data, model, stride=64, - block_size=512, ## TODO: change to args version x2? + block_size=512, + batch_size=training_args.per_device_eval_batch_size, + ) + metrics[f"{dataset_name}/{lang}/pr_auc"] = score + metrics[f"{dataset_name}/{lang}/f1"] = info["f1"] + metrics[f"{dataset_name}/{lang}/f1_best"] = info["f1_best"] + metrics[f"{dataset_name}/{lang}/threshold_best"] = info["threshold_best"] + if args.do_lowercase and args.do_remove_punct: + score_corrupted, info_corrupted = evaluate_sentence( + lang, + eval_data, + model, + stride=64, + block_size=512, batch_size=training_args.per_device_eval_batch_size, + do_lowercase=True, + do_remove_punct=True, ) - metrics[f"{lang}_{dataset_name}_pr_auc"] = score - metrics[f"{lang}_{dataset_name}_f1"] = info["f1"] - metrics[f"{lang}_{dataset_name}_f1_best"] = info["f1_best"] - metrics[f"{lang}_{dataset_name}_threshold_best"] = info["threshold_best"] + metrics[f"{dataset_name}/{lang}/corrupted/pr_auc"] = score_corrupted + metrics[f"{dataset_name}/{lang}/corrupted/f1"] = info_corrupted["f1"] + metrics[f"{dataset_name}/{lang}/corrupted/f1_best"] = info_corrupted["f1_best"] + metrics[f"{dataset_name}/{lang}/corrupted/threshold_best"] = info_corrupted["threshold_best"] + elif args.do_lowercase or args.do_remove_punct: + raise NotImplementedError("Currently we only corrupt both ways!") + if args.eval_pairwise: + score_pairwise, avg_acc = evaluate_sentence_pairwise( + lang, + eval_data, + model, + stride=args.eval_stride, + block_size=args.block_size, + batch_size=training_args.per_device_eval_batch_size, + threshold=0.1, + ) + metrics[f"{dataset_name}/{lang}/pairwise/pr_auc"] = score_pairwise + metrics[f"{dataset_name}/{lang}/pairwise/acc"] = avg_acc - return metrics + return metrics label_dict = ( get_subword_label_dict(label_args, tokenizer) if args.use_subwords else get_label_dict(label_args) @@ -475,6 +495,19 @@ def compute_metrics(trainer): if args.clf_from_scratch: model.backbone.classifier = torch.nn.Linear(model.backbone.config.hidden_size, num_labels) + if args.unfreeze_ln: + for n, p in model.backbone.named_parameters(): + if "LayerNorm" in n: + p.requires_grad = True + + if args.meta_clf: + clf = model.backbone.classifier + model.backbone.classifier = torch.nn.Sequential( + clf, # original classifier - if frozen above, also frozen here + torch.nn.Linear(clf.out_features, 1), + ) + model.backbone.config.num_labels = 1 + trainer = AdapterTrainer( model, training_args, @@ -487,8 +520,10 @@ def compute_metrics(trainer): label_args=label_args, label_dict=label_dict, tokenizer=tokenizer, + add_lang_ids=False, ), - logging_suffix=f"{lang}_{dataset_name}", + logging_prefix=f"{dataset_name}/{lang}/", + skip_eval_loss=args.skip_eval_loss, ) trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) with training_args.main_process_first(): @@ -501,6 +536,39 @@ def compute_metrics(trainer): save_directory=os.path.join(training_args.output_dir, dataset_name, lang), with_head=True, ) + if training_args.local_rank == 0: + # eval here within 1 go + if args.do_lowercase and args.do_remove_punct: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --do_lowercase --do_remove_punct" + ) + elif args.eval_pairwise: + os.system( + f"python3 wtpsplit/evaluation/intrinsic_pairwise.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1" + ) + elif "lines" in args.text_path: + if args.do_lowercase and args.do_remove_punct: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines --do_lowercase --do_remove_punct" + ) + else: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines" + ) + elif "verses" in args.text_path: + if args.do_lowercase and args.do_remove_punct: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses --do_lowercase --do_remove_punct" + ) + else: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses" + ) + else: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1" + ) + def _mp_fn(index): # For xla_spawn (TPUs)