Skip to content

Commit

Permalink
Add core test coverage (#54)
Browse files Browse the repository at this point in the history
* Core decorator tests

* Multi-directory test collection

* Non-existent tag collection

* Context collection
  • Loading branch information
maxmynter authored Feb 2, 2024
1 parent 1ace1bd commit cf0a7c3
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@
def testfolder() -> str:
"""A test directory for benchmark collection."""
return str(HERE / "test_benchmarks")


@pytest.fixture(scope="session")
def another_testfolder() -> str:
"""Another test directory for benchmark collection."""
return str(HERE / "test_benchmarks_multidir_collection")
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import nnbench


@nnbench.benchmark(tags=("runner-collect",))
def bad_random_number_gen() -> int:
return 1
67 changes: 67 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest

import nnbench
from nnbench import benchmark, parametrize, product

from .test_utils import has_expected_args


def test_benchmark_no_args():
@benchmark
def sample_benchmark() -> str:
return "test"

assert isinstance(sample_benchmark, nnbench.types.Benchmark)


def test_benchmark_with_args():
@benchmark(name="Test Name", tags=("tag1", "tag2"))
def another_benchmark() -> str:
return "test"

assert another_benchmark.name == "Test Name"
assert another_benchmark.tags == ("tag1", "tag2")


def test_parametrize():
@parametrize([{"param": 1}, {"param": 2}])
def parametrized_benchmark(param: int) -> int:
return param

assert len(parametrized_benchmark) == 2
assert has_expected_args(parametrized_benchmark[0].fn, {"param": 1})
assert parametrized_benchmark[0].fn() == 1
assert has_expected_args(parametrized_benchmark[1].fn, {"param": 2})
assert parametrized_benchmark[1].fn() == 2


def test_parametrize_with_duplicate_parameters():
with pytest.warns(UserWarning, match="duplicate"):

@parametrize([{"param": 1}, {"param": 1}])
def parametrized_benchmark(param: int) -> int:
return param


def test_product():
@product(iter1=[1, 2], iter2=["a", "b"])
def product_benchmark(iter1: int, iter2: str) -> tuple[int, str]:
return iter1, iter2

assert len(product_benchmark) == 4
assert has_expected_args(product_benchmark[0].fn, {"iter1": 1, "iter2": "a"})
assert product_benchmark[0].fn() == (1, "a")
assert has_expected_args(product_benchmark[1].fn, {"iter1": 1, "iter2": "b"})
assert product_benchmark[1].fn() == (1, "b")
assert has_expected_args(product_benchmark[2].fn, {"iter1": 2, "iter2": "a"})
assert product_benchmark[2].fn() == (2, "a")
assert has_expected_args(product_benchmark[3].fn, {"iter1": 2, "iter2": "b"})
assert product_benchmark[3].fn() == (2, "b")


def test_product_with_duplicate_parameters():
with pytest.warns(UserWarning, match="duplicate"):

@product(iter=[1, 1])
def product_benchmark(iter: int) -> int:
return iter
44 changes: 43 additions & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import os

import pytest

import nnbench
from nnbench.context import cpuarch, python_version, system


def test_runner_discovery(testfolder: str) -> None:
def test_runner_discovery(testfolder: str, another_testfolder: str) -> None:
r = nnbench.BenchmarkRunner()

r.collect(os.path.join(testfolder, "standard_benchmarks.py"), tags=("runner-collect",))
assert len(r.benchmarks) == 1
r.clear()

r.collect(testfolder, tags=("non-existing-tag",))
assert len(r.benchmarks) == 0
r.clear()

r.collect(testfolder, tags=("runner-collect",))
assert len(r.benchmarks) == 1
r.collect(another_testfolder, tags=("runner-collect",))
assert len(r.benchmarks) == 2


def test_tag_selection(testfolder: str) -> None:
Expand All @@ -31,3 +39,37 @@ def test_tag_selection(testfolder: str) -> None:
r.collect(PATH, tags=("tag2",))
assert len(r.benchmarks) == 1
r.clear()


def test_context_collection_in_runner(testfolder: str) -> None:
r = nnbench.BenchmarkRunner()

context_providers = [system, cpuarch, python_version]
result = r.run(
testfolder,
tags=("standard",),
params={"x": 1, "y": 1},
context=context_providers,
)

print(result)
assert "system" in result["context"]
assert "cpuarch" in result["context"]
assert "python_version" in result["context"]


def test_error_on_duplicate_context_keys_in_runner(testfolder: str) -> None:
r = nnbench.BenchmarkRunner()

def duplicate_context_provider() -> dict[str, str]:
return {"system": "DuplicateSystem"}

context_providers = [system, duplicate_context_provider]

with pytest.raises(ValueError, match="got multiple values for context key 'system'"):
r.run(
testfolder,
tags=("standard",),
params={"x": 1, "y": 1},
context=context_providers,
)
8 changes: 8 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import pytest

from nnbench.util import ismodule, modulename
Expand All @@ -16,3 +18,9 @@ def test_ismodule(name: str, expected: bool) -> None:
def test_modulename(name: str, expected: str) -> None:
actual = modulename(name)
assert expected == actual


def has_expected_args(fn, expected_args):
signature = inspect.signature(fn)
params = signature.parameters
return all(param in params for param in expected_args)

0 comments on commit cf0a7c3

Please sign in to comment.