From f9f4d588332c3b75d69902067f476e6ba9fefc36 Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Wed, 27 Nov 2024 11:12:32 +0100 Subject: [PATCH] Remove typechecks from benchmark runner (#174) * Remove typechecks from benchmark runner This was a nice idea, but since we allow per-benchmark-family fixtures by now, and we cannot enforce hard typechecks in a benchmark loop because of UX reasons (cannot just wreck a user's run like that), we remove the typechecking facility. * Remove typecheck input annotation, mentions of typechecks in docs --- docs/guides/customization.md | 6 +-- src/nnbench/cli.py | 8 +-- src/nnbench/runner.py | 102 +---------------------------------- tests/test_argcheck.py | 43 --------------- 4 files changed, 6 insertions(+), 153 deletions(-) delete mode 100644 tests/test_argcheck.py diff --git a/docs/guides/customization.md b/docs/guides/customization.md index 9b6d2612..abb1dd5f 100644 --- a/docs/guides/customization.md +++ b/docs/guides/customization.md @@ -16,7 +16,7 @@ import nnbench def set_envvar(**params): os.environ["MY_ENV"] = "MY_VALUE" - + @nnbench.benchmark(setUp=set_envvar) def prod(a: int, b: int) -> int: @@ -33,7 +33,7 @@ import nnbench def set_envvar(**params): os.environ["MY_ENV"] = "MY_VALUE" - + def pop_envvar(**params): os.environ.pop("MY_ENV") @@ -108,4 +108,4 @@ runner = nnbench.BenchmarkRunner() result = runner.run(__name__, params=params) ``` -While this does not have a concrete advantage in terms of type safety over a raw dictionary (all inputs will be checked against the types expected from the benchmark interfaces), it guards against accidental modification of parameters breaking reproducibility. +While this does not have a concrete advantage in terms of type safety over a raw dictionary, it guards against accidental modification of parameters breaking reproducibility. diff --git a/src/nnbench/cli.py b/src/nnbench/cli.py index c97e6e78..104a1fdd 100644 --- a/src/nnbench/cli.py +++ b/src/nnbench/cli.py @@ -78,12 +78,6 @@ def main() -> int: help="File or stream to write results to, defaults to stdout.", default=sys.stdout, ) - run_parser.add_argument( - "--typecheck", - action=argparse.BooleanOptionalAction, - default=True, - help="Whether or not to strictly check types of benchmark inputs.", - ) compare_parser = subparsers.add_parser( "compare", @@ -136,7 +130,7 @@ def main() -> int: # TODO: Support builtin providers in the runner context[k] = v - record = BenchmarkRunner(typecheck=args.typecheck).run( + record = BenchmarkRunner().run( args.benchmarks, tags=tuple(args.tags), context=[lambda: context], diff --git a/src/nnbench/runner.py b/src/nnbench/runner.py index f4bcf8f5..5ebe5170 100644 --- a/src/nnbench/runner.py +++ b/src/nnbench/runner.py @@ -14,7 +14,7 @@ from dataclasses import asdict from datetime import datetime from pathlib import Path -from typing import Any, get_origin +from typing import Any from nnbench.context import ContextProvider from nnbench.fixtures import FixtureManager @@ -48,107 +48,12 @@ class BenchmarkRunner: Collects benchmarks from a module or file using the collect() method. Runs a previously collected benchmark workload with parameters in the run() method, outputting the results to a JSON-like document. - - Optionally checks input parameters against the benchmark function's interfaces, - raising an error if the input types do not match the expected types. - - Parameters - ---------- - typecheck: bool - Whether to check parameter types against the expected benchmark input types. - Type mismatches will result in an error before the workload is run. """ benchmark_type = Benchmark - def __init__(self, typecheck: bool = True): + def __init__(self): self.benchmarks: list[Benchmark] = list() - self.typecheck = typecheck - - def _check(self, params: dict[str, Any]) -> None: - allvars: dict[str, tuple[type, Any]] = {} - required: set[str] = set() - empty = inspect.Parameter.empty - - def _issubtype(t1: type, t2: type) -> bool: - """Small helper to make typechecks work on generics.""" - - if t1 == t2: - return True - - t1 = get_origin(t1) or t1 - t2 = get_origin(t2) or t2 - if not inspect.isclass(t1): - return False - # TODO: Extend typing checks to args. - return issubclass(t1, t2) - - # stitch together the union interface comprised of all benchmarks. - for bm in self.benchmarks: - for var in bm.interface.variables: - name, typ, default = var - if default is empty: - required.add(name) - if name in params and default != empty: - logger.debug( - f"using given value {params[name]} over default value {default} " - f"for parameter {name!r} in benchmark {bm.name}()" - ) - - if typ == empty: - logger.debug(f"parameter {name!r} untyped in benchmark {bm.name}().") - - if name in allvars: - currvar = allvars[name] - orig_type, orig_val = new_type, new_val = currvar - # If a benchmark has a variable without a default value, - # that variable is taken into the combined interface as no-default. - if default is empty: - new_val = default - # These types need not be exact matches, just compatible. - # Two types are compatible iff either is a subtype of the other. - # We only log the narrowest type for each varname in the final interface, - # since that determines whether an input value is admissible. - if _issubtype(orig_type, typ): - pass - elif _issubtype(typ, orig_type): - new_type = typ - else: - raise TypeError( - f"got incompatible types {orig_type}, {typ} for parameter {name!r}" - ) - newvar = (new_type, new_val) - if newvar != currvar: - allvars[name] = newvar - else: - allvars[name] = (typ, default) - - # check if any required variable has no parameter. - missing = required - params.keys() - if missing: - msng, *_ = missing - raise ValueError(f"missing value for required parameter {msng!r}") - - for k, v in params.items(): - if k not in allvars: - warnings.warn( - f"ignoring parameter {k!r} since it is not part of any benchmark interface." - ) - continue - - typ, default = allvars[k] - # skip the subsequent type check if the variable is untyped. - if typ == empty: - continue - - vtype = type(v) - if is_memo(v) and not is_memo_type(typ): - # in case of a thunk, check the result type of __call__() instead. - vtype = inspect.signature(v).return_annotation - - # type-check parameter value against the narrowest hinted type. - if not _issubtype(vtype, typ): - raise TypeError(f"expected type {typ} for parameter {k!r}, got {vtype}") def jsonify_params( self, params: dict[str, Any], repr_hooks: dict[type, Callable] | None = None @@ -318,9 +223,6 @@ def run( else: dparams = params or {} - if dparams and self.typecheck: - self._check(dparams) - results: list[dict[str, Any]] = [] def _maybe_dememo(v, expected_type): diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py deleted file mode 100644 index b547e7c9..00000000 --- a/tests/test_argcheck.py +++ /dev/null @@ -1,43 +0,0 @@ -import logging -import os - -import pytest - -import nnbench - - -def test_argcheck(testfolder: str) -> None: - benchmarks = os.path.join(testfolder, "standard.py") - r = nnbench.BenchmarkRunner() - - with pytest.raises(TypeError, match="expected type .*"): - r.run(benchmarks, params={"x": 1, "y": "1"}, tags=("standard",)) - - with pytest.raises(ValueError, match="missing value for required parameter 'y'.*"): - r.run(benchmarks, params={"x": 1}, tags=("standard",)) - - r.run(benchmarks, params={"x": 1, "y": 1}, tags=("standard",)) - - -def test_error_on_duplicate_params(testfolder: str) -> None: - benchmarks = os.path.join(testfolder, "argchecks.py") - r = nnbench.BenchmarkRunner() - - with pytest.raises(TypeError, match="got incompatible types.*"): - r.run(benchmarks, params={"x": 1, "y": 1}, tags=("duplicate",)) - - -def test_log_warn_on_overwrite_default(testfolder: str, caplog: pytest.LogCaptureFixture) -> None: - benchmark = os.path.join(testfolder, "argchecks.py") - r = nnbench.BenchmarkRunner() - - with caplog.at_level(logging.DEBUG): - r.run(benchmark, params={"a": 1}, tags=("with_default",)) - assert "using given value 1 over default value" in caplog.text - - -def test_untyped_interface(testfolder: str) -> None: - benchmarks = os.path.join(testfolder, "argchecks.py") - - r = nnbench.BenchmarkRunner() - r.run(benchmarks, params={"value": 2}, tags=("untyped",))