diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index 949ee9c21f..892800d48e 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -490,7 +490,7 @@ def _omni_map_theta_B(params, transforms, profiles, data, **kwargs): parameterization="desc.magnetic_fields._core.OmnigenousField", ) def _omni_map_zeta_B(params, transforms, profiles, data, **kwargs): - return data + return data # noqa: unused dependency @register_compute_fun( diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index cf0a32714e..de4ecf9314 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -176,10 +176,11 @@ def _decorator(func): if name in data_index[base_class]: if p == data_index[base_class][name]["parameterization"]: raise ValueError( - f"Already registered function with parameterization {p} and name {name}." + f"Already registered function with parameterization {p}" + f" and name {name}." ) - # if it was already registered from a parent class, we prefer - # the child class. + # if it was already registered from a parent class, we + # prefer the child class. inheritance_order = [base_class] + superclasses if inheritance_order.index(p) > inheritance_order.index( data_index[base_class][name]["parameterization"] diff --git a/tests/test_data_index.py b/tests/test_data_index.py index b582ff878c..ebb3df2479 100644 --- a/tests/test_data_index.py +++ b/tests/test_data_index.py @@ -11,121 +11,127 @@ from desc.utils import errorif -class TestDataIndex: - """Tests for things related to data_index.""" - - @staticmethod - def get_matches(fun, pattern): - """Return all matches of ``pattern`` in source code of function ``fun``.""" - src = inspect.getsource(fun) - # attempt to remove any decorator functions - # (currently works without this filter, but better to be defensive) +def _get_matches(fun, pattern, ignore_comments=True): + """Return all matches of ``pattern`` in source code of function ``fun``.""" + src = inspect.getsource(fun) + if ignore_comments: + # remove any decorator functions src = src.partition("def ")[2] - # attempt to remove comments + # remove comments src = "\n".join(line.partition("#")[0] for line in src.splitlines()) - matches = pattern.findall(src) - matches = {s.strip().strip('"') for s in matches} - return matches + matches = pattern.findall(src) + matches = {s.strip().strip('"') for s in matches} + return matches + + +def _get_parameterization(fun, default="desc.equilibrium.equilibrium.Equilibrium"): + """Get parameterization of thing computed by function ``fun``.""" + pattern = re.compile(r'parameterization=(?:\[([^]]+)]|"([^"]+)")') + decorator = inspect.getsource(fun).partition("def ")[0] + matches = pattern.findall(decorator) + # if list was found, split strings in list, else string was found so get that + matches = [match[0].split(",") if match[0] else [match[1]] for match in matches] + # flatten the list + matches = {s.strip().strip('"') for sublist in matches for s in sublist} + matches.discard("") + return matches if matches else {default} - @staticmethod - def get_parameterization(fun, default="desc.equilibrium.equilibrium.Equilibrium"): - """Get parameterization of thing computed by function ``fun``.""" - pattern = re.compile(r'parameterization=(?:\[([^]]+)]|"([^"]+)")') - decorator = inspect.getsource(fun).partition("def ")[0] - matches = pattern.findall(decorator) - # if list was found, split strings in list, else string was found so get that - matches = [match[0].split(",") if match[0] else [match[1]] for match in matches] - # flatten the list - matches = {s.strip().strip('"') for sublist in matches for s in sublist} - matches.discard("") - return matches if matches else {default} - @pytest.mark.unit - def test_data_index_deps(self): - """Ensure developers do not add extra (or forget needed) dependencies. +@pytest.mark.unit +def test_data_index_deps(): + """Ensure developers do not add extra (or forget needed) dependencies. - The regular expressions used in this test will fail to detect the data - dependencies in the source code of compute functions if the query to - the key in the data dictionary is split across multiple lines. - To avoid failing this test unnecessarily in this case, try to refactor - code by wrapping the query to the key in the data dictionary inside a - parenthesis. + The regular expressions used in this test will fail to detect the data + dependencies in the source code of compute functions if the query to + the key in the data dictionary is split across multiple lines. + To avoid failing this test unnecessarily in this case, try to refactor + code by wrapping the query to the key in the data dictionary inside a + parenthesis. - Examples - -------- - .. code-block:: python + Examples + -------- + .. code-block:: python - # Don't do this. - x_square = data[ - "x" - ] ** 2 - # Either do this - x_square = ( - data["x"] - ) ** 2 - # or do this - x_square = data["x"] ** 2 + # Don't do this. + x_square = data[ + "x" + ] ** 2 + # Either do this + x_square = ( + data["x"] + ) ** 2 + # or do this + x_square = data["x"] ** 2 - """ - queried_deps = {p: {} for p in _class_inheritance} + """ + queried_deps = {p: {} for p in _class_inheritance} - pattern_names = re.compile(r"(? inheritance_order.index( + data_index[base_class][register_name][ + "parameterization" + ] + ): + continue + queried_deps[base_class][register_name] = deps + aliases = data_index[base_class][register_name]["aliases"] + for alias in aliases: + queried_deps[base_class][alias] = deps - for p in data_index: - for name, val in data_index[p].items(): - err_msg = f"Parameterization: {p}. Name: {name}." - deps = val["dependencies"] - data = set(deps["data"]) - axis_limit_data = set(deps["axis_limit_data"]) - profiles = set(deps["profiles"]) - params = set(deps["params"]) - # assert no duplicate dependencies - assert len(data) == len(deps["data"]), err_msg - assert len(axis_limit_data) == len(deps["axis_limit_data"]), err_msg - assert data.isdisjoint(axis_limit_data), err_msg - assert len(profiles) == len(deps["profiles"]), err_msg - assert len(params) == len(deps["params"]), err_msg - # assert correct dependencies are queried - # TODO: conversion from rpz to xyz is taken out of actual function - # registration because of this data["phi"] is not queried in - # the source code but actually needed for the computation. This - # is a temporary fix until we have a better way to automatically - # handle this. - assert queried_deps[p][name]["data"].issubset( - data | axis_limit_data - ), err_msg - errorif( - name not in queried_deps[p], - AssertionError, - "Did you reuse the function name (i.e. def_...) for" - f" '{name}' for some other quantity?", - ) - assert queried_deps[p][name]["profiles"] == profiles, err_msg - assert queried_deps[p][name]["params"] == params, err_msg + for p in data_index: + for name, val in data_index[p].items(): + err_msg = f"Parameterization: {p}. Name: {name}." + deps = val["dependencies"] + data = set(deps["data"]) + axis_limit_data = set(deps["axis_limit_data"]) + profiles = set(deps["profiles"]) + params = set(deps["params"]) + # assert no duplicate dependencies + assert len(data) == len(deps["data"]), err_msg + assert len(axis_limit_data) == len(deps["axis_limit_data"]), err_msg + assert data.isdisjoint(axis_limit_data), err_msg + assert len(profiles) == len(deps["profiles"]), err_msg + assert len(params) == len(deps["params"]), err_msg + errorif( + name not in queried_deps[p], + AssertionError, + "Did you reuse the function name (i.e. def_...) for" + f" '{name}' for some other quantity?", + ) + # assert correct dependencies are queried + if not queried_deps[p][name]["ignore"]: + assert queried_deps[p][name]["data"] == data | axis_limit_data, err_msg + assert queried_deps[p][name]["profiles"] == profiles, err_msg + assert queried_deps[p][name]["params"] == params, err_msg