From 6459683cbaf94f1cdb54375f3c786e9ffe5c3df8 Mon Sep 17 00:00:00 2001 From: markus583 Date: Sat, 2 Mar 2024 07:25:45 +0000 Subject: [PATCH] add some stuff for adp --- wtpsplit/evaluation/intrinsic.py | 8 +++++ wtpsplit/extract.py | 2 +- wtpsplit/train/train_adapter_parallel.py | 37 +++++++++++++++++------- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index f5ddeb75..7abdb13f 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -144,6 +144,14 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st with_head=True, load_as="text", ) + if hasattr(model.model.config, "unfreeze_ln"): + if model.model.config.unfreeze_ln: + ln_dict = torch.load( + args.adapter_path + "/" + dataset_name + "/" + lang_code + "/ln_dict.pth" + ) + for n, p in model.backbone.named_parameters(): + if "LayerNorm" in n: + p.data = ln_dict[n].data except Exception as e: print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}") continue diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 75d65560..3a6fc688 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -174,7 +174,7 @@ def extract( np.zeros( ( length, - model.config.num_labels, + model.classifier.out_features ), dtype=np.float16, ) diff --git a/wtpsplit/train/train_adapter_parallel.py b/wtpsplit/train/train_adapter_parallel.py index da616c7c..8cbd4c9b 100644 --- a/wtpsplit/train/train_adapter_parallel.py +++ b/wtpsplit/train/train_adapter_parallel.py @@ -113,6 +113,7 @@ class Args: use_subwords: bool = False freeze_classifier: bool = False clf_from_scratch: bool = False + unfreeze_ln: bool = False do_process: bool = False n_train_steps: List[int] = field(default_factory=lambda: [1000, 10000, 100000]) meta_clf: bool = False @@ -137,11 +138,11 @@ def main( logger.warning(f"{tpu_core_idx}: LANG GROUP {lang_groups}") num_labels = Constants.AUX_OFFSET + ( - (1 + len(Constants.PUNCTUATION_CHARS)) if label_args.use_auxiliary or args.do_auxiliary_training else 0 + (1 + len(Constants.PUNCTUATION_CHARS)) if (label_args.use_auxiliary or args.do_auxiliary_training or args.meta_clf) else 0 ) config = SubwordXLMConfig.from_pretrained( args.model_name_or_path, - num_labels=num_labels if not args.meta_clf else 1, + num_labels=num_labels, ) # 1 wandb run for all language-dataset combinations @@ -165,7 +166,7 @@ def main( for i, ((lang, dataset_name), train_step) in tqdm(enumerate(zip(lang_groups, train_steps)), total=len(lang_groups)): # do model stuff here; otherwise, head params would be overwritten every time backbone = SubwordXLMForTokenClassification.from_pretrained( - args.model_name_or_path, config=config, ignore_mismatched_sizes=True + args.model_name_or_path, config=copy.deepcopy(config), ignore_mismatched_sizes=True ) logger.warning(f"{tpu_core_idx}: Loaded backbone {args.model_name_or_path}.") backbone.config.base_model = args.base_model @@ -243,12 +244,19 @@ def compute_metrics(trainer): p.requires_grad = False if args.clf_from_scratch: model.backbone.classifier = torch.nn.Linear(model.backbone.config.hidden_size, num_labels) + + if args.unfreeze_ln: + for n, p in model.backbone.named_parameters(): + if "LayerNorm" in n: + p.requires_grad = True + if args.meta_clf: clf = model.backbone.classifier model.backbone.classifier = torch.nn.Sequential( clf, # original classifier - if frozen above, also frozen here - torch.nn.Linear(clf.out_features, num_labels) + torch.nn.Linear(clf.out_features, 1) ) + model.backbone.config.num_labels = 1 trainer = AdapterTrainer( model, @@ -277,19 +285,24 @@ def compute_metrics(trainer): if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)): os.makedirs(os.path.join(training_args.output_dir, dataset_name, lang)) save_model = copy.deepcopy(model.backbone) - # TODO: check if concurrent saving is fine (if ds duplicated for TPUs) save_model = save_model.to("cpu") save_model.save_adapter( adapter_name="text", save_directory=os.path.join(training_args.output_dir, dataset_name, lang), with_head=True, ) + # also save LNs + if args.unfreeze_ln: + # no way within adapters to do this, need to do it manually + ln_dict = {n: p for n, p in save_model.named_parameters() if "LayerNorm" in n} + torch.save(ln_dict, os.path.join(training_args.output_dir, dataset_name, lang, "ln_dict.pth")) logger.warning(f"{tpu_core_idx}: DONE TRAIN {lang} {dataset_name}.") if callbacks: wandb.log({"train/batch_progress": (i + 1) / len(lang_groups)}) xm.rendezvous("end_training") + xm.mark_step() xm.rendezvous("all_done") wandb.finish() @@ -671,6 +684,13 @@ def merge_dict_into_sublists(d): current_lang_groups, train_steps, ) + + xm.rendezvous("all training done") + if index == 0: + # eval here within 1 go + os.system( + f"python3 wtpsplit/evaluation/intrinsic.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.5" + ) if __name__ == "__main__": @@ -681,9 +701,4 @@ def merge_dict_into_sublists(d): args=(), nprocs=8, ) - -# TODO: check grouping for TPUs: 1k, 10k, ...; what is most sensible? - -# TODO: see if shuffle x1, shuffle x num_epochs, or no shuffle is best -# TODO: double-check effect of non_punctuation_sample_ratio -# TODO: try: freeze head, add clf on top (or do not freeze head, diff LR, etc.) + \ No newline at end of file