From b2b96e4858dd4cc022059ff336852e5293af98e4 Mon Sep 17 00:00:00 2001 From: Marianna <43296932+marianna13@users.noreply.github.com> Date: Tue, 7 May 2024 15:24:06 +0200 Subject: [PATCH] Unsigned int tokenizer and srun args (#154) * add srun_args dict * add num_bytes to test_tokenization * fix srun_args default value * slurm_rank * Revert "slurm_rank" This reverts commit 9f0145d64cc49a106aa0f954017d0ae7189a9b64. * revert tokenization changes * linting --------- Co-authored-by: marianna13 Co-authored-by: guipenedo --- src/datatrove/executor/slurm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index 3e74a662..a8133f68 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -112,6 +112,7 @@ def __init__( mail_type: str = "ALL", mail_user: str = None, requeue: bool = True, + srun_args: dict = None, tasks_per_job: int = 1, ): super().__init__(pipeline, logging_dir, skip_completed) @@ -139,6 +140,7 @@ def __init__( self.requeue_signals = requeue_signals self.mail_type = mail_type self.mail_user = mail_user + self.srun_args = srun_args self.slurm_logs_folder = ( slurm_logs_folder if slurm_logs_folder @@ -256,9 +258,10 @@ def launch_job(self): max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch # create the actual sbatch script + srun_args_str = " ".join([f"--{k}={v}" for k, v in self.srun_args.items()]) if self.srun_args else "" launch_file_contents = self.get_launch_file_contents( self.get_sbatch_args(max_array), - f"srun -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}", + f"srun {srun_args_str} -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}", ) # save it with self.logging_dir.open("launch_script.slurm", "w") as launchscript_f: