Skip to content

Commit

Permalink
Add API docs for nnbench.types and nnbench.reporter
Browse files Browse the repository at this point in the history
Completes the nnbench public API documentation.
  • Loading branch information
nicholasjng committed Dec 2, 2024
1 parent 49332e4 commit 95c5d35
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A framework for organizing and running benchmark workloads on machine learning models."""

from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter, ConsoleReporter
from .reporter import BenchmarkReporter, ConsoleReporter, FileReporter
from .runner import BenchmarkRunner
from .types import Benchmark, BenchmarkRecord, Memo, Parameters

Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main() -> int:
builtin_providers[p.name] = klass(**p.arguments)
for val in args.context:
try:
k, v = val.split("=")
k, v = val.split("=", 1)
except ValueError:
raise ValueError("context values need to be of the form <key>=<value>")
if k == "provider":
Expand Down
3 changes: 2 additions & 1 deletion src/nnbench/reporter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
A lightweight interface for refining, displaying, and streaming benchmark results to various sinks.
An interface for displaying, writing, or streaming benchmark results to
files, databases, or web services.
"""

from .base import BenchmarkReporter
Expand Down
21 changes: 17 additions & 4 deletions src/nnbench/reporter/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,34 @@ def get_value_by_name(result: dict[str, Any]) -> str:
class ConsoleReporter(BenchmarkReporter):
"""
The base interface for a console reporter class.
Wraps a ``rich.Console()`` to display values in a rich-text table.
"""

def __init__(self, *args, **kwargs):
"""
Initialize a console reporter.
Parameters
----------
*args: Any
Positional arguments, unused.
**kwargs: Any
Keyword arguments, forwarded directly to ``rich.Console()``.
"""
super().__init__(*args, **kwargs)
# TODO: Add context manager to register live console prints
self.console = Console(**kwargs)

def display(self, record: BenchmarkRecord) -> None:
"""
Display a benchmark record in the console.
Display a benchmark record in the console as a rich-text table.
Benchmarks and context values will be filtered before display
if any filtering is applied.
Gives a summary of all present context values directly above the table,
as a pretty-printed JSON record.
Columns that do not contain any useful information are omitted by default.
By default, displays only the benchmark name, value, execution wall time,
and parameters.
Parameters
----------
Expand Down
15 changes: 9 additions & 6 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

_Options = dict[str, Any]
SerDe = tuple[
Callable[[BenchmarkRecord, IO, dict[str, Any]], None],
Callable[[IO, dict[str, Any]], BenchmarkRecord],
Callable[[BenchmarkRecord, IO, _Options], None],
Callable[[IO, _Options], BenchmarkRecord],
]


Expand Down Expand Up @@ -183,9 +183,10 @@ def read(
Parameters
----------
file: str | os.PathLike[str] | IO[str]
The file name to read from.
The file name, or object, to read from.
mode: str
File mode to use. Can be any of the modes used in builtin ``open()``.
Mode to use when opening a new file from a path.
Can be any of the read modes supported by built-in ``open()``.
driver: str | None
File driver implementation to use. If None, the file driver inferred from the
given file path's extension will be used.
Expand Down Expand Up @@ -251,9 +252,10 @@ def write(
record: BenchmarkRecord
The record to write to the database.
file: str | os.PathLike[str]
The file name to write to.
The file name, or object, to write to.
mode: str
File mode to use. Can be any of the modes used in builtin ``open()``.
Mode to use when opening a new file from a path.
Can be any of the write modes supported by built-in ``open()``.
driver: str | None
File driver implementation to use. If None, the file driver inferred from the
given file path's extension will be used.
Expand All @@ -266,6 +268,7 @@ def write(
If no registered file driver matches the file extension and no other driver
was explicitly specified.
"""
# TODO: Guard against file
driver = driver or get_extension(file)

if not driver:
Expand Down
56 changes: 32 additions & 24 deletions src/nnbench/types/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,38 @@

@dataclass(frozen=True)
class State:
"""
A dataclass holding some basic information about a benchmark and its hierarchy
inside its *family* (i.e. a series of the same benchmark for different parameters).
For benchmarks registered with ``@nnbench.benchmark``, meaning no parametrization,
each benchmark constitutes its own family, and ``family_size == 1`` holds true.
"""

name: str
family: str
family_size: int
family_index: int


def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
"""A no-op setup/teardown callback that does nothing."""
pass


@dataclass(frozen=True)
class BenchmarkRecord:
"""
A dataclass representing the result of a benchmark run, i.e. the return value
of a call to ``BenchmarkRunner.run()``.
"""

