Skip to content

Commit

Permalink
Rename abstract base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Aug 31, 2023
1 parent 27bd131 commit 4476cfe
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 182 deletions.
14 changes: 7 additions & 7 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]


class Basis(IOAble, ABC):
class _Basis(IOAble, ABC):
"""Basis is an abstract base class for spectral basis sets."""

_io_attrs_ = [
Expand Down Expand Up @@ -223,7 +223,7 @@ def __repr__(self):
)


class PowerSeries(Basis):
class PowerSeries(_Basis):
"""1D basis set for flux surface quantities.
Power series in the radial coordinate.
Expand Down Expand Up @@ -333,7 +333,7 @@ def change_resolution(self, L):
self._set_up()


class FourierSeries(Basis):
class FourierSeries(_Basis):
"""1D basis set for use with the magnetic axis.
Fourier series in the toroidal coordinate.
Expand Down Expand Up @@ -453,7 +453,7 @@ def change_resolution(self, N, NFP=None, sym=None):
self._set_up()


class DoubleFourierSeries(Basis):
class DoubleFourierSeries(_Basis):
"""2D basis set for use on a single flux surface.
Fourier series in both the poloidal and toroidal coordinates.
Expand Down Expand Up @@ -600,7 +600,7 @@ def change_resolution(self, M, N, NFP=None, sym=None):
self._set_up()


class ZernikePolynomial(Basis):
class ZernikePolynomial(_Basis):
"""2D basis set for analytic functions in a unit disc.
Parameters
Expand Down Expand Up @@ -810,7 +810,7 @@ def change_resolution(self, L, M, sym=None):
self._set_up()


class ChebyshevDoubleFourierBasis(Basis):
class ChebyshevDoubleFourierBasis(_Basis):
"""3D basis: tensor product of Chebyshev polynomials and two Fourier series.
Fourier series in both the poloidal and toroidal coordinates.
Expand Down Expand Up @@ -943,7 +943,7 @@ def change_resolution(self, L, M, N, NFP=None, sym=None):
self._set_up()


class FourierZernikeBasis(Basis):
class FourierZernikeBasis(_Basis):
"""3D basis set for analytic functions in a toroidal volume.
Zernike polynomials in the radial & poloidal coordinates, and a Fourier
Expand Down
41 changes: 20 additions & 21 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
FourierXYZCurve,
SplineXYZCurve,
)
from desc.grid import Grid
from desc.magnetic_fields import MagneticField, biot_savart
from desc.magnetic_fields import _MagneticField, biot_savart
from desc.utils import flatten_list


class Coil(MagneticField, ABC):
class _Coil(_MagneticField, ABC):
"""Base class representing a magnetic field coil.
Represents coils as a combination of a Curve and current
Expand All @@ -37,7 +36,7 @@ class Coil(MagneticField, ABC):
current passing through the coil, in Amperes
"""

_io_attrs_ = MagneticField._io_attrs_ + ["_current"]
_io_attrs_ = _MagneticField._io_attrs_ + ["_current"]

