Skip to content

Commit

Permalink
Further unit tests of JobResources
Browse files Browse the repository at this point in the history
More complete set of unit tests for the `JobResources` class. Also
incorporated the formatting methods of that class into `_click_batch`.
  • Loading branch information
TimothyWillard committed Nov 8, 2024
1 parent 42a9d78 commit 5ca83c5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 8 deletions.
26 changes: 19 additions & 7 deletions flepimop/gempyor_pkg/src/gempyor/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,17 @@ def total_resources(self) -> tuple[int, int, int]:
"""
return (self.nodes, self.total_cpus, self.total_memory)

def format_nodes(self, batch_system: BatchSystem | None) -> str:
return str(self.nodes)

def format_cpus(self, batch_system: BatchSystem | None) -> str:
return str(self.cpus)

def format_memory(self, batch_system: BatchSystem | None) -> str:
if batch_system == BatchSystem.SLURM:
return f"{self.memory}MB"
return str(self.memory)


@dataclass(frozen=True, slots=True)
class JobTimeLimit:
Expand Down Expand Up @@ -290,7 +301,7 @@ def __gt__(self, other: Self | timedelta) -> bool:
def __ge__(self, other: Self | timedelta) -> bool:
return self.__eq__(other) or self.__gt__(other)

def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) -> str:
def format(self, batch_system: BatchSystem | None = None) -> str:
"""
Format the job time limit as a string appropriate for a given batch system.
Expand All @@ -301,16 +312,17 @@ def format(self, batch_system: Literal["aws", "local", "slurm"] | None = None) -
The time limit formatted for the batch system.
Examples:
>>> from gempyor.batch import BatchSystem
>>> from datetime import timedelta
>>> job_time_limit = JobTimeLimit(
... time_limit=timedelta(days=1, hours=2, minutes=34, seconds=5)
... )
>>> job_time_limit.format()
'1595'
>>> job_time_limit.format(batch_system="slurm")
>>> job_time_limit.format(batch_system=BatchSystem.SLURM)
'26:34:05'
"""
if batch_system == "slurm":
if batch_system == BatchSystem.SLURM:
total_seconds = self.time_limit.total_seconds()
hours = math.floor(total_seconds / (60.0 * 60.0))
minutes = math.floor((total_seconds - (60.0 * 60.0 * hours)) / 60.0)
Expand Down Expand Up @@ -846,12 +858,12 @@ def _click_batch(ctx: click.Context = mock_context, **kwargs) -> None:
options = {
"chdir": kwargs["project_path"].absolute(),
"comment": f"Generated on {now:%c %Z} and submitted by {getuser()}.",
"cpus-per-task": job_resources.cpus,
"cpus-per-task": job_resources.format_cpus(batch_system),
"job-name": job_name,
"mem": job_resources.memory,
"nodes": job_resources.nodes,
"mem": job_resources.format_memory(batch_system),
"nodes": job_resources.format_nodes(batch_system),
"ntasks-per-node": 1,
"time": job_time_limit.format("slurm"),
"time": job_time_limit.format(batch_system),
}
if kwargs["partition"] is not None:
options["partition"] = kwargs["partition"]
Expand Down
54 changes: 53 additions & 1 deletion flepimop/gempyor_pkg/tests/batch/test_job_resources_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from gempyor.batch import JobResources
from gempyor.batch import BatchSystem, JobResources, JobSize


@pytest.mark.parametrize(
Expand All @@ -29,3 +29,55 @@ def test_less_than_one_value_error(
),
):
JobResources(**kwargs)


@pytest.mark.parametrize("nodes", (1, 2, 4, 8))
@pytest.mark.parametrize("cpus", (1, 2, 4, 8))
@pytest.mark.parametrize("memory", (1024, 2 * 1024, 4 * 1024, 8 * 1024))
def test_instance_attributes(nodes: int, cpus: int, memory: int) -> None:
job_resources = JobResources(nodes=nodes, cpus=cpus, memory=memory)
assert job_resources.total_cpus >= cpus
assert job_resources.total_memory >= memory
assert job_resources.total_resources() == (nodes, nodes * cpus, nodes * memory)


@pytest.mark.parametrize("nodes", (1, 2, 4, 8))
@pytest.mark.parametrize("cpus", (1, 2, 4, 8))
@pytest.mark.parametrize("memory", (1024, 2 * 1024, 4 * 1024, 8 * 1024))
@pytest.mark.parametrize(
"batch_system", (BatchSystem.AWS, BatchSystem.LOCAL, BatchSystem.SLURM, None)
)
def test_formatting(
nodes: int, cpus: int, memory: int, batch_system: BatchSystem | None
) -> None:
job_resources = JobResources(nodes=nodes, cpus=cpus, memory=memory)

formatted_nodes = job_resources.format_nodes(batch_system)
assert isinstance(formatted_nodes, str)
assert str(nodes) in formatted_nodes

formatted_cpus = job_resources.format_cpus(batch_system)
assert isinstance(formatted_cpus, str)
assert str(cpus) in formatted_cpus

formatted_memory = job_resources.format_memory(batch_system)
assert isinstance(formatted_memory, str)
assert str(memory) in formatted_memory


@pytest.mark.parametrize("jobs", (1, 4, 16, 32))
@pytest.mark.parametrize("simulations", (250, 4 * 250, 16 * 250, 32 * 250))
@pytest.mark.parametrize("blocks", (1, 4, 16, 32))
@pytest.mark.parametrize("inference_method", ("emcee", None))
def test_from_presets_for_select_inputs(
jobs: int, simulations: int, blocks: int, inference_method: Literal["emcee"] | None
) -> None:
job_size = JobSize(jobs=jobs, simulations=simulations, blocks=blocks)
job_resources = JobResources.from_presets(job_size, inference_method)
if inference_method == "emcee":
assert job_resources.nodes == 1
assert job_resources.cpus % 2 == 0
assert job_resources.memory % (2 * 1024) == 0
else:
assert job_resources.cpus == 2
assert job_resources.memory == 2 * 1024

0 comments on commit 5ca83c5

Please sign in to comment.