Skip to content

Commit

Permalink
refactor: Remove context class, back to raw dictionaries
Browse files Browse the repository at this point in the history
With flattening and unflattening as functional methods moved into
`nnbench.util`.

The class was pretty much just a wrapper around a dict with nesting
support on a separator, and didn't provide a lot of value on its own.

Since we already typed the context provider as a callable returning a raw
dictionary, it doesn't create any churn to switch the context out for a
raw dict, either.
  • Loading branch information
nicholasjng committed Oct 29, 2024
1 parent 9173419 commit cf8973d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 241 deletions.
219 changes: 1 addition & 218 deletions src/nnbench/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Utilities for collecting context key-value pairs as metadata in benchmark runs."""

import itertools
import platform
import sys
from collections.abc import Callable, Iterator
from collections.abc import Callable
from typing import Any, Literal

ContextProvider = Callable[[], dict[str, Any]]
Expand Down Expand Up @@ -194,219 +193,3 @@ def __call__(self) -> dict[str, Any]:
result["memory_unit"] = self.memunit
# TODO: Lacks CPU cache info, which requires a solution other than psutil.
return {self.key: result}


class Context:
def __init__(self, data: dict[str, Any] | None = None) -> None:
self._data: dict[str, Any] = data or {}

def __contains__(self, key: str) -> bool:
return key in self.keys()

def __eq__(self, other):
if not isinstance(other, Context):
raise NotImplementedError(
f"cannot compare {type(self)} for equality with type {type(other)}"
)
return self._data.__eq__(other._data)

@property
def data(self):
return self._data

@staticmethod
def _ctx_items(d: dict[str, Any], prefix: str, sep: str) -> Iterator[tuple[str, Any]]:
"""
Iterate over nested dictionary items. Keys are formatted to indicate their nested path.
Parameters
----------
d : dict[str, Any]
Dictionary to iterate over.
prefix : str
Current prefix to prepend to keys, used for recursion to build the full key path.
sep : str
The separator to use between levels of nesting in the key path.
Yields
------
tuple[str, Any]
Iterator over key-value tuples.
"""
for k, v in d.items():
new_key = prefix + sep + k if prefix else k
if isinstance(v, dict):
yield from Context._ctx_items(d=v, prefix=new_key, sep=sep)
else:
yield new_key, v

def keys(self, sep: str = ".") -> Iterator[str]:
"""
Keys of the context dictionary, with an optional separator for nested keys.
Parameters
----------
sep : str, optional
Separator to use for nested keys.
Yields
------
str
Iterator over the context dictionary keys.
"""
for k, _ in self._ctx_items(d=self._data, prefix="", sep=sep):
yield k

def values(self) -> Iterator[Any]:
"""
Values of the context dictionary, including values from nested dictionaries.
Yields
------
Any
Iterator over all values in the context dictionary.
"""
for _, v in self._ctx_items(d=self._data, prefix="", sep=""):
yield v

def items(self, sep: str = ".") -> Iterator[tuple[str, Any]]:
"""
Items (key-value pairs) of the context dictionary, with an separator for nested keys.
Parameters
----------
sep : str, optional
Separator to use for nested dictionary keys.
Yields
------
tuple[str, Any]
Iterator over the items of the context dictionary.
"""
yield from self._ctx_items(d=self._data, prefix="", sep=sep)

def add(self, provider: ContextProvider, replace: bool = False) -> None:
"""
Adds data from a provider to the context.
Parameters
----------
provider : ContextProvider
The provider to inject into this context.
replace : bool
Whether to replace existing context values upon key collision. Raises ValueError otherwise.
"""
self.update(Context.make(provider()), replace=replace)

def update(self, other: "Context", replace: bool = False) -> None:
"""
Updates the context.
Parameters
----------
other : Context
The other context to update this context with.
replace : bool
Whether to replace existing context values upon key collision. Raises ValueError otherwise.
Raises
------
ValueError
If ``other contains top-level keys already present in the context and ``replace=False``.
"""
duplicates = set(self.keys()) & set(other.keys())
if not replace and duplicates:
dupe, *_ = duplicates
raise ValueError(f"got multiple values for context key {dupe!r}")
self._data.update(other._data)

@staticmethod
def _flatten_dict(d: dict[str, Any], prefix: str = "", sep: str = ".") -> dict[str, Any]:
"""
Turn a nested dictionary into a flattened dictionary.
Parameters
----------
d : dict[str, Any]
(Possibly) nested dictionary to flatten.
prefix : str
Key prefix to apply at the top-level (nesting level 0).
sep : str
Separator on which to join keys, "." by default.
Returns
-------
dict[str, Any]
The flattened dictionary.
"""

items: list[tuple[str, Any]] = []
for key, value in d.items():
new_key = prefix + sep + key if prefix else key
if isinstance(value, dict):
items.extend(Context._flatten_dict(d=value, prefix=new_key, sep=sep).items())
else:
items.append((new_key, value))
return dict(items)

def flatten(self, sep: str = ".") -> dict[str, Any]:
"""
Flatten the context's dictionary, converting nested dictionaries into a single dictionary with keys separated by `sep`.
Parameters
----------
sep : str, optional
The separator used to join nested keys.
Returns
-------
dict[str, Any]
The flattened context values as a Python dictionary.
"""

return self._flatten_dict(self._data, prefix="", sep=sep)

@staticmethod
def unflatten(d: dict[str, Any], sep: str = ".") -> dict[str, Any]:
"""
Recursively unflatten a dictionary by expanding keys seperated by `sep` into nested dictionaries.
Parameters
----------
d : dict[str, Any]
The dictionary to unflatten.
sep : str, optional
The separator used in the flattened keys.
Returns
-------
dict[str, Any]
The unflattened dictionary.
"""
sorted_keys = sorted(d.keys())
unflattened = {}
for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]):
key_group = list(keys)
if len(key_group) == 1 and sep not in key_group[0]:
unflattened[prefix] = d[prefix]
else:
nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group}
unflattened[prefix] = Context.unflatten(d=nested_dict, sep=sep)
return unflattened

@classmethod
def make(cls, d: dict[str, Any]) -> "Context":
"""
Create a new Context instance from a given dictionary.
Parameters
----------
d : dict[str, Any]
The initialization dictionary.
Returns
-------
Context
The new Context instance.
"""
return cls(data=cls.unflatten(d))
20 changes: 11 additions & 9 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path
from typing import Any, get_origin

from nnbench.context import Context, ContextProvider
from nnbench.context import ContextProvider
from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State
from nnbench.types.memo import is_memo, is_memo_type
from nnbench.util import import_file_as_module, ismodule
Expand Down Expand Up @@ -268,7 +268,7 @@ def run(
path_or_module: str | os.PathLike[str],
params: dict[str, Any] | Parameters | None = None,
tags: tuple[str, ...] = (),
context: Sequence[ContextProvider] | Context = (),
context: Sequence[ContextProvider] = (),
) -> BenchmarkRecord:
"""
Run a previously collected benchmark workload.
Expand All @@ -284,7 +284,7 @@ def run(
tags: tuple[str, ...]
Tags to filter for when collecting benchmarks. Only benchmarks containing either of
these tags are collected.
context: Sequence[ContextProvider] | Context
context: Sequence[ContextProvider]
Additional context to log with the benchmark in the output JSON record. Useful for
obtaining environment information and configuration, like CPU/GPU hardware info,
ML model metadata, and more.
Expand All @@ -302,12 +302,14 @@ def run(
family_sizes: dict[str, Any] = collections.defaultdict(int)
family_indices: dict[str, Any] = collections.defaultdict(int)

if isinstance(context, Context):
ctx = context
else:
ctx = Context()
for provider in context:
ctx.add(provider)
ctx: dict[str, Any] = {}
for provider in context:
val = provider()
duplicates = set(ctx.keys()) & set(val.keys())
if duplicates:
dupe, *_ = duplicates
raise ValueError(f"got multiple values for context key {dupe!r}")
ctx.update(val)

# if we didn't find any benchmarks, warn and return an empty record.
if not self.benchmarks:
Expand Down
20 changes: 6 additions & 14 deletions src/nnbench/types/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
else:
from typing_extensions import Self

from nnbench.context import Context
from nnbench.types.interface import Interface


Expand All @@ -30,7 +29,7 @@ def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None

@dataclass(frozen=True)
class BenchmarkRecord:
context: Context
context: dict[str, Any]
benchmarks: list[dict[str, Any]]

def compact(
Expand Down Expand Up @@ -63,12 +62,7 @@ def compact(

for b in self.benchmarks:
bc = copy.deepcopy(b)
if mode == "inline":
bc["context"] = self.context.data
elif mode == "flatten":
flat = self.context.flatten(sep=sep)
bc.update(flat)
bc["_contextkeys"] = list(self.context.keys())
bc["context"] = self.context
result.append(bc)
return result

Expand All @@ -90,16 +84,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> Self:
The resulting record with the context extracted.
"""
dctx: dict[str, Any] = {}
ctx: dict[str, Any] = {}
for b in bms:
if "context" in b:
dctx = b.pop("context")
ctx = b.pop("context")
elif "_contextkeys" in b:
ctxkeys = b.pop("_contextkeys")
for k in ctxkeys:
# This should never throw, save for data corruption.
dctx[k] = b.pop(k)
return cls(context=Context.make(dctx), benchmarks=bms)
ctx[k] = b.pop(k)
return cls(context=ctx, benchmarks=bms)

# TODO: Add an expandmany() API for returning a sequence of records for heterogeneous
# context data.
Expand Down Expand Up @@ -151,5 +145,3 @@ class Parameters:
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

pass
25 changes: 25 additions & 0 deletions src/nnbench/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,38 @@

import importlib
import importlib.util
import itertools
import os
import sys
from importlib.machinery import ModuleSpec
from pathlib import Path
from types import ModuleType


def flatten(d: dict, sep: str = ".", prefix: str = "") -> dict:
d_flat = {}
for k, v in d.items():
new_key = prefix + sep + k if prefix else k
if isinstance(v, dict):
d_flat.update(flatten(v, sep=sep, prefix=new_key))
else:
d_flat[k] = v
return d_flat


def unflatten(d: dict, sep: str = ".") -> dict:
sorted_keys = sorted(d.keys())
unflattened = {}
for prefix, keys in itertools.groupby(sorted_keys, key=lambda key: key.split(sep, 1)[0]):
key_group = list(keys)
if len(key_group) == 1 and sep not in key_group[0]:
unflattened[prefix] = d[prefix]
else:
nested_dict = {key.split(sep, 1)[1]: d[key] for key in key_group}
unflattened[prefix] = unflatten(nested_dict, sep=sep)
return unflattened


def ismodule(name: str | os.PathLike[str]) -> bool:
"""Checks if the current interpreter has an available Python module named `name`."""
name = str(name)
Expand Down

0 comments on commit cf8973d

Please sign in to comment.