From a6b271bf0103a27d622b42c0f8b06c6d2e8e87be Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Sat, 26 Oct 2024 11:16:30 +0200 Subject: [PATCH] refactor: Remove context class, back to raw dictionaries With flattening and unflattening as functional methods moved into `nnbench.util`. The class was pretty much just a wrapper around a dict with nesting support on a separator, and didn't provide a lot of value on its own. Since we already typed the context provider as a callable returning a raw dictionary, it doesn't create any churn to switch the context out for a raw dict, either. --- src/nnbench/context.py | 219 +-------------------------------- src/nnbench/runner.py | 20 +-- src/nnbench/types/benchmark.py | 20 +-- src/nnbench/util.py | 25 ++++ 4 files changed, 43 insertions(+), 241 deletions(-) diff --git a/src/nnbench/context.py b/src/nnbench/context.py index c02fa910..80ceedbc 100644 --- a/src/nnbench/context.py +++ b/src/nnbench/context.py @@ -1,9 +1,8 @@ """Utilities for collecting context key-value pairs as metadata in benchmark runs.""" -import itertools import platform import sys -from collections.abc import Callable, Iterator +from collections.abc import Callable from typing import Any, Literal ContextProvider = Callable[[], dict[str, Any]] @@ -194,219 +193,3 @@ def __call__(self) -> dict[str, Any]: result["memory_unit"] = self.memunit # TODO: Lacks CPU cache info, which requires a solution other than psutil. return {self.key: result} - - -class Context: - def __init__(self, data: dict[str, Any] | None = None) -> None: - self._data: dict[str, Any] = data or {} - - def __contains__(self, key: str) -> bool: - return key in self.keys() - - def __eq__(self, other): - if not isinstance(other, Context): - raise NotImplementedError( - f"cannot compare {type(self)} for equality with type {type(other)}" - ) - return self._data.__eq__(other._data) - - @property - def data(self): - return self._data - - @staticmethod - def _ctx_items(d: dict[str, Any], prefix: str, sep: str) -> Iterator[tuple[str, Any]]: - """ - Iterate over nested dictionary items. Keys are formatted to indicate their nested path. - - Parameters - ---------- - d : dict[str, Any] - Dictionary to iterate over. - prefix : str - Current prefix to prepend to keys, used for recursion to build the full key path. - sep : str - The separator to use between levels of nesting in the key path. - - Yields - ------ - tuple[str, Any] - Iterator over key-value tuples. - """ - for k, v in d.items(): - new_key = prefix + sep + k if prefix else k - if isinstance(v, dict): - yield from Context._ctx_items(d=v, prefix=new_key, sep=sep) - else: - yield new_key, v - - def keys(self, sep: str = ".") -> Iterator[str]: - """ - Keys of the context dictionary, with an optional separator for nested keys. - - Parameters - ---------- - sep : str, optional - Separator to use for nested keys. - - Yields - ------ - str - Iterator over the context dictionary keys. - """ - for k, _ in self._ctx_items(d=self._data, prefix="", sep=sep): - yield k - - def values(self) -> Iterator[Any]: - """ - Values of the context dictionary, including values from nested dictionaries. - - Yields - ------ - Any - Iterator over all values in the context dictionary. - """ - for _, v in self._ctx_items(d=self._data, prefix="", sep=""): - yield v - - def items(self, sep: str = ".") -> Iterator[tuple[str, Any]]: - """ - Items (key-value pairs) of the context dictionary, with an separator for nested keys. - - Parameters - ---------- - sep : str, optional - Separator to use for nested dictionary keys. - - Yields - ------ - tuple[str, Any] - Iterator over the items of the context dictionary. - """ - yield from self._ctx_items(d=self._data, prefix="", sep=sep) - - def add(self, provider: ContextProvider, replace: bool = False) -> None: - """ - Adds data from a provider to the context. - - Parameters - ---------- - provider : ContextProvider - The provider to inject into this context. - replace : bool - Whether to replace existing context values upon key collision. Raises ValueError otherwise. - """ - self.update(Context.make(provider()), replace=replace) - - def update(self, other: "Context", replace: bool = False) -> None: - """ - Updates the context. - - Parameters - ---------- - other : Context - The other context to update this context with. - replace : bool - Whether to replace existing context values upon key collision. Raises ValueError otherwise. - - Raises - ------ - ValueError - If ``other contains top-level keys already present in the context and ``replace=False``. - """ - duplicates = set(self.keys()) & set(other.keys()) - if not replace and duplicates: - dupe, *_ = duplicates - raise ValueError(f"got multiple values for context key {dupe!r}") - self._data.update(other._data) - - @staticmethod - def _flatten_dict(d: dict[str, Any], prefix: str = "", sep: str = ".") -> dict[str, Any]: - """ - Turn a nested dictionary into a flattened dictionary. - - Parameters - ---------- - d : dict[str, Any] - (Possibly) nested dictionary to flatten. - prefix : str - Key prefix to apply at the top-level (nesting level 0). - sep : str - Separator on which to join keys, "." by default. - - Returns - ------- - dict[str, Any] - The flattened dictionary. - """ - - items: list[tuple[str, Any]] = [] - for key, value in d.items(): - new_key = prefix + sep + key if prefix else key - if isinstance(value, dict): - items.extend(Context._flatten_dict(d=value, prefix=new_key, sep=sep).items()) - else: - items.append((new_key, value)) - return dict(items) - - def flatten(self, sep: str = ".") -> dict[str, Any]: - """ - Flatten the context's dictionary, converting nested dictionaries into a single dictionary with keys separated by `sep`. - - Parameters - ---------- - sep : str, optional - The separator used to join nested keys. - - Returns - ------- - dict[str, Any] - The flattened context values as a Python dictionary. - """ - - return self._flatten_dict(self._data, prefix="", sep=sep) - - @staticmethod - def unflatten(d: dict[str, Any], sep: str = ".") -> dict[str, Any]: - """ - Recursively unflatten a dictionary by expanding keys seperated by `sep` into nested dictionaries. - - Parameters - ---------- - d : dict[str, Any] - The dictionary to unflatten. - sep : str, optional - The separator used in the flattened keys. - - Returns - ------- - dict[str, Any] - The unflattened dictionary. - """ - sorted_keys = sorted(d.keys()) - unflattened = {} - for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]): - key_group = list(keys) - if len(key_group) == 1 and sep not in key_group[0]: - unflattened[prefix] = d[prefix] - else: - nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group} - unflattened[prefix] = Context.unflatten(d=nested_dict, sep=sep) - return unflattened - - @classmethod - def make(cls, d: dict[str, Any]) -> "Context": - """ - Create a new Context instance from a given dictionary. - - Parameters - ---------- - d : dict[str, Any] - The initialization dictionary. - - Returns - ------- - Context - The new Context instance. - """ - return cls(data=cls.unflatten(d)) diff --git a/src/nnbench/runner.py b/src/nnbench/runner.py index f23993e6..db05e6be 100644 --- a/src/nnbench/runner.py +++ b/src/nnbench/runner.py @@ -14,7 +14,7 @@ from pathlib import Path from typing import Any, get_origin -from nnbench.context import Context, ContextProvider +from nnbench.context import ContextProvider from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State from nnbench.types.memo import is_memo, is_memo_type from nnbench.util import import_file_as_module, ismodule @@ -268,7 +268,7 @@ def run( path_or_module: str | os.PathLike[str], params: dict[str, Any] | Parameters | None = None, tags: tuple[str, ...] = (), - context: Sequence[ContextProvider] | Context = (), + context: Sequence[ContextProvider] = (), ) -> BenchmarkRecord: """ Run a previously collected benchmark workload. @@ -284,7 +284,7 @@ def run( tags: tuple[str, ...] Tags to filter for when collecting benchmarks. Only benchmarks containing either of these tags are collected. - context: Sequence[ContextProvider] | Context + context: Sequence[ContextProvider] Additional context to log with the benchmark in the output JSON record. Useful for obtaining environment information and configuration, like CPU/GPU hardware info, ML model metadata, and more. @@ -302,12 +302,14 @@ def run( family_sizes: dict[str, Any] = collections.defaultdict(int) family_indices: dict[str, Any] = collections.defaultdict(int) - if isinstance(context, Context): - ctx = context - else: - ctx = Context() - for provider in context: - ctx.add(provider) + ctx: dict[str, Any] = {} + for provider in context: + val = provider() + duplicates = set(ctx.keys()) & set(val.keys()) + if duplicates: + dupe, *_ = duplicates + raise ValueError(f"got multiple values for context key {dupe!r}") + ctx.update(val) # if we didn't find any benchmarks, warn and return an empty record. if not self.benchmarks: diff --git a/src/nnbench/types/benchmark.py b/src/nnbench/types/benchmark.py index 566b15f8..83c7216a 100644 --- a/src/nnbench/types/benchmark.py +++ b/src/nnbench/types/benchmark.py @@ -12,7 +12,6 @@ else: from typing_extensions import Self -from nnbench.context import Context from nnbench.types.interface import Interface @@ -30,7 +29,7 @@ def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None @dataclass(frozen=True) class BenchmarkRecord: - context: Context + context: dict[str, Any] benchmarks: list[dict[str, Any]] def compact( @@ -63,12 +62,7 @@ def compact( for b in self.benchmarks: bc = copy.deepcopy(b) - if mode == "inline": - bc["context"] = self.context.data - elif mode == "flatten": - flat = self.context.flatten(sep=sep) - bc.update(flat) - bc["_contextkeys"] = list(self.context.keys()) + bc["context"] = self.context result.append(bc) return result @@ -90,16 +84,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> Self: The resulting record with the context extracted. """ - dctx: dict[str, Any] = {} + ctx: dict[str, Any] = {} for b in bms: if "context" in b: - dctx = b.pop("context") + ctx = b.pop("context") elif "_contextkeys" in b: ctxkeys = b.pop("_contextkeys") for k in ctxkeys: # This should never throw, save for data corruption. - dctx[k] = b.pop(k) - return cls(context=Context.make(dctx), benchmarks=bms) + ctx[k] = b.pop(k) + return cls(context=ctx, benchmarks=bms) # TODO: Add an expandmany() API for returning a sequence of records for heterogeneous # context data. @@ -151,5 +145,3 @@ class Parameters: The main advantage over passing parameters as a dictionary is, of course, static analysis and type safety for your benchmarking code. """ - - pass diff --git a/src/nnbench/util.py b/src/nnbench/util.py index d4a2a8db..df4c087b 100644 --- a/src/nnbench/util.py +++ b/src/nnbench/util.py @@ -2,6 +2,7 @@ import importlib import importlib.util +import itertools import os import sys from importlib.machinery import ModuleSpec @@ -9,6 +10,30 @@ from types import ModuleType +def flatten(d: dict, sep: str = ".", prefix: str = "") -> dict: + d_flat = {} + for k, v in d.items(): + new_key = prefix + sep + k if prefix else k + if isinstance(v, dict): + d_flat.update(flatten(v, sep=sep, prefix=new_key)) + else: + d_flat[k] = v + return d_flat + + +def unflatten(d: dict, sep: str = ".") -> dict: + sorted_keys = sorted(d.keys()) + unflattened = {} + for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]): + key_group = list(keys) + if len(key_group) == 1 and sep not in key_group[0]: + unflattened[prefix] = d[prefix] + else: + nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group} + unflattened[prefix] = unflatten(nested_dict, sep=sep) + return unflattened + + def ismodule(name: str | os.PathLike[str]) -> bool: """Checks if the current interpreter has an available Python module named `name`.""" name = str(name)