Skip to content

Commit

Permalink
Remove typechecks from benchmark runner (#174)
Browse files Browse the repository at this point in the history
* Remove typechecks from benchmark runner

This was a nice idea, but since we allow per-benchmark-family fixtures
by now, and we cannot enforce hard typechecks in a benchmark loop because
of UX reasons (cannot just wreck a user's run like that), we remove the
typechecking facility.

* Remove typecheck input annotation, mentions of typechecks in docs
  • Loading branch information
nicholasjng authored Nov 27, 2024
1 parent 1ed24f6 commit f9f4d58
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 153 deletions.
6 changes: 3 additions & 3 deletions docs/guides/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import nnbench

def set_envvar(**params):
os.environ["MY_ENV"] = "MY_VALUE"


@nnbench.benchmark(setUp=set_envvar)
def prod(a: int, b: int) -> int:
Expand All @@ -33,7 +33,7 @@ import nnbench

def set_envvar(**params):
os.environ["MY_ENV"] = "MY_VALUE"


def pop_envvar(**params):
os.environ.pop("MY_ENV")
Expand Down Expand Up @@ -108,4 +108,4 @@ runner = nnbench.BenchmarkRunner()
result = runner.run(__name__, params=params)
```

While this does not have a concrete advantage in terms of type safety over a raw dictionary (all inputs will be checked against the types expected from the benchmark interfaces), it guards against accidental modification of parameters breaking reproducibility.
While this does not have a concrete advantage in terms of type safety over a raw dictionary, it guards against accidental modification of parameters breaking reproducibility.
8 changes: 1 addition & 7 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,6 @@ def main() -> int:
help="File or stream to write results to, defaults to stdout.",
default=sys.stdout,
)
run_parser.add_argument(
"--typecheck",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether or not to strictly check types of benchmark inputs.",
)

compare_parser = subparsers.add_parser(
"compare",
Expand Down Expand Up @@ -136,7 +130,7 @@ def main() -> int:
# TODO: Support builtin providers in the runner
context[k] = v

record = BenchmarkRunner(typecheck=args.typecheck).run(
record = BenchmarkRunner().run(
args.benchmarks,
tags=tuple(args.tags),
context=[lambda: context],
Expand Down
102 changes: 2 additions & 100 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any, get_origin
from typing import Any

from nnbench.context import ContextProvider
from nnbench.fixtures import FixtureManager
Expand Down Expand Up @@ -48,107 +48,12 @@ class BenchmarkRunner:
Collects benchmarks from a module or file using the collect() method.
Runs a previously collected benchmark workload with parameters in the run() method,
outputting the results to a JSON-like document.
Optionally checks input parameters against the benchmark function's interfaces,
raising an error if the input types do not match the expected types.
Parameters
----------
typecheck: bool
Whether to check parameter types against the expected benchmark input types.
Type mismatches will result in an error before the workload is run.
"""

benchmark_type = Benchmark

def __init__(self, typecheck: bool = True):
def __init__(self):
self.benchmarks: list[Benchmark] = list()
self.typecheck = typecheck

def _check(self, params: dict[str, Any]) -> None:
allvars: dict[str, tuple[type, Any]] = {}
required: set[str] = set()
empty = inspect.Parameter.empty

def _issubtype(t1: type, t2: type) -> bool:
"""Small helper to make typechecks work on generics."""

if t1 == t2:
return True

t1 = get_origin(t1) or t1
t2 = get_origin(t2) or t2
if not inspect.isclass(t1):
return False
# TODO: Extend typing checks to args.
return issubclass(t1, t2)

# stitch together the union interface comprised of all benchmarks.
for bm in self.benchmarks:
for var in bm.interface.variables:
name, typ, default = var
if default is empty:
required.add(name)
if name in params and default != empty:
logger.debug(
f"using given value {params[name]} over default value {default} "
f"for parameter {name!r} in benchmark {bm.name}()"
)

if typ == empty:
logger.debug(f"parameter {name!r} untyped in benchmark {bm.name}().")

if name in allvars:
currvar = allvars[name]
orig_type, orig_val = new_type, new_val = currvar
# If a benchmark has a variable without a default value,
# that variable is taken into the combined interface as no-default.
if default is empty:
new_val = default
# These types need not be exact matches, just compatible.
# Two types are compatible iff either is a subtype of the other.
# We only log the narrowest type for each varname in the final interface,
# since that determines whether an input value is admissible.
if _issubtype(orig_type, typ):
pass
elif _issubtype(typ, orig_type):
new_type = typ
else:
raise TypeError(
f"got incompatible types {orig_type}, {typ} for parameter {name!r}"
)
newvar = (new_type, new_val)
if newvar != currvar:
allvars[name] = newvar
else:
allvars[name] = (typ, default)

# check if any required variable has no parameter.
missing = required - params.keys()
if missing:
msng, *_ = missing
raise ValueError(f"missing value for required parameter {msng!r}")

for k, v in params.items():
if k not in allvars:
warnings.warn(
f"ignoring parameter {k!r} since it is not part of any benchmark interface."
)
continue

typ, default = allvars[k]
# skip the subsequent type check if the variable is untyped.
if typ == empty:
continue

vtype = type(v)
if is_memo(v) and not is_memo_type(typ):
# in case of a thunk, check the result type of __call__() instead.
vtype = inspect.signature(v).return_annotation

# type-check parameter value against the narrowest hinted type.
if not _issubtype(vtype, typ):
raise TypeError(f"expected type {typ} for parameter {k!r}, got {vtype}")

def jsonify_params(
self, params: dict[str, Any], repr_hooks: dict[type, Callable] | None = None
Expand Down Expand Up @@ -318,9 +223,6 @@ def run(
else:
dparams = params or {}

if dparams and self.typecheck:
self._check(dparams)

results: list[dict[str, Any]] = []

def _maybe_dememo(v, expected_type):
Expand Down
43 changes: 0 additions & 43 deletions tests/test_argcheck.py

This file was deleted.

0 comments on commit f9f4d58

Please sign in to comment.