Skip to content

Commit

Permalink
Add BatchSystem enum
Browse files Browse the repository at this point in the history
Added an enum for batch systems to formalize the valid systems flepiMoP
can run on.
  • Loading branch information
TimothyWillard committed Nov 7, 2024
1 parent aa3564f commit 37ce2c1
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 61 deletions.
112 changes: 64 additions & 48 deletions flepimop/gempyor_pkg/src/gempyor/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
metadata and job size calculations for example.
"""

__all__ = ["JobSize", "JobTimeLimit", "write_manifest"]
__all__ = ["BatchSystem", "JobSize", "JobTimeLimit", "write_manifest"]


from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum, auto
from getpass import getuser
import json
import math
Expand Down Expand Up @@ -43,6 +44,62 @@
_JOB_NAME_REGEX = re.compile(r"^[a-z]{1}([a-z0-9\_\-]+)?$", flags=re.IGNORECASE)


class BatchSystem(Enum):
"""
Enum representing the various batch systems that flepiMoP can run on.
"""

AWS = auto()
LOCAL = auto()
SLURM = auto()

@classmethod
def from_options(
cls,
batch_system: Literal["aws", "local", "slurm"] | None,
aws: bool,
local: bool,
slurm: bool,
) -> "BatchSystem":
"""
Resolve the batch system options.
Args:
batch_system: The name of the batch system to use if provided explicitly by
name or `None` to rely on the other flags.
aws: A flag indicating if the batch system should be AWS.
local: A flag indicating if the batch system should be local.
slurm: A flag indicating if the batch system should be slurm.
Returns:
The name of the batch system to use given the user options.
"""
batch_system = batch_system.lower() if batch_system is not None else batch_system
if (boolean_flags := sum((aws, local, slurm))) > 1:
raise ValueError(
f"There were {boolean_flags} boolean flags given, expected either 0 or 1."
)
if batch_system is not None:
for name, flag in zip(("aws", "local", "slurm"), (aws, local, slurm)):
if flag and batch_system != name:
raise ValueError(
"Conflicting batch systems given. The batch system name "
f"is '{batch_system}' and the flags indicate '{name}'."
)
if batch_system is None:
if aws:
batch_system = "aws"
elif local:
batch_system = "local"
else:
batch_system = "slurm"
if batch_system == "aws":
return cls.AWS
elif batch_system == "local":
return cls.LOCAL
return cls.SLURM


@dataclass(frozen=True, slots=True)
class JobSize:
"""
Expand Down Expand Up @@ -80,7 +137,7 @@ def size_from_jobs_sims_blocks(
iterations_per_slot: int | None,
slots: int | None,
subpops: int | None,
batch_system: Literal["aws", "local", "slurm"],
batch_system: BatchSystem,
) -> "JobSize":
"""
Infer a job size from several explicit and implicit parameters.
Expand Down Expand Up @@ -142,7 +199,7 @@ def size_from_jobs_sims_blocks(
"provided, then a subpops must be given."
)
)
if batch_system == "aws":
if batch_system == BatchSystem.AWS:
simulations = 5 * math.ceil(max(60 - math.sqrt(subpops), 10) / 5)
else:
simulations = iterations_per_slot
Expand Down Expand Up @@ -514,47 +571,6 @@ def _job_name(name: str | None, timestamp: datetime | None) -> str:
return f"{name}-{timestamp}" if name else timestamp


def _resolve_batch_system(
batch_system: Literal["aws", "local", "slurm"] | None,
aws: bool,
local: bool,
slurm: bool,
) -> Literal["aws", "local", "slurm"]:
"""
Resolve the batch system options.
Args:
batch_system: The name of the batch system to use if provided explicitly by
name or `None` to rely on the other flags.
aws: A flag indicating if the batch system should be AWS.
local: A flag indicating if the batch system should be local.
slurm: A flag indicating if the batch system should be slurm.
Returns:
The name of the batch system to use given the user options.
"""
batch_system = batch_system.lower() if batch_system is not None else batch_system
if (boolean_flags := sum((aws, local, slurm))) > 1:
raise ValueError(
f"There were {boolean_flags} boolean flags given, expected either 0 or 1."
)
if batch_system is not None:
for name, flag in zip(("aws", "local", "slurm"), (aws, local, slurm)):
if flag and batch_system != name:
raise ValueError(
"Conflicting batch systems given. The batch system name "
f"is '{batch_system}' and the flags indicate '{name}'."
)
if batch_system is None:
if aws:
batch_system = "aws"
elif local:
batch_system = "local"
else:
batch_system = "slurm"
return batch_system


@cli.command(
name="batch",
params=[config_files_argument]
Expand Down Expand Up @@ -720,16 +736,16 @@ def _click_batch(ctx: click.Context = mock_context, **kwargs) -> None:
logger.info("Using a run id of '%s'", kwargs["run_id"])

# Batch system
batch_system = _resolve_batch_system(
batch_system = BatchSystem.from_options(
kwargs["batch_system"], kwargs["aws"], kwargs["local"], kwargs["slurm"]
)
if batch_system != "slurm":
if batch_system != BatchSystem.SLURM:
# Temporary limitation
raise NotImplementedError(
"The `flepimop batch` CLI only supports batch submission to slurm."
)
logger.info("Constructing a job to submit to %s", batch_system)
if batch_system != "slurm" and kwargs["email"] is not None:
if batch_system != BatchSystem.SLURM and kwargs["email"] is not None:
logger.warning(
"The email option, given '%s', is only used when "
"the batch system is slurm, but is instead %s.",
Expand Down Expand Up @@ -777,7 +793,7 @@ def _click_batch(ctx: click.Context = mock_context, **kwargs) -> None:

# Cluster info
cluster: Cluster | None = None
if batch_system == "slurm":
if batch_system == BatchSystem.SLURM:
if kwargs["cluster"] is None:
raise ValueError("When submitting a batch job to slurm a cluster is required.")
cluster = get_cluster_info(kwargs["cluster"])
Expand Down
10 changes: 10 additions & 0 deletions flepimop/gempyor_pkg/tests/batch/test__click_batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from typing import Any

Expand All @@ -8,6 +9,15 @@
from gempyor.batch import _click_batch


@pytest.fixture
def add_sbatch_to_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
sbatch = tmp_path / "bin" / "sbatch"
sbatch.parent.mkdir(parents=True, exist_ok=True)
sbatch.touch(mode=0o755)
monkeypatch.setenv("PATH", str(sbatch.parent.absolute()), prepend=os.pathsep)
return sbatch


@pytest.mark.parametrize(
"config",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from gempyor.batch import _resolve_batch_system
from gempyor.batch import BatchSystem


@pytest.mark.parametrize(
Expand All @@ -17,7 +17,7 @@ def test_multiple_flags_value_error(aws: bool, local: bool, slurm: bool) -> None
"flags given, expected either 0 or 1.$"
),
):
_resolve_batch_system(None, aws, local, slurm)
BatchSystem.from_options(None, aws, local, slurm)


@pytest.mark.parametrize(
Expand All @@ -42,21 +42,21 @@ def test_batch_system_flag_mismatch_value_error(
f"is '{batch_system}' and the flags indicate '{name}'.$"
),
):
_resolve_batch_system(batch_system, aws, local, slurm)
BatchSystem.from_options(batch_system, aws, local, slurm)


@pytest.mark.parametrize(
("batch_system", "aws", "local", "slurm", "expected"),
(
(None, True, False, False, "aws"),
("aws", False, False, False, "aws"),
("aws", True, False, False, "aws"),
(None, False, True, False, "local"),
("local", False, False, False, "local"),
("local", False, True, False, "local"),
(None, False, False, True, "slurm"),
("slurm", False, False, False, "slurm"),
("slurm", False, False, True, "slurm"),
(None, True, False, False, BatchSystem.AWS),
("aws", False, False, False, BatchSystem.AWS),
("aws", True, False, False, BatchSystem.AWS),
(None, False, True, False, BatchSystem.LOCAL),
("local", False, False, False, BatchSystem.LOCAL),
("local", False, True, False, BatchSystem.LOCAL),
(None, False, False, True, BatchSystem.SLURM),
("slurm", False, False, False, BatchSystem.SLURM),
("slurm", False, False, True, BatchSystem.SLURM),
),
)
def test_output_validation(
Expand All @@ -66,4 +66,4 @@ def test_output_validation(
slurm: bool,
expected: Literal["aws", "local", "slurm"],
) -> None:
assert _resolve_batch_system(batch_system, aws, local, slurm) == expected
assert BatchSystem.from_options(batch_system, aws, local, slurm) == expected

0 comments on commit 37ce2c1

Please sign in to comment.