From 74b4ccdd553b97f63e28b51b039490fbf773caf0 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:44:09 -0500 Subject: [PATCH] Add `JobTimeLimit` class Added a representation of batch job time limits, similar to `JobSize`, along with corresponding documentation and unit tests. --- flepimop/gempyor_pkg/src/gempyor/batch.py | 58 ++++++++++++++++- .../tests/batch/test_job_time_limit_class.py | 63 +++++++++++++++++++ 2 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py diff --git a/flepimop/gempyor_pkg/src/gempyor/batch.py b/flepimop/gempyor_pkg/src/gempyor/batch.py index cc01def50..acad4235a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/batch.py +++ b/flepimop/gempyor_pkg/src/gempyor/batch.py @@ -5,12 +5,12 @@ metadata and job size calculations for example. """ -__all__ = ["JobSize", "write_manifest"] +__all__ = ["JobSize", "JobTimeLimit", "write_manifest"] import click from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import json import math from pathlib import Path @@ -149,6 +149,60 @@ def size_from_jobs_sims_blocks( return cls(jobs=jobs, simulations=simulations, blocks=blocks) +@dataclass(frozen=True, slots=True) +class JobTimeLimit: + """ + A batch submission job time limit. + + Attributes: + time_limit: The time limit of the batch job. + + Raises: + ValueError: If the `time_limit` attribute is not positive. + """ + + time_limit: timedelta + + def __post_init__(self) -> None: + if (total_seconds := self.time_limit.total_seconds()) <= 0.0: + raise ValueError( + f"The `time_limit` attribute has {math.floor(total_seconds):,} " + "seconds, which is less than or equal to 0." + ) + + def __str__(self) -> str: + return self.format() + + def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) -> str: + """ + Format the job time limit as a string appropriate for a given batch system. + + Args: + batch_system: The batch system the format should be formatted for. + + Returns: + The time limit formatted for the batch system. + + Examples: + >>> from datetime import timedelta + >>> job_time_limit = JobTimeLimit( + ... time_limit=timedelta(days=1, hours=2, minutes=34, seconds=5) + ... ) + >>> job_time_limit.format() + '1595' + >>> job_time_limit.format(batch_system="slurm") + '26:34:05' + """ + if batch_system == "slurm": + total_seconds = self.time_limit.total_seconds() + hours = math.floor(total_seconds / (60.0 * 60.0)) + minutes = math.floor((total_seconds - (60.0 * 60.0 * hours)) / 60.0) + seconds = math.ceil(total_seconds - (60.0 * minutes) - (60.0 * 60.0 * hours)) + return f"{hours}:{minutes:02d}:{seconds:02d}" + limit_in_mins = math.ceil(self.time_limit.total_seconds() / 60.0) + return str(limit_in_mins) + + def write_manifest( job_name: str, flepi_path: Path, diff --git a/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py b/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py new file mode 100644 index 000000000..6fe1a9b8d --- /dev/null +++ b/flepimop/gempyor_pkg/tests/batch/test_job_time_limit_class.py @@ -0,0 +1,63 @@ +from datetime import timedelta +import re +from typing import Literal + +import pytest + +from gempyor.batch import JobTimeLimit + + +@pytest.mark.parametrize( + "time_limit", (timedelta(), timedelta(hours=-1.0), timedelta(days=-3.0)) +) +def test_time_limit_non_positive_value_error(time_limit: timedelta) -> None: + with pytest.raises( + ValueError, + match=( + r"^The \`time\_limit\` attribute has [0-9\,\-]+ seconds\, " + r"which is less than or equal to 0\.$" + ), + ): + JobTimeLimit(time_limit=time_limit) + + +@pytest.mark.parametrize( + "time_limit", + ( + timedelta(hours=1), + timedelta(hours=2, minutes=34, seconds=56), + timedelta(days=1, seconds=3), + timedelta(minutes=12345), + ), +) +@pytest.mark.parametrize("batch_system", (None, "aws", "local", "slurm")) +def test_format_output_validation( + time_limit: timedelta, batch_system: Literal["aws", "local", "slurm"] | None +) -> None: + job_time_limit = JobTimeLimit(time_limit=time_limit) + formatted_time_limit = job_time_limit.format(batch_system=batch_system) + assert isinstance(formatted_time_limit, str) + if batch_system == "slurm": + assert re.match(r"^[0-9]+\:[0-9]{2}\:[0-9]{2}$", formatted_time_limit) + else: + assert formatted_time_limit.isdigit() + + +@pytest.mark.parametrize( + ("time_limit", "batch_system", "expected"), + ( + (timedelta(hours=1), None, "60"), + (timedelta(seconds=20), None, "1"), + (timedelta(days=2, hours=3, minutes=45), None, "3105"), + (timedelta(hours=1), "slurm", "1:00:00"), + (timedelta(seconds=20), "slurm", "0:00:20"), + (timedelta(days=1, hours=2, minutes=34, seconds=5), "slurm", "26:34:05"), + ), +) +def test_format_exact_results( + time_limit: timedelta, + batch_system: Literal["aws", "local", "slurm"] | None, + expected: str, +) -> None: + job_time_limit = JobTimeLimit(time_limit=time_limit) + assert job_time_limit.format(batch_system=batch_system) == expected