Skip to content

Commit

Permalink
Unsigned int tokenizer and srun args (huggingface#154)
Browse files Browse the repository at this point in the history
* add srun_args dict

* add num_bytes to test_tokenization

* fix srun_args default value

* slurm_rank

* Revert "slurm_rank"

This reverts commit 9f0145d.

* revert tokenization changes

* linting

---------

Co-authored-by: marianna13 <[email protected]>
Co-authored-by: guipenedo <[email protected]>
  • Loading branch information
3 people authored May 7, 2024
1 parent ff50473 commit b2b96e4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
mail_type: str = "ALL",
mail_user: str = None,
requeue: bool = True,
srun_args: dict = None,
tasks_per_job: int = 1,
):
super().__init__(pipeline, logging_dir, skip_completed)
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
self.requeue_signals = requeue_signals
self.mail_type = mail_type
self.mail_user = mail_user
self.srun_args = srun_args
self.slurm_logs_folder = (
slurm_logs_folder
if slurm_logs_folder
Expand Down Expand Up @@ -256,9 +258,10 @@ def launch_job(self):
max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch

# create the actual sbatch script
srun_args_str = " ".join([f"--{k}={v}" for k, v in self.srun_args.items()]) if self.srun_args else ""
launch_file_contents = self.get_launch_file_contents(
self.get_sbatch_args(max_array),
f"srun -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}",
f"srun {srun_args_str} -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}",
)
# save it
with self.logging_dir.open("launch_script.slurm", "w") as launchscript_f:
Expand Down

0 comments on commit b2b96e4

Please sign in to comment.