def __init__(self, current, *args, **kwargs):
self._current = current
Expand Down Expand Up @@ -77,7 +76,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None):
magnetic field at specified points, in either rpz or xyz coordinates
"""
assert basis.lower() in ["rpz", "xyz"]
if isinstance(coords, Grid):
if hasattr(coords, "nodes"):
coords = coords.nodes
coords = jnp.atleast_2d(coords)
if basis == "rpz":
Expand All @@ -102,7 +101,7 @@ def __repr__(self):
)


class FourierRZCoil(Coil, FourierRZCurve):
class FourierRZCoil(_Coil, FourierRZCurve):
"""Coil parameterized by fourier series for R,Z in terms of toroidal angle phi.
Parameters
Expand Down Expand Up @@ -154,7 +153,7 @@ class FourierRZCoil(Coil, FourierRZCurve):
"""

_io_attrs_ = Coil._io_attrs_ + FourierRZCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + FourierRZCurve._io_attrs_

def __init__(
self,
Expand All @@ -170,7 +169,7 @@ def __init__(
super().__init__(current, R_n, Z_n, modes_R, modes_Z, NFP, sym, name)


class FourierXYZCoil(Coil, FourierXYZCurve):
class FourierXYZCoil(_Coil, FourierXYZCurve):
"""Coil parameterized by fourier series for X,Y,Z in terms of arbitrary angle phi.
Parameters
Expand Down Expand Up @@ -219,7 +218,7 @@ class FourierXYZCoil(Coil, FourierXYZCurve):
"""

_io_attrs_ = Coil._io_attrs_ + FourierXYZCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + FourierXYZCurve._io_attrs_

def __init__(
self,
Expand All @@ -233,7 +232,7 @@ def __init__(
super().__init__(current, X_n, Y_n, Z_n, modes, name)


class FourierPlanarCoil(Coil, FourierPlanarCurve):
class FourierPlanarCoil(_Coil, FourierPlanarCurve):
"""Coil that lines in a plane.
Parameterized by a point (the center of the coil), a vector (normal to the plane),
Expand Down Expand Up @@ -290,7 +289,7 @@ class FourierPlanarCoil(Coil, FourierPlanarCurve):
"""

_io_attrs_ = Coil._io_attrs_ + FourierPlanarCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + FourierPlanarCurve._io_attrs_

def __init__(
self,
Expand All @@ -304,7 +303,7 @@ def __init__(
super().__init__(current, center, normal, r_n, modes, name)


class SplineXYZCoil(Coil, SplineXYZCurve):
class SplineXYZCoil(_Coil, SplineXYZCurve):
"""Coil parameterized by spline points in X,Y,Z.
Parameters
Expand Down Expand Up @@ -338,7 +337,7 @@ class SplineXYZCoil(Coil, SplineXYZCurve):
"""

_io_attrs_ = Coil._io_attrs_ + SplineXYZCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + SplineXYZCurve._io_attrs_

def __init__(
self,
Expand All @@ -353,7 +352,7 @@ def __init__(
super().__init__(current, X, Y, Z, knots, method, name)


class CoilSet(Coil, MutableSequence):
class CoilSet(_Coil, MutableSequence):
"""Set of coils of different geometry.
Parameters
Expand All @@ -367,11 +366,11 @@ class CoilSet(Coil, MutableSequence):
"""

_io_attrs_ = Coil._io_attrs_ + ["_coils"]
_io_attrs_ = _Coil._io_attrs_ + ["_coils"]

def __init__(self, *coils, name=""):
coils = flatten_list(coils, flatten_tuple=True)
assert all([isinstance(coil, (Coil)) for coil in coils])
assert all([isinstance(coil, (_Coil)) for coil in coils])
self._coils = list(coils)
self._name = str(name)

Expand Down Expand Up @@ -525,7 +524,7 @@ def linspaced_angular(
endpoint : bool
whether to include a coil at final angle
"""
assert isinstance(coil, Coil)
assert isinstance(coil, _Coil)
if current is None:
current = coil.current
currents = jnp.broadcast_to(current, (n,))
Expand Down Expand Up @@ -557,7 +556,7 @@ def linspaced_linear(
endpoint : bool
whether to include a coil at final point
"""
assert isinstance(coil, Coil)
assert isinstance(coil, _Coil)
if current is None:
current = coil.current
currents = jnp.broadcast_to(current, (n,))
Expand Down Expand Up @@ -762,7 +761,7 @@ def flatten_coils(coilset):
coil_end_inds = [] # indices where the coils end, need to track these
# to place the coilgroup number and name later, which MAKEGRID expects
# at the end of each individual coil
if isinstance(grid, Grid):
if hasattr(grid, "endpoint"):
endpoint = grid.endpoint
elif isinstance(grid, numbers.Integral):
endpoint = True # if int, will create a grid w/ endpoint=True in compute
Expand Down Expand Up @@ -830,7 +829,7 @@ def __getitem__(self, i):
return self.coils[i]

def __setitem__(self, i, new_item):
if not isinstance(new_item, Coil):
if not isinstance(new_item, _Coil):
raise TypeError("Members of CoilSet must be of type Coil.")
self._coils[i] = new_item

Expand All @@ -842,7 +841,7 @@ def __len__(self):

def insert(self, i, new_item):
"""Insert a new coil into the coilset at position i."""
if not isinstance(new_item, Coil):
if not isinstance(new_item, _Coil):
raise TypeError("Members of CoilSet must be of type Coil.")
self._coils.insert(i, new_item)

Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Surface,
ZernikeRZToroidalSection,
)
from desc.grid import Grid, LinearGrid, QuadratureGrid
from desc.grid import LinearGrid, QuadratureGrid, _Grid
from desc.io import IOAble
from desc.objectives import (
ForceBalance,
Expand Down Expand Up @@ -733,7 +733,7 @@ def compute(
names = [names]
if grid is None:
grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP)
elif not isinstance(grid, Grid):
elif not isinstance(grid, _Grid):
raise TypeError(
"must pass in a Grid object for argument grid!"
f" instead got type {type(grid)}"
Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/initial_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from desc.backend import fori_loop, jit, jnp, put
from desc.basis import zernike_radial
from desc.geometry import FourierRZCurve, Surface
from desc.grid import Grid
from desc.grid import Grid, _Grid
from desc.io import load
from desc.transform import Transform
from desc.utils import copy_coeffs
Expand Down Expand Up @@ -324,7 +324,7 @@ def _initial_guess_points(nodes, x, x_basis):
Vector of flux surface coefficients associated with x_basis.
"""
if not isinstance(nodes, Grid):
if not isinstance(nodes, _Grid):
nodes = Grid(nodes, sort=False)
transform = Transform(nodes, x_basis, build=False, build_pinv=True)
x_lmn = transform.fit(x)
Expand Down
4 changes: 2 additions & 2 deletions desc/equilibrium/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Surface,
ZernikeRZToroidalSection,
)
from desc.profiles import PowerSeriesProfile, Profile
from desc.profiles import PowerSeriesProfile, _Profile
from desc.utils import isnonnegint


