From fa5cb3aea5481858227ba1301136c06251a13687 Mon Sep 17 00:00:00 2001 From: markus583 Date: Thu, 28 Mar 2024 12:05:21 +0000 Subject: [PATCH] add full model ft support --- wtpsplit/evaluation/intrinsic.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 1c6e0b6e..580da0ef 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -156,6 +156,10 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st for n, p in model.backbone.named_parameters(): if "LayerNorm" in n: p.data = ln_dict[n].data + if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")): + model_path = os.path.join(args.model_path, dataset_name, "en") + print(model_path) + model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)) except Exception as e: print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}") continue @@ -248,7 +252,14 @@ def main(args): valid_data = None print("Loading model...") - model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device)) + # if model_path does not contain a model, take first subfolder + if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")): + model_path = os.path.join(args.model_path, os.listdir(args.model_path)[0], "en") + print("joined") + print(model_path) + else: + model_path = args.model_path + model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)) if args.adapter_path: model_type = model.model.config.model_type # adapters need xlm-roberta as model type.