diff --git a/src/nnbench/context.py b/src/nnbench/context.py index 8964507..27a3c33 100644 --- a/src/nnbench/context.py +++ b/src/nnbench/context.py @@ -186,15 +186,15 @@ def __call__(self) -> dict[str, Any]: try: # The CPU frequency is not available on some ARM devices freq_struct = psutil.cpu_freq() - result["min_frequency"] = freq_struct.min - result["max_frequency"] = freq_struct.max + result["min_frequency"] = float(freq_struct.min) + result["max_frequency"] = float(freq_struct.max) freq_conversion = self.conversion_table[self.frequnit[0]] # result is in MHz, so we convert to Hz and apply the conversion factor. result["frequency"] = freq_struct.current * 1e6 / freq_conversion except RuntimeError: - result["frequency"] = 0 - result["min_frequency"] = 0 - result["max_frequency"] = 0 + result["frequency"] = 0.0 + result["min_frequency"] = 0.0 + result["max_frequency"] = 0.0 result["frequency_unit"] = self.frequnit result["num_cpus"] = psutil.cpu_count(logical=False) @@ -205,8 +205,11 @@ def __call__(self) -> dict[str, Any]: # result is in bytes, so no need for base conversion. result["total_memory"] = mem_struct.total / mem_conversion result["memory_unit"] = self.memunit - # TODO: Lacks CPU cache info, which requires a solution other than psutil. return {self.key: result} -builtin_providers: dict[str, ContextProvider] = {"cpu": CPUInfo(), "python": PythonInfo()} +builtin_providers: dict[str, ContextProvider] = { + "cpu": CPUInfo(), + "git": GitEnvironmentInfo(), + "python": PythonInfo(), +} diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..ad9a52a --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,56 @@ +import os +from pathlib import Path + +from nnbench.context import CPUInfo, GitEnvironmentInfo, PythonInfo + + +def test_cpu_info_provider() -> None: + """Tests CPU info integrity, along with some assumptions about metrics.""" + + c = CPUInfo() + ctx = c()["cpu"] + for k in [ + "architecture", + "system", + "frequency", + "min_frequency", + "max_frequency", + "frequency_unit", + "memory_unit", + ]: + assert k in ctx + + assert isinstance(ctx["frequency"], float) + assert isinstance(ctx["min_frequency"], float) + assert isinstance(ctx["max_frequency"], float) + assert isinstance(ctx["total_memory"], float) + + +def test_git_info_provider() -> None: + """Tests git provider value integrity, along with some data sanity checks.""" + g = GitEnvironmentInfo() + # git info needs to be collected inside the nnbench repo, otherwise we get no values. + os.chdir(Path(__file__).parent) + ctx = g()["git"] + + # tag is not checked, because that can be empty (e.g. in a shallow repo clone). + for k in ["provider", "repository", "commit"]: + assert k in ctx + assert ctx[k] != "", f"empty value for context {k!r}" + + assert ctx["repository"].split("/")[1] == "nnbench" + assert ctx["provider"] == "github.com" + + +def test_python_info_provider() -> None: + """Tests Python info, along with an example of Python package version scraping.""" + packages = ["rich", "pytest"] + p = PythonInfo(packages=packages) + ctx = p()["python"] + + for k in ["version", "implementation", "packages"]: + assert k in ctx + + assert list(ctx["packages"].keys()) == packages + for v in ctx["packages"].values(): + assert v != ""