From 5a990386a5f24a5dec6dc9a5c62922e759e3853e Mon Sep 17 00:00:00 2001 From: markus583 Date: Tue, 21 May 2024 11:02:57 +0000 Subject: [PATCH] quick xlm-r small model fix --- wtpsplit/train/train_adapter.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/wtpsplit/train/train_adapter.py b/wtpsplit/train/train_adapter.py index 637599de..423e1ac6 100644 --- a/wtpsplit/train/train_adapter.py +++ b/wtpsplit/train/train_adapter.py @@ -57,6 +57,7 @@ class Args: adapter_warmup_steps: int = 0 adapter_lr_multiplier: float = 1.0 text_column: str = "text" + num_hidden_layers = None # NEW PARAMS use_subwords: bool = False @@ -90,10 +91,15 @@ def main(): if (label_args.use_auxiliary or args.do_auxiliary_training or args.meta_clf) else 0 ) - config = SubwordXLMConfig.from_pretrained( - args.model_name_or_path, - num_labels=num_labels, - ) + + if args.num_hidden_layers: + config = SubwordXLMConfig.from_pretrained( + args.model_name_or_path, + num_labels=num_labels, + num_hidden_layers=args.num_hidden_layers, + ) + else: + config = SubwordXLMConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels) def prepare_dataset( data,