Skip to content

Commit

Permalink
Merge branch 'efficient-transfer' of https://github.com/bminixhofer/w…
Browse files Browse the repository at this point in the history
…tpsplit into efficient-transfer
  • Loading branch information
markus583 committed Apr 4, 2024
2 parents 284e925 + 31254c3 commit 888c1c9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 28 deletions.
16 changes: 9 additions & 7 deletions configs/peft/adapter_lyrics.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlmr-normal-p-v3",
"output_dir": "xlmr-3l-v3_adapter_rf32_ep30_v2_corrupted_np_no-sl",
"model_name_or_path": "xlmr-12l-v3",
"output_dir": "xlmr-12l-v3_adapter_rf4_ep30_v2_mldbW-verses-S_fs_ss0.1",
"block_size": 256,
"eval_stride": 128,
"do_train": true,
Expand All @@ -20,7 +20,7 @@
"wandb_project": "lyrics-peft",
"save_steps": 100000000,
"remove_unused_columns": false,
"one_sample_per_line": false,
"one_sample_per_line": true,
"do_sentence_training": true,
"do_auxiliary_training": false,
"warmup_ratio": 0.1,
Expand All @@ -31,13 +31,15 @@
"use_subwords": true,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning",
"adapter_config": "seq_bn[reduction_factor=32]",
"adapter_config": "seq_bn[reduction_factor=4]",
"weight_decay": 0.0,
"auxiliary_remove_prob": 0.0,
"do_process": true,
"text_path": "data/lyrics_lines.pt",
"text_path": "data/mldbW_verses_strip_n_strip_single_f.pt",
"skip_eval_loss": false,
"shuffle": false,
"do_remove_punct": true,
"do_lowercase": true
"do_remove_punct": false,
"do_lowercase": false,
"train_adapter": true,
"subsample": 0.1
}
21 changes: 15 additions & 6 deletions wtpsplit/evaluation/intrinsic_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,30 +127,39 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
# train on all mldb, eval on mldbW
if "mldbW" in args.eval_data_path and (
"mldbW" not in args.model_path and "mldbW" not in args.adapter_path
):
dataset_load_name = "unk"
else:
dataset_load_name = dataset_name
try:
if args.adapter_path:
model.model.load_adapter(
args.adapter_path + "/" + dataset_name + "/" + lang_code,
args.adapter_path + "/" + dataset_load_name + "/" + lang_code,
set_active=True,
with_head=True,
load_as="text",
)
if hasattr(model.model.config, "unfreeze_ln"):
if model.model.config.unfreeze_ln:
ln_dict = torch.load(
args.adapter_path + "/" + dataset_name + "/" + lang_code + "/ln_dict.pth"
args.adapter_path + "/" + dataset_load_name + "/" + lang_code + "/ln_dict.pth"
)
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.data = ln_dict[n].data
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")):
model_path = os.path.join(args.model_path, dataset_name, "en")
model_path = os.path.join(args.model_path, dataset_load_name, "en")
print(model_path)
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
model = PyTorchWrapper(
AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
print(f"Error loading adapter for {dataset_load_name} in {lang_code}: {e}")
continue
print(dataset_name)
print(dataset_name, dataset_load_name)
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
else:
Expand Down
41 changes: 26 additions & 15 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from wtpsplit.train.utils import Model
from wtpsplit.utils import Constants, LabelArgs, get_label_dict, get_subword_label_dict
from tqdm import tqdm
from typing import Union, Optional

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,6 +72,7 @@ class Args:
do_remove_punct: bool = False
eval_pairwise: bool = False
skip_eval_loss: bool = False
subsample: Optional[float] = None


def main():
Expand Down Expand Up @@ -104,6 +106,7 @@ def prepare_dataset(
split="train",
do_lowercase=False,
do_remove_punct=False,
subsample: Union[None, int, float] = None
):
# maybe we use more than 1 lang later at once.
with training_args.main_process_first():
Expand Down Expand Up @@ -154,6 +157,15 @@ def prepare_dataset(

if shuffle:
dataset = dataset.shuffle(seed=42)
if subsample:
old_length = len(dataset)
if isinstance(subsample, int):
# ensure that we don't try to select more than the dataset length
subsample = min(subsample, len(dataset))
dataset = dataset.select(range(subsample))
elif isinstance(subsample, float):
dataset = dataset.select(range(int(subsample * len(dataset))))
logger.warning(f"Subsampled {len(dataset)} examples from {old_length}.")

# very likely not relevant / used only for the compound part
if args.ignore_non_hyphen:
Expand Down Expand Up @@ -443,6 +455,7 @@ def maybe_pad(text):
split="train",
do_lowercase=args.do_lowercase,
do_remove_punct=args.do_remove_punct,
subsample=args.subsample,
)
if train_dataset is None or valid_dataset is None:
logger.warning(f"Skipping {lang} {dataset_name} due to missing data.")
Expand Down Expand Up @@ -580,7 +593,7 @@ def compute_metrics(trainer):
with_head=True,
)
else:
save_model.save_pretrained(os.path.join(training_args.output_dir, dataset_name, lang))
save_model.to("cpu").save_pretrained(os.path.join(training_args.output_dir, dataset_name, lang))
if training_args.local_rank == 0:
# eval here within 1 go
cmd = ""
Expand All @@ -591,21 +604,19 @@ def compute_metrics(trainer):
eval_function = "intrinsic_list"
else:
eval_function = "intrinsic"
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1"
if "lines" in args.text_path:
if args.do_lowercase and args.do_remove_punct:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines --do_lowercase --do_remove_punct"
else:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines"
elif "verses" in args.text_path:
if args.do_lowercase and args.do_remove_punct:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n_single.pt --save_suffix verses --do_lowercase --do_remove_punct"
else:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses"
elif args.do_lowercase and args.do_remove_punct:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --do_lowercase --do_remove_punct"
if args.do_lowercase and args.do_remove_punct:
suffix = "--do_lowercase --do_remove_punct"
else:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1"
suffix = ""
if "adapter" in training_args.output_dir:
model_info = f"--model_path {args.model_name_or_path} --adapter_path {training_args.output_dir}"
else:
model_info = f"--model_path {training_args.output_dir}"

if "verses" in args.text_path or "lines" in args.text_path:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py {model_info} --threshold 0.1 --custom_language_list data/mldb_langs.csv --eval_data_path {args.text_path} {suffix}"
else:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py {model_info} --threshold 0.1 {suffix}"
print(cmd)
os.system(cmd)

Expand Down

0 comments on commit 888c1c9

Please sign in to comment.