diff --git a/snakemake_executor_plugin_slurm_jobstep/__init__.py b/snakemake_executor_plugin_slurm_jobstep/__init__.py index 66b08b6..cc84692 100644 --- a/snakemake_executor_plugin_slurm_jobstep/__init__.py +++ b/snakemake_executor_plugin_slurm_jobstep/__init__.py @@ -6,8 +6,9 @@ import os import subprocess import sys +from typing import Generator, List from snakemake_interface_executor_plugins.executors.base import SubmittedJobInfo -from snakemake_interface_executor_plugins.executors.real import RealExecutor +from snakemake_interface_executor_plugins.executors.remote import RemoteExecutor from snakemake_interface_executor_plugins.jobs import ( JobExecutorInterface, ) @@ -37,22 +38,14 @@ # Required: # Implementation of your executor -class Executor(RealExecutor): +class Executor(RemoteExecutor): def __post_init__(self): # These environment variables are set by SLURM. # only needed for commented out jobstep handling below self.jobid = os.getenv("SLURM_JOB_ID") + self.jobs = dict() def run_job(self, job: JobExecutorInterface): - # Implement here how to run a job. - # You can access the job's resources, etc. - # via the job object. - # After submitting the job, you have to call - # self.report_job_submission(job_info). - # with job_info being of type - # snakemake_interface_executor_plugins.executors.base.SubmittedJobInfo. - - jobsteps = dict() # TODO revisit special handling for group job levels via srun at a later stage # if job.is_group(): @@ -113,16 +106,27 @@ def run_job(self, job: JobExecutorInterface): # this dict is to support the to be implemented feature of oversubscription in # "ordinary" group jobs. - jobsteps[job] = subprocess.Popen(call, shell=True) + proc = subprocess.Popen(call, shell=True) + job_info = SubmittedJobInfo(job, aux={"proc": proc}) - job_info = SubmittedJobInfo(job) self.report_job_submission(job_info) - # wait until all steps are finished - if any(proc.wait() != 0 for proc in jobsteps.values()): - self.report_job_error(job_info) - else: - self.report_job_success(job_info) + async def check_active_jobs( + self, active_jobs: List[SubmittedJobInfo] + ) -> Generator[SubmittedJobInfo, None, None]: + for active_job in active_jobs: + ret = active_job.aux["proc"].poll() + if ret is not None: + if ret != 0: + self.report_job_error(active_job) + else: + self.report_job_success(active_job) + else: + yield active_job + + def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]): + for active_job in active_jobs: + active_job.aux["proc"].terminate() def cancel(self): pass