From 4476cfe208dd21bd321eae7bb098c14dcf61856e Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Thu, 31 Aug 2023 13:38:38 -0400 Subject: [PATCH] Rename abstract base classes --- desc/basis.py | 14 +-- desc/coils.py | 41 ++++--- desc/equilibrium/equilibrium.py | 4 +- desc/equilibrium/initial_guess.py | 4 +- desc/equilibrium/utils.py | 4 +- desc/geometry/core.py | 6 +- desc/grid.py | 188 ++++++++++++++++-------------- desc/magnetic_fields.py | 44 +++---- desc/objectives/_generic.py | 5 +- desc/profiles.py | 62 +++++----- 10 files changed, 190 insertions(+), 182 deletions(-) diff --git a/desc/basis.py b/desc/basis.py index 1c7edeea2e..1e9920d8da 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -21,7 +21,7 @@ ] -class Basis(IOAble, ABC): +class _Basis(IOAble, ABC): """Basis is an abstract base class for spectral basis sets.""" _io_attrs_ = [ @@ -223,7 +223,7 @@ def __repr__(self): ) -class PowerSeries(Basis): +class PowerSeries(_Basis): """1D basis set for flux surface quantities. Power series in the radial coordinate. @@ -333,7 +333,7 @@ def change_resolution(self, L): self._set_up() -class FourierSeries(Basis): +class FourierSeries(_Basis): """1D basis set for use with the magnetic axis. Fourier series in the toroidal coordinate. @@ -453,7 +453,7 @@ def change_resolution(self, N, NFP=None, sym=None): self._set_up() -class DoubleFourierSeries(Basis): +class DoubleFourierSeries(_Basis): """2D basis set for use on a single flux surface. Fourier series in both the poloidal and toroidal coordinates. @@ -600,7 +600,7 @@ def change_resolution(self, M, N, NFP=None, sym=None): self._set_up() -class ZernikePolynomial(Basis): +class ZernikePolynomial(_Basis): """2D basis set for analytic functions in a unit disc. Parameters @@ -810,7 +810,7 @@ def change_resolution(self, L, M, sym=None): self._set_up() -class ChebyshevDoubleFourierBasis(Basis): +class ChebyshevDoubleFourierBasis(_Basis): """3D basis: tensor product of Chebyshev polynomials and two Fourier series. Fourier series in both the poloidal and toroidal coordinates. @@ -943,7 +943,7 @@ def change_resolution(self, L, M, N, NFP=None, sym=None): self._set_up() -class FourierZernikeBasis(Basis): +class FourierZernikeBasis(_Basis): """3D basis set for analytic functions in a toroidal volume. Zernike polynomials in the radial & poloidal coordinates, and a Fourier diff --git a/desc/coils.py b/desc/coils.py index a87e36a1ff..23869183f6 100644 --- a/desc/coils.py +++ b/desc/coils.py @@ -14,12 +14,11 @@ FourierXYZCurve, SplineXYZCurve, ) -from desc.grid import Grid -from desc.magnetic_fields import MagneticField, biot_savart +from desc.magnetic_fields import _MagneticField, biot_savart from desc.utils import flatten_list -class Coil(MagneticField, ABC): +class _Coil(_MagneticField, ABC): """Base class representing a magnetic field coil. Represents coils as a combination of a Curve and current @@ -37,7 +36,7 @@ class Coil(MagneticField, ABC): current passing through the coil, in Amperes """ - _io_attrs_ = MagneticField._io_attrs_ + ["_current"] + _io_attrs_ = _MagneticField._io_attrs_ + ["_current"] def __init__(self, current, *args, **kwargs): self._current = current @@ -77,7 +76,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): magnetic field at specified points, in either rpz or xyz coordinates """ assert basis.lower() in ["rpz", "xyz"] - if isinstance(coords, Grid): + if hasattr(coords, "nodes"): coords = coords.nodes coords = jnp.atleast_2d(coords) if basis == "rpz": @@ -102,7 +101,7 @@ def __repr__(self): ) -class FourierRZCoil(Coil, FourierRZCurve): +class FourierRZCoil(_Coil, FourierRZCurve): """Coil parameterized by fourier series for R,Z in terms of toroidal angle phi. Parameters @@ -154,7 +153,7 @@ class FourierRZCoil(Coil, FourierRZCurve): """ - _io_attrs_ = Coil._io_attrs_ + FourierRZCurve._io_attrs_ + _io_attrs_ = _Coil._io_attrs_ + FourierRZCurve._io_attrs_ def __init__( self, @@ -170,7 +169,7 @@ def __init__( super().__init__(current, R_n, Z_n, modes_R, modes_Z, NFP, sym, name) -class FourierXYZCoil(Coil, FourierXYZCurve): +class FourierXYZCoil(_Coil, FourierXYZCurve): """Coil parameterized by fourier series for X,Y,Z in terms of arbitrary angle phi. Parameters @@ -219,7 +218,7 @@ class FourierXYZCoil(Coil, FourierXYZCurve): """ - _io_attrs_ = Coil._io_attrs_ + FourierXYZCurve._io_attrs_ + _io_attrs_ = _Coil._io_attrs_ + FourierXYZCurve._io_attrs_ def __init__( self, @@ -233,7 +232,7 @@ def __init__( super().__init__(current, X_n, Y_n, Z_n, modes, name) -class FourierPlanarCoil(Coil, FourierPlanarCurve): +class FourierPlanarCoil(_Coil, FourierPlanarCurve): """Coil that lines in a plane. Parameterized by a point (the center of the coil), a vector (normal to the plane), @@ -290,7 +289,7 @@ class FourierPlanarCoil(Coil, FourierPlanarCurve): """ - _io_attrs_ = Coil._io_attrs_ + FourierPlanarCurve._io_attrs_ + _io_attrs_ = _Coil._io_attrs_ + FourierPlanarCurve._io_attrs_ def __init__( self, @@ -304,7 +303,7 @@ def __init__( super().__init__(current, center, normal, r_n, modes, name) -class SplineXYZCoil(Coil, SplineXYZCurve): +class SplineXYZCoil(_Coil, SplineXYZCurve): """Coil parameterized by spline points in X,Y,Z. Parameters @@ -338,7 +337,7 @@ class SplineXYZCoil(Coil, SplineXYZCurve): """ - _io_attrs_ = Coil._io_attrs_ + SplineXYZCurve._io_attrs_ + _io_attrs_ = _Coil._io_attrs_ + SplineXYZCurve._io_attrs_ def __init__( self, @@ -353,7 +352,7 @@ def __init__( super().__init__(current, X, Y, Z, knots, method, name) -class CoilSet(Coil, MutableSequence): +class CoilSet(_Coil, MutableSequence): """Set of coils of different geometry. Parameters @@ -367,11 +366,11 @@ class CoilSet(Coil, MutableSequence): """ - _io_attrs_ = Coil._io_attrs_ + ["_coils"] + _io_attrs_ = _Coil._io_attrs_ + ["_coils"] def __init__(self, *coils, name=""): coils = flatten_list(coils, flatten_tuple=True) - assert all([isinstance(coil, (Coil)) for coil in coils]) + assert all([isinstance(coil, (_Coil)) for coil in coils]) self._coils = list(coils) self._name = str(name) @@ -525,7 +524,7 @@ def linspaced_angular( endpoint : bool whether to include a coil at final angle """ - assert isinstance(coil, Coil) + assert isinstance(coil, _Coil) if current is None: current = coil.current currents = jnp.broadcast_to(current, (n,)) @@ -557,7 +556,7 @@ def linspaced_linear( endpoint : bool whether to include a coil at final point """ - assert isinstance(coil, Coil) + assert isinstance(coil, _Coil) if current is None: current = coil.current currents = jnp.broadcast_to(current, (n,)) @@ -762,7 +761,7 @@ def flatten_coils(coilset): coil_end_inds = [] # indices where the coils end, need to track these # to place the coilgroup number and name later, which MAKEGRID expects # at the end of each individual coil - if isinstance(grid, Grid): + if hasattr(grid, "endpoint"): endpoint = grid.endpoint elif isinstance(grid, numbers.Integral): endpoint = True # if int, will create a grid w/ endpoint=True in compute @@ -830,7 +829,7 @@ def __getitem__(self, i): return self.coils[i] def __setitem__(self, i, new_item): - if not isinstance(new_item, Coil): + if not isinstance(new_item, _Coil): raise TypeError("Members of CoilSet must be of type Coil.") self._coils[i] = new_item @@ -842,7 +841,7 @@ def __len__(self): def insert(self, i, new_item): """Insert a new coil into the coilset at position i.""" - if not isinstance(new_item, Coil): + if not isinstance(new_item, _Coil): raise TypeError("Members of CoilSet must be of type Coil.") self._coils.insert(i, new_item) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index d646ac0094..0fd24448d1 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -21,7 +21,7 @@ Surface, ZernikeRZToroidalSection, ) -from desc.grid import Grid, LinearGrid, QuadratureGrid +from desc.grid import LinearGrid, QuadratureGrid, _Grid from desc.io import IOAble from desc.objectives import ( ForceBalance, @@ -733,7 +733,7 @@ def compute( names = [names] if grid is None: grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP) - elif not isinstance(grid, Grid): + elif not isinstance(grid, _Grid): raise TypeError( "must pass in a Grid object for argument grid!" f" instead got type {type(grid)}" diff --git a/desc/equilibrium/initial_guess.py b/desc/equilibrium/initial_guess.py index 2d6e94065b..29e782f144 100644 --- a/desc/equilibrium/initial_guess.py +++ b/desc/equilibrium/initial_guess.py @@ -7,7 +7,7 @@ from desc.backend import fori_loop, jit, jnp, put from desc.basis import zernike_radial from desc.geometry import FourierRZCurve, Surface -from desc.grid import Grid +from desc.grid import Grid, _Grid from desc.io import load from desc.transform import Transform from desc.utils import copy_coeffs @@ -324,7 +324,7 @@ def _initial_guess_points(nodes, x, x_basis): Vector of flux surface coefficients associated with x_basis. """ - if not isinstance(nodes, Grid): + if not isinstance(nodes, _Grid): nodes = Grid(nodes, sort=False) transform = Transform(nodes, x_basis, build=False, build_pinv=True) x_lmn = transform.fit(x) diff --git a/desc/equilibrium/utils.py b/desc/equilibrium/utils.py index 3b061d7202..8a9024917d 100644 --- a/desc/equilibrium/utils.py +++ b/desc/equilibrium/utils.py @@ -11,7 +11,7 @@ Surface, ZernikeRZToroidalSection, ) -from desc.profiles import PowerSeriesProfile, Profile +from desc.profiles import PowerSeriesProfile, _Profile from desc.utils import isnonnegint @@ -41,7 +41,7 @@ def parse_profile(prof, name="", **kwargs): TypeError If the object cannot be parsed as a Profile """ - if isinstance(prof, Profile): + if isinstance(prof, _Profile): return prof if isinstance(prof, numbers.Number) or ( isinstance(prof, (np.ndarray, jnp.ndarray)) and prof.ndim == 1 diff --git a/desc/geometry/core.py b/desc/geometry/core.py index 1262356e08..dd5c4cfc17 100644 --- a/desc/geometry/core.py +++ b/desc/geometry/core.py @@ -15,7 +15,7 @@ get_params, get_transforms, ) -from desc.grid import Grid, LinearGrid, QuadratureGrid +from desc.grid import LinearGrid, QuadratureGrid, _Grid from desc.io import IOAble @@ -79,7 +79,7 @@ def compute( elif isinstance(grid, numbers.Integral): NFP = self.NFP if hasattr(self, "NFP") else 1 grid = LinearGrid(N=grid, NFP=NFP, endpoint=True) - elif isinstance(grid, Grid): + elif hasattr(grid, "NFP"): NFP = grid.NFP else: raise TypeError( @@ -343,7 +343,7 @@ def compute( elif hasattr(self, "zeta"): # constant zeta surface grid = QuadratureGrid(L=2 * self.L + 5, M=2 * self.M + 5, N=0, NFP=1) grid._nodes[:, 2] = self.zeta - elif not isinstance(grid, Grid): + elif not isinstance(grid, _Grid): raise TypeError( "must pass in a Grid object or an integer for argument grid!" f" instead got type {type(grid)}" diff --git a/desc/grid.py b/desc/grid.py index 81150b44dc..d3cb75e68e 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -1,5 +1,7 @@ """Classes for representing flux coordinates.""" +from abc import ABC, abstractmethod + import numpy as np from scipy import optimize, special @@ -17,24 +19,8 @@ ] -class Grid(IOAble): - """Base class for collocation grids. - - Unlike subclasses LinearGrid and ConcentricGrid, the base Grid allows the user - to pass in a custom set of collocation nodes. - - Parameters - ---------- - nodes : ndarray of float, size(num_nodes,3) - node coordinates, in (rho,theta,zeta) - sort : bool - whether to sort the nodes for use with FFT method. - jitable : bool - Whether to skip certain checks and conditionals that don't work under jit. - Allows grid to be created on the fly with custom nodes, but weights, symmetry - etc may be wrong if grid contains duplicate nodes. - - """ +class _Grid(IOAble, ABC): + """Base class for collocation grids.""" # TODO: calculate weights automatically using voronoi / delaunay triangulation _io_attrs_ = [ @@ -56,74 +42,10 @@ class Grid(IOAble): "_inverse_zeta_idx", ] - def __init__(self, nodes, sort=True, jitable=False): - # Python 3.3 (PEP 412) introduced key-sharing dictionaries. - # This change measurably reduces memory usage of objects that - # define all attributes in their __init__ method. - self._NFP = 1 - self._sym = False - self._node_pattern = "custom" - self._nodes, self._spacing = self._create_nodes(nodes) - if sort: - self._sort_nodes() - if jitable: - # dont do anything with symmetry since that changes # of nodes - # avoid point at the axis, for now. FIXME: make axis boolean mask? - r, t, z = self._nodes.T - r = jnp.where(r == 0, 1e-12, r) - self._nodes = jnp.array([r, t, z]).T - self._axis = np.array([], dtype=int) - self._unique_rho_idx = np.arange(self._nodes.shape[0]) - self._unique_theta_idx = np.arange(self._nodes.shape[0]) - self._unique_zeta_idx = np.arange(self._nodes.shape[0]) - self._inverse_rho_idx = np.arange(self._nodes.shape[0]) - self._inverse_theta_idx = np.arange(self._nodes.shape[0]) - self._inverse_zeta_idx = np.arange(self._nodes.shape[0]) - # don't do anything fancy with weights - self._weights = self._spacing.prod(axis=1) - else: - self._enforce_symmetry() - self._axis = self._find_axis() - ( - self._unique_rho_idx, - self._inverse_rho_idx, - self._unique_theta_idx, - self._inverse_theta_idx, - self._unique_zeta_idx, - self._inverse_zeta_idx, - ) = self._find_unique_inverse_nodes() - self._weights = self._scale_weights() - - self._L = self.num_rho - self._M = self.num_theta - self._N = self.num_zeta - - def _create_nodes(self, nodes): - """Allow for custom node creation. - - Parameters - ---------- - nodes : ndarray of float, size(num_nodes,3) - Node coordinates, in (rho,theta,zeta). - - Returns - ------- - nodes : ndarray of float, size(num_nodes,3) - Node coordinates, in (rho,theta,zeta). - spacing : ndarray of float, size(num_nodes,3) - Node spacing, in (rho,theta,zeta). - - """ - nodes = jnp.atleast_2d(nodes).reshape((-1, 3)).astype(float) - # Do not alter nodes given by the user for custom grids. - # In particular, do not modulo nodes by 2pi or 2pi/NFP. - # This may cause the surface_integrals() function to fail recognizing - # surfaces outside the interval [0, 2pi] as duplicates. However, most - # surface integral computations are done with LinearGrid anyway. - spacing = ( # make weights sum to 4pi^2 - jnp.ones_like(nodes) * jnp.array([1, 2 * np.pi, 2 * np.pi]) / nodes.shape[0] - ) - return nodes, spacing + @abstractmethod + def _create_nodes(self, *args, **kwargs): + """Allow for custom node creation.""" + pass def _enforce_symmetry(self): """Enforce stellarator symmetry. @@ -465,7 +387,95 @@ def replace_at_axis(self, x, y, copy=False, **kwargs): return x -class LinearGrid(Grid): +class Grid(_Grid): + """Collocation grid with custom node placement. + + Unlike subclasses LinearGrid and ConcentricGrid, the base Grid allows the user + to pass in a custom set of collocation nodes. + + Parameters + ---------- + nodes : ndarray of float, size(num_nodes,3) + node coordinates, in (rho,theta,zeta) + sort : bool + whether to sort the nodes for use with FFT method. + jitable : bool + Whether to skip certain checks and conditionals that don't work under jit. + Allows grid to be created on the fly with custom nodes, but weights, symmetry + etc may be wrong if grid contains duplicate nodes. + """ + + def __init__(self, nodes, sort=True, jitable=False): + # Python 3.3 (PEP 412) introduced key-sharing dictionaries. + # This change measurably reduces memory usage of objects that + # define all attributes in their __init__ method. + self._NFP = 1 + self._sym = False + self._node_pattern = "custom" + self._nodes, self._spacing = self._create_nodes(nodes) + if sort: + self._sort_nodes() + if jitable: + # dont do anything with symmetry since that changes # of nodes + # avoid point at the axis, for now. FIXME: make axis boolean mask? + r, t, z = self._nodes.T + r = jnp.where(r == 0, 1e-12, r) + self._nodes = jnp.array([r, t, z]).T + self._axis = np.array([], dtype=int) + self._unique_rho_idx = np.arange(self._nodes.shape[0]) + self._unique_theta_idx = np.arange(self._nodes.shape[0]) + self._unique_zeta_idx = np.arange(self._nodes.shape[0]) + self._inverse_rho_idx = np.arange(self._nodes.shape[0]) + self._inverse_theta_idx = np.arange(self._nodes.shape[0]) + self._inverse_zeta_idx = np.arange(self._nodes.shape[0]) + # don't do anything fancy with weights + self._weights = self._spacing.prod(axis=1) + else: + self._enforce_symmetry() + self._axis = self._find_axis() + ( + self._unique_rho_idx, + self._inverse_rho_idx, + self._unique_theta_idx, + self._inverse_theta_idx, + self._unique_zeta_idx, + self._inverse_zeta_idx, + ) = self._find_unique_inverse_nodes() + self._weights = self._scale_weights() + + self._L = self.num_rho + self._M = self.num_theta + self._N = self.num_zeta + + def _create_nodes(self, nodes): + """Allow for custom node creation. + + Parameters + ---------- + nodes : ndarray of float, size(num_nodes,3) + Node coordinates, in (rho,theta,zeta). + + Returns + ------- + nodes : ndarray of float, size(num_nodes,3) + Node coordinates, in (rho,theta,zeta). + spacing : ndarray of float, size(num_nodes,3) + Node spacing, in (rho,theta,zeta). + + """ + nodes = jnp.atleast_2d(nodes).reshape((-1, 3)).astype(float) + # Do not alter nodes given by the user for custom grids. + # In particular, do not modulo nodes by 2pi or 2pi/NFP. + # This may cause the surface_integrals() function to fail recognizing + # surfaces outside the interval [0, 2pi] as duplicates. However, most + # surface integral computations are done with LinearGrid anyway. + spacing = ( # make weights sum to 4pi^2 + jnp.ones_like(nodes) * jnp.array([1, 2 * np.pi, 2 * np.pi]) / nodes.shape[0] + ) + return nodes, spacing + + +class LinearGrid(_Grid): """Grid in which the nodes are linearly spaced in each coordinate. Useful for plotting and other analysis, though not very efficient for using as the @@ -827,7 +837,7 @@ def endpoint(self): return self.__dict__.setdefault("_endpoint", False) -class QuadratureGrid(Grid): +class QuadratureGrid(_Grid): """Grid used for numerical quadrature. Exactly integrates a Fourier-Zernike basis of resolution (L,M,N) @@ -957,7 +967,7 @@ def change_resolution(self, L, M, N, NFP=None): self._weights = self.spacing.prod(axis=1) # instead of _scale_weights -class ConcentricGrid(Grid): +class ConcentricGrid(_Grid): """Grid in which the nodes are arranged in concentric circles. Nodes are arranged concentrically within each toroidal cross-section, with more diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index 8125ab8a10..1f1b84ebfb 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -10,7 +10,7 @@ from desc.compute import rpz2xyz_vec, xyz2rpz from desc.derivatives import Derivative from desc.equilibrium import EquilibriaFamily, Equilibrium -from desc.grid import Grid, LinearGrid +from desc.grid import LinearGrid from desc.interpolate import _approx_df, interp2d, interp3d from desc.io import IOAble from desc.transform import Transform @@ -142,7 +142,7 @@ def read_BNORM_file(fname, surface, eval_grid=None, scale_by_curpol=True): return Bnorm -class MagneticField(IOAble, ABC): +class _MagneticField(IOAble, ABC): """Base class for all magnetic fields. Subclasses must implement the "compute_magnetic_field" method @@ -161,7 +161,7 @@ def __rmul__(self, x): return self.__mul__(x) def __add__(self, x): - if isinstance(x, MagneticField): + if isinstance(x, _MagneticField): return SumMagneticField(self, x) else: return NotImplemented @@ -378,7 +378,7 @@ def save_BNORM_file( return None -class ScaledMagneticField(MagneticField): +class ScaledMagneticField(_MagneticField): """Magnetic field scaled by a scalar value. ie B_new = scalar * B_old @@ -392,12 +392,12 @@ class ScaledMagneticField(MagneticField): """ - _io_attrs = MagneticField._io_attrs_ + ["_field", "_scalar"] + _io_attrs = _MagneticField._io_attrs_ + ["_field", "_scalar"] def __init__(self, scalar, field): assert np.isscalar(scalar), "scalar must actually be a scalar value" assert isinstance( - field, MagneticField + field, _MagneticField ), "field should be a subclass of MagneticField, got type {}".format( type(field) ) @@ -430,7 +430,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): ) -class SumMagneticField(MagneticField): +class SumMagneticField(_MagneticField): """Sum of two or more magnetic field sources. Parameters @@ -439,11 +439,11 @@ class SumMagneticField(MagneticField): two or more MagneticFields to add together """ - _io_attrs = MagneticField._io_attrs_ + ["_fields"] + _io_attrs = _MagneticField._io_attrs_ + ["_fields"] def __init__(self, *fields): assert all( - [isinstance(field, MagneticField) for field in fields] + [isinstance(field, _MagneticField) for field in fields] ), "fields should each be a subclass of MagneticField, got {}".format( [type(field) for field in fields] ) @@ -485,7 +485,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): return B -class ToroidalMagneticField(MagneticField): +class ToroidalMagneticField(_MagneticField): """Magnetic field purely in the toroidal (phi) direction. Magnitude is B0*R0/R where R0 is the major radius of the axis and B0 @@ -499,7 +499,7 @@ class ToroidalMagneticField(MagneticField): """ - _io_attrs_ = MagneticField._io_attrs_ + ["_B0", "_R0"] + _io_attrs_ = _MagneticField._io_attrs_ + ["_B0", "_R0"] def __init__(self, B0, R0): assert float(B0) == B0, "B0 must be a scalar" @@ -530,7 +530,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): """ assert basis.lower() in ["rpz", "xyz"] - if isinstance(coords, Grid): + if hasattr(coords, "nodes"): coords = coords.nodes coords = jnp.atleast_2d(coords) if basis == "xyz": @@ -544,7 +544,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): return B -class VerticalMagneticField(MagneticField): +class VerticalMagneticField(_MagneticField): """Uniform magnetic field purely in the vertical (Z) direction. Parameters @@ -554,7 +554,7 @@ class VerticalMagneticField(MagneticField): """ - _io_attrs_ = MagneticField._io_attrs_ + ["_B0"] + _io_attrs_ = _MagneticField._io_attrs_ + ["_B0"] def __init__(self, B0): assert np.isscalar(B0), "B0 must be a scalar" @@ -584,7 +584,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): """ assert basis.lower() in ["rpz", "xyz"] - if isinstance(coords, Grid): + if hasattr(coords, "nodes"): coords = coords.nodes coords = jnp.atleast_2d(coords) if basis == "xyz": @@ -598,7 +598,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): return B -class PoloidalMagneticField(MagneticField): +class PoloidalMagneticField(_MagneticField): """Pure poloidal magnetic field (ie in theta direction). Field strength is B0*iota*r/R0 where B0 is the toroidal field on axis, @@ -622,7 +622,7 @@ class PoloidalMagneticField(MagneticField): """ - _io_attrs_ = MagneticField._io_attrs_ + ["_B0", "_R0", "_iota"] + _io_attrs_ = _MagneticField._io_attrs_ + ["_B0", "_R0", "_iota"] def __init__(self, B0, R0, iota): assert np.isscalar(B0), "B0 must be a scalar" @@ -656,7 +656,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): """ assert basis.lower() in ["rpz", "xyz"] - if isinstance(coords, Grid): + if hasattr(coords, "nodes"): coords = coords.nodes coords = jnp.atleast_2d(coords) if basis == "xyz": @@ -676,7 +676,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): return B -class SplineMagneticField(MagneticField): +class SplineMagneticField(_MagneticField): """Magnetic field from precomputed values on a grid. Parameters @@ -788,7 +788,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): """ assert basis.lower() in ["rpz", "xyz"] - if isinstance(coords, Grid): + if hasattr(coords, "nodes"): coords = coords.nodes coords = jnp.atleast_2d(coords) if basis == "xyz": @@ -981,7 +981,7 @@ def from_field( ) -class ScalarPotentialField(MagneticField): +class ScalarPotentialField(_MagneticField): """Magnetic field due to a scalar magnetic potential in cylindrical coordinates. Parameters @@ -1023,7 +1023,7 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None): """ assert basis.lower() in ["rpz", "xyz"] - if isinstance(coords, Grid): + if hasattr(coords, "nodes"): coords = coords.nodes coords = jnp.atleast_2d(coords) if basis == "xyz": diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index e9bfbf6650..d0fc30a21b 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -7,7 +7,6 @@ from desc.compute import data_index from desc.compute.utils import get_params, get_profiles, get_transforms from desc.grid import LinearGrid, QuadratureGrid -from desc.profiles import Profile from desc.utils import Timer from .normalization import compute_scaling_factors @@ -412,7 +411,7 @@ def build(self, eq=None, use_jit=True, verbose=1): else: grid = self._grid - if isinstance(self._target, Profile): + if callable(self._target): self._target = self._target(grid.nodes[grid.unique_rho_idx]) self._dim_f = grid.num_rho @@ -602,7 +601,7 @@ def build(self, eq=None, use_jit=True, verbose=1): else: grid = self._grid - if isinstance(self._target, Profile): + if callable(self._target): self._target = self._target(grid.nodes[grid.unique_rho_idx]) self._dim_f = grid.num_rho diff --git a/desc/profiles.py b/desc/profiles.py index 1db526a1de..3199803223 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -9,14 +9,14 @@ from desc.backend import jit, jnp, put, sign from desc.basis import FourierZernikeBasis, PowerSeries from desc.derivatives import Derivative -from desc.grid import Grid, LinearGrid +from desc.grid import Grid, LinearGrid, _Grid from desc.interpolate import interp1d from desc.io import IOAble from desc.transform import Transform from desc.utils import combination_permutation, copy_coeffs, multinomial_coefficients -class Profile(IOAble, ABC): +class _Profile(IOAble, ABC): """Abstract base class for profiles. All profile classes inherit from this, and must implement @@ -51,7 +51,7 @@ def grid(self): @grid.setter def grid(self, grid): - if isinstance(grid, Grid): + if isinstance(grid, _Grid): self._grid = grid return if np.isscalar(grid): @@ -235,7 +235,7 @@ def __mul__(self, x): """Multiply this profile by another or a constant.""" if np.isscalar(x): return ScaledProfile(x, self) - elif isinstance(x, Profile): + elif isinstance(x, _Profile): return ProductProfile(self, x) else: raise NotImplementedError() @@ -246,7 +246,7 @@ def __rmul__(self, x): def __add__(self, x): """Add this profile with another.""" - if isinstance(x, Profile): + if isinstance(x, _Profile): return SumProfile(self, x) else: raise NotImplementedError() @@ -260,7 +260,7 @@ def __sub__(self, x): return self.__add__(-x) -class ScaledProfile(Profile): +class ScaledProfile(_Profile): """Profile times a constant value. f_1(x) = a*f(x) @@ -274,11 +274,11 @@ class ScaledProfile(Profile): """ - _io_attrs_ = Profile._io_attrs_ + ["_profile", "_scale"] + _io_attrs_ = _Profile._io_attrs_ + ["_profile", "_scale"] def __init__(self, scale, profile, **kwargs): assert isinstance( - profile, Profile + profile, _Profile ), "profile in a ScaledProfile must be a Profile or subclass, got {}.".format( str(profile) ) @@ -298,7 +298,7 @@ def grid(self): @grid.setter def grid(self, new): - Profile.grid.fset(self, new) + _Profile.grid.fset(self, new) self._profile.grid = new @property @@ -361,7 +361,7 @@ def __repr__(self): return s -class SumProfile(Profile): +class SumProfile(_Profile): """Sum of two or more Profiles. f(x) = f1(x) + f2(x) + f3(x) ... @@ -373,12 +373,12 @@ class SumProfile(Profile): """ - _io_attrs_ = Profile._io_attrs_ + ["_profiles"] + _io_attrs_ = _Profile._io_attrs_ + ["_profiles"] def __init__(self, *profiles, **kwargs): self._profiles = [] for profile in profiles: - assert isinstance(profile, Profile), ( + assert isinstance(profile, _Profile), ( "Each profile in a SumProfile must be a Profile or " + "subclass, got {}.".format(str(profile)) ) @@ -395,7 +395,7 @@ def grid(self): @grid.setter def grid(self, new): - Profile.grid.fset(self, new) + _Profile.grid.fset(self, new) for profile in self._profiles: profile.grid = new @@ -459,7 +459,7 @@ def __repr__(self): return s -class ProductProfile(Profile): +class ProductProfile(_Profile): """Product of two or more Profiles. f(x) = f1(x) * f2(x) * f3(x) ... @@ -471,12 +471,12 @@ class ProductProfile(Profile): """ - _io_attrs_ = Profile._io_attrs_ + ["_profiles"] + _io_attrs_ = _Profile._io_attrs_ + ["_profiles"] def __init__(self, *profiles, **kwargs): self._profiles = [] for profile in profiles: - assert isinstance(profile, Profile), ( + assert isinstance(profile, _Profile), ( "Each profile in a ProductProfile must be a Profile or " + "subclass, got {}.".format(str(profile)) ) @@ -493,7 +493,7 @@ def grid(self): @grid.setter def grid(self, new): - Profile.grid.fset(self, new) + _Profile.grid.fset(self, new) for profile in self._profiles: profile.grid = new @@ -567,7 +567,7 @@ def __repr__(self): return s -class PowerSeriesProfile(Profile): +class PowerSeriesProfile(_Profile): """Profile represented by a monic power series. f(x) = a[0] + a[1]*x + a[2]*x**2 + ... @@ -589,7 +589,7 @@ class PowerSeriesProfile(Profile): Name of the profile. """ - _io_attrs_ = Profile._io_attrs_ + ["_basis", "_transform"] + _io_attrs_ = _Profile._io_attrs_ + ["_basis", "_transform"] def __init__(self, params=None, modes=None, grid=None, sym="auto", name=""): super().__init__(grid, name) @@ -622,7 +622,7 @@ def __init__(self, params=None, modes=None, grid=None, sym="auto", name=""): def _get_transform(self, grid): if grid is None: return self._transform - if not isinstance(grid, Grid): + if not isinstance(grid, _Grid): if np.isscalar(grid): grid = np.linspace(0, 1, grid) grid = np.atleast_1d(grid) @@ -661,7 +661,7 @@ def grid(self): @grid.setter def grid(self, new): - Profile.grid.fset(self, new) + _Profile.grid.fset(self, new) if hasattr(self, "_transform"): self._transform.grid = self.grid self._transform.build() @@ -776,7 +776,7 @@ def from_values( return cls(params, grid=grid, sym=sym, name=name) -class SplineProfile(Profile): +class SplineProfile(_Profile): """Profile represented by a piecewise cubic spline. Parameters @@ -800,7 +800,7 @@ class SplineProfile(Profile): """ - _io_attrs_ = Profile._io_attrs_ + ["_knots", "_method"] + _io_attrs_ = _Profile._io_attrs_ + ["_knots", "_method"] def __init__(self, values=None, knots=None, grid=None, method="cubic2", name=""): super().__init__(grid, name) @@ -830,7 +830,7 @@ def grid(self): @grid.setter def grid(self, new): - Profile.grid.fset(self, new) + _Profile.grid.fset(self, new) @property def params(self): @@ -850,7 +850,7 @@ def params(self, new): def _get_xq(self, grid): if grid is None: return self.grid.nodes[:, 0] - if isinstance(grid, Grid): + if hasattr(grid, "nodes"): return grid.nodes[:, 0] if np.isscalar(grid): return np.linspace(0, 1, grid) @@ -890,7 +890,7 @@ def compute(self, params=None, grid=None, dr=0, dt=0, dz=0): return fq -class MTanhProfile(Profile): +class MTanhProfile(_Profile): r"""Profile represented by a modified hyperbolic tangent + polynomial. Profile is parameterized by pedestal height (ped, :math:`p`), SOL height @@ -1009,7 +1009,7 @@ def _mtanh(x, ped, offset, sym, width, core_poly, dx=0): def _get_xq(self, grid): if grid is None: return self.grid.nodes[:, 0] - if isinstance(grid, Grid): + if hasattr(grid, "nodes"): return grid.nodes[:, 0] if np.isscalar(grid): return np.linspace(0, 1, grid) @@ -1147,7 +1147,7 @@ def from_values( return MTanhProfile(params, grid, name) -class FourierZernikeProfile(Profile): +class FourierZernikeProfile(_Profile): """Possibly anisotropic profile represented by Fourier-Zernike basis. Parameters @@ -1168,7 +1168,7 @@ class FourierZernikeProfile(Profile): """ - _io_attrs_ = Profile._io_attrs_ + ["_basis", "_transform"] + _io_attrs_ = _Profile._io_attrs_ + ["_basis", "_transform"] def __init__(self, params=None, modes=None, grid=None, sym="auto", NFP=1, name=""): super().__init__(grid, name) @@ -1208,7 +1208,7 @@ def __init__(self, params=None, modes=None, grid=None, sym="auto", NFP=1, name=" def _get_transform(self, grid): if grid is None: return self._transform - if not isinstance(grid, Grid): + if not isinstance(grid, _Grid): if np.isscalar(grid): grid = np.linspace(0, 1, grid) grid = np.atleast_1d(grid) @@ -1242,7 +1242,7 @@ def grid(self): @grid.setter def grid(self, new): - Profile.grid.fset(self, new) + _Profile.grid.fset(self, new) if hasattr(self, "_transform"): self._transform.grid = self.grid self._transform.build()