Skip to content

Commit

Permalink
add corruptions & domain setup
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 6, 2024
1 parent 1773c47 commit 45b892b
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 15 deletions.
40 changes: 40 additions & 0 deletions configs/peft/adapter_lyrics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"model_name_or_path": "xlmr-normal-p-v3",
"output_dir": "xlmr-3l-v3_adapter_rf32_ep20_v2_100-1k-10k_lines",
"block_size": 256,
"eval_stride": 128,
"do_train": true,
"do_eval": true,
"per_device_train_batch_size": 64,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 1,
"preprocessing_num_workers": 1,
"learning_rate": 3e-4,
"fp16": false,
"num_train_epochs": 20,
"logging_steps": 50,
"report_to": "wandb",
"wandb_project": "lyrics-peft",
"save_steps": 100000000,
"remove_unused_columns": false,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": false,
"warmup_ratio": 0.1,
"non_punctuation_sample_ratio": null,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning",
"adapter_config": "seq_bn[reduction_factor=32]",
"weight_decay": 0.0,
"auxiliary_remove_prob": 0.0,
"do_process": true,
"n_train_steps": [100, 1000, 10000],
"do_lowercase": true,
"do_remove_punct": true
}
41 changes: 41 additions & 0 deletions configs/peft/adapter_pairwise.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "xlmr-normal-p-v3",
"output_dir": "xlmr-3l-v3_adapter_rf32_ep20_v2_100-1k-10k_pairwise_bs32",
"block_size": 32,
"eval_stride": 16,
"do_train": true,
"do_eval": true,
"per_device_train_batch_size": 64,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 1,
"preprocessing_num_workers": 1,
"learning_rate": 3e-4,
"fp16": false,
"num_train_epochs": 20,
"logging_steps": 50,
"report_to": "wandb",
"wandb_project": "pairwise-peft",
"save_steps": 100000000,
"remove_unused_columns": false,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": false,
"warmup_ratio": 0.1,
"non_punctuation_sample_ratio": null,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning",
"adapter_config": "seq_bn[reduction_factor=32]",
"weight_decay": 0.0,
"auxiliary_remove_prob": 0.0,
"do_process": true,
"n_train_steps": [100, 1000, 10000],
"do_lowercase": true,
"do_remove_punct": true,
"eval_pairwise": true
}
46 changes: 44 additions & 2 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import random
import sys
import logging

import h5py
import skops.io as sio

Check failure on line 12 in wtpsplit/evaluation/intrinsic_pairwise.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

wtpsplit/evaluation/intrinsic_pairwise.py:12:20: F401 `skops.io` imported but unused

Check failure on line 12 in wtpsplit/evaluation/intrinsic_pairwise.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

wtpsplit/evaluation/intrinsic_pairwise.py:12:20: F401 `skops.io` imported but unused
Expand All @@ -14,6 +15,7 @@
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification, HfArgumentParser
import numpy as np
import adapters

import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
Expand All @@ -22,10 +24,13 @@
from wtpsplit.utils import Constants
from wtpsplit.evaluation.intrinsic import compute_statistics, corrupt

logger = logging.getLogger()
logger.setLevel(logging.INFO)

@dataclass
class Args:
model_path: str
adapter_path: str = None
# eval data in the format:
# {
# "<lang_code>": {
Expand Down Expand Up @@ -162,6 +167,25 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
try:
if args.adapter_path:
model.model.load_adapter(
args.adapter_path + "/" + dataset_name + "/" + lang_code,
set_active=True,
with_head=True,
load_as="text",
)
if hasattr(model.model.config, "unfreeze_ln"):
if model.model.config.unfreeze_ln:
ln_dict = torch.load(
args.adapter_path + "/" + dataset_name + "/" + lang_code + "/ln_dict.pth"
)
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.data = ln_dict[n].data
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
print(dataset_name)
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
Expand Down Expand Up @@ -246,7 +270,12 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st


def main(args):
save_str = f"{args.model_path.replace('/','_')}_b{args.block_size}_u{args.threshold}{args.save_suffix}"
save_model_path = args.model_path
if args.adapter_path:
save_model_path = args.adapter_path
save_str = (
f"{save_model_path.replace('/','_')}_b{args.block_size}_u{args.threshold}{args.save_suffix}"
)
if args.do_lowercase:
save_str += "_lc"
if args.do_remove_punct:
Expand All @@ -260,7 +289,20 @@ def main(args):

print("Loading model...")
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))

if args.adapter_path:
model_type = model.model.config.model_type
# adapters need xlm-roberta as model type.
model.model.config.model_type = "xlm-roberta"
adapters.init(model.model)
# reset model type (used later)
model.model.config.model_type = model_type
if "meta-clf" in args.adapter_path:
clf = model.model.classifier
model.model.classifier = torch.nn.Sequential(
clf,
torch.nn.Linear(clf.out_features, 1)
)

# first, logits for everything.
f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str)

Expand Down
6 changes: 3 additions & 3 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def compute_f1(pred, true):
)


def get_metrics(labels, preds, threshold: float = 0.5):
def get_metrics(labels, preds, threshold: float = 0.1):
# Compute precision-recall curve and AUC
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, preds)
pr_auc = sklearn.metrics.auc(recall, precision)
Expand Down Expand Up @@ -141,6 +141,7 @@ def evaluate_sentence_pairwise(
positive_index=None,
do_lowercase=False,
do_remove_punct=False,
threshold: float = 0.1
):
if positive_index is None:
positive_index = Constants.NEWLINE_INDEX
Expand Down Expand Up @@ -173,7 +174,6 @@ def evaluate_sentence_pairwise(
)

