Skip to content

Commit

Permalink
Adds tasks_per_job to slurm executor (huggingface#153)
Browse files Browse the repository at this point in the history
* added tasks_per_job option

* Update src/datatrove/executor/slurm.py
  • Loading branch information
guipenedo authored May 7, 2024
1 parent 8805e13 commit ff50473
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import math
import os
import random
import signal
Expand Down Expand Up @@ -79,7 +80,7 @@ class SlurmPipelineExecutor(PipelineExecutor):
mail_type: see https://slurm.schedmd.com/sbatch.html. Common values are (NONE, BEGIN, END, FAIL, REQUEUE, ALL)
mail_user: email address to send notifications to
requeue: requeue the job if it fails
tasks_per_job: each slurm job in the job array will run these many datatrove tasks. This reduces the total nb of slurm jobs launched.
"""

def __init__(
Expand Down Expand Up @@ -111,13 +112,15 @@ def __init__(
mail_type: str = "ALL",
mail_user: str = None,
requeue: bool = True,
tasks_per_job: int = 1,
):
super().__init__(pipeline, logging_dir, skip_completed)
self.tasks = tasks
self.workers = workers
self.partition = partition
self.cpus_per_task = cpus_per_task
self.mem_per_cpu_gb = mem_per_cpu_gb
self.tasks_per_job = tasks_per_job
self.time = time
self.job_name = job_name
self.qos = qos
Expand Down Expand Up @@ -160,18 +163,23 @@ def run(self):
slurm_rank = int(os.environ["SLURM_ARRAY_TASK_ID"]) + self.max_array_size * int(
os.environ.get("RUN_OFFSET", 0)
)
ranks_to_run_range = (slurm_rank * self.tasks_per_job, (slurm_rank + 1) * self.tasks_per_job)
with self.logging_dir.open("ranks_to_run.json", "r") as ranks_to_run_file:
all_ranks = json.load(ranks_to_run_file)
if slurm_rank >= len(all_ranks):
if ranks_to_run_range[0] >= len(all_ranks):
return
rank = all_ranks[slurm_rank]

for ss in self.requeue_signals or []:
signal.signal(signal.Signals[ss], requeue_handler)

if self.randomize_start:
time.sleep(random.randint(0, 60 * 3))
self._run_for_rank(rank)
for rank_to_run in range(*ranks_to_run_range):
if rank_to_run >= len(all_ranks):
break
rank = all_ranks[rank_to_run]

if self.randomize_start:
time.sleep(random.randint(0, 60 * 3))
self._run_for_rank(rank)
else:
# we still have to launch the job
self.launch_job()
Expand Down Expand Up @@ -244,7 +252,8 @@ def launch_job(self):
# we actually save this (only once) to avoid race conditions
json.dump(ranks_to_run, ranks_to_run_file)

max_array = min(len(ranks_to_run), self.max_array_size) if self.max_array_size != -1 else len(ranks_to_run)
nb_jobs_to_launch = math.ceil(len(ranks_to_run) / self.tasks_per_job)
max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch

# create the actual sbatch script
launch_file_contents = self.get_launch_file_contents(
Expand All @@ -261,7 +270,7 @@ def launch_job(self):

# launch (possibly multiple) jobs
launched_jobs = 0
while launched_jobs * max_array < len(ranks_to_run):
while launched_jobs * max_array < nb_jobs_to_launch:
if launched_jobs and self.max_array_launch_parallel and self.stagger_max_array_jobs > 0:
time.sleep(self.stagger_max_array_jobs)
args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"]
Expand Down

0 comments on commit ff50473

Please sign in to comment.