Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure correct data types in getter methods #1030

Merged
merged 22 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,27 +169,32 @@ def change_resolution(self):
@property
def L(self):
"""int: Maximum radial resolution."""
return self.__dict__.setdefault("_L", 0)
return int(self.__dict__.setdefault("_L", 0))

@property
def M(self):
"""int: Maximum poloidal resolution."""
return self.__dict__.setdefault("_M", 0)
return int(self.__dict__.setdefault("_M", 0))

@property
def N(self):
"""int: Maximum toroidal resolution."""
return self.__dict__.setdefault("_N", 0)
return int(self.__dict__.setdefault("_N", 0))

@property
def NFP(self):
"""int: Number of field periods."""
return self.__dict__.setdefault("_NFP", 1)
return int(self.__dict__.setdefault("_NFP", 1))

@property
def sym(self):
"""str: {``'cos'``, ``'sin'``, ``False``} Type of symmetry."""
return self.__dict__.setdefault("_sym", False)
"""str: Type of symmetry."""
# one of: {'even', 'sin', 'cos', 'cos(t)', False}
sym = self.__dict__.setdefault("_sym", False)
if not sym:
return bool(sym)
else:
return str(sym)

@property
def modes(self):
Expand All @@ -199,12 +204,12 @@ def modes(self):
@property
def num_modes(self):
"""int: Total number of modes in the spectral basis."""
return self.modes.shape[0]
return int(self.modes.shape[0])

@property
def spectral_indexing(self):
"""str: Type of indexing used for the spectral basis."""
return self.__dict__.setdefault("_spectral_indexing", "linear")
return str(self.__dict__.setdefault("_spectral_indexing", "linear"))

def __repr__(self):
"""Get the string form of the object."""
Expand Down Expand Up @@ -238,7 +243,7 @@ def __init__(self, L, sym="even"):
self._M = 0
self._N = 0
self._NFP = 1
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
kianorr marked this conversation as resolved.
Show resolved Hide resolved
self._spectral_indexing = "linear"

self._modes = self._get_modes(L=self.L)
Expand Down Expand Up @@ -351,7 +356,7 @@ def __init__(self, N, NFP=1, sym=False):
self._M = 0
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(N=self.N)
Expand Down Expand Up @@ -474,7 +479,7 @@ def __init__(self, M, N, NFP=1, sym=False):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(M=self.M, N=self.N)
Expand Down Expand Up @@ -635,8 +640,8 @@ def __init__(self, L, M, sym=False, spectral_indexing="ansi"):
self._M = check_nonnegint(M, "M", False)
self._N = 0
self._NFP = 1
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = str(spectral_indexing)

self._modes = self._get_modes(
L=self.L, M=self.M, spectral_indexing=self.spectral_indexing
Expand Down Expand Up @@ -831,7 +836,7 @@ def __init__(self, L, M, N, NFP=1, sym=False):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = "linear"

self._modes = self._get_modes(L=self.L, M=self.M, N=self.N)
Expand Down Expand Up @@ -983,8 +988,8 @@ def __init__(self, L, M, N, NFP=1, sym=False, spectral_indexing="ansi"):
self._M = check_nonnegint(M, "M", False)
self._N = check_nonnegint(N, "N", False)
self._NFP = check_posint(NFP, "NFP", False)
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym) if not sym else str(sym)
self._spectral_indexing = str(spectral_indexing)

self._modes = self._get_modes(
L=self.L, M=self.M, N=self.N, spectral_indexing=self.spectral_indexing
Expand Down
10 changes: 5 additions & 5 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,14 +748,14 @@ def __init__(self, *coils, NFP=1, sym=False, name=""):
assert all([isinstance(coil, (_Coil)) for coil in coils])
[_check_type(coil, coils[0]) for coil in coils]
self._coils = list(coils)
self._NFP = NFP
self._sym = sym
self._NFP = int(NFP)
self._sym = bool(sym)
self._name = str(name)

@property
def name(self):
"""str: Name of the curve."""
return self._name
return str(self.__dict__.setdefault("_name", ""))

@name.setter
def name(self, new):
Expand All @@ -769,12 +769,12 @@ def coils(self):
@property
def NFP(self):
"""int: Number of (toroidal) field periods."""
return self._NFP
return int(self._NFP)

@property
def sym(self):
"""bool: Whether this coil set is stellarator symmetric."""
return self._sym
return bool(self._sym)

@property
def current(self):
Expand Down
14 changes: 7 additions & 7 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
ValueError,
f"sym should be one of True, False, None, got {sym}",
)
self._sym = setdefault(sym, getattr(surface, "sym", False))
self._sym = bool(setdefault(sym, getattr(surface, "sym", False)))
self._R_sym = "cos" if self.sym else False
self._Z_sym = "sin" if self.sym else False

