Skip to content

Commit

Permalink
clean up cache right before training
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 28, 2023
1 parent 55bc87f commit 89c8a0c
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 10 deletions.
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers_bs128.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal",
"output_dir": "xlmr-normal-bs-128",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 128,
Expand Down
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers_bs256.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal",
"output_dir": "xlmr-normal-bs256",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 256,
Expand Down
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers_bs64.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal",
"output_dir": "xlmr-normal-bs64",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 64,
Expand Down
43 changes: 43 additions & 0 deletions configs/xlmr_stratify_0.1_3layers_highlr.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-highlr",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 3e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 2000000,
"save_steps": 100000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "info"
}
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers_no_aux.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal",
"output_dir": "xlmr-normal-noaux",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
Expand Down
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers_nounks.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-no_unks",
"output_dir": "xlmr-normal-nounks",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
Expand Down
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers_shorter.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-shorter",
"output_dir": "xlmr-normal-shorter",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
Expand Down
43 changes: 43 additions & 0 deletions configs/xlmr_stratify_0.1_6layers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-normal-6",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 2000000,
"save_steps": 100000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 6,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "info"
}
5 changes: 5 additions & 0 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ def maybe_pad(text):
split="train",
)
logger.info(f"Train dataset has {len(train_dataset)} examples.")

with training_args.main_process_first():
train_dataset.cleanup_cache_files()
valid_dataset.cleanup_cache_files()
logger.warning("Cleaned up cache files.")

# print some samples from the dataset
count = 0
Expand Down
8 changes: 4 additions & 4 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def get_subword_label_dict(label_args, tokenizer):

n_unks = 0
# Map auxiliary characters to token IDs with labels
logger.warning(f"Using {Constants.PUNCTUATION_CHARS} auxiliary characters.")
logger.info(f"Using {Constants.PUNCTUATION_CHARS} auxiliary characters.")
for i, c in enumerate(Constants.PUNCTUATION_CHARS):
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.AUX_OFFSET + i
logger.warning(
logger.info(
f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}"
)
if token_id == tokenizer.unk_token_id:
Expand All @@ -118,8 +118,8 @@ def get_subword_label_dict(label_args, tokenizer):
for c in label_args.newline_chars:
token_id = tokenizer.convert_tokens_to_ids(c)
label_dict[token_id] = 1 + Constants.NEWLINE_INDEX
logger.warning(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
logger.warning(f"{tokenizer.decode([token_id])}")
logger.info(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:")
logger.info(f"{tokenizer.decode([token_id])}")

return label_dict

Expand Down

0 comments on commit 89c8a0c

Please sign in to comment.