Skip to content

Commit

Permalink
fix: fix parallelism in case of group jobs (#12)
Browse files Browse the repository at this point in the history
Before, jobs were executed sequentially in case of group jobs
  • Loading branch information
johanneskoester authored Jan 13, 2024
1 parent bcb2ee0 commit 6c06d14
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions snakemake_executor_plugin_slurm_jobstep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import os
import subprocess
import sys
from typing import Generator, List
from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo
from snakemake_interface_executor_plugins.executors.real import RealExecutor
from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor
from snakemake_interface_executor_plugins.jobs import (
JobExecutorInterface,
)
Expand Down Expand Up @@ -37,22 +38,14 @@

# Required:
# Implementation of your executor
class Executor(RealExecutor):
class Executor(RemoteExecutor):
def __post_init__(self):
# These environment variables are set by SLURM.
# only needed for commented out jobstep handling below
self.jobid = os.getenv("SLURM_JOB_ID")
self.jobs = dict()

def run_job(self, job: JobExecutorInterface):
# Implement here how to run a job.
# You can access the job's resources, etc.
# via the job object.
# After submitting the job, you have to call
# self.report_job_submission(job_info).
# with job_info being of type
# snakemake_interface_executor_plugins.executors.base.SubmittedJobInfo.

jobsteps = dict()
# TODO revisit special handling for group job levels via srun at a later stage
# if job.is_group():

Expand Down Expand Up @@ -113,16 +106,27 @@ def run_job(self, job: JobExecutorInterface):

# this dict is to support the to be implemented feature of oversubscription in
# "ordinary" group jobs.
jobsteps[job] = subprocess.Popen(call, shell=True)
proc = subprocess.Popen(call, shell=True)
job_info = SubmittedJobInfo(job, aux={"proc": proc})

job_info = SubmittedJobInfo(job)
self.report_job_submission(job_info)

# wait until all steps are finished
if any(proc.wait() != 0 for proc in jobsteps.values()):
self.report_job_error(job_info)
else:
self.report_job_success(job_info)
async def check_active_jobs(
self, active_jobs: List[SubmittedJobInfo]
) -> Generator[SubmittedJobInfo, None, None]:
for active_job in active_jobs:
ret = active_job.aux["proc"].poll()
if ret is not None:
if ret != 0:
self.report_job_error(active_job)
else:
self.report_job_success(active_job)
else:
yield active_job

def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]):
for active_job in active_jobs:
active_job.aux["proc"].terminate()

def cancel(self):
pass
Expand Down

0 comments on commit 6c06d14

Please sign in to comment.