Skip to content

Commit

Permalink
fix meta-clf head loading
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 2, 2024
1 parent 6459683 commit 153485a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List
import os
import time
import logging

import h5py
import skops.io as sio

Check failure on line 10 in wtpsplit/evaluation/intrinsic.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

wtpsplit/evaluation/intrinsic.py:10:20: F401 `skops.io` imported but unused

Check failure on line 10 in wtpsplit/evaluation/intrinsic.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

wtpsplit/evaluation/intrinsic.py:10:20: F401 `skops.io` imported but unused

Check failure on line 10 in wtpsplit/evaluation/intrinsic.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

wtpsplit/evaluation/intrinsic.py:10:20: F401 `skops.io` imported but unused
Expand All @@ -19,6 +20,8 @@
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants

logger = logging.getLogger()
logger.setLevel(logging.INFO)

@dataclass
class Args:
Expand Down Expand Up @@ -243,6 +246,12 @@ def main(args):
adapters.init(model.model)
# reset model type (used later)
model.model.config.model_type = model_type
if "meta-clf" in args.adapter_path:
clf = model.model.classifier
model.model.classifier = torch.nn.Sequential(
clf,
torch.nn.Linear(clf.out_features, 1)
)

# first, logits for everything.
f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str)
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def extract(
np.zeros(
(
length,
model.classifier.out_features
model.config.num_labels
),
dtype=np.float16,
)
Expand Down

0 comments on commit 153485a

Please sign in to comment.