diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 92c41a000f..f6c7b12e68 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -216,7 +216,7 @@ def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None): Returns ------- - deps : list of str + deps : list[str] Names of quantities needed to compute key. """ diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 2fda119f04..a89742b40f 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -138,7 +138,7 @@ def map_coordinates( # noqa: C901 # do surface average to get iota once if "iota" in profiles and profiles["iota"] is None: - profiles["iota"] = eq.get_profile("iota", params=params) + profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params) params["i_l"] = profiles["iota"].params @functools.partial(jit, static_argnums=1) @@ -146,11 +146,11 @@ def compute(y, basis): grid = Grid(y, sort=False, jitable=True) data = {} if "iota" in deps: - data["iota"] = profiles["iota"](grid, params=params["i_l"]) + data["iota"] = profiles["iota"].compute(grid, params=params["i_l"]) if "iota_r" in deps: - data["iota_r"] = profiles["iota"](grid, dr=1, params=params["i_l"]) + data["iota_r"] = profiles["iota"].compute(grid, dr=1, params=params["i_l"]) if "iota_rr" in deps: - data["iota_rr"] = profiles["iota"](grid, dr=2, params=params["i_l"]) + data["iota_rr"] = profiles["iota"].compute(grid, dr=2, params=params["i_l"]) transforms = get_transforms(basis, eq, grid, jitable=True) data = compute_fun(eq, basis, params, transforms, profiles, data) x = jnp.array([data[k] for k in basis]).T @@ -243,7 +243,10 @@ def _initial_guess_heuristic(yk, coords, inbasis, eq, profiles): theta = coords[:, inbasis.index(poloidal)] elif poloidal == "alpha": alpha = coords[:, inbasis.index("alpha")] - iota = profiles["iota"](rho) + rho = jnp.atleast_1d(rho) + zero = jnp.zeros_like(rho) + grid = Grid(nodes=jnp.column_stack([rho, zero, zero]), sort=False, jitable=True) + iota = profiles["iota"].compute(grid) theta = (alpha + iota * zeta) % (2 * jnp.pi) yk = jnp.column_stack([rho, theta, zeta]) @@ -677,7 +680,7 @@ def get_rtz_grid( rtz : rho, theta, zeta period : tuple of float Assumed periodicity for each quantity in inbasis. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. jitable : bool, optional If false the returned grid has additional attributes. Required to be false to retain nodes at magnetic axis. @@ -691,6 +694,8 @@ def get_rtz_grid( grid = Grid.create_meshgrid( [radial, poloidal, toroidal], coordinates=coordinates, period=period ) + if "iota" in kwargs: + kwargs["iota"] = grid.expand(kwargs["iota"]) inbasis = { "r": "rho", "t": "theta", diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 41d205a595..8d09d5f64b 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -37,7 +37,7 @@ from desc.optimizable import Optimizable, optimizable_parameter from desc.optimize import Optimizer from desc.perturbations import perturb -from desc.profiles import PowerSeriesProfile, SplineProfile +from desc.profiles import HermiteSplineProfile, PowerSeriesProfile, SplineProfile from desc.transform import Transform from desc.utils import ( ResolutionWarning, @@ -732,6 +732,8 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs): ---------- name : str Name of the quantity to compute. + If list is given, then two names are expected: the quantity to spline + and its radial derivative. grid : Grid, optional Grid of coordinates to evaluate at. Defaults to the quadrature grid. Note profile will only be a function of the radial coordinate. @@ -748,14 +750,17 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs): if grid is None: grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP) data = self.compute(name, grid=grid, **kwargs) - f = data[name] - f = grid.compress(f, surface_label="rho") - x = grid.nodes[grid.unique_rho_idx, 0] - p = SplineProfile(f, x, name=name) + knots = grid.compress(grid.nodes[:, 0]) + if isinstance(name, str): + f = grid.compress(data[name]) + p = SplineProfile(f, knots, name=name) + else: + f, df = map(grid.compress, (data[name[0]], data[name[1]])) + p = HermiteSplineProfile(f, df, knots, name=name) if kind == "power_series": - p = p.to_powerseries(order=min(self.L, len(x)), xs=x, sym=True) + p = p.to_powerseries(order=min(self.L, grid.num_rho), xs=knots, sym=True) if kind == "fourier_zernike": - p = p.to_fourierzernike(L=min(self.L, len(x)), xs=x) + p = p.to_fourierzernike(L=min(self.L, grid.num_rho), xs=knots) return p def get_axis(self): @@ -1161,8 +1166,6 @@ def map_coordinates( Parameters ---------- - eq : Equilibrium - Equilibrium to use. coords : ndarray Shape (k, 3). 2D array of input coordinates. Each row is a different point in space. diff --git a/desc/profiles.py b/desc/profiles.py index e7dcc6dd77..064871e583 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -17,6 +17,7 @@ copy_coeffs, errorif, multinomial_coefficients, + setdefault, warnif, ) @@ -613,7 +614,7 @@ def get_params(self, l): def set_params(self, l, a=None): """Set specific power series coefficients.""" - l, a = np.atleast_1d(l), np.atleast_1d(a) + l, a = np.atleast_1d(l, a) a = np.broadcast_to(a, l.shape) for ll, aa in zip(l, a): idx = self.basis.get_idx(ll, 0, 0) @@ -793,24 +794,25 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): class SplineProfile(_Profile): - """Profile represented by a piecewise cubic spline. + """Radial profile represented by a piecewise cubic spline. Parameters ---------- values: array-like - Values of the function at knot locations. - knots : int or ndarray - x locations to use for spline. If an integer, uses that many points linearly - spaced between 0,1 + 1-D array containing values of the dependent variable. + knots : array-like + 1-D array containing values of the independent variable. + Must be real, finite, and in strictly increasing order in [0, 1]. + If ``None``, assumes ``values`` is given on knots uniformly spaced in [0, 1]. method : str - method of interpolation + Method of interpolation. Default is cubic2. - `'nearest'`: nearest neighbor interpolation - `'linear'`: linear interpolation - `'cubic'`: C1 cubic splines (aka local splines) - `'cubic2'`: C2 cubic splines (aka natural splines) - `'catmull-rom'`: C1 cubic centripetal "tension" splines name : str - name of the profile + Optional name of the profile. """ @@ -821,11 +823,12 @@ def __init__(self, values=None, knots=None, method="cubic2", name=""): if values is None: values = [0, 0, 0] - values = np.atleast_1d(values) + values = jnp.atleast_1d(values) if knots is None: - knots = np.linspace(0, 1, values.size) - else: - knots = np.atleast_1d(knots) + knots = jnp.linspace(0, 1, values.size) + knots = jnp.atleast_1d(knots) + errorif(values.shape[-1] != knots.shape[-1]) + errorif(not (values.ndim == knots.ndim == 1), NotImplementedError) self._knots = knots self._params = values self._method = method @@ -834,7 +837,7 @@ def __repr__(self): """Get the string form of the object.""" s = super().__repr__() s = s[:-1] - s += ", method={}, num_knots={})".format(self._method, len(self._knots)) + s += ", method={}, num_knots={})".format(self._method, self._knots.size) return s @property @@ -849,13 +852,12 @@ def params(self): @params.setter def params(self, new): - if len(new) == len(self._knots): - self._params = jnp.asarray(new) - else: - raise ValueError( - "params should have the same size as the knots, " - + f"got {len(new)} values for {len(self._knots)} knots" - ) + errorif( + len(new) != self._knots.size, + msg="params should have the same size as the knots, " + + f"got {len(new)} values for {self._knots.size} knots", + ) + self._params = jnp.asarray(new) def compute(self, grid, params=None, dr=0, dt=0, dz=0): """Compute values of profile at specified nodes. @@ -863,10 +865,10 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): Parameters ---------- grid : Grid - locations to compute values at. + Locations to compute values at. params : array-like - spline values to use. If not given, uses the - values given by the params attribute + Values of the function at ``self.knots``. + If not given, uses ``self.params``. dr, dt, dz : int derivative order in rho, theta, zeta @@ -876,15 +878,112 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): values of the profile or its derivative at the points specified """ - if params is None: - params = self.params if dt != 0 or dz != 0: return jnp.zeros_like(grid.nodes[:, 0]) - x = self.knots - f = params - xq = grid.nodes[:, 0] - fq = interp1d(xq, x, f, method=self._method, derivative=dr, extrap=True) - return fq + params = setdefault(params, self._params) + return interp1d( + xq=grid.nodes[:, 0], + x=self._knots, + f=params, + method=self._method, + derivative=dr, + extrap=True, + ) + + +class HermiteSplineProfile(_Profile): + """Radial profile represented by a piecewise cubic Hermite spline. + + Parameters + ---------- + f: array-like + 1-D array containing values of the dependent variable. + df: array-like + 1-D array containing derivatives of the dependent variable. + knots : array-like + 1-D array containing values of the independent variable. + Must be real, finite, and in strictly increasing order in [0, 1]. + If ``None``, assumes ``f`` and ``df`` are given on knots uniformly + spaced in [0, 1]. + name : str + Optional name of the profile. + + """ + + _io_attrs_ = _Profile._io_attrs_ + ["_knots", "_params"] + + def __init__(self, f, df, knots=None, name=""): + super().__init__(name) + + f, df = jnp.atleast_1d(f, df) + if knots is None: + knots = jnp.linspace(0, 1, f.size) + knots = jnp.atleast_1d(knots) + errorif(not (f.shape[-1] == df.shape[-1] == knots.shape[-1])) + errorif(not (f.ndim == df.ndim == knots.ndim == 1), NotImplementedError) + self._knots = knots + self._params = jnp.concatenate([f, df]) + + def __repr__(self): + """Get the string form of the object.""" + s = super().__repr__() + s = s[:-1] + s += ", num_knots={})".format(self._knots.size) + return s + + @property + def knots(self): + """ndarray: Knot locations.""" + return self._knots + + @property + def params(self): + """ndarray: Parameters for computation. + + First (second) half stores function (derivative) values at ``knots``. + """ + return self._params + + @params.setter + def params(self, new): + new = jnp.asarray(new) + errorif( + new.ndim != 1 or new.size != 2 * self._knots.size, + msg="Params should be 1D with size twice number of knots. " + f"Got {new.shape} params for {self._knots.size} knots.", + ) + self._params = new + + def compute(self, grid, params=None, dr=0, dt=0, dz=0): + """Compute values of profile at specified nodes. + + Parameters + ---------- + grid : Grid + Locations to compute values at. + params : array-like + First (second) half stores function (derivative) values at ``knots``. + If not given, uses ``self.params``. + dr, dt, dz : int + derivative order in rho, theta, zeta + + Returns + ------- + f : ndarray + Array containing values of the dependent variable at the points specified. + + """ + if dt != 0 or dz != 0: + return jnp.zeros_like(grid.nodes[:, 0]) + params = setdefault(params, self._params) + return interp1d( + xq=grid.nodes[:, 0], + x=self._knots, + f=params[: self._knots.size], + fx=params[self._knots.size :], + derivative=dr, + extrap=True, + ) class MTanhProfile(_Profile): diff --git a/tests/test_profiles.py b/tests/test_profiles.py index 530fd5f765..201996c1a7 100644 --- a/tests/test_profiles.py +++ b/tests/test_profiles.py @@ -6,6 +6,7 @@ from scipy.interpolate import interp1d from desc.equilibrium import Equilibrium +from desc.examples import get from desc.grid import LinearGrid from desc.io import InputReader from desc.objectives import ( @@ -15,6 +16,7 @@ ) from desc.profiles import ( FourierZernikeProfile, + HermiteSplineProfile, MTanhProfile, PowerSeriesProfile, SplineProfile, @@ -507,3 +509,14 @@ def test_kinetic_pressure(self): assert np.all(data2["Te_r"] == data2["Ti_r"]) np.testing.assert_allclose(data1["p"], data2["p"]) np.testing.assert_allclose(data1["p_r"], data2["p_r"]) + + @pytest.mark.unit + def test_hermite_spline_solve(self): + """Test that spline with double number of parameters is optimized.""" + eq = get("DSHAPE") + rho = np.linspace(0, 1.0, 20, endpoint=True) + eq.pressure = HermiteSplineProfile( + eq.pressure(rho), eq.pressure(rho, dr=1), rho + ) + eq.solve() + assert eq.is_nested()