Expand Down Expand Up @@ -565,7 +565,7 @@ def change_resolution(
self._M_grid = int(setdefault(M_grid, self.M_grid))
self._N_grid = int(setdefault(N_grid, self.N_grid))
self._NFP = int(setdefault(NFP, self.NFP))
self._sym = setdefault(sym, self.sym)
self._sym = bool(setdefault(sym, self.sym))

old_modes_R = self.R_basis.modes
old_modes_Z = self.Z_basis.modes
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def spectral_indexing(self):
@property
def sym(self):
"""bool: Whether this equilibrium is stellarator symmetric."""
return self._sym
return bool(self._sym)

@property
def bdry_mode(self):
Expand All @@ -1240,22 +1240,22 @@ def Psi(self, Psi):
@property
def NFP(self):
"""int: Number of (toroidal) field periods."""
return self._NFP
return int(self._NFP)

@property
def L(self):
"""int: Maximum radial mode number."""
return self._L
return int(self._L)

@property
def M(self):
"""int: Maximum poloidal fourier mode number."""
return self._M
return int(self._M)

@property
def N(self):
"""int: Maximum toroidal fourier mode number."""
return self._N
return int(self._N)

@optimizable_parameter
@property
Expand Down
16 changes: 8 additions & 8 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def rotmat(self, new):
@property
def name(self):
"""Name of the curve."""
return self._name
return str(self.__dict__.setdefault("_name", ""))

@name.setter
def name(self, new):
self._name = new
self._name = str(new)

def compute(
self,
Expand Down Expand Up @@ -323,31 +323,31 @@ def _set_up(self):
@property
def name(self):
"""str: Name of the surface."""
return self._name
return str(self.__dict__.setdefault("_name", ""))

@name.setter
def name(self, new):
self._name = new
self._name = str(new)

@property
def L(self):
"""int: Maximum radial mode number."""
return self._L
return int(self._L)

@property
def M(self):
"""int: Maximum poloidal mode number."""
return self._M
return int(self._M)

@property
def N(self):
"""int: Maximum toroidal mode number."""
return self._N
return int(self._N)

@property
def sym(self):
"""bool: Whether or not the surface is stellarator symmetric."""
return self._sym
return bool(self._sym)

def _compute_orientation(self):
"""Handedness of coordinate system.
Expand Down
16 changes: 8 additions & 8 deletions desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def __init__(

@property
def sym(self):
"""Whether this curve has stellarator symmetry."""
return self._sym
"""bool: Whether or not the curve is stellarator symmetric."""
return bool(self._sym)

@property
def R_basis(self):
Expand All @@ -110,7 +110,7 @@ def Z_basis(self):
@property
def NFP(self):
"""Number of field periods."""
return self._NFP
return int(self._NFP)

@property
def N(self):
Expand All @@ -128,7 +128,7 @@ def change_resolution(self, N=None, NFP=None, sym=None):
and (sym != self.sym)
):
self._NFP = int(NFP if NFP is not None else self.NFP)
self._sym = sym if sym is not None else self.sym
self._sym = bool(sym) if sym is not None else self.sym
N = int(N if N is not None else self.N)
R_modes_old = self.R_basis.modes
Z_modes_old = self.Z_basis.modes
Expand Down Expand Up @@ -371,7 +371,7 @@ def Z_basis(self):
@property
def N(self):
"""Maximum mode number."""
return max(self.X_basis.N, self.Y_basis.N, self.Z_basis.N)
return int(max(self.X_basis.N, self.Y_basis.N, self.Z_basis.N))

def change_resolution(self, N=None):
"""Change the maximum angular resolution."""
Expand Down Expand Up @@ -609,7 +609,7 @@ def r_basis(self):
@property
def N(self):
"""Maximum mode number."""
return self.r_basis.N
return int(self.r_basis.N)

def change_resolution(self, N=None):
"""Change the maximum angular resolution."""
Expand Down Expand Up @@ -894,12 +894,12 @@ def knots(self, new):
@property
def N(self):
"""Number of knots in the spline."""
return self.knots.size
return int(self.knots.size)

@property
def method(self):
"""Method of interpolation to use."""
return self._method
return str(self._method)

@method.setter
def method(self, new):
Expand Down
10 changes: 5 additions & 5 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(

self._R_lmn = copy_coeffs(R_lmn, modes_R, self.R_basis.modes[:, 1:])
self._Z_lmn = copy_coeffs(Z_lmn, modes_Z, self.Z_basis.modes[:, 1:])
self._sym = sym
self._sym = bool(sym)
self._rho = rho

if check_orientation and self._compute_orientation() == -1:
Expand All @@ -142,7 +142,7 @@ def __init__(
@property
def NFP(self):
"""int: Number of (toroidal) field periods."""
return self._NFP
return int(self._NFP)

@property
def R_basis(self):
Expand Down Expand Up @@ -870,8 +870,8 @@ def __init__(

self._R_lmn = copy_coeffs(R_lmn, modes_R, self.R_basis.modes[:, :2])
self._Z_lmn = copy_coeffs(Z_lmn, modes_Z, self.Z_basis.modes[:, :2])
self._sym = sym
self._spectral_indexing = spectral_indexing
self._sym = bool(sym)
self._spectral_indexing = str(spectral_indexing)

self._zeta = zeta

Expand All @@ -889,7 +889,7 @@ def __init__(
@property
def spectral_indexing(self):
"""str: Type of spectral indexing for Zernike basis."""
return self._spectral_indexing
return str(self._spectral_indexing)

@property
def R_basis(self):
Expand Down
2 changes: 1 addition & 1 deletion desc/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def desc_output_to_input( # noqa: C901 - fxn too complex
Fourier coefficients below this value will be set to 0.
"""
from desc.grid import LinearGrid
from desc.io.equilibrium_io import load
from desc.io.optimizable_io import load
from desc.profiles import PowerSeriesProfile
from desc.utils import copy_coeffs

Expand Down
4 changes: 2 additions & 2 deletions desc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Functions and classes for reading and writing DESC data."""

# InputReader lives outside this module for import ordering reasons, so we can
# import InputReader in __main__ without importing equilibrium_io which imports JAX
# import InputReader in __main__ without importing optimizable_io which imports JAX
# stuff potentially before we've set the GPU correctly.
# We include a link to it here for backwards compatibility
from desc.input_reader import InputReader

from .ascii_io import read_ascii, write_ascii
from .equilibrium_io import IOAble, load
from .hdf5_io import hdf5Reader, hdf5Writer
from .optimizable_io import IOAble, load
from .pickle_io import PickleReader, PickleWriter

__all__ = ["InputReader", "load"]
7 changes: 6 additions & 1 deletion desc/io/hdf5_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def read_obj(self, obj, where=None):
continue
if isinstance(loc[attr], h5py.Dataset):
s = self._decode_attr(loc, attr)
# cast NumPy data types to native Python types
ddudt marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(s, np.bool_):
s = bool(s)
elif isinstance(s, np.int_):
s = int(s)
if not isinstance(s, str) or s != "__class__":
setattr(obj, attr, s)
elif isinstance(loc[attr], h5py.Group):
Expand Down Expand Up @@ -332,7 +337,7 @@ def isarray(x):
group = loc.create_group(attr)
self.write_list(data, where=group)
else:
from .equilibrium_io import IOAble
from .optimizable_io import IOAble

if isinstance(data, IOAble):
group = loc.create_group(attr)
Expand Down
Loading
Loading