diff --git a/flepimop/gempyor_pkg/src/gempyor/batch.py b/flepimop/gempyor_pkg/src/gempyor/batch.py index acad4235a..dbb242f8b 100644 --- a/flepimop/gempyor_pkg/src/gempyor/batch.py +++ b/flepimop/gempyor_pkg/src/gempyor/batch.py @@ -18,7 +18,7 @@ from shlex import quote import subprocess import sys -from typing import Any, Literal +from typing import Any, Literal, Self from ._jinja import _render_template_to_file, _render_template_to_temp_file from .logging import get_script_logger @@ -173,6 +173,38 @@ def __post_init__(self) -> None: def __str__(self) -> str: return self.format() + def __hash__(self) -> int: + return hash(self.time_limit) + + def __eq__(self, other: Self | timedelta) -> bool: + if isinstance(other, JobTimeLimit): + return self.time_limit == other.time_limit + if isinstance(other, timedelta): + return self.time_limit == other + raise TypeError( + "'==' not supported between instances of " + f"'JobTimeLimit' and '{type(other).__name__}'." + ) + + def __lt__(self, other: Self | timedelta) -> bool: + if isinstance(other, JobTimeLimit): + return self.time_limit < other.time_limit + if isinstance(other, timedelta): + return self.time_limit < other + raise TypeError( + "'<' not supported between instances of " + f"'JobTimeLimit' and '{type(other).__name__}'." + ) + + def __le__(self, other: Self | timedelta) -> bool: + return self.__eq__(other) or self.__lt__(other) + + def __gt__(self, other: Self | timedelta) -> bool: + return not self.__le__(other) + + def __ge__(self, other: Self | timedelta) -> bool: + return self.__eq__(other) or self.__gt__(other) + def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) -> str: """ Format the job time limit as a string appropriate for a given batch system. @@ -202,6 +234,42 @@ def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) - limit_in_mins = math.ceil(self.time_limit.total_seconds() / 60.0) return str(limit_in_mins) + @classmethod + def from_per_simulation_time( + cls, job_size: JobSize, time_per_simulation: timedelta, initial_time: timedelta + ) -> "JobTimeLimit": + """ + Construct a job time limit that scales with job size. + + Args: + job_size: The job size to scale the time limit with. + time_per_simulation: The time per a simulation. + initial_time: Time required to setup per a job. + + Returns: + A job time limit that is scaled to match `job_size`. + + Raises: + ValueError: If `time_per_simulation` is non-positive. + ValueError: If `initial_time` is non-positive. + + Examples: + """ + if (total_seconds := time_per_simulation.total_seconds()) <= 0.0: + raise ValueError( + f"The `time_per_simulation` is '{math.floor(total_seconds):,}' " + "seconds, which is less than or equal to 0." + ) + if (total_seconds := initial_time.total_seconds()) <= 0.0: + raise ValueError( + f"The `initial_time` is '{math.floor(total_seconds):,}' " + "seconds, which is less than or equal to 0." + ) + time_limit = ( + job_size.blocks * job_size.simulations * time_per_simulation + ) + initial_time + return cls(time_limit=time_limit) + def write_manifest( job_name: str, diff --git a/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py b/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py index 6fe1a9b8d..f51be9ef7 100644 --- a/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py +++ b/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py @@ -4,12 +4,13 @@ import pytest -from gempyor.batch import JobTimeLimit +from gempyor.batch import JobSize, JobTimeLimit -@pytest.mark.parametrize( - "time_limit", (timedelta(), timedelta(hours=-1.0), timedelta(days=-3.0)) -) +NONPOSITIVE_TIMEDELTAS = (timedelta(), timedelta(hours=-1.0), timedelta(days=-3.0)) + + +@pytest.mark.parametrize("time_limit", NONPOSITIVE_TIMEDELTAS) def test_time_limit_non_positive_value_error(time_limit: timedelta) -> None: with pytest.raises( ValueError, @@ -61,3 +62,63 @@ def test_format_exact_results( ) -> None: job_time_limit = JobTimeLimit(time_limit=time_limit) assert job_time_limit.format(batch_system=batch_system) == expected + + +@pytest.mark.parametrize("time_per_simulation", NONPOSITIVE_TIMEDELTAS) +def test_from_per_simulation_time_per_simulation_nonpositive_value_error( + time_per_simulation: timedelta, +) -> None: + job_size = JobSize(jobs=1, simulations=1, blocks=1) + with pytest.raises( + ValueError, + match=( + r"^The \`time\_per\_simulation\` is \'[0-9\,\-]+\' seconds\, " + r"which is less than or equal to 0\.$" + ), + ): + JobTimeLimit.from_per_simulation_time( + job_size, time_per_simulation, timedelta(minutes=10) + ) + + +@pytest.mark.parametrize("initial_time", NONPOSITIVE_TIMEDELTAS) +def test_from_per_simulation_initial_time_nonpositive_value_error( + initial_time: timedelta, +) -> None: + job_size = JobSize(jobs=1, simulations=1, blocks=1) + with pytest.raises( + ValueError, + match=( + r"^The \`initial\_time\` is \'[0-9\,\-]+\' seconds\, " + r"which is less than or equal to 0\.$" + ), + ): + JobTimeLimit.from_per_simulation_time(job_size, timedelta(minutes=10), initial_time) + + +@pytest.mark.parametrize( + "job_size", + ( + JobSize(jobs=1, simulations=10, blocks=1), + JobSize(jobs=10, simulations=25, blocks=15), + ), +) +@pytest.mark.parametrize( + "time_per_simulation", + (timedelta(minutes=5), timedelta(seconds=120), timedelta(hours=1.5)), +) +@pytest.mark.parametrize( + "initial_time", (timedelta(minutes=10), timedelta(seconds=30), timedelta(hours=2)) +) +def test_from_per_simulation_time( + job_size: JobSize, time_per_simulation: timedelta, initial_time: timedelta +) -> None: + job_time_limit = JobTimeLimit.from_per_simulation_time( + job_size, time_per_simulation, initial_time + ) + assert job_time_limit.time_limit >= initial_time + + double_job_time_limit = JobTimeLimit.from_per_simulation_time( + job_size, 2 * time_per_simulation, initial_time + ) + assert double_job_time_limit > job_time_limit