From cf0a7c3a2f38379b13148a3f34b5a7b6c9024f75 Mon Sep 17 00:00:00 2001 From: Max Mynter <32773644+maxmynter@users.noreply.github.com> Date: Fri, 2 Feb 2024 17:16:07 +0100 Subject: [PATCH] Add core test coverage (#54) * Core decorator tests * Multi-directory test collection * Non-existent tag collection * Context collection --- tests/conftest.py | 6 ++ .../benchmark_in_another_dir.py | 6 ++ tests/test_decorators.py | 67 +++++++++++++++++++ tests/test_runner.py | 44 +++++++++++- tests/test_utils.py | 8 +++ 5 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 tests/test_benchmarks_multidir_collection/benchmark_in_another_dir.py create mode 100644 tests/test_decorators.py diff --git a/tests/conftest.py b/tests/conftest.py index 9af4ff21..cab0fbf0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,3 +13,9 @@ def testfolder() -> str: """A test directory for benchmark collection.""" return str(HERE / "test_benchmarks") + + +@pytest.fixture(scope="session") +def another_testfolder() -> str: + """Another test directory for benchmark collection.""" + return str(HERE / "test_benchmarks_multidir_collection") diff --git a/tests/test_benchmarks_multidir_collection/benchmark_in_another_dir.py b/tests/test_benchmarks_multidir_collection/benchmark_in_another_dir.py new file mode 100644 index 00000000..855b8d56 --- /dev/null +++ b/tests/test_benchmarks_multidir_collection/benchmark_in_another_dir.py @@ -0,0 +1,6 @@ +import nnbench + + +@nnbench.benchmark(tags=("runner-collect",)) +def bad_random_number_gen() -> int: + return 1 diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..743257f4 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,67 @@ +import pytest + +import nnbench +from nnbench import benchmark, parametrize, product + +from .test_utils import has_expected_args + + +def test_benchmark_no_args(): + @benchmark + def sample_benchmark() -> str: + return "test" + + assert isinstance(sample_benchmark, nnbench.types.Benchmark) + + +def test_benchmark_with_args(): + @benchmark(name="Test Name", tags=("tag1", "tag2")) + def another_benchmark() -> str: + return "test" + + assert another_benchmark.name == "Test Name" + assert another_benchmark.tags == ("tag1", "tag2") + + +def test_parametrize(): + @parametrize([{"param": 1}, {"param": 2}]) + def parametrized_benchmark(param: int) -> int: + return param + + assert len(parametrized_benchmark) == 2 + assert has_expected_args(parametrized_benchmark[0].fn, {"param": 1}) + assert parametrized_benchmark[0].fn() == 1 + assert has_expected_args(parametrized_benchmark[1].fn, {"param": 2}) + assert parametrized_benchmark[1].fn() == 2 + + +def test_parametrize_with_duplicate_parameters(): + with pytest.warns(UserWarning, match="duplicate"): + + @parametrize([{"param": 1}, {"param": 1}]) + def parametrized_benchmark(param: int) -> int: + return param + + +def test_product(): + @product(iter1=[1, 2], iter2=["a", "b"]) + def product_benchmark(iter1: int, iter2: str) -> tuple[int, str]: + return iter1, iter2 + + assert len(product_benchmark) == 4 + assert has_expected_args(product_benchmark[0].fn, {"iter1": 1, "iter2": "a"}) + assert product_benchmark[0].fn() == (1, "a") + assert has_expected_args(product_benchmark[1].fn, {"iter1": 1, "iter2": "b"}) + assert product_benchmark[1].fn() == (1, "b") + assert has_expected_args(product_benchmark[2].fn, {"iter1": 2, "iter2": "a"}) + assert product_benchmark[2].fn() == (2, "a") + assert has_expected_args(product_benchmark[3].fn, {"iter1": 2, "iter2": "b"}) + assert product_benchmark[3].fn() == (2, "b") + + +def test_product_with_duplicate_parameters(): + with pytest.warns(UserWarning, match="duplicate"): + + @product(iter=[1, 1]) + def product_benchmark(iter: int) -> int: + return iter diff --git a/tests/test_runner.py b/tests/test_runner.py index 73648c85..1cf98fb5 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -1,18 +1,26 @@ import os +import pytest + import nnbench +from nnbench.context import cpuarch, python_version, system -def test_runner_discovery(testfolder: str) -> None: +def test_runner_discovery(testfolder: str, another_testfolder: str) -> None: r = nnbench.BenchmarkRunner() r.collect(os.path.join(testfolder, "standard_benchmarks.py"), tags=("runner-collect",)) assert len(r.benchmarks) == 1 + r.clear() + r.collect(testfolder, tags=("non-existing-tag",)) + assert len(r.benchmarks) == 0 r.clear() r.collect(testfolder, tags=("runner-collect",)) assert len(r.benchmarks) == 1 + r.collect(another_testfolder, tags=("runner-collect",)) + assert len(r.benchmarks) == 2 def test_tag_selection(testfolder: str) -> None: @@ -31,3 +39,37 @@ def test_tag_selection(testfolder: str) -> None: r.collect(PATH, tags=("tag2",)) assert len(r.benchmarks) == 1 r.clear() + + +def test_context_collection_in_runner(testfolder: str) -> None: + r = nnbench.BenchmarkRunner() + + context_providers = [system, cpuarch, python_version] + result = r.run( + testfolder, + tags=("standard",), + params={"x": 1, "y": 1}, + context=context_providers, + ) + + print(result) + assert "system" in result["context"] + assert "cpuarch" in result["context"] + assert "python_version" in result["context"] + + +def test_error_on_duplicate_context_keys_in_runner(testfolder: str) -> None: + r = nnbench.BenchmarkRunner() + + def duplicate_context_provider() -> dict[str, str]: + return {"system": "DuplicateSystem"} + + context_providers = [system, duplicate_context_provider] + + with pytest.raises(ValueError, match="got multiple values for context key 'system'"): + r.run( + testfolder, + tags=("standard",), + params={"x": 1, "y": 1}, + context=context_providers, + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 41da9384..e5206ca0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +import inspect + import pytest from nnbench.util import ismodule, modulename @@ -16,3 +18,9 @@ def test_ismodule(name: str, expected: bool) -> None: def test_modulename(name: str, expected: str) -> None: actual = modulename(name) assert expected == actual + + +def has_expected_args(fn, expected_args): + signature = inspect.signature(fn) + params = signature.parameters + return all(param in params for param in expected_args)