Skip to content

Commit

Permalink
add meta clf for ft
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Feb 27, 2024
1 parent 0fee97d commit 0cd453a
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion wtpsplit/train/train_adapter_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Args:
clf_from_scratch: bool = False
do_process: bool = False
n_train_steps: List[int] = field(default_factory=lambda: [1000, 10000, 100000])
meta_clf: bool = False


def main(
Expand All @@ -140,7 +141,7 @@ def main(
)
config = SubwordXLMConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels,
num_labels=num_labels if not args.meta_clf else 1,
)

# 1 wandb run for all language-dataset combinations
Expand Down Expand Up @@ -242,6 +243,12 @@ def compute_metrics(trainer):
p.requires_grad = False
if args.clf_from_scratch:
model.backbone.classifier = torch.nn.Linear(model.backbone.config.hidden_size, num_labels)
if args.meta_clf:
clf = model.backbone.classifier
model.backbone.classifier = torch.nn.Sequential(
clf, # original classifier - if frozen above, also frozen here
torch.nn.Linear(clf.out_features, num_labels)
)

trainer = AdapterTrainer(
model,
Expand Down

0 comments on commit 0cd453a

Please sign in to comment.