diff --git a/pyproject.toml b/pyproject.toml index 0d09d44..17aef06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,9 @@ docs = [ "neoteroi-mkdocs", ] +[project.scripts] +nnbench = "nnbench.cli:main" + [tool.setuptools] package-dir = { "" = "src" } diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index d35141e..ea29327 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -11,4 +11,13 @@ from .core import benchmark, parametrize, product from .reporter import BenchmarkReporter from .runner import BenchmarkRunner -from .types import Memo, Parameters +from .types import Benchmark, BenchmarkRecord, Memo, Parameters + + +# TODO: This isn't great, make it functional instead? +def default_runner() -> BenchmarkRunner: + return BenchmarkRunner() + + +def default_reporter() -> BenchmarkReporter: + return BenchmarkReporter() diff --git a/src/nnbench/__main__.py b/src/nnbench/__main__.py new file mode 100644 index 0000000..8562f76 --- /dev/null +++ b/src/nnbench/__main__.py @@ -0,0 +1,6 @@ +import sys + +from nnbench.cli import main + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/nnbench/cli.py b/src/nnbench/cli.py new file mode 100644 index 0000000..aa11474 --- /dev/null +++ b/src/nnbench/cli.py @@ -0,0 +1,49 @@ +import argparse +from typing import Any + +from nnbench import default_reporter, default_runner + + +def main() -> int: + parser = argparse.ArgumentParser("nnbench") + # can be a directory, single file, or glob + parser.add_argument( + "benchmarks", + nargs="?", + metavar="", + help="Python file or directory of files containing benchmarks to run.", + default="benchmarks", + ) + parser.add_argument( + "--context", + action="append", + metavar="=", + help="Additional context values giving information about the benchmark run.", + default=list(), + ) + parser.add_argument( + "-t", + "--tag", + action="append", + metavar="", + dest="tags", + help="Only run benchmarks marked with one or more given tag(s).", + default=tuple(), + ) + + args = parser.parse_args() + + runner = default_runner() + reporter = default_reporter() + + context: dict[str, Any] = {} + for val in args.context: + try: + k, v = val.split("=") + except ValueError: + raise ValueError("context values need to be of the form =") + context[k] = v + + record = runner.run(args.benchmarks, tags=tuple(args.tags)) + reporter.display(record) + return 0