Skip to content

Commit

Permalink
Add from_per_simulation_time method
Browse files Browse the repository at this point in the history
Added the `JobTimeLimit.from_per_simulation_time` class method to easily
create instances from user provided inputs. Also added comparison
methods for ease of unit testing.
  • Loading branch information
TimothyWillard committed Nov 6, 2024
1 parent 74b4ccd commit bff518d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 5 deletions.
70 changes: 69 additions & 1 deletion flepimop/gempyor_pkg/src/gempyor/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from shlex import quote
import subprocess
import sys
from typing import Any, Literal
from typing import Any, Literal, Self

from ._jinja import _render_template_to_file, _render_template_to_temp_file
from .logging import get_script_logger
Expand Down Expand Up @@ -173,6 +173,38 @@ def __post_init__(self) -> None:
def __str__(self) -> str:
return self.format()

def __hash__(self) -> int:
return hash(self.time_limit)

def __eq__(self, other: Self | timedelta) -> bool:
if isinstance(other, JobTimeLimit):
return self.time_limit == other.time_limit
if isinstance(other, timedelta):
return self.time_limit == other
raise TypeError(
"'==' not supported between instances of "
f"'JobTimeLimit' and '{type(other).__name__}'."
)

def __lt__(self, other: Self | timedelta) -> bool:
if isinstance(other, JobTimeLimit):
return self.time_limit < other.time_limit
if isinstance(other, timedelta):
return self.time_limit < other
raise TypeError(
"'<' not supported between instances of "
f"'JobTimeLimit' and '{type(other).__name__}'."
)

def __le__(self, other: Self | timedelta) -> bool:
return self.__eq__(other) or self.__lt__(other)

def __gt__(self, other: Self | timedelta) -> bool:
return not self.__le__(other)

def __ge__(self, other: Self | timedelta) -> bool:
return self.__eq__(other) or self.__gt__(other)

def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) -> str:
"""
Format the job time limit as a string appropriate for a given batch system.
Expand Down Expand Up @@ -202,6 +234,42 @@ def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) -
limit_in_mins = math.ceil(self.time_limit.total_seconds() / 60.0)
return str(limit_in_mins)

@classmethod
def from_per_simulation_time(
cls, job_size: JobSize, time_per_simulation: timedelta, initial_time: timedelta
) -> "JobTimeLimit":
"""
Construct a job time limit that scales with job size.
Args:
job_size: The job size to scale the time limit with.
time_per_simulation: The time per a simulation.
initial_time: Time required to setup per a job.
Returns:
A job time limit that is scaled to match `job_size`.
Raises:
ValueError: If `time_per_simulation` is non-positive.
ValueError: If `initial_time` is non-positive.
Examples:
"""
if (total_seconds := time_per_simulation.total_seconds()) <= 0.0:
raise ValueError(
f"The `time_per_simulation` is '{math.floor(total_seconds):,}' "
"seconds, which is less than or equal to 0."
)
if (total_seconds := initial_time.total_seconds()) <= 0.0:
raise ValueError(
f"The `initial_time` is '{math.floor(total_seconds):,}' "
"seconds, which is less than or equal to 0."
)
time_limit = (
job_size.blocks * job_size.simulations * time_per_simulation
) + initial_time
return cls(time_limit=time_limit)


def write_manifest(
job_name: str,
Expand Down
69 changes: 65 additions & 4 deletions flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import pytest

from gempyor.batch import JobTimeLimit
from gempyor.batch import JobSize, JobTimeLimit


@pytest.mark.parametrize(
"time_limit", (timedelta(), timedelta(hours=-1.0), timedelta(days=-3.0))
)
NONPOSITIVE_TIMEDELTAS = (timedelta(), timedelta(hours=-1.0), timedelta(days=-3.0))


@pytest.mark.parametrize("time_limit", NONPOSITIVE_TIMEDELTAS)
def test_time_limit_non_positive_value_error(time_limit: timedelta) -> None:
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -61,3 +62,63 @@ def test_format_exact_results(
) -> None:
job_time_limit = JobTimeLimit(time_limit=time_limit)
assert job_time_limit.format(batch_system=batch_system) == expected


@pytest.mark.parametrize("time_per_simulation", NONPOSITIVE_TIMEDELTAS)
def test_from_per_simulation_time_per_simulation_nonpositive_value_error(
time_per_simulation: timedelta,
) -> None:
job_size = JobSize(jobs=1, simulations=1, blocks=1)
with pytest.raises(
ValueError,
match=(
r"^The \`time\_per\_simulation\` is \'[0-9\,\-]+\' seconds\, "
r"which is less than or equal to 0\.$"
),
):
JobTimeLimit.from_per_simulation_time(
job_size, time_per_simulation, timedelta(minutes=10)
)


@pytest.mark.parametrize("initial_time", NONPOSITIVE_TIMEDELTAS)
def test_from_per_simulation_initial_time_nonpositive_value_error(
initial_time: timedelta,
) -> None:
job_size = JobSize(jobs=1, simulations=1, blocks=1)
with pytest.raises(
ValueError,
match=(
r"^The \`initial\_time\` is \'[0-9\,\-]+\' seconds\, "
r"which is less than or equal to 0\.$"
),
):
JobTimeLimit.from_per_simulation_time(job_size, timedelta(minutes=10), initial_time)


@pytest.mark.parametrize(
"job_size",
(
JobSize(jobs=1, simulations=10, blocks=1),
JobSize(jobs=10, simulations=25, blocks=15),
),
)
@pytest.mark.parametrize(
"time_per_simulation",
(timedelta(minutes=5), timedelta(seconds=120), timedelta(hours=1.5)),
)
@pytest.mark.parametrize(
"initial_time", (timedelta(minutes=10), timedelta(seconds=30), timedelta(hours=2))
)
def test_from_per_simulation_time(
job_size: JobSize, time_per_simulation: timedelta, initial_time: timedelta
) -> None:
job_time_limit = JobTimeLimit.from_per_simulation_time(
job_size, time_per_simulation, initial_time
)
assert job_time_limit.time_limit >= initial_time

double_job_time_limit = JobTimeLimit.from_per_simulation_time(
job_size, 2 * time_per_simulation, initial_time
)
assert double_job_time_limit > job_time_limit

0 comments on commit bff518d

Please sign in to comment.