diff --git a/cwltool/executors.py b/cwltool/executors.py index 31c9b052c..3cbac72c7 100644 --- a/cwltool/executors.py +++ b/cwltool/executors.py @@ -26,6 +26,7 @@ from .command_line_tool import CallbackJob, ExpressionJob from .context import RuntimeContext, getdefault +from .cuda import cuda_version_and_device_count from .cwlprov.provenance_profile import ProvenanceProfile from .errors import WorkflowException from .job import JobBase @@ -269,8 +270,10 @@ def __init__(self) -> None: self.max_ram = int(psutil.virtual_memory().available / 2**20) self.max_cores = float(psutil.cpu_count()) + self.max_cuda = cuda_version_and_device_count()[1] self.allocated_ram = float(0) self.allocated_cores = float(0) + self.allocated_cuda: int = 0 def select_resources( self, request: Dict[str, Union[int, float]], runtime_context: RuntimeContext @@ -278,7 +281,11 @@ def select_resources( """Naïve check for available cpu cores and memory.""" result: Dict[str, Union[int, float]] = {} maxrsc = {"cores": self.max_cores, "ram": self.max_ram} - for rsc in ("cores", "ram"): + resources_types = {"cores", "ram"} + if "cudaDeviceCountMin" in request or "cudaDeviceCountMax" in request: + maxrsc["cudaDeviceCount"] = self.max_cuda + resources_types.add("cudaDeviceCount") + for rsc in resources_types: rsc_min = request[rsc + "Min"] if rsc_min > maxrsc[rsc]: raise WorkflowException( @@ -293,9 +300,6 @@ def select_resources( result["tmpdirSize"] = math.ceil(request["tmpdirMin"]) result["outdirSize"] = math.ceil(request["outdirMin"]) - if "cudaDeviceCount" in request: - result["cudaDeviceCount"] = request["cudaDeviceCount"] - return result def _runner( @@ -326,6 +330,10 @@ def _runner( self.allocated_ram -= ram cores = job.builder.resources["cores"] self.allocated_cores -= cores + cudaDevices: int = cast( + int, job.builder.resources.get("cudaDeviceCount", 0) + ) + self.allocated_cuda -= cudaDevices runtime_context.workflow_eval_lock.notify_all() def run_job( @@ -349,16 +357,21 @@ def run_job( if isinstance(job, JobBase): ram = job.builder.resources["ram"] cores = job.builder.resources["cores"] - if ram > self.max_ram or cores > self.max_cores: + cudaDevices = cast(int, job.builder.resources.get("cudaDeviceCount", 0)) + if ram > self.max_ram or cores > self.max_cores or cudaDevices > self.max_cuda: _logger.error( 'Job "%s" cannot be run, requests more resources (%s) ' - "than available on this host (max ram %d, max cores %d", + "than available on this host (already allocated ram is %d, " + "allocated cores is %d, allocated CUDA is %d, " + "max ram %d, max cores %d, max CUDA %d).", job.name, job.builder.resources, self.allocated_ram, self.allocated_cores, + self.allocated_cuda, self.max_ram, self.max_cores, + self.max_cuda, ) self.pending_jobs.remove(job) return @@ -366,17 +379,21 @@ def run_job( if ( self.allocated_ram + ram > self.max_ram or self.allocated_cores + cores > self.max_cores + or self.allocated_cuda + cudaDevices > self.max_cuda ): _logger.debug( 'Job "%s" cannot run yet, resources (%s) are not ' "available (already allocated ram is %d, allocated cores is %d, " - "max ram %d, max cores %d", + "allocated CUDA devices is %d, " + "max ram %d, max cores %d, max CUDA %d).", job.name, job.builder.resources, self.allocated_ram, self.allocated_cores, + self.allocated_cuda, self.max_ram, self.max_cores, + self.max_cuda, ) n += 1 continue @@ -386,6 +403,8 @@ def run_job( self.allocated_ram += ram cores = job.builder.resources["cores"] self.allocated_cores += cores + cuda = cast(int, job.builder.resources.get("cudaDevices", 0)) + self.allocated_cuda += cuda self.taskqueue.add( functools.partial(self._runner, job, runtime_context, TMPDIR_LOCK), runtime_context.workflow_eval_lock, diff --git a/cwltool/process.py b/cwltool/process.py index f0c44fe17..80d70be74 100644 --- a/cwltool/process.py +++ b/cwltool/process.py @@ -980,7 +980,8 @@ def evalResources( ): if rsc is None: continue - mn = mx = None # type: Optional[Union[int, float]] + mn: Optional[Union[int, float]] = None + mx: Optional[Union[int, float]] = None if rsc.get(a + "Min"): with SourceLine(rsc, f"{a}Min", WorkflowException, runtimeContext.debug): mn = cast(