# simulate performance for WtP-U
DEFAULT_THRESHOLD = 0.01

for i, (sentence1, sentence2) in enumerate(sampled_pairs):
newline_probs = logits[i][:, positive_index]
Expand All @@ -188,7 +188,7 @@ def evaluate_sentence_pairwise(
# Get metrics for the pair
pair_metrics, _ = get_metrics(newline_labels, newline_probs)
metrics_list.append(pair_metrics["pr_auc"])
predicted_labels = newline_probs > np.log(DEFAULT_THRESHOLD / (1 - DEFAULT_THRESHOLD)) # inverse sigmoid
predicted_labels = newline_probs > np.log(threshold / (1 - threshold)) # inverse sigmoid
# for accuracy, check if the single label in between is correctly predicted (ignore the one at the end)
if sum(predicted_labels[:-1]) > 0:
correct = (np.where(newline_labels[:-1])[0] == np.where(predicted_labels[:-1])[0]).all()
Expand Down
45 changes: 38 additions & 7 deletions wtpsplit/train/train_adapter_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ParallelTPUWandbCallback as WandbCallback,
)
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.train.evaluate import evaluate_sentence
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise
from wtpsplit.train.train import collate_fn
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
Expand Down Expand Up @@ -118,9 +118,11 @@ class Args:
do_process: bool = False
n_train_steps: List[int] = field(default_factory=lambda: [1000, 10000, 100000])
meta_clf: bool = False
wandb_project = "sentence"
# corruption
do_lowercase: bool = False
do_remove_punct: bool = False
eval_pairwise: bool = False


def main(
Expand Down Expand Up @@ -151,7 +153,7 @@ def main(

# 1 wandb run for all language-dataset combinations
if "wandb" in training_args.report_to:
wandb.init(name=f"{wandb_name}-{tpu_core_idx}", project="sentence-peft-v2", group=wandb_name)
wandb.init(name=f"{wandb_name}-{tpu_core_idx}", project=args.wandb_project, group=wandb_name)
wandb.config.update(args)
wandb.config.update(training_args)
wandb.config.update(label_args)
Expand Down Expand Up @@ -207,7 +209,7 @@ def main(
index = random.choice(range(len(train_ds[(lang, dataset_name)])))
sample = train_ds[(lang, dataset_name)][index]

logger.warning(f"TPU {tpu_core_idx}: Sample {index} of the training set: {sample}.")
logger.warning(f"{tpu_core_idx}: Sample {index} of the training set: {sample}.")
if tokenizer:
logger.warning(tokenizer.decode(sample["input_ids"]))
count += 1
Expand Down Expand Up @@ -247,6 +249,18 @@ def compute_metrics(trainer):
metrics[f"{dataset_name}/{lang}/corrupted/threshold_best"] = info_corrupted["threshold_best"]
elif args.do_lowercase or args.do_remove_punct:
raise NotImplementedError("Currently we only corrupt both ways!")
if args.eval_pairwise:
score_pairwise, avg_acc = evaluate_sentence_pairwise(
lang,
eval_data,
model,
stride=args.eval_stride,
block_size=args.block_size,
batch_size=training_args.per_device_eval_batch_size,
threshold=0.1,
)
metrics[f"{dataset_name}/{lang}/pairwise/pr_auc"] = score_pairwise
metrics[f"{dataset_name}/{lang}/pairwise/acc"] = avg_acc
xm.rendezvous("eval log done")

return metrics
Expand Down Expand Up @@ -312,7 +326,7 @@ def compute_metrics(trainer):
save_directory=os.path.join(training_args.output_dir, dataset_name, lang),
with_head=True,
)
# also save LNs

if args.unfreeze_ln:
# no way within adapters to do this, need to do it manually
ln_dict = {n: p for n, p in save_model.named_parameters() if "LayerNorm" in n}
Expand Down Expand Up @@ -715,9 +729,26 @@ def merge_dict_into_sublists(d):
xm.rendezvous("all training done")
if index == 0:
# eval here within 1 go
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.5"
)
if args.do_lowercase and args.do_remove_punct:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --do_lowercase --do_remove_punct"
)
elif args.eval_pairwise:
os.system(
f"python3 wtpsplit/evaluation/intrinsic_pairwise.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1"
)
elif "lines" in args.text_path:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1--custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines"
)
elif "verses" in args.text_path:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses"
)
else:
os.system(
f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1"
)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_subword_label_dict(label_args, tokenizer):
for i, c in enumerate(Constants.PUNCTUATION_CHARS):
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.AUX_OFFSET + i
logger.info(
logger.warning(
f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}"
)
if token_id == tokenizer.unk_token_id:
Expand All @@ -126,8 +126,8 @@ def get_subword_label_dict(label_args, tokenizer):
for c in label_args.newline_chars:
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.NEWLINE_INDEX
logger.info(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
logger.info(f"{tokenizer.decode([token_id])}")
logger.warning(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
logger.warning(f"{tokenizer.decode([token_id])}")

return label_dict

Expand Down

0 comments on commit 45b892b

Please sign in to comment.