From e2bff46114b91246f1c48142bf139cefa11f501e Mon Sep 17 00:00:00 2001 From: markus583 Date: Wed, 27 Dec 2023 07:30:37 +0000 Subject: [PATCH] fix collate length --- configs/xlmr_stratify_0.1_3layers.json | 6 +-- .../xlmr_stratify_0.1_3layers_shorter.json | 41 +++++++++++++++++++ wtpsplit/train/train.py | 4 ++ 3 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 configs/xlmr_stratify_0.1_3layers_shorter.json diff --git a/configs/xlmr_stratify_0.1_3layers.json b/configs/xlmr_stratify_0.1_3layers.json index fc40f4c0..d8425261 100644 --- a/configs/xlmr_stratify_0.1_3layers.json +++ b/configs/xlmr_stratify_0.1_3layers.json @@ -1,6 +1,6 @@ { "model_name_or_path": "xlm-roberta-base", - "output_dir": "xlmr-TEST", + "output_dir": "xlmr-normal", "train_text_path": "data/sentence/train.parquet", "valid_text_path": "data/sentence/valid.parquet", "block_size": 512, @@ -8,9 +8,9 @@ "do_train": true, "do_eval": true, "evaluation_strategy": "steps", - "per_device_train_batch_size": 64, + "per_device_train_batch_size": 32, "per_device_eval_batch_size": 32, - "gradient_accumulation_steps": 1, + "gradient_accumulation_steps": 2, "eval_accumulation_steps": 8, "dataloader_num_workers": 4, "preprocessing_num_workers": 32, diff --git a/configs/xlmr_stratify_0.1_3layers_shorter.json b/configs/xlmr_stratify_0.1_3layers_shorter.json new file mode 100644 index 00000000..056dee98 --- /dev/null +++ b/configs/xlmr_stratify_0.1_3layers_shorter.json @@ -0,0 +1,41 @@ +{ + "model_name_or_path": "xlm-roberta-base", + "output_dir": "xlmr-shorter", + "train_text_path": "data/sentence/train.parquet", + "valid_text_path": "data/sentence/valid.parquet", + "block_size": 512, + "use_bert": true, + "do_train": true, + "do_eval": true, + "evaluation_strategy": "steps", + "per_device_train_batch_size": 32, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 2, + "eval_accumulation_steps": 8, + "dataloader_num_workers": 4, + "preprocessing_num_workers": 32, + "learning_rate": 1e-4, + "save_strategy": "steps", + "fp16": false, + "max_steps": 400000, + "save_steps": 100000, + "eval_steps": 5000, + "logging_steps": 50, + "report_to": "wandb", + "is_decoder": false, + "remove_unused_columns": false, + "lookahead": null, + "one_sample_per_line": false, + "do_sentence_training": true, + "do_auxiliary_training": true, + "warmup_steps": 5000, + "adapter_warmup_steps": 0, + "adapter_lr_multiplier": 1, + "ngram_order": 1, + "non_punctuation_sample_ratio": 0.1, + "prediction_loss_only": true, + "use_auxiliary": true, + "ddp_timeout": 3600, + "use_subwords": true, + "num_hidden_layers": 3 +} \ No newline at end of file diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 8de8f6dd..c4177774 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -240,6 +240,10 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): labels = labels[start : start + args.block_size - 1] # always include SEP if input_ids[-1] != tokenizer.sep_token_id: + # also insert PAD token as long as len < block_size + while len(input_ids) < args.block_size - 1: + input_ids = input_ids + [tokenizer.pad_token_id] + labels = labels + [0] input_ids = input_ids + [tokenizer.sep_token_id] labels = labels + [0] else: