diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 7abdb13f..15b6aa6e 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -4,6 +4,7 @@ from typing import List import os import time +import logging import h5py import skops.io as sio @@ -19,6 +20,8 @@ from wtpsplit.extract import PyTorchWrapper, extract from wtpsplit.utils import Constants +logger = logging.getLogger() +logger.setLevel(logging.INFO) @dataclass class Args: @@ -243,6 +246,12 @@ def main(args): adapters.init(model.model) # reset model type (used later) model.model.config.model_type = model_type + if "meta-clf" in args.adapter_path: + clf = model.model.classifier + model.model.classifier = torch.nn.Sequential( + clf, + torch.nn.Linear(clf.out_features, 1) + ) # first, logits for everything. f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str) diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 3a6fc688..9087ac74 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -174,7 +174,7 @@ def extract( np.zeros( ( length, - model.classifier.out_features + model.config.num_labels ), dtype=np.float16, )