run: str
"""A name describing the run."""
context: dict[str, Any]
"""A map of key-value pairs describing context information around the benchmark run."""
benchmarks: list[dict[str, Any]]
"""The list of benchmark results, each given as a Python dictionary."""

def to_json(self) -> dict[str, Any]:
"""
Expand All @@ -47,7 +64,7 @@ def to_json(self) -> dict[str, Any]:
def to_list(self) -> list[dict[str, Any]]:
"""
Export a benchmark record to a list of individual results,
each with the benchmark context inlined.
each with the benchmark run name and context inlined.
"""
results = []
for b in self.benchmarks:
Expand All @@ -72,8 +89,7 @@ def expand(cls, bms: dict[str, Any] | list[dict[str, Any]]) -> Self:
Returns
-------
BenchmarkRecord
The resulting record with the context extracted.
The resulting record, with the context and run name extracted.
"""
context: dict[str, Any]
if isinstance(bms, dict):
Expand Down Expand Up @@ -102,32 +118,22 @@ def expand(cls, bms: dict[str, Any] | list[dict[str, Any]]) -> Self:
class Benchmark:
"""
Data model representing a benchmark. Subclass this to define your own custom benchmark.
Parameters
----------
fn: Callable[..., Any]
The function defining the benchmark.
name: str | None
A name to display for the given benchmark. If not given, will be constructed from the
function name and given parameters.
params: dict[str, Any]
A partial parametrization to apply to the benchmark function. Internal only,
you should not need to set this yourself.
setUp: Callable[..., None]
A setup hook run before the benchmark. Must take all members of `params` as inputs.
tearDown: Callable[..., None]
A teardown hook run after the benchmark. Must take all members of `params` as inputs.
tags: tuple[str, ...]
Additional tags to attach for bookkeeping and selective filtering during runs.
"""

fn: Callable[..., Any]
"""The function defining the benchmark."""
name: str = ""
"""A name to display for the given benchmark. If not given, a name will be constructed from the function name and given parameters."""
params: dict[str, Any] = field(default_factory=dict)
"""A partial parametrization to apply to the benchmark function. Internal only, you should not need to set this yourself."""
setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
"""A setup hook run before the benchmark. Must take all members of ``params`` as inputs."""
tearDown: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
"""A teardown hook run after the benchmark. Must take all members of ``params`` as inputs."""
tags: tuple[str, ...] = field(repr=False, default=())
"""Additional tags to attach for bookkeeping and selective filtering during runs."""
interface: Interface = field(init=False, repr=False)
"""Benchmark interface, constructed from the given function. Implementation detail."""

def __post_init__(self):
if not self.name:
Expand All @@ -138,9 +144,11 @@ def __post_init__(self):
@dataclass(init=False, frozen=True)
class Parameters:
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
A dataclass designed to hold benchmark parameters.
This class is not functional on its own, and needs to be subclassed
according to your benchmarking workloads.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
The main advantage over passing parameters as a dictionary are static analysis
and type safety for your benchmarking code.
"""
29 changes: 14 additions & 15 deletions src/nnbench/types/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,33 @@
@dataclass(frozen=True)
class Interface:
"""
Data model representing a function's interface. An instance of this class
is created using the `from_callable` class method.
Parameters:
----------
names : tuple[str, ...]
Names of the function parameters.
types : tuple[type, ...]
Types of the function parameters.
defaults : tuple
A tuple of the function parameters' default values.
variables : tuple[Variable, ...]
A tuple of tuples, where each inner tuple contains the parameter name and type.
returntype: type
The function's return type annotation, or NoneType if left untyped.
Data model representing a function's interface.
An instance of this class is created using the ``Interface.from_callable()``
class method.
"""

funcname: str
"""Name of the function."""
names: tuple[str, ...]
"""Names of the function parameters."""
types: tuple[type, ...]
"""Type hints of the function parameters."""
defaults: tuple
"""The function parameters' default values, or inspect.Parameter.empty if a parameter has no default."""
variables: tuple[Variable, ...]
"""A tuple of tuples, where each inner tuple contains the parameter name, type, and default value."""
returntype: type
"""The function's return type annotation, or NoneType if left untyped."""

@classmethod
def from_callable(cls, fn: Callable, defaults: dict[str, Any]) -> Self:
"""
Creates an interface instance from the given callable.
Wraps the information given by ``inspect.signature()``, with the option to
supply a ``defaults`` map and overwrite any default set in the function's
signature.
"""
# Set `follow_wrapped=False` to get the partially filled interfaces.
# Otherwise we get missing value errors for parameters supplied in benchmark decorators.
Expand Down

0 comments on commit 95c5d35

Please sign in to comment.