Skip to content

Commit

Permalink
full FT compat
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 29, 2024
1 parent 19aa4b2 commit ab776f8
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion wtpsplit/evaluation/intrinsic_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.data = ln_dict[n].data
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")):
model_path = os.path.join(args.model_path, dataset_name, "en")
print(model_path)
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
Expand Down Expand Up @@ -276,7 +280,14 @@ def main(args):
valid_data = None

print("Loading model...")
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))
# if model_path does not contain a model, take first subfolder
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")):
model_path = os.path.join(args.model_path, os.listdir(args.model_path)[0], "en")
print("joined")
print(model_path)
else:
model_path = args.model_path
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
if args.adapter_path:
model_type = model.model.config.model_type
# adapters need xlm-roberta as model type.
Expand Down

0 comments on commit ab776f8

Please sign in to comment.