Skip to content

Commit

Permalink
Merge pull request #646 from PlasmaControl/rc/cleanup
Browse files Browse the repository at this point in the history
Fix spelling errors, rename abstract base classes
  • Loading branch information
f0uriest authored Aug 31, 2023
2 parents 9680907 + edfe4c7 commit c6a2c34
Show file tree
Hide file tree
Showing 50 changed files with 348 additions and 370 deletions.
6 changes: 3 additions & 3 deletions desc/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# git-archive tarball (such as those provided by github's download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
Expand Down Expand Up @@ -292,7 +292,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# TAG-NUM-gHEX
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
# unparsable. Maybe git-describe is misbehaving?
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces

Expand Down Expand Up @@ -409,7 +409,7 @@ def render_pep440_old(pieces):
The ".dev0" means dirty.
Eexceptions:
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
Expand Down
16 changes: 8 additions & 8 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]


class Basis(IOAble, ABC):
class _Basis(IOAble, ABC):
"""Basis is an abstract base class for spectral basis sets."""

_io_attrs_ = [
Expand Down Expand Up @@ -223,7 +223,7 @@ def __repr__(self):
)


class PowerSeries(Basis):
class PowerSeries(_Basis):
"""1D basis set for flux surface quantities.
Power series in the radial coordinate.
Expand Down Expand Up @@ -333,7 +333,7 @@ def change_resolution(self, L):
self._set_up()


class FourierSeries(Basis):
class FourierSeries(_Basis):
"""1D basis set for use with the magnetic axis.
Fourier series in the toroidal coordinate.
Expand Down Expand Up @@ -453,7 +453,7 @@ def change_resolution(self, N, NFP=None, sym=None):
self._set_up()


class DoubleFourierSeries(Basis):
class DoubleFourierSeries(_Basis):
"""2D basis set for use on a single flux surface.
Fourier series in both the poloidal and toroidal coordinates.
Expand Down Expand Up @@ -600,7 +600,7 @@ def change_resolution(self, M, N, NFP=None, sym=None):
self._set_up()


class ZernikePolynomial(Basis):
class ZernikePolynomial(_Basis):
"""2D basis set for analytic functions in a unit disc.
Parameters
Expand Down Expand Up @@ -810,7 +810,7 @@ def change_resolution(self, L, M, sym=None):
self._set_up()


class ChebyshevDoubleFourierBasis(Basis):
class ChebyshevDoubleFourierBasis(_Basis):
"""3D basis: tensor product of Chebyshev polynomials and two Fourier series.
Fourier series in both the poloidal and toroidal coordinates.
Expand Down Expand Up @@ -943,7 +943,7 @@ def change_resolution(self, L, M, N, NFP=None, sym=None):
self._set_up()


class FourierZernikeBasis(Basis):
class FourierZernikeBasis(_Basis):
"""3D basis set for analytic functions in a toroidal volume.
Zernike polynomials in the radial & poloidal coordinates, and a Fourier
Expand Down Expand Up @@ -1318,7 +1318,7 @@ def zernike_radial_coeffs(l, m, exact=True):
Returns
-------
coeffs : ndarray
Polynomial coefficients for Zernike polynomials, in descending powers of r.
Notes
-----
Expand Down
41 changes: 20 additions & 21 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
FourierXYZCurve,
SplineXYZCurve,
)
from desc.grid import Grid
from desc.magnetic_fields import MagneticField, biot_savart
from desc.magnetic_fields import _MagneticField, biot_savart
from desc.utils import flatten_list


class Coil(MagneticField, ABC):
class _Coil(_MagneticField, ABC):
"""Base class representing a magnetic field coil.
Represents coils as a combination of a Curve and current
Expand All @@ -37,7 +36,7 @@ class Coil(MagneticField, ABC):
current passing through the coil, in Amperes
"""

_io_attrs_ = MagneticField._io_attrs_ + ["_current"]
_io_attrs_ = _MagneticField._io_attrs_ + ["_current"]

