From 7532f99442942e74e26c5fc5efb49998532e9d3f Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Thu, 24 Oct 2024 15:31:36 +0200 Subject: [PATCH] Add CLI entrypoints for module and program invocation (#152) * Add CLI entrypoints for module and program invocation Supports `nnbench ` as well as `python -m nnbench `. Like pytest, we give the option to omit the target location, in which case it becomes a directory named "benchmarks". Supports a very rudimentary context building from switches method, and tag filtering via switch. --- pyproject.toml | 3 +++ src/nnbench/__init__.py | 11 ++++++++- src/nnbench/__main__.py | 6 +++++ src/nnbench/cli.py | 49 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 src/nnbench/__main__.py create mode 100644 src/nnbench/cli.py diff --git a/pyproject.toml b/pyproject.toml index 0d09d440..17aef06c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,9 @@ docs = [ "neoteroi-mkdocs", ] +[project.scripts] +nnbench = "nnbench.cli:main" + [tool.setuptools] package-dir = { "" = "src" } diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index d35141e1..ea293277 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -11,4 +11,13 @@ from .core import benchmark, parametrize, product from .reporter import BenchmarkReporter from .runner import BenchmarkRunner -from .types import Memo, Parameters +from .types import Benchmark, BenchmarkRecord, Memo, Parameters + + +# TODO: This isn't great, make it functional instead? +def default_runner() -> BenchmarkRunner: + return BenchmarkRunner() + + +def default_reporter() -> BenchmarkReporter: + return BenchmarkReporter() diff --git a/src/nnbench/__main__.py b/src/nnbench/__main__.py new file mode 100644 index 00000000..8562f76f --- /dev/null +++ b/src/nnbench/__main__.py @@ -0,0 +1,6 @@ +import sys + +from nnbench.cli import main + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/nnbench/cli.py b/src/nnbench/cli.py new file mode 100644 index 00000000..aa114742 --- /dev/null +++ b/src/nnbench/cli.py @@ -0,0 +1,49 @@ +import argparse +from typing import Any + +from nnbench import default_reporter, default_runner + + +def main() -> int: + parser = argparse.ArgumentParser("nnbench") + # can be a directory, single file, or glob + parser.add_argument( + "benchmarks", + nargs="?", + metavar="", + help="Python file or directory of files containing benchmarks to run.", + default="benchmarks", + ) + parser.add_argument( + "--context", + action="append", + metavar="=", + help="Additional context values giving information about the benchmark run.", + default=list(), + ) + parser.add_argument( + "-t", + "--tag", + action="append", + metavar="", + dest="tags", + help="Only run benchmarks marked with one or more given tag(s).", + default=tuple(), + ) + + args = parser.parse_args() + + runner = default_runner() + reporter = default_reporter() + + context: dict[str, Any] = {} + for val in args.context: + try: + k, v = val.split("=") + except ValueError: + raise ValueError("context values need to be of the form =") + context[k] = v + + record = runner.run(args.benchmarks, tags=tuple(args.tags)) + reporter.display(record) + return 0