Skip to content

Commit

Permalink
add aux_training_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 10, 2024
1 parent 9e0b3e2 commit c931647
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
44 changes: 44 additions & 0 deletions configs/xlmr_stratify_0.1_3layers_p_v2_0.5aux.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-p-v2-aux0.5",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 200000,
"save_steps": 100000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"aux_training_weight": 0.5,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning"
}
18 changes: 10 additions & 8 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Args:
use_loss_weights: bool = False
do_sentence_training: bool = True
do_auxiliary_training: bool = True
aux_training_weight: float = 1.0
ignore_non_hyphen: bool = False
non_punctuation_sample_ratio: float = None
adapter_warmup_steps: int = 0
Expand Down Expand Up @@ -267,6 +268,7 @@ def main():
use_loss_weights=args.use_loss_weights,
do_sentence_training=args.do_sentence_training,
do_auxiliary_training=args.do_auxiliary_training,
aux_training_weight=args.aux_training_weight,
)

if training_args.local_rank == 0:
Expand Down Expand Up @@ -525,7 +527,7 @@ def maybe_pad(text):
num_workers=args.preprocessing_num_workers,
include_languages=args.include_languages,
shuffle=args.shuffle,
split="valid",
split="train",
)
logger.warning(f"Train dataset has {len(train_dataset)} examples.")

Expand Down Expand Up @@ -597,13 +599,13 @@ def compute_metrics(trainer):
training_args.adapter_lr_multiplier = args.adapter_lr_multiplier

# give .map in multiprocessing enough of time to finish, to be safe
# time.sleep(10)
# if training_args.local_rank == 0:
# # since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files()
# # because that would remove the cache files of the other dataset!
# cleanup_cache_files([train_dataset, valid_dataset])
# logger.warning("Cleaned up cache files.")
# time.sleep(10)
time.sleep(10)
if training_args.local_rank == 0:
# since both share the *same* cache_dir, we cannot simply call dataset.cleanup_cache_files()
# because that would remove the cache files of the other dataset!
cleanup_cache_files([train_dataset, valid_dataset])
logger.warning("Cleaned up cache files.")
time.sleep(10)

trainer = Trainer(
model,
Expand Down
6 changes: 5 additions & 1 deletion wtpsplit/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
use_loss_weights=False,
do_sentence_training=True,
do_auxiliary_training=False,
aux_training_weight=1.0,
):
super().__init__()
self.backbone = backbone
Expand All @@ -26,6 +27,7 @@ def __init__(
self.use_loss_weights = use_loss_weights
self.do_sentence_training = do_sentence_training
self.do_auxiliary_training = do_auxiliary_training
self.aux_training_weight = aux_training_weight

@property
def device(self):
Expand Down Expand Up @@ -107,7 +109,9 @@ def forward(
)
)

loss = torch.stack(losses).sum()
loss = losses[0]
if len(losses) > 1:
loss += self.aux_training_weight * losses[1]

output["loss"] = loss

Expand Down

0 comments on commit c931647

Please sign in to comment.