Skip to content

Commit

Permalink
update non-parallel adp training
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 7, 2024
1 parent 5c89af1 commit c6361d9
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 87 deletions.
33 changes: 17 additions & 16 deletions wtpsplit/train/adaptertrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,22 +806,22 @@ def evaluation_loop(
"""
args = self.args

if not self.skip_eval_loss:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(
self, num_training_steps=0, resume_from_checkpoint=None, inference=True
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine

model = self._wrap_model(self.model, training=False, dataloader=dataloader)
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(
self, num_training_steps=0, resume_from_checkpoint=None, inference=True
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine

model = self._wrap_model(self.model, training=False, dataloader=dataloader)

if not self.skip_eval_loss:
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
Expand All @@ -832,7 +832,7 @@ def evaluation_loop(

batch_size = self.args.eval_batch_size

logger.info(f"***** Running {description} *****")
logger.warning(f"***** Running {description} *****")
if has_length(dataloader):
logger.warning(f" Num examples = {self.num_examples(dataloader)}")
else:
Expand Down Expand Up @@ -983,6 +983,7 @@ def evaluation_loop(
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
else:
xm.rendezvous("eval_metrics")
all_losses, all_preds, all_labels, all_inputs, num_samples = None, None, None, None, 0

# Metrics!
Expand Down
210 changes: 139 additions & 71 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
from dataclasses import dataclass
import copy
import logging
import sys
import math
import os
import copy
import random
import sys
from collections import Counter
from dataclasses import dataclass
from functools import partial
from glob import glob
from typing import List
from adapters import AdapterArguments
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
from wtpsplit.train.utils import Model
from wtpsplit.train.train import setup_logging, collate_fn
from wtpsplit.models import SubwordXLMForTokenClassification, SubwordXLMConfig
from tokenizers import AddedToken

import adapters
import datasets
import numpy as np
import math
from collections import Counter
import torch
import random
from tokenizers import AddedToken
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

import adapters
import wandb
from glob import glob
from functools import partial
from adapters import AdapterArguments
from wtpsplit.evaluation.intrinsic import corrupt
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise
from wtpsplit.train.train import collate_fn, setup_logging
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
from tqdm import tqdm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,6 +61,15 @@ class Args:
use_subwords: bool = False
freeze_classifier: bool = False
clf_from_scratch: bool = False
unfreeze_ln: bool = False
do_process: bool = False
meta_clf: bool = False
wandb_project: str = "sentence"
# corruption
do_lowercase: bool = False
do_remove_punct: bool = False
eval_pairwise: bool = False
skip_eval_loss: bool = False


def main():
Expand All @@ -73,68 +85,52 @@ def main():
set_seed(training_args.seed)

num_labels = Constants.AUX_OFFSET + (
(1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training or label_args.use_auxiliary else 0
(1 + len(Constants.PUNCTUATION_CHARS))
if (label_args.use_auxiliary or args.do_auxiliary_training or args.meta_clf)
else 0
)
config = SubwordXLMConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels,
)

# since we pre-tokenize, running multiple epochs would iterate over data in same order
# hence, we duplicate & shuffle train data sentences in prepare_dataset
# and set num_train_epochs to 1 --> simulate multiple epochs, each with different sentence order
num_train_epochs = training_args.num_train_epochs

training_args.num_train_epochs = 1
training_args.evaluation_strategy = "steps"

def prepare_dataset(
data,
num_workers=1,
include_languages=None,
dataset_name="ud",
shuffle=False,
split="train",
do_lowercase=False,
do_remove_punct=False,
):
# maybe we use more than 1 lang later at once.
with training_args.main_process_first():
# maybe we use more than 1 lang later at once.
for lang in include_languages:
if split == "train":
dataset = data[lang]["sentence"][dataset_name]["meta"]["train_data"]
elif split == "valid":
dataset = data[lang]["sentence"][dataset_name]["data"]
data_list = []
if dataset is None:
return None
for sample in dataset:
ends_with_punctuation = sample.endswith(tuple(Constants.PUNCTUATION_CHARS))
data_list.append(
dataset = datasets.Dataset.from_list(
[
{
args.text_column: sample + "\n" if len(sample) > 0 and sample[-1] != "\n" else sample,
args.text_column: corrupt(sample, do_lowercase, do_remove_punct) + "\n"
if sample and sample[-1] != "\n"
else corrupt(sample, do_lowercase, do_remove_punct),
"lang": lang,
"ends_with_punctuation": ends_with_punctuation,
"ends_with_punctuation": sample.endswith(tuple(Constants.PUNCTUATION_CHARS)),
}
)
dataset = datasets.Dataset.from_list(data_list)
with training_args.main_process_first():
logger.warning(f"Loaded {len(dataset)} examples for {lang} {dataset_name} {split} dataset.")

if include_languages is not None:
include_languages = set(include_languages)

dataset = dataset.filter(
lambda example: example["lang"] in include_languages,
num_proc=args.preprocessing_num_workers,
)
for sample in dataset
]
)
with training_args.main_process_first():
logger.warning(f"Filtered to {len(dataset)} examples.")
logger.warning(f"Loaded {len(dataset)} examples for {lang} {dataset_name} {split} dataset.")

if shuffle:
# create n_epochs copies of the dataset and shuffle them individually
dataset = datasets.concatenate_datasets([dataset.shuffle(seed=i) for i in range(num_train_epochs)])

with training_args.main_process_first():
logger.warning(f"Shuffled dataset to {len(dataset)} examples.")
dataset = dataset.shuffle(seed=42)

# very likely not relevant / used only for the compound part
if args.ignore_non_hyphen:
Expand Down Expand Up @@ -347,20 +343,21 @@ def maybe_pad(text):

# 1 wandb run for all language-dataset combinations
if "wandb" in training_args.report_to and training_args.process_index == 0:
wandb.init(name=wandb_name, project="sentence-peft")
wandb.init(name=wandb_name, project=args.wandb_project, group=wandb_name)
wandb.config.update(args)
wandb.config.update(training_args)
wandb.config.update(label_args)
wandb.config.update(adapter_args)

for file in glob(os.path.join(os.path.dirname(__file__), "*.py")):
wandb.save(os.path.abspath(file), policy="now")

for lang in data.keys():
for lang in tqdm(data.keys(), desc="Language"):
if lang in args.include_languages:
for dataset_name in data[lang]["sentence"].keys():
# do model stuff here; otherwise, head params would be overwritten every time
backbone = SubwordXLMForTokenClassification.from_pretrained(
args.model_name_or_path, config=config, ignore_mismatched_sizes=True
args.model_name_or_path, config=copy.deepcopy(config), ignore_mismatched_sizes=True
)
backbone.config.base_model = args.base_model

Expand Down Expand Up @@ -397,6 +394,8 @@ def maybe_pad(text):
dataset_name=dataset_name,
shuffle=False,
split="valid",
do_lowercase=args.do_lowercase,
do_remove_punct=args.do_remove_punct,
)
logger.warning(f"Valid ds for {lang} {dataset_name} has {len(valid_dataset)} examples.")

Expand All @@ -407,22 +406,14 @@ def maybe_pad(text):
dataset_name=dataset_name,
shuffle=args.shuffle,
split="train",
do_lowercase=args.do_lowercase,
do_remove_punct=args.do_remove_punct,
)
if train_dataset is None or valid_dataset is None:
logger.warning(f"Skipping {lang} {dataset_name} due to missing data.")
continue
logger.warning(f"Train ds for {lang} {dataset_name} has {len(train_dataset)} examples.")

# eval every actual epoch, based on steps
training_args.eval_steps = (
len(train_dataset)
// (
training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* num_train_epochs
)
) + 1

# print some samples from the dataset
count = 0
while count < 1:
Expand All @@ -446,15 +437,44 @@ def compute_metrics(trainer):
eval_data,
model,
stride=64,
block_size=512, ## TODO: change to args version x2?
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
)
metrics[f"{dataset_name}/{lang}/pr_auc"] = score
metrics[f"{dataset_name}/{lang}/f1"] = info["f1"]
metrics[f"{dataset_name}/{lang}/f1_best"] = info["f1_best"]
metrics[f"{dataset_name}/{lang}/threshold_best"] = info["threshold_best"]
if args.do_lowercase and args.do_remove_punct:
score_corrupted, info_corrupted = evaluate_sentence(
lang,
eval_data,
model,
stride=64,
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
do_lowercase=True,
do_remove_punct=True,
)
metrics[f"{lang}_{dataset_name}_pr_auc"] = score
metrics[f"{lang}_{dataset_name}_f1"] = info["f1"]
metrics[f"{lang}_{dataset_name}_f1_best"] = info["f1_best"]
metrics[f"{lang}_{dataset_name}_threshold_best"] = info["threshold_best"]
metrics[f"{dataset_name}/{lang}/corrupted/pr_auc"] = score_corrupted
metrics[f"{dataset_name}/{lang}/corrupted/f1"] = info_corrupted["f1"]
metrics[f"{dataset_name}/{lang}/corrupted/f1_best"] = info_corrupted["f1_best"]
metrics[f"{dataset_name}/{lang}/corrupted/threshold_best"] = info_corrupted["threshold_best"]
elif args.do_lowercase or args.do_remove_punct:
raise NotImplementedError("Currently we only corrupt both ways!")
if args.eval_pairwise:
score_pairwise, avg_acc = evaluate_sentence_pairwise(
lang,
eval_data,
model,
stride=args.eval_stride,
block_size=args.block_size,
batch_size=training_args.per_device_eval_batch_size,
threshold=0.1,
)
metrics[f"{dataset_name}/{lang}/pairwise/pr_auc"] = score_pairwise
metrics[f"{dataset_name}/{lang}/pairwise/acc"] = avg_acc

return metrics
return metrics

label_dict = (
get_subword_label_dict(label_args, tokenizer) if args.use_subwords else get_label_dict(label_args)
Expand All @@ -475,6 +495,19 @@ def compute_metrics(trainer):
if args.clf_from_scratch:
model.backbone.classifier = torch.nn.Linear(model.backbone.config.hidden_size, num_labels)

if args.unfreeze_ln:
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.requires_grad = True

if args.meta_clf:
clf = model.backbone.classifier
model.backbone.classifier = torch.nn.Sequential(
clf, # original classifier - if frozen above, also frozen here
torch.nn.Linear(clf.out_features, 1),
)
model.backbone.config.num_labels = 1

trainer = AdapterTrainer(
model,
training_args,
Expand All @@ -487,8 +520,10 @@ def compute_metrics(trainer):
label_args=label_args,
label_dict=label_dict,
tokenizer=tokenizer,
add_lang_ids=False,
),
logging_suffix=f"{lang}_{dataset_name}",
logging_prefix=f"{dataset_name}/{lang}/",
skip_eval_loss=args.skip_eval_loss,
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
with training_args.main_process_first():
Expand All @@ -501,6 +536,39 @@ def compute_metrics(trainer):
save_directory=os.path.join(training_args.output_dir, dataset_name, lang),
with_head=True,
)
if training_args.local_rank == 0:
# eval here within 1 go
if args.do_lowercase and args.do_remove_punct:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --do_lowercase --do_remove_punct"
)
elif args.eval_pairwise:
os.system(
f"python3 wtpsplit/evaluation/intrinsic_pairwise.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1"
)
elif "lines" in args.text_path:
if args.do_lowercase and args.do_remove_punct:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines --do_lowercase --do_remove_punct"
)
else:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines"
)
elif "verses" in args.text_path:
if args.do_lowercase and args.do_remove_punct:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses --do_lowercase --do_remove_punct"
)
else:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses"
)
else:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1"
)


def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down

0 comments on commit c6361d9

Please sign in to comment.