diff --git a/configs/peft/adapter_lyrics.json b/configs/peft/adapter_lyrics.json new file mode 100644 index 00000000..218e8fc6 --- /dev/null +++ b/configs/peft/adapter_lyrics.json @@ -0,0 +1,40 @@ +{ + "model_name_or_path": "xlmr-normal-p-v3", + "output_dir": "xlmr-3l-v3_adapter_rf32_ep20_v2_100-1k-10k_lines", + "block_size": 256, + "eval_stride": 128, + "do_train": true, + "do_eval": true, + "per_device_train_batch_size": 64, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 1, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 1, + "preprocessing_num_workers": 1, + "learning_rate": 3e-4, + "fp16": false, + "num_train_epochs": 20, + "logging_steps": 50, + "report_to": "wandb", + "wandb_project": "lyrics-peft", + "save_steps": 100000000, + "remove_unused_columns": false, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": false, + "warmup_ratio": 0.1, + "non_punctuation_sample_ratio": null, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "custom_punctuation_file": "punctuation_xlmr_unk.txt", + "log_level": "warning", + "adapter_config": "seq_bn[reduction_factor=32]", + "weight_decay": 0.0, + "auxiliary_remove_prob": 0.0, + "do_process": true, + "n_train_steps": [100, 1000, 10000], + "do_lowercase": true, + "do_remove_punct": true +} \ No newline at end of file diff --git a/configs/peft/adapter_pairwise.json b/configs/peft/adapter_pairwise.json new file mode 100644 index 00000000..70dd60fc --- /dev/null +++ b/configs/peft/adapter_pairwise.json @@ -0,0 +1,41 @@ +{ + "model_name_or_path": "xlmr-normal-p-v3", + "output_dir": "xlmr-3l-v3_adapter_rf32_ep20_v2_100-1k-10k_pairwise_bs32", + "block_size": 32, + "eval_stride": 16, + "do_train": true, + "do_eval": true, + "per_device_train_batch_size": 64, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 1, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 1, + "preprocessing_num_workers": 1, + "learning_rate": 3e-4, + "fp16": false, + "num_train_epochs": 20, + "logging_steps": 50, + "report_to": "wandb", + "wandb_project": "pairwise-peft", + "save_steps": 100000000, + "remove_unused_columns": false, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": false, + "warmup_ratio": 0.1, + "non_punctuation_sample_ratio": null, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "custom_punctuation_file": "punctuation_xlmr_unk.txt", + "log_level": "warning", + "adapter_config": "seq_bn[reduction_factor=32]", + "weight_decay": 0.0, + "auxiliary_remove_prob": 0.0, + "do_process": true, + "n_train_steps": [100, 1000, 10000], + "do_lowercase": true, + "do_remove_punct": true, + "eval_pairwise": true +} \ No newline at end of file diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index a9418a6b..cbf19e1b 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -6,6 +6,7 @@ import time import random import sys +import logging import h5py import skops.io as sio @@ -14,6 +15,7 @@ from tqdm.auto import tqdm from transformers import AutoModelForTokenClassification, HfArgumentParser import numpy as np +import adapters import wtpsplit.models # noqa: F401 from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs @@ -22,10 +24,13 @@ from wtpsplit.utils import Constants from wtpsplit.evaluation.intrinsic import compute_statistics, corrupt +logger = logging.getLogger() +logger.setLevel(logging.INFO) @dataclass class Args: model_path: str + adapter_path: str = None # eval data in the format: # { # "": { @@ -162,6 +167,25 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st # eval data for dataset_name, dataset in eval_data[lang_code]["sentence"].items(): + try: + if args.adapter_path: + model.model.load_adapter( + args.adapter_path + "/" + dataset_name + "/" + lang_code, + set_active=True, + with_head=True, + load_as="text", + ) + if hasattr(model.model.config, "unfreeze_ln"): + if model.model.config.unfreeze_ln: + ln_dict = torch.load( + args.adapter_path + "/" + dataset_name + "/" + lang_code + "/ln_dict.pth" + ) + for n, p in model.backbone.named_parameters(): + if "LayerNorm" in n: + p.data = ln_dict[n].data + except Exception as e: + print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}") + continue print(dataset_name) if dataset_name not in lang_group: dset_group = lang_group.create_group(dataset_name) @@ -246,7 +270,12 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st def main(args): - save_str = f"{args.model_path.replace('/','_')}_b{args.block_size}_u{args.threshold}{args.save_suffix}" + save_model_path = args.model_path + if args.adapter_path: + save_model_path = args.adapter_path + save_str = ( + f"{save_model_path.replace('/','_')}_b{args.block_size}_u{args.threshold}{args.save_suffix}" + ) if args.do_lowercase: save_str += "_lc" if args.do_remove_punct: @@ -260,7 +289,20 @@ def main(args): print("Loading model...") model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device)) - + if args.adapter_path: + model_type = model.model.config.model_type + # adapters need xlm-roberta as model type. + model.model.config.model_type = "xlm-roberta" + adapters.init(model.model) + # reset model type (used later) + model.model.config.model_type = model_type + if "meta-clf" in args.adapter_path: + clf = model.model.classifier + model.model.classifier = torch.nn.Sequential( + clf, + torch.nn.Linear(clf.out_features, 1) + ) + # first, logits for everything. f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str) diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index f3927165..c2558ff6 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -37,7 +37,7 @@ def compute_f1(pred, true): ) -def get_metrics(labels, preds, threshold: float = 0.5): +def get_metrics(labels, preds, threshold: float = 0.1): # Compute precision-recall curve and AUC precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, preds) pr_auc = sklearn.metrics.auc(recall, precision) @@ -141,6 +141,7 @@ def evaluate_sentence_pairwise( positive_index=None, do_lowercase=False, do_remove_punct=False, + threshold: float = 0.1 ): if positive_index is None: positive_index = Constants.NEWLINE_INDEX @@ -173,7 +174,6 @@ def evaluate_sentence_pairwise( ) # simulate performance for WtP-U - DEFAULT_THRESHOLD = 0.01 for i, (sentence1, sentence2) in enumerate(sampled_pairs): newline_probs = logits[i][:, positive_index] @@ -188,7 +188,7 @@ def evaluate_sentence_pairwise( # Get metrics for the pair pair_metrics, _ = get_metrics(newline_labels, newline_probs) metrics_list.append(pair_metrics["pr_auc"]) - predicted_labels = newline_probs > np.log(DEFAULT_THRESHOLD / (1 - DEFAULT_THRESHOLD)) # inverse sigmoid + predicted_labels = newline_probs > np.log(threshold / (1 - threshold)) # inverse sigmoid # for accuracy, check if the single label in between is correctly predicted (ignore the one at the end) if sum(predicted_labels[:-1]) > 0: correct = (np.where(newline_labels[:-1])[0] == np.where(predicted_labels[:-1])[0]).all() diff --git a/wtpsplit/train/train_adapter_parallel.py b/wtpsplit/train/train_adapter_parallel.py index d8613234..f56a9d83 100644 --- a/wtpsplit/train/train_adapter_parallel.py +++ b/wtpsplit/train/train_adapter_parallel.py @@ -32,7 +32,7 @@ ParallelTPUWandbCallback as WandbCallback, ) from wtpsplit.train.adaptertrainer import AdapterTrainer -from wtpsplit.train.evaluate import evaluate_sentence +from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise from wtpsplit.train.train import collate_fn from wtpsplit.train.utils import Model from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict @@ -118,9 +118,11 @@ class Args: do_process: bool = False n_train_steps: List[int] = field(default_factory=lambda: [1000, 10000, 100000]) meta_clf: bool = False + wandb_project = "sentence" # corruption do_lowercase: bool = False do_remove_punct: bool = False + eval_pairwise: bool = False def main( @@ -151,7 +153,7 @@ def main( # 1 wandb run for all language-dataset combinations if "wandb" in training_args.report_to: - wandb.init(name=f"{wandb_name}-{tpu_core_idx}", project="sentence-peft-v2", group=wandb_name) + wandb.init(name=f"{wandb_name}-{tpu_core_idx}", project=args.wandb_project, group=wandb_name) wandb.config.update(args) wandb.config.update(training_args) wandb.config.update(label_args) @@ -207,7 +209,7 @@ def main( index = random.choice(range(len(train_ds[(lang, dataset_name)]))) sample = train_ds[(lang, dataset_name)][index] - logger.warning(f"TPU {tpu_core_idx}: Sample {index} of the training set: {sample}.") + logger.warning(f"{tpu_core_idx}: Sample {index} of the training set: {sample}.") if tokenizer: logger.warning(tokenizer.decode(sample["input_ids"])) count += 1 @@ -247,6 +249,18 @@ def compute_metrics(trainer): metrics[f"{dataset_name}/{lang}/corrupted/threshold_best"] = info_corrupted["threshold_best"] elif args.do_lowercase or args.do_remove_punct: raise NotImplementedError("Currently we only corrupt both ways!") + if args.eval_pairwise: + score_pairwise, avg_acc = evaluate_sentence_pairwise( + lang, + eval_data, + model, + stride=args.eval_stride, + block_size=args.block_size, + batch_size=training_args.per_device_eval_batch_size, + threshold=0.1, + ) + metrics[f"{dataset_name}/{lang}/pairwise/pr_auc"] = score_pairwise + metrics[f"{dataset_name}/{lang}/pairwise/acc"] = avg_acc xm.rendezvous("eval log done") return metrics @@ -312,7 +326,7 @@ def compute_metrics(trainer): save_directory=os.path.join(training_args.output_dir, dataset_name, lang), with_head=True, ) - # also save LNs + if args.unfreeze_ln: # no way within adapters to do this, need to do it manually ln_dict = {n: p for n, p in save_model.named_parameters() if "LayerNorm" in n} @@ -715,9 +729,26 @@ def merge_dict_into_sublists(d): xm.rendezvous("all training done") if index == 0: # eval here within 1 go - os.system( - f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.5" - ) + if args.do_lowercase and args.do_remove_punct: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --do_lowercase --do_remove_punct" + ) + elif args.eval_pairwise: + os.system( + f"python3 wtpsplit/evaluation/intrinsic_pairwise.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1" + ) + elif "lines" in args.text_path: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1--custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines" + ) + elif "verses" in args.text_path: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses" + ) + else: + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1" + ) if __name__ == "__main__": diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index 14d652fe..cf2af9e1 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -114,7 +114,7 @@ def get_subword_label_dict(label_args, tokenizer): for i, c in enumerate(Constants.PUNCTUATION_CHARS): token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.AUX_OFFSET + i - logger.info( + logger.warning( f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}" ) if token_id == tokenizer.unk_token_id: @@ -126,8 +126,8 @@ def get_subword_label_dict(label_args, tokenizer): for c in label_args.newline_chars: token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.NEWLINE_INDEX - logger.info(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") - logger.info(f"{tokenizer.decode([token_id])}") + logger.warning(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") + logger.warning(f"{tokenizer.decode([token_id])}") return label_dict