Skip to content

Commit

Permalink
Initial JobSize class
Browse files Browse the repository at this point in the history
Added a dataclass to represent batch job sizes, including light
validation.
  • Loading branch information
TimothyWillard committed Nov 8, 2024
1 parent 3eb1cb2 commit 06dbf0d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
32 changes: 31 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
metadata and job size calculations for example.
"""

__all__ = ["write_manifest"]
__all__ = ["JobSize", "write_manifest"]


from dataclasses import dataclass
import json
from pathlib import Path
from shlex import quote
Expand All @@ -19,6 +20,35 @@
from .utils import _git_head, _shutil_which


@dataclass(frozen=True, slots=True)
class JobSize:
"""
A batch submission job size.
Attributes:
jobs: The number of jobs to use.
simulations: The number of simulations to run per a block.
blocks: The number of sequential blocks to run per a job.
Raises:
ValueError: If any of the attributes are less than 1.
"""

jobs: int
simulations: int
blocks: int

def __post_init__(self) -> None:
for p in self.__slots__:
if (val := getattr(self, p)) < 1:
raise ValueError(
(
f"The '{p}' attribute must be greater than 0, "
f"but instead was given '{val}'."
)
)


def write_manifest(
job_name: str,
flepi_path: Path,
Expand Down
31 changes: 31 additions & 0 deletions flepimop/gempyor_pkg/tests/batch/test_job_size_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Literal

import pytest

from gempyor.batch import JobSize


@pytest.mark.parametrize(
"kwargs",
(
{"jobs": 0, "simulations": 1, "blocks": 1},
{"jobs": 1, "simulations": 0, "blocks": 1},
{"jobs": 1, "simulations": 1, "blocks": 0},
{"jobs": 0, "simulations": 0, "blocks": 1},
{"jobs": 1, "simulations": 0, "blocks": 0},
{"jobs": 0, "simulations": 1, "blocks": 0},
{"jobs": 0, "simulations": 0, "blocks": 0},
),
)
def test_less_than_one_value_error(
kwargs: dict[Literal["jobs", "simulations", "blocks"], int]
) -> None:
param = next(k for k, v in kwargs.items() if v < 1)
with pytest.raises(
ValueError,
match=(
f"^The '{param}' attribute must be greater than 0, "
f"but instead was given '{kwargs.get(param)}'.$"
),
):
JobSize(**kwargs)

0 comments on commit 06dbf0d

Please sign in to comment.