Skip to content

Commit

Permalink
Add tests for nnbench.context
Browse files Browse the repository at this point in the history
Puts some unit tests for expected values in provided info, and some
data sanity checks where applicable.
  • Loading branch information
nicholasjng committed Dec 4, 2024
1 parent b210c3e commit d5c0a27
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/nnbench/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def __call__(self) -> dict[str, Any]:
try:
# The CPU frequency is not available on some ARM devices
freq_struct = psutil.cpu_freq()
result["min_frequency"] = freq_struct.min
result["max_frequency"] = freq_struct.max
result["min_frequency"] = float(freq_struct.min)
result["max_frequency"] = float(freq_struct.max)
freq_conversion = self.conversion_table[self.frequnit[0]]
# result is in MHz, so we convert to Hz and apply the conversion factor.
result["frequency"] = freq_struct.current * 1e6 / freq_conversion
except RuntimeError:
result["frequency"] = 0
result["min_frequency"] = 0
result["max_frequency"] = 0
result["frequency"] = 0.0
result["min_frequency"] = 0.0
result["max_frequency"] = 0.0

result["frequency_unit"] = self.frequnit
result["num_cpus"] = psutil.cpu_count(logical=False)
Expand All @@ -205,8 +205,11 @@ def __call__(self) -> dict[str, Any]:
# result is in bytes, so no need for base conversion.
result["total_memory"] = mem_struct.total / mem_conversion
result["memory_unit"] = self.memunit
# TODO: Lacks CPU cache info, which requires a solution other than psutil.
return {self.key: result}


builtin_providers: dict[str, ContextProvider] = {"cpu": CPUInfo(), "python": PythonInfo()}
builtin_providers: dict[str, ContextProvider] = {
"cpu": CPUInfo(),
"git": GitEnvironmentInfo(),
"python": PythonInfo(),
}
56 changes: 56 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
from pathlib import Path

from nnbench.context import CPUInfo, GitEnvironmentInfo, PythonInfo


def test_cpu_info_provider() -> None:
"""Tests CPU info integrity, along with some assumptions about metrics."""

c = CPUInfo()
ctx = c()["cpu"]
for k in [
"architecture",
"system",
"frequency",
"min_frequency",
"max_frequency",
"frequency_unit",
"memory_unit",
]:
assert k in ctx

assert isinstance(ctx["frequency"], float)
assert isinstance(ctx["min_frequency"], float)
assert isinstance(ctx["max_frequency"], float)
assert isinstance(ctx["total_memory"], float)


def test_git_info_provider() -> None:
"""Tests git provider value integrity, along with some data sanity checks."""
g = GitEnvironmentInfo()
# git info needs to be collected inside the nnbench repo, otherwise we get no values.
os.chdir(Path(__file__).parent)
ctx = g()["git"]

# tag is not checked, because that can be empty (e.g. in a shallow repo clone).
for k in ["provider", "repository", "commit"]:
assert k in ctx
assert ctx[k] != "", f"empty value for context {k!r}"

assert ctx["repository"].split("/")[1] == "nnbench"
assert ctx["provider"] == "github.com"


def test_python_info_provider() -> None:
"""Tests Python info, along with an example of Python package version scraping."""
packages = ["rich", "pytest"]
p = PythonInfo(packages=packages)
ctx = p()["python"]

for k in ["version", "implementation", "packages"]:
assert k in ctx

assert list(ctx["packages"].keys()) == packages
for v in ctx["packages"].values():
assert v != ""

0 comments on commit d5c0a27

Please sign in to comment.