Expand Down Expand Up @@ -41,7 +41,7 @@ def parse_profile(prof, name="", **kwargs):
TypeError
If the object cannot be parsed as a Profile
"""
if isinstance(prof, Profile):
if isinstance(prof, _Profile):
return prof
if isinstance(prof, numbers.Number) or (
isinstance(prof, (np.ndarray, jnp.ndarray)) and prof.ndim == 1
Expand Down
6 changes: 3 additions & 3 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_params,
get_transforms,
)
from desc.grid import Grid, LinearGrid, QuadratureGrid
from desc.grid import LinearGrid, QuadratureGrid, _Grid
from desc.io import IOAble


Expand Down Expand Up @@ -79,7 +79,7 @@ def compute(
elif isinstance(grid, numbers.Integral):
NFP = self.NFP if hasattr(self, "NFP") else 1
grid = LinearGrid(N=grid, NFP=NFP, endpoint=True)
elif isinstance(grid, Grid):
elif hasattr(grid, "NFP"):
NFP = grid.NFP
else:
raise TypeError(
Expand Down Expand Up @@ -343,7 +343,7 @@ def compute(
elif hasattr(self, "zeta"): # constant zeta surface
grid = QuadratureGrid(L=2 * self.L + 5, M=2 * self.M + 5, N=0, NFP=1)
grid._nodes[:, 2] = self.zeta
elif not isinstance(grid, Grid):
elif not isinstance(grid, _Grid):
raise TypeError(
"must pass in a Grid object or an integer for argument grid!"
f" instead got type {type(grid)}"
Expand Down
Loading

0 comments on commit 4476cfe

Please sign in to comment.