Skip to content

Commit

Permalink
eval mldbW for mldb models
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 31, 2024
1 parent 8ded82c commit a37f2e2
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions wtpsplit/evaluation/intrinsic_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,30 +127,39 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
# train on all mldb, eval on mldbW
if "mldbW" in args.eval_data_path and (
"mldbW" not in args.model_path or "mldbW" not in args.adapter_path
):
dataset_load_name = "unk"
else:
dataset_load_name = dataset_name
try:
if args.adapter_path:
model.model.load_adapter(
args.adapter_path + "/" + dataset_name + "/" + lang_code,
args.adapter_path + "/" + dataset_load_name + "/" + lang_code,
set_active=True,
with_head=True,
load_as="text",
)
if hasattr(model.model.config, "unfreeze_ln"):
if model.model.config.unfreeze_ln:
ln_dict = torch.load(
args.adapter_path + "/" + dataset_name + "/" + lang_code + "/ln_dict.pth"
args.adapter_path + "/" + dataset_load_name + "/" + lang_code + "/ln_dict.pth"
)
for n, p in model.backbone.named_parameters():
if "LayerNorm" in n:
p.data = ln_dict[n].data
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")):
model_path = os.path.join(args.model_path, dataset_name, "en")
model_path = os.path.join(args.model_path, dataset_load_name, "en")
print(model_path)
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
model = PyTorchWrapper(
AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
print(f"Error loading adapter for {dataset_load_name} in {lang_code}: {e}")
continue
print(dataset_name)
print(dataset_name, dataset_load_name)
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
else:
Expand Down

0 comments on commit a37f2e2

Please sign in to comment.