Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure correct data types in getter methods #1030

Merged
merged 22 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def NFP(self):

@property
def sym(self):
"""str: {``'cos'``, ``'sin'``, ``False``} Type of symmetry."""
"""str: Type of symmetry."""
# one of: {'even', 'sin', 'cos', 'cos(t)', False}
return self.__dict__.setdefault("_sym", False)

@property
Expand Down Expand Up @@ -238,7 +239,7 @@ def __init__(self, L, sym="even"):
self._M = 0
self._N = 0
self._NFP = 1
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
kianorr marked this conversation as resolved.
Show resolved Hide resolved
self._spectral_indexing = "linear"

self._modes = self._get_modes(L=self.L)
Expand Down Expand Up @@ -351,7 +352,7 @@ def __init__(self, N, NFP=1, sym=False):
self._M = 0
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(N=self.N)
Expand Down Expand Up @@ -474,7 +475,7 @@ def __init__(self, M, N, NFP=1, sym=False):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(M=self.M, N=self.N)
Expand Down Expand Up @@ -635,8 +636,8 @@ def __init__(self, L, M, sym=False, spectral_indexing="ansi"):
self._M = check_nonnegint(M, "M", False)
self._N = 0
self._NFP = 1
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = str(spectral_indexing)

self._modes = self._get_modes(
L=self.L, M=self.M, spectral_indexing=self.spectral_indexing
Expand Down Expand Up @@ -831,7 +832,7 @@ def __init__(self, L, M, N, NFP=1, sym=False):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(L=self.L, M=self.M, N=self.N)
Expand Down Expand Up @@ -983,8 +984,8 @@ def __init__(self, L, M, N, NFP=1, sym=False, spectral_indexing="ansi"):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = str(spectral_indexing)

self._modes = self._get_modes(
L=self.L, M=self.M, N=self.N, spectral_indexing=self.spectral_indexing
Expand Down
7 changes: 4 additions & 3 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,14 +748,14 @@ def __init__(self, *coils, NFP=1, sym=False, name=""):
assert all([isinstance(coil, (_Coil)) for coil in coils])
[_check_type(coil, coils[0]) for coil in coils]
self._coils = list(coils)
self._NFP = NFP
self._sym = sym
self._NFP = int(NFP)
self._sym = bool(sym)
self._name = str(name)

@property
def name(self):
"""str: Name of the curve."""
return self._name
return self.__dict__.setdefault("_name", "")

@name.setter
def name(self, new):
Expand Down Expand Up @@ -837,6 +837,7 @@ def compute(
params = [get_params(names, coil) for coil in self]
if data is None:
data = [{}] * len(self)

# if user supplied initial data for each coil we also need to vmap over that.
data = vmap(
lambda d, x: self[0].compute(
Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
ValueError,
f"sym should be one of True, False, None, got {sym}",
)
self._sym = setdefault(sym, getattr(surface, "sym", False))
self._sym = bool(setdefault(sym, getattr(surface, "sym", False)))
self._R_sym = "cos" if self.sym else False
self._Z_sym = "sin" if self.sym else False

Expand Down Expand Up @@ -564,7 +564,7 @@ def change_resolution(
self._M_grid = int(setdefault(M_grid, self.M_grid))
self._N_grid = int(setdefault(N_grid, self.N_grid))
self._NFP = int(setdefault(NFP, self.NFP))
self._sym = setdefault(sym, self.sym)
self._sym = bool(setdefault(sym, self.sym))

old_modes_R = self.R_basis.modes
old_modes_Z = self.Z_basis.modes
Expand Down
8 changes: 4 additions & 4 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def rotmat(self, new):
@property
def name(self):
"""Name of the curve."""
return self._name
return self.__dict__.setdefault("_name", "")

@name.setter
def name(self, new):
self._name = new
self._name = str(new)

def compute(
self,
Expand Down Expand Up @@ -323,11 +323,11 @@ def _set_up(self):
@property
def name(self):
"""str: Name of the surface."""
return self._name
return self.__dict__.setdefault("_name", "")

@name.setter
def name(self, new):
self._name = new
self._name = str(new)

@property
def L(self):
Expand Down
4 changes: 2 additions & 2 deletions desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(

@property
def sym(self):
"""Whether this curve has stellarator symmetry."""
"""bool: Whether or not the curve is stellarator symmetric."""
return self._sym

@property
Expand Down Expand Up @@ -128,7 +128,7 @@ def change_resolution(self, N=None, NFP=None, sym=None):
and (sym != self.sym)
):
self._NFP = int(NFP if NFP is not None else self.NFP)
self._sym = sym if sym is not None else self.sym
self._sym = bool(sym) if sym is not None else self.sym
N = int(N if N is not None else self.N)
R_modes_old = self.R_basis.modes
Z_modes_old = self.Z_basis.modes
Expand Down
6 changes: 3 additions & 3 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(

self._R_lmn = copy_coeffs(R_lmn, modes_R, self.R_basis.modes[:, 1:])
self._Z_lmn = copy_coeffs(Z_lmn, modes_Z, self.Z_basis.modes[:, 1:])
self._sym = sym
self._sym = bool(sym)
self._rho = rho

if check_orientation and self._compute_orientation() == -1:
Expand Down Expand Up @@ -870,8 +870,8 @@ def __init__(

self._R_lmn = copy_coeffs(R_lmn, modes_R, self.R_basis.modes[:, :2])
self._Z_lmn = copy_coeffs(Z_lmn, modes_Z, self.Z_basis.modes[:, :2])
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym)
self._spectral_indexing = str(spectral_indexing)

self._zeta = zeta

Expand Down
2 changes: 1 addition & 1 deletion desc/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def desc_output_to_input( # noqa: C901 - fxn too complex
Fourier coefficients below this value will be set to 0.
"""
from desc.grid import LinearGrid
from desc.io.equilibrium_io import load
from desc.io.optimizable_io import load
from desc.profiles import PowerSeriesProfile
from desc.utils import copy_coeffs

Expand Down
4 changes: 2 additions & 2 deletions desc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Functions and classes for reading and writing DESC data."""

# InputReader lives outside this module for import ordering reasons, so we can
# import InputReader in __main__ without importing equilibrium_io which imports JAX
# import InputReader in __main__ without importing optimizable_io which imports JAX
# stuff potentially before we've set the GPU correctly.
# We include a link to it here for backwards compatibility
from desc.input_reader import InputReader

from .ascii_io import read_ascii, write_ascii
from .equilibrium_io import IOAble, load
from .hdf5_io import hdf5Reader, hdf5Writer
from .optimizable_io import IOAble, load
from .pickle_io import PickleReader, PickleWriter

__all__ = ["InputReader", "load"]
2 changes: 1 addition & 1 deletion desc/io/hdf5_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def isarray(x):
group = loc.create_group(attr)
self.write_list(data, where=group)
else:
from .equilibrium_io import IOAble
from .optimizable_io import IOAble

if isinstance(data, IOAble):
group = loc.create_group(attr)
Expand Down
3 changes: 3 additions & 0 deletions desc/io/equilibrium_io.py → desc/io/optimizable_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def load(load_from, file_format=None):
-------
obj :
The object saved in the file

"""
if file_format is None and isinstance(load_from, (str, os.PathLike)):
name = str(load_from)
Expand Down Expand Up @@ -83,6 +84,8 @@ def _unjittable(x):
return any([_unjittable(y) for y in x])
if isinstance(x, dict):
return any([_unjittable(y) for y in x.values()])
if hasattr(x, "dtype") and np.ndim(x) == 0:
return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_)
return isinstance(x, (str, types.FunctionType, bool, int, np.int_))


Expand Down
124 changes: 48 additions & 76 deletions desc/objectives/_coils.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While not directly related to the underlying issue for this PR, a lot of the data type errors were coming from the _CoilObjective.build method. The pytree stuff we had in here was really clunky, like creating a MixedCoilSet that contained _Grids instead of _Coils. I tried to simplify the logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked at it a little bit but haven't gotten the chance to understand the changes yet

Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
of the objective. Has no effect on self.grad or self.hess which always use
reverse mode and forward over reverse mode respectively.
grid : Grid, list, optional
Collocation grid containing the nodes to evaluate at. If list, has to adhere to
Objective.dim_f
Collocation grid containing the nodes to evaluate at.
If a list, must have the same structure as coil.
name : str, optional
Name of the objective function.

Expand Down Expand Up @@ -96,102 +96,74 @@

"""
# local import to avoid circular import
from desc.coils import CoilSet, MixedCoilSet
from desc.coils import CoilSet, MixedCoilSet, _Coil

self._dim_f = 0
self._quad_weights = jnp.array([])
def _is_single_coil(c):
return isinstance(c, _Coil) and not isinstance(c, CoilSet)

def to_list(coilset):
"""Turn a MixedCoilSet container into a list of what it's containing."""
if isinstance(coilset, list):
return [to_list(x) for x in coilset]
elif isinstance(coilset, MixedCoilSet):
return [to_list(x) for x in coilset]
def _prune_coilset_tree(coilset):
"""Remove extra members from CoilSets (but not MixedCoilSets)."""
if isinstance(coilset, list) or isinstance(coilset, MixedCoilSet):
return [_prune_coilset_tree(c) for c in coilset]
elif isinstance(coilset, CoilSet):
# use the same grid/transform for CoilSet
return to_list(coilset.coils[0])
# CoilSet only uses a single grid/transform for all coils
return _prune_coilset_tree(coilset.coils[0])
else:
return [coilset]
return coilset # single coil

coil = self.things[0]
grid = self._grid

# get individual coils from coilset
coils, structure = tree_flatten(coil, is_leaf=_is_single_coil)
self._num_coils = len(coils)

# map grid to list of length coils
if grid is None:
grid = [LinearGrid(N=2 * c.N + 5, endpoint=False) for c in coils]
if isinstance(grid, numbers.Integral):
grid = LinearGrid(N=self._grid, endpoint=False)

Check warning on line 125 in desc/objectives/_coils.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_coils.py#L125

Added line #L125 was not covered by tests
if isinstance(grid, _Grid):
grid = [grid] * self._num_coils
if isinstance(grid, list):
grid = tree_leaves(grid, is_leaf=lambda g: isinstance(g, _Grid))
assert len(grid) == len(coils)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to make this a ValueError with a more descriptive error message?

Or maybe put a separate check after if isinstance(grid, list): since that's the only case where it could cause problems

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also what if the user passes in different grids for each coil in a regular CoilSet? I think in that case the code as is now will only use the first grid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re your second question: yes the code will only use the first grid. That is how the existing code also behaves, so this PR is not changing that. My changes earlier this week corrected this, but then I reverted them because my solution didn't work for nested coilsets. We could try to fix this in the future.


# gives structure of coils, e.g. MixedCoilSet(coils, coils) would give a
# a structure of [[*, *], [*, *]] if n = 2 coils
coil_leaves, coil_structure = tree_flatten(
self.things[0], is_leaf=lambda x: not hasattr(x, "__len__")
errorif(
np.any([g.num_rho > 1 or g.num_theta > 1 for g in grid]),
ValueError,
"Only use toroidal resolution for coil grids.",
)
self._num_coils = len(coil_leaves)

# check type
if isinstance(self._grid, numbers.Integral):
self._grid = LinearGrid(N=self._grid, endpoint=False)
# all of these cases return a container MixedCoilSet that contains
# LinearGrids. i.e. MixedCoilSet.coils = list of LinearGrid
if self._grid is None:
# map default grid to structure of inputted coils
self._grid = tree_map(
lambda x: LinearGrid(
N=2 * x.N + 5, NFP=getattr(x, "NFP", 1), endpoint=False
),
self.things[0],
is_leaf=lambda x: not hasattr(x, "__len__"),
)
elif isinstance(self._grid, _Grid):
# map inputted single LinearGrid to structure of inputted coils
self._grid = [self._grid] * self._num_coils
self._grid = tree_unflatten(coil_structure, self._grid)
else:
# this case covers an inputted list of grids that matches the size
# of the inputted coils. Can be a 1D list or nested list.
flattened_grid = tree_leaves(
self._grid, is_leaf=lambda x: isinstance(x, _Grid)
)
self._grid = tree_unflatten(coil_structure, flattened_grid)

self._dim_f = np.sum([g.num_nodes for g in grid])
quad_weights = np.concatenate([g.spacing[:, 2] for g in grid])

timer = Timer()
if verbose > 0:
print("Precomputing transforms")
timer.start("Precomputing transforms")

# map grid/transform to the same structure as coil
grid = tree_unflatten(structure, grid)
transforms = tree_map(
lambda x, y: get_transforms(self._data_keys, obj=x, grid=y),
self.things[0],
self._grid,
is_leaf=lambda x: not hasattr(x, "__len__"),
)

grids = tree_leaves(self._grid, is_leaf=lambda x: hasattr(x, "num_nodes"))
self._dim_f = np.sum([grid.num_nodes for grid in grids])
self._quad_weights = np.concatenate([grid.spacing[:, 2] for grid in grids])

# get only needed grids (1 per CoilSet) and flatten that list
self._grid = tree_leaves(
to_list(self._grid), is_leaf=lambda x: isinstance(x, _Grid)
)
transforms = tree_leaves(
to_list(transforms), is_leaf=lambda x: isinstance(x, dict)
)

errorif(
np.any([grid.num_rho > 1 or grid.num_theta > 1 for grid in self._grid]),
ValueError,
"Only use toroidal resolution for coil grids.",
lambda c, g: get_transforms(self._data_keys, obj=c, grid=g),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor point: This will get more transforms than needed, since the ones for CoilSet are later pruned to just the unique ones. Would be nice if we can avoid that redundant calculation but not sure how feasible it is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've resolved this in the latest commit

coil,
grid,
is_leaf=lambda x: _is_single_coil(x) or isinstance(x, _Grid),
)

# CoilSet and _Coil have one grid/transform
if not isinstance(self.things[0], MixedCoilSet):
self._grid = self._grid[0]
transforms = transforms[0]
# remove unnecessary members from coil tree structure
self._grid = _prune_coilset_tree(grid)
transforms = _prune_coilset_tree(transforms)

self._constants = {
"transforms": transforms,
"quad_weights": self._quad_weights,
}
self._constants = {"transforms": transforms, "quad_weights": quad_weights}

timer.stop("Precomputing transforms")
if verbose > 1:
timer.disp("Precomputing transforms")

if self._normalize:
self._scales = [compute_scaling_factors(coil) for coil in coil_leaves]
self._scales = [compute_scaling_factors(coil) for coil in coils]

super().build(use_jit=use_jit, verbose=verbose)

Expand Down
Loading
Loading