Skip to content

Commit

Permalink
add xlm-v
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 31, 2023
1 parent 0d8f765 commit f9dbd9f
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ notebooks
final_outputs
.cache*
data_subset/**
*.pth
*.pth
**/checkpoint-*/**
43 changes: 43 additions & 0 deletions configs/xlmv_stratify_0.1_3layers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"model_name_or_path": "facebook/xlm-v-base",
"output_dir": "xlmv-normal",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 1,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 400000,
"save_steps": 20000,
"eval_steps": 5000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3,
"custom_punctuation_file": "punctuation_xlmv_unk.txt",
"log_level": "info"
}
25 changes: 25 additions & 0 deletions get_mem.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# This script writes the "Mem" line from the free -h command and the current time in one line every second to a file, and also prints it to the console

output_file="memory_usage_log.txt"

# Header line
header=" total used free shared buff/cache available"

# Write header to both console and file
echo "$header" | tee $output_file

while true; do
# Get the "Mem" line from free -h and store it
mem_usage=$(free -h | grep "Mem:")

# Format the output string
output="$(date +"%a %b %d %H:%M:%S") | $mem_usage"

# Write to both console and file
echo -e "$output" | tee -a $output_file

# Wait for one second
sleep 1
done
98 changes: 98 additions & 0 deletions wtpsplit/data/punctuation_xlmv.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
¡
£
§
©
«
¬
®
°
±
·
»
¿
÷
՝
՞
։
־
׳
،
؛
؟
۔
99 changes: 99 additions & 0 deletions wtpsplit/data/punctuation_xlmv_unk.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
¡
£
<unk>
§
©
«
¬
®
°
±
·
»
¿
÷
՝
՞
։
־
׳
،
؛
؟
۔
1 change: 1 addition & 0 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
else:
valid_data = None

print("Loading model...")
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))

# first, logits for everything.
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from wtpsplit.utils import Constants, hash_encode

logger = logging.getLogger(__name__)
# logger = logging.getLogger(__name__)


class ORTWrapper:
Expand Down Expand Up @@ -224,7 +224,7 @@ def extract(
)["logits"]
if use_subwords:
logits = logits[:, 1:-1, :] # remove CLS and SEP tokens
logger.debug(np.max(logits[0, :, 0]))
# logger.debug(np.max(logits[0, :, 0]))

for i in range(start, end):
original_idx, start_char_idx, end_char_idx = locs[i]
Expand Down
1 change: 1 addition & 0 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def compute_metrics(trainer):
# because that would remove the cache files of the other dataset!
cleanup_cache_files([train_dataset, valid_dataset])
logger.warning("Cleaned up cache files.")
time.sleep(20)

trainer = Trainer(
model,
Expand Down
7 changes: 6 additions & 1 deletion wtpsplit/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,6 +41,10 @@ def cleanup_cache_files(datasets) -> int:

for file_path in files_to_remove:
logger.warning(f"Removing {file_path}")
os.remove(file_path)
try:
os.remove(file_path)
except Exception as e:
logger.warning(f"Error while trying to remove {file_path}: {e}")
time.sleep(0.5)

return len(files_to_remove)

0 comments on commit f9dbd9f

Please sign in to comment.