Skip to content

Commit

Permalink
Add CLI entrypoints for module and program invocation (#152)
Browse files Browse the repository at this point in the history
* Add CLI entrypoints for module and program invocation

Supports `nnbench <benchmark> <options>` as well as
`python -m nnbench <benchmark> <options>`.

Like pytest, we give the option to omit the target location, in which
case it becomes a directory named "benchmarks".

Supports a very rudimentary context building from switches method, and
tag filtering via switch.
  • Loading branch information
nicholasjng authored Oct 24, 2024
1 parent b66d484 commit 7532f99
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 1 deletion.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ docs = [
"neoteroi-mkdocs",
]

[project.scripts]
nnbench = "nnbench.cli:main"

[tool.setuptools]
package-dir = { "" = "src" }

Expand Down
11 changes: 10 additions & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,13 @@
from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter
from .runner import BenchmarkRunner
from .types import Memo, Parameters
from .types import Benchmark, BenchmarkRecord, Memo, Parameters


# TODO: This isn't great, make it functional instead?
def default_runner() -> BenchmarkRunner:
return BenchmarkRunner()


def default_reporter() -> BenchmarkReporter:
return BenchmarkReporter()
6 changes: 6 additions & 0 deletions src/nnbench/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys

from nnbench.cli import main

if __name__ == "__main__":
sys.exit(main())
49 changes: 49 additions & 0 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
from typing import Any

from nnbench import default_reporter, default_runner


def main() -> int:
parser = argparse.ArgumentParser("nnbench")
# can be a directory, single file, or glob
parser.add_argument(
"benchmarks",
nargs="?",
metavar="<benchmarks>",
help="Python file or directory of files containing benchmarks to run.",
default="benchmarks",
)
parser.add_argument(
"--context",
action="append",
metavar="<key>=<value>",
help="Additional context values giving information about the benchmark run.",
default=list(),
)
parser.add_argument(
"-t",
"--tag",
action="append",
metavar="<tag>",
dest="tags",
help="Only run benchmarks marked with one or more given tag(s).",
default=tuple(),
)

args = parser.parse_args()

runner = default_runner()
reporter = default_reporter()

context: dict[str, Any] = {}
for val in args.context:
try:
k, v = val.split("=")
except ValueError:
raise ValueError("context values need to be of the form <key>=<value>")
context[k] = v

record = runner.run(args.benchmarks, tags=tuple(args.tags))
reporter.display(record)
return 0

0 comments on commit 7532f99

Please sign in to comment.