def __init__(self, current, *args, **kwargs):
self._current = current
Expand Down Expand Up @@ -77,7 +76,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None):
magnetic field at specified points, in either rpz or xyz coordinates
"""
assert basis.lower() in ["rpz", "xyz"]
if isinstance(coords, Grid):
if hasattr(coords, "nodes"):
coords = coords.nodes
coords = jnp.atleast_2d(coords)
if basis == "rpz":
Expand All @@ -102,7 +101,7 @@ def __repr__(self):
)


class FourierRZCoil(Coil, FourierRZCurve):
class FourierRZCoil(_Coil, FourierRZCurve):
"""Coil parameterized by fourier series for R,Z in terms of toroidal angle phi.
Parameters
Expand Down Expand Up @@ -154,7 +153,7 @@ class FourierRZCoil(Coil, FourierRZCurve):
"""

_io_attrs_ = Coil._io_attrs_ + FourierRZCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + FourierRZCurve._io_attrs_

def __init__(
self,
Expand All @@ -170,7 +169,7 @@ def __init__(
super().__init__(current, R_n, Z_n, modes_R, modes_Z, NFP, sym, name)


class FourierXYZCoil(Coil, FourierXYZCurve):
class FourierXYZCoil(_Coil, FourierXYZCurve):
"""Coil parameterized by fourier series for X,Y,Z in terms of arbitrary angle phi.
Parameters
Expand Down Expand Up @@ -219,7 +218,7 @@ class FourierXYZCoil(Coil, FourierXYZCurve):
"""

_io_attrs_ = Coil._io_attrs_ + FourierXYZCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + FourierXYZCurve._io_attrs_

def __init__(
self,
Expand All @@ -233,7 +232,7 @@ def __init__(
super().__init__(current, X_n, Y_n, Z_n, modes, name)


class FourierPlanarCoil(Coil, FourierPlanarCurve):
class FourierPlanarCoil(_Coil, FourierPlanarCurve):
"""Coil that lines in a plane.
Parameterized by a point (the center of the coil), a vector (normal to the plane),
Expand Down Expand Up @@ -290,7 +289,7 @@ class FourierPlanarCoil(Coil, FourierPlanarCurve):
"""

_io_attrs_ = Coil._io_attrs_ + FourierPlanarCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + FourierPlanarCurve._io_attrs_

def __init__(
self,
Expand All @@ -304,7 +303,7 @@ def __init__(
super().__init__(current, center, normal, r_n, modes, name)


class SplineXYZCoil(Coil, SplineXYZCurve):
class SplineXYZCoil(_Coil, SplineXYZCurve):
"""Coil parameterized by spline points in X,Y,Z.
Parameters
Expand Down Expand Up @@ -338,7 +337,7 @@ class SplineXYZCoil(Coil, SplineXYZCurve):
"""

_io_attrs_ = Coil._io_attrs_ + SplineXYZCurve._io_attrs_
_io_attrs_ = _Coil._io_attrs_ + SplineXYZCurve._io_attrs_

def __init__(
self,
Expand All @@ -353,7 +352,7 @@ def __init__(
super().__init__(current, X, Y, Z, knots, method, name)


class CoilSet(Coil, MutableSequence):
class CoilSet(_Coil, MutableSequence):
"""Set of coils of different geometry.
Parameters
Expand All @@ -367,11 +366,11 @@ class CoilSet(Coil, MutableSequence):
"""

_io_attrs_ = Coil._io_attrs_ + ["_coils"]
_io_attrs_ = _Coil._io_attrs_ + ["_coils"]

def __init__(self, *coils, name=""):
coils = flatten_list(coils, flatten_tuple=True)
assert all([isinstance(coil, (Coil)) for coil in coils])
assert all([isinstance(coil, (_Coil)) for coil in coils])
self._coils = list(coils)
self._name = str(name)

Expand Down Expand Up @@ -525,7 +524,7 @@ def linspaced_angular(
endpoint : bool
whether to include a coil at final angle
"""
assert isinstance(coil, Coil)
assert isinstance(coil, _Coil)
if current is None:
current = coil.current
currents = jnp.broadcast_to(current, (n,))
Expand Down Expand Up @@ -557,7 +556,7 @@ def linspaced_linear(
endpoint : bool
whether to include a coil at final point
"""
assert isinstance(coil, Coil)
assert isinstance(coil, _Coil)
if current is None:
current = coil.current
currents = jnp.broadcast_to(current, (n,))
Expand Down Expand Up @@ -762,7 +761,7 @@ def flatten_coils(coilset):
coil_end_inds = [] # indices where the coils end, need to track these
# to place the coilgroup number and name later, which MAKEGRID expects
# at the end of each individual coil
if isinstance(grid, Grid):
if hasattr(grid, "endpoint"):
endpoint = grid.endpoint
elif isinstance(grid, numbers.Integral):
endpoint = True # if int, will create a grid w/ endpoint=True in compute
Expand Down Expand Up @@ -830,7 +829,7 @@ def __getitem__(self, i):
return self.coils[i]

def __setitem__(self, i, new_item):
if not isinstance(new_item, Coil):
if not isinstance(new_item, _Coil):
raise TypeError("Members of CoilSet must be of type Coil.")
self._coils[i] = new_item

Expand All @@ -842,7 +841,7 @@ def __len__(self):

def insert(self, i, new_item):
"""Insert a new coil into the coilset at position i."""
if not isinstance(new_item, Coil):
if not isinstance(new_item, _Coil):
raise TypeError("Members of CoilSet must be of type Coil.")
self._coils.insert(i, new_item)

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _R_rrt(params, transforms, profiles, data, **kwargs):
label="\\partial_{\\rho \\rho \\theta \\theta} R",
units="m",
units_long="meters",
description="Major radius in lab frame, fouth derivative, wrt radius twice "
description="Major radius in lab frame, fourth derivative, wrt radius twice "
"and poloidal angle twice",
dim=1,
params=["R_lmn"],
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/data_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _decorator(func):
return _decorator


# This allows us to handle subclasses whos data_index stuff should inherit
# This allows us to handle subclasses whose data_index stuff should inherit
# from parent classes.
# This is the least bad solution I've found, since everything else requires
# crazy circular imports
Expand Down
60 changes: 31 additions & 29 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,6 @@

from .data_index import data_index

# defines the order in which objective arguments get concatenated into the state vector
arg_order = (
"R_lmn",
"Z_lmn",
"L_lmn",
"p_l",
"i_l",
"c_l",
"Psi",
"Te_l",
"ne_l",
"Ti_l",
"Zeff_l",
"Ra_n",
"Za_n",
"Rb_lmn",
"Zb_lmn",
)
# map from profile name to equilibrium parameter name
profile_names = {
"pressure": "p_l",
"iota": "i_l",
"current": "c_l",
"electron_temperature": "Te_l",
"electron_density": "ne_l",
"ion_temperature": "Ti_l",
"atomic_number": "Zeff_l",
}


def _parse_parameterization(p):
if isinstance(p, str):
Expand Down Expand Up @@ -1309,3 +1280,34 @@ def body(i, mins):
# The above implementation was benchmarked to be more efficient than
# alternatives without explicit loops in GitHub pull request #501.
return grid.expand(mins, surface_label)


# defines the order in which objective arguments get concatenated into the state vector
arg_order = (
"R_lmn",
"Z_lmn",
"L_lmn",
"p_l",
"i_l",
"c_l",
"Psi",
"Te_l",
"ne_l",
"Ti_l",
"Zeff_l",
"Ra_n",
"Za_n",
"Rb_lmn",
"Zb_lmn",
)

# map from profile name to equilibrium parameter name
profile_names = {
"pressure": "p_l",
"iota": "i_l",
"current": "c_l",
"electron_temperature": "Te_l",
"electron_density": "ne_l",
"ion_temperature": "Ti_l",
"atomic_number": "Zeff_l",
}
Loading

0 comments on commit c6a2c34

Please sign in to comment.