From ca50c3f9a10037d6fcb657e68d860a067227c61f Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 16 Aug 2024 19:40:27 -0400 Subject: [PATCH 1/9] Default to cubic hermite spline of iota in map_coordinates --- desc/equilibrium/coords.py | 15 ++++++++++----- desc/equilibrium/equilibrium.py | 18 ++++++++++++------ desc/profiles.py | 18 +++++++++++++++--- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 2fda119f04..5a4462c351 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -146,11 +146,11 @@ def compute(y, basis): grid = Grid(y, sort=False, jitable=True) data = {} if "iota" in deps: - data["iota"] = profiles["iota"](grid, params=params["i_l"]) + data["iota"] = profiles["iota"].compute(grid, params=params["i_l"]) if "iota_r" in deps: - data["iota_r"] = profiles["iota"](grid, dr=1, params=params["i_l"]) + data["iota_r"] = profiles["iota"].compute(grid, dr=1, params=params["i_l"]) if "iota_rr" in deps: - data["iota_rr"] = profiles["iota"](grid, dr=2, params=params["i_l"]) + data["iota_rr"] = profiles["iota"].compute(grid, dr=2, params=params["i_l"]) transforms = get_transforms(basis, eq, grid, jitable=True) data = compute_fun(eq, basis, params, transforms, profiles, data) x = jnp.array([data[k] for k in basis]).T @@ -243,7 +243,10 @@ def _initial_guess_heuristic(yk, coords, inbasis, eq, profiles): theta = coords[:, inbasis.index(poloidal)] elif poloidal == "alpha": alpha = coords[:, inbasis.index("alpha")] - iota = profiles["iota"](rho) + rho = jnp.atleast_1d(rho) + zero = jnp.zeros_like(rho) + grid = Grid(nodes=jnp.column_stack([rho, zero, zero]), sort=False, jitable=True) + iota = profiles["iota"].compute(grid) theta = (alpha + iota * zeta) % (2 * jnp.pi) yk = jnp.column_stack([rho, theta, zeta]) @@ -677,7 +680,7 @@ def get_rtz_grid( rtz : rho, theta, zeta period : tuple of float Assumed periodicity for each quantity in inbasis. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. jitable : bool, optional If false the returned grid has additional attributes. Required to be false to retain nodes at magnetic axis. @@ -691,6 +694,8 @@ def get_rtz_grid( grid = Grid.create_meshgrid( [radial, poloidal, toroidal], coordinates=coordinates, period=period ) + if "iota" in kwargs: + kwargs["iota"] = grid.expand(kwargs["iota"]) inbasis = { "r": "rho", "t": "theta", diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 41d205a595..2dae3ead21 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -732,6 +732,8 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs): ---------- name : str Name of the quantity to compute. + If list is given, then two names are expected: the quantity to spline + and its radial derivative. grid : Grid, optional Grid of coordinates to evaluate at. Defaults to the quadrature grid. Note profile will only be a function of the radial coordinate. @@ -748,10 +750,16 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs): if grid is None: grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP) data = self.compute(name, grid=grid, **kwargs) - f = data[name] - f = grid.compress(f, surface_label="rho") - x = grid.nodes[grid.unique_rho_idx, 0] - p = SplineProfile(f, x, name=name) + if isinstance(name, str): + f = data[name] + df = None + method = "cubic2" + else: + f = data[name[0]] + df = grid.compress(data[name[1]], surface_label="rho") + method = "cubic" + x, f = map(grid.compress, (grid.nodes[:, 0], f)) + p = SplineProfile(f, df=df, knots=x, method=method, name=name) if kind == "power_series": p = p.to_powerseries(order=min(self.L, len(x)), xs=x, sym=True) if kind == "fourier_zernike": @@ -1161,8 +1169,6 @@ def map_coordinates( Parameters ---------- - eq : Equilibrium - Equilibrium to use. coords : ndarray Shape (k, 3). 2D array of input coordinates. Each row is a different point in space. diff --git a/desc/profiles.py b/desc/profiles.py index e7dcc6dd77..733dbbca54 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -811,12 +811,14 @@ class SplineProfile(_Profile): - `'catmull-rom'`: C1 cubic centripetal "tension" splines name : str name of the profile + df : array-like + Optional. Values of the function derivative at knot locations. """ _io_attrs_ = _Profile._io_attrs_ + ["_knots", "_method"] - def __init__(self, values=None, knots=None, method="cubic2", name=""): + def __init__(self, values=None, knots=None, method="cubic2", name="", df=None): super().__init__(name) if values is None: @@ -828,6 +830,7 @@ def __init__(self, values=None, knots=None, method="cubic2", name=""): knots = np.atleast_1d(knots) self._knots = knots self._params = values + self._params_derivative = df self._method = method def __repr__(self): @@ -851,13 +854,14 @@ def params(self): def params(self, new): if len(new) == len(self._knots): self._params = jnp.asarray(new) + self._params_derivative = None else: raise ValueError( "params should have the same size as the knots, " + f"got {len(new)} values for {len(self._knots)} knots" ) - def compute(self, grid, params=None, dr=0, dt=0, dz=0): + def compute(self, grid, params=None, dr=0, dt=0, dz=0, params_derivative=None): """Compute values of profile at specified nodes. Parameters @@ -869,6 +873,9 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): values given by the params attribute dr, dt, dz : int derivative order in rho, theta, zeta + params_derivative : array-like + spline derivative values to use. If not given, uses the + values given by the params_derivative attribute Returns ------- @@ -878,12 +885,17 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): """ if params is None: params = self.params + if params_derivative is None: + params_derivative = self._params_derivative if dt != 0 or dz != 0: return jnp.zeros_like(grid.nodes[:, 0]) x = self.knots f = params + fx = {} + if params_derivative is not None: + fx["fx"] = params_derivative xq = grid.nodes[:, 0] - fq = interp1d(xq, x, f, method=self._method, derivative=dr, extrap=True) + fq = interp1d(xq, x, f, method=self._method, derivative=dr, extrap=True, **fx) return fq From bd24f083be0bab8fff22fd455fb7fa21cabcb956 Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 16 Aug 2024 19:48:14 -0400 Subject: [PATCH 2/9] Use jnp instead of np to avoid traced array conversion error Otherwise you'll get this error message in jitted functions: > g = obj.grad(obj.x()) tests/test_neoclassical.py:84: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ desc/objectives/objective_funs.py:415: in grad return jnp.atleast_1d(self._grad(x, constants).squeeze()) desc/derivatives.py:89: in __call__ return self.compute(*args, **kwargs) desc/derivatives.py:150: in compute return self._compute(*args, **kwargs) desc/objectives/objective_funs.py:335: in compute_scalar f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2 desc/objectives/objective_funs.py:312: in compute_scaled_error [ desc/objectives/objective_funs.py:313: in obj.compute_scaled_error(*par, constants=const) desc/objectives/objective_funs.py:952: in compute_scaled_error f = self.compute(*args, **kwargs) desc/objectives/_neoclassical.py:235: in compute iota=SplineProfile(iota, df=iota_r, knots=self._rho, name="iota", jnp=jnp), desc/profiles.py:827: in __init__ values = np.atleast_1d(values) ../../../miniconda3/envs/desc-env/lib/python3.10/site-packages/numpy/core/shape_base.py:65: in atleast_1d ary = asanyarray(ary) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = TracerArrayConversionError(Tracedwith) tracer = Tracedwith def __init__(self, tracer: core.Tracer): super().__init__( "The numpy.ndarray conversion method __array__() was called on " > f"{tracer._error_repr()}{tracer._origin_msg()}") E IndexError: list index out of range --- desc/profiles.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/desc/profiles.py b/desc/profiles.py index 733dbbca54..192e157bd1 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -823,11 +823,10 @@ def __init__(self, values=None, knots=None, method="cubic2", name="", df=None): if values is None: values = [0, 0, 0] - values = np.atleast_1d(values) + values = jnp.atleast_1d(values) if knots is None: - knots = np.linspace(0, 1, values.size) - else: - knots = np.atleast_1d(knots) + knots = jnp.linspace(0, 1, values.size) + knots = jnp.atleast_1d(knots) self._knots = knots self._params = values self._params_derivative = df From 70219e085ae520ed70179a258073047889e4dd8c Mon Sep 17 00:00:00 2001 From: unalmis Date: Fri, 16 Aug 2024 19:58:23 -0400 Subject: [PATCH 3/9] Add missing part of commit ca50c3f9a10037d6fcb657e68d860a067227c61f --- desc/equilibrium/coords.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index 5a4462c351..a89742b40f 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -138,7 +138,7 @@ def map_coordinates( # noqa: C901 # do surface average to get iota once if "iota" in profiles and profiles["iota"] is None: - profiles["iota"] = eq.get_profile("iota", params=params) + profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params) params["i_l"] = profiles["iota"].params @functools.partial(jit, static_argnums=1) From fb81104271f60311d9592c3500401da98c183e6e Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 18 Aug 2024 00:28:37 -0400 Subject: [PATCH 4/9] Add HermiteSplineProfile --- desc/equilibrium/equilibrium.py | 19 ++--- desc/profiles.py | 142 +++++++++++++++++++++++++------- 2 files changed, 119 insertions(+), 42 deletions(-) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 2dae3ead21..c58cd620b9 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -37,7 +37,7 @@ from desc.optimizable import Optimizable, optimizable_parameter from desc.optimize import Optimizer from desc.perturbations import perturb -from desc.profiles import PowerSeriesProfile, SplineProfile +from desc.profiles import HermiteSplineProfile, PowerSeriesProfile, SplineProfile from desc.transform import Transform from desc.utils import ( ResolutionWarning, @@ -750,20 +750,17 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs): if grid is None: grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP) data = self.compute(name, grid=grid, **kwargs) + knots = grid.compress(grid.nodes[:, 0]) if isinstance(name, str): - f = data[name] - df = None - method = "cubic2" + f = grid.compress(data[name]) + p = SplineProfile(f, knots, name=name) else: - f = data[name[0]] - df = grid.compress(data[name[1]], surface_label="rho") - method = "cubic" - x, f = map(grid.compress, (grid.nodes[:, 0], f)) - p = SplineProfile(f, df=df, knots=x, method=method, name=name) + f, dfdr = map(grid.compress, (data[name[0]], data[name[1]])) + p = HermiteSplineProfile(knots, f, dfdr, name=name) if kind == "power_series": - p = p.to_powerseries(order=min(self.L, len(x)), xs=x, sym=True) + p = p.to_powerseries(order=min(self.L, grid.num_rho), xs=knots, sym=True) if kind == "fourier_zernike": - p = p.to_fourierzernike(L=min(self.L, len(x)), xs=x) + p = p.to_fourierzernike(L=min(self.L, grid.num_rho), xs=knots) return p def get_axis(self): diff --git a/desc/profiles.py b/desc/profiles.py index 192e157bd1..aaa9e92d65 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -17,6 +17,7 @@ copy_coeffs, errorif, multinomial_coefficients, + setdefault, warnif, ) @@ -798,38 +799,35 @@ class SplineProfile(_Profile): Parameters ---------- values: array-like - Values of the function at knot locations. + Array containing values of the dependent variable. knots : int or ndarray - x locations to use for spline. If an integer, uses that many points linearly - spaced between 0,1 + 1-D array containing values of the independent variable. + Values must be real, finite, and in strictly increasing order in [0, 1]. + If an integer, uses that many uniformly spaced points. method : str - method of interpolation + Method of interpolation. Default is cubic2. - `'nearest'`: nearest neighbor interpolation - `'linear'`: linear interpolation - `'cubic'`: C1 cubic splines (aka local splines) - `'cubic2'`: C2 cubic splines (aka natural splines) - `'catmull-rom'`: C1 cubic centripetal "tension" splines name : str - name of the profile - df : array-like - Optional. Values of the function derivative at knot locations. + Optional name of the profile. """ _io_attrs_ = _Profile._io_attrs_ + ["_knots", "_method"] - def __init__(self, values=None, knots=None, method="cubic2", name="", df=None): + def __init__(self, values=None, knots=None, method="cubic2", name=""): super().__init__(name) if values is None: values = [0, 0, 0] - values = jnp.atleast_1d(values) if knots is None: knots = jnp.linspace(0, 1, values.size) - knots = jnp.atleast_1d(knots) + knots, values = jnp.atleast_1d(knots, values) self._knots = knots self._params = values - self._params_derivative = df self._method = method def __repr__(self): @@ -851,14 +849,12 @@ def params(self): @params.setter def params(self, new): - if len(new) == len(self._knots): - self._params = jnp.asarray(new) - self._params_derivative = None - else: - raise ValueError( - "params should have the same size as the knots, " - + f"got {len(new)} values for {len(self._knots)} knots" - ) + errorif( + len(new) != len(self._knots), + msg="params should have the same size as the knots, " + + f"got {len(new)} values for {len(self._knots)} knots", + ) + self._params = jnp.asarray(new) def compute(self, grid, params=None, dr=0, dt=0, dz=0, params_derivative=None): """Compute values of profile at specified nodes. @@ -882,20 +878,104 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0, params_derivative=None): values of the profile or its derivative at the points specified """ - if params is None: - params = self.params - if params_derivative is None: - params_derivative = self._params_derivative if dt != 0 or dz != 0: return jnp.zeros_like(grid.nodes[:, 0]) - x = self.knots - f = params - fx = {} - if params_derivative is not None: - fx["fx"] = params_derivative - xq = grid.nodes[:, 0] - fq = interp1d(xq, x, f, method=self._method, derivative=dr, extrap=True, **fx) - return fq + params = setdefault(params, self._params) + return interp1d( + xq=grid.nodes[:, 0], + x=self._knots, + f=params, + method=self._method, + derivative=dr, + extrap=True, + ) + + +class HermiteSplineProfile(_Profile): + """Profile represented by a piecewise cubic Hermite spline. + + Parameters + ---------- + r : array-like + 1-D array containing values of the independent variable. + Values must be real, finite, and in strictly increasing order in [0, 1]. + f: array-like + Array containing values of the dependent variable. + dfdr: array-like + Array containing derivatives of the dependent variable. + name : str + Optional name of the profile. + + """ + + _io_attrs_ = _Profile._io_attrs_ + ["_knots", "_params"] + + def __init__(self, r, f, dfdr, name=""): + super().__init__(name) + r, f, dfdr = jnp.atleast_1d(r, f, dfdr) + self._knots = r + self._params = jnp.stack([f, dfdr]) + + def __repr__(self): + """Get the string form of the object.""" + s = super().__repr__() + s = s[:-1] + s += ", num_knots={})".format(len(self._knots)) + return s + + @property + def knots(self): + """ndarray: Knot locations.""" + return self._knots + + @property + def params(self): + """ndarray: Parameters for computation. + + First (second) index stores function (derivative) values. + """ + return self._params + + @params.setter + def params(self, new): + new = jnp.asarray(new) + errorif( + new.shape[-1] != self._knots.shape[-1], + msg="Params should have shape that broadcast with knots. " + f"Got {new.shape} params for {self._knots} knots.", + ) + self._params = new + + def compute(self, grid, params=None, dr=0, dt=0, dz=0): + """Compute values of profile at specified nodes. + + Parameters + ---------- + grid : Grid + locations to compute values at. + params : array-like + First (second) index stores function (derivative) values + evaluated at knots. Defaults to ``self.params``. + dr, dt, dz : int + derivative order in rho, theta, zeta + + Returns + ------- + f : ndarray + Array containing values of the dependent variable at the points specified. + + """ + if dt != 0 or dz != 0: + return jnp.zeros_like(grid.nodes[:, 0]) + params = setdefault(params, self._params) + return interp1d( + xq=grid.nodes[:, 0], + x=self._knots, + f=params[0], + fx=params[1], + derivative=dr, + extrap=True, + ) class MTanhProfile(_Profile): From e4b93ca143332ad720eed7f7fce320b85de1b5d8 Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 18 Aug 2024 00:35:08 -0400 Subject: [PATCH 5/9] Fix incorrect docstring for SplineProfile --- desc/profiles.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/desc/profiles.py b/desc/profiles.py index aaa9e92d65..74e70ca099 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -614,7 +614,7 @@ def get_params(self, l): def set_params(self, l, a=None): """Set specific power series coefficients.""" - l, a = np.atleast_1d(l), np.atleast_1d(a) + l, a = np.atleast_1d(l, a) a = np.broadcast_to(a, l.shape) for ll, aa in zip(l, a): idx = self.basis.get_idx(ll, 0, 0) @@ -800,10 +800,10 @@ class SplineProfile(_Profile): ---------- values: array-like Array containing values of the dependent variable. - knots : int or ndarray + knots : array-like 1-D array containing values of the independent variable. Values must be real, finite, and in strictly increasing order in [0, 1]. - If an integer, uses that many uniformly spaced points. + If not given, assumes values is uniformly spaced in [0, 1]. method : str Method of interpolation. Default is cubic2. - `'nearest'`: nearest neighbor interpolation From 974fc96265a22828436d80600db52a6fbe2a69b5 Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 18 Aug 2024 00:38:55 -0400 Subject: [PATCH 6/9] Remove now unused parameter in SplineProfile --- desc/profiles.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/desc/profiles.py b/desc/profiles.py index 74e70ca099..6f6b3a41af 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -856,7 +856,7 @@ def params(self, new): ) self._params = jnp.asarray(new) - def compute(self, grid, params=None, dr=0, dt=0, dz=0, params_derivative=None): + def compute(self, grid, params=None, dr=0, dt=0, dz=0): """Compute values of profile at specified nodes. Parameters @@ -868,9 +868,6 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0, params_derivative=None): values given by the params attribute dr, dt, dz : int derivative order in rho, theta, zeta - params_derivative : array-like - spline derivative values to use. If not given, uses the - values given by the params_derivative attribute Returns ------- From f423d556d8eb380fd73a8b59c56746fb54979d1b Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 18 Aug 2024 00:41:57 -0400 Subject: [PATCH 7/9] Undo unneeded change --- desc/profiles.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/desc/profiles.py b/desc/profiles.py index 6f6b3a41af..524a5ab727 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -823,9 +823,10 @@ def __init__(self, values=None, knots=None, method="cubic2", name=""): if values is None: values = [0, 0, 0] + values = jnp.atleast_1d(values) if knots is None: knots = jnp.linspace(0, 1, values.size) - knots, values = jnp.atleast_1d(knots, values) + knots = jnp.atleast_1d(knots) self._knots = knots self._params = values self._method = method From 031e840556bf568fc9985598518e03348ac424a4 Mon Sep 17 00:00:00 2001 From: unalmis Date: Sun, 18 Aug 2024 16:19:14 -0400 Subject: [PATCH 8/9] Change HermiteSplineAPI to ensure 2N parameters are optimizable --- desc/equilibrium/equilibrium.py | 4 +-- desc/profiles.py | 63 +++++++++++++++++++-------------- tests/test_profiles.py | 13 +++++++ 3 files changed, 51 insertions(+), 29 deletions(-) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index c58cd620b9..8d09d5f64b 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -755,8 +755,8 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs): f = grid.compress(data[name]) p = SplineProfile(f, knots, name=name) else: - f, dfdr = map(grid.compress, (data[name[0]], data[name[1]])) - p = HermiteSplineProfile(knots, f, dfdr, name=name) + f, df = map(grid.compress, (data[name[0]], data[name[1]])) + p = HermiteSplineProfile(f, df, knots, name=name) if kind == "power_series": p = p.to_powerseries(order=min(self.L, grid.num_rho), xs=knots, sym=True) if kind == "fourier_zernike": diff --git a/desc/profiles.py b/desc/profiles.py index 524a5ab727..69dff6ca56 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -794,16 +794,16 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): class SplineProfile(_Profile): - """Profile represented by a piecewise cubic spline. + """Radial profile represented by a piecewise cubic spline. Parameters ---------- values: array-like - Array containing values of the dependent variable. + 1-D Array containing values of the dependent variable. knots : array-like 1-D array containing values of the independent variable. - Values must be real, finite, and in strictly increasing order in [0, 1]. - If not given, assumes values is uniformly spaced in [0, 1]. + Must be real, finite, and in strictly increasing order in [0, 1]. + If not given, assumes ``values`` is uniformly spaced in [0, 1]. method : str Method of interpolation. Default is cubic2. - `'nearest'`: nearest neighbor interpolation @@ -827,6 +827,8 @@ def __init__(self, values=None, knots=None, method="cubic2", name=""): if knots is None: knots = jnp.linspace(0, 1, values.size) knots = jnp.atleast_1d(knots) + errorif(values.shape[-1] != knots.shape[-1]) + errorif(not (values.ndim == knots.ndim == 1), NotImplementedError) self._knots = knots self._params = values self._method = method @@ -835,7 +837,7 @@ def __repr__(self): """Get the string form of the object.""" s = super().__repr__() s = s[:-1] - s += ", method={}, num_knots={})".format(self._method, len(self._knots)) + s += ", method={}, num_knots={})".format(self._method, self._knots.size) return s @property @@ -851,9 +853,9 @@ def params(self): @params.setter def params(self, new): errorif( - len(new) != len(self._knots), + len(new) != self._knots.size, msg="params should have the same size as the knots, " - + f"got {len(new)} values for {len(self._knots)} knots", + + f"got {len(new)} values for {self._knots.size} knots", ) self._params = jnp.asarray(new) @@ -890,17 +892,18 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): class HermiteSplineProfile(_Profile): - """Profile represented by a piecewise cubic Hermite spline. + """Radial profile represented by a piecewise cubic Hermite spline. Parameters ---------- - r : array-like - 1-D array containing values of the independent variable. - Values must be real, finite, and in strictly increasing order in [0, 1]. f: array-like - Array containing values of the dependent variable. - dfdr: array-like - Array containing derivatives of the dependent variable. + 1-D Array containing values of the dependent variable. + df: array-like + 1-D Array containing derivatives of the dependent variable. + knots : array-like + 1-D array containing values of the independent variable. + Must be real, finite, and in strictly increasing order in [0, 1]. + If not given, assumes ``values`` is uniformly spaced in [0, 1]. name : str Optional name of the profile. @@ -908,17 +911,23 @@ class HermiteSplineProfile(_Profile): _io_attrs_ = _Profile._io_attrs_ + ["_knots", "_params"] - def __init__(self, r, f, dfdr, name=""): + def __init__(self, f, df, knots=None, name=""): super().__init__(name) - r, f, dfdr = jnp.atleast_1d(r, f, dfdr) - self._knots = r - self._params = jnp.stack([f, dfdr]) + + f, df = jnp.atleast_1d(f, df) + if knots is None: + knots = jnp.linspace(0, 1, f.size) + knots = jnp.atleast_1d(knots) + errorif(not (f.shape[-1] == df.shape[-1] == knots.shape[-1])) + errorif(not (f.ndim == df.ndim == knots.ndim == 1), NotImplementedError) + self._knots = knots + self._params = jnp.concatenate([f, df]) def __repr__(self): """Get the string form of the object.""" s = super().__repr__() s = s[:-1] - s += ", num_knots={})".format(len(self._knots)) + s += ", num_knots={})".format(self._knots.size) return s @property @@ -930,7 +939,7 @@ def knots(self): def params(self): """ndarray: Parameters for computation. - First (second) index stores function (derivative) values. + First (second) half stores function (derivative) values at ``knots``. """ return self._params @@ -938,9 +947,9 @@ def params(self): def params(self, new): new = jnp.asarray(new) errorif( - new.shape[-1] != self._knots.shape[-1], - msg="Params should have shape that broadcast with knots. " - f"Got {new.shape} params for {self._knots} knots.", + new.ndim != 1 or new.size != 2 * self._knots.size, + msg="Params should be 1D with size twice number of knots. " + f"Got {new.shape} params for {self._knots.size} knots.", ) self._params = new @@ -952,8 +961,8 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): grid : Grid locations to compute values at. params : array-like - First (second) index stores function (derivative) values - evaluated at knots. Defaults to ``self.params``. + First (second) half stores function (derivative) values at ``knots``. + Defaults to ``self.params``. dr, dt, dz : int derivative order in rho, theta, zeta @@ -969,8 +978,8 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): return interp1d( xq=grid.nodes[:, 0], x=self._knots, - f=params[0], - fx=params[1], + f=params[: self._knots.size], + fx=params[self._knots.size :], derivative=dr, extrap=True, ) diff --git a/tests/test_profiles.py b/tests/test_profiles.py index 530fd5f765..201996c1a7 100644 --- a/tests/test_profiles.py +++ b/tests/test_profiles.py @@ -6,6 +6,7 @@ from scipy.interpolate import interp1d from desc.equilibrium import Equilibrium +from desc.examples import get from desc.grid import LinearGrid from desc.io import InputReader from desc.objectives import ( @@ -15,6 +16,7 @@ ) from desc.profiles import ( FourierZernikeProfile, + HermiteSplineProfile, MTanhProfile, PowerSeriesProfile, SplineProfile, @@ -507,3 +509,14 @@ def test_kinetic_pressure(self): assert np.all(data2["Te_r"] == data2["Ti_r"]) np.testing.assert_allclose(data1["p"], data2["p"]) np.testing.assert_allclose(data1["p_r"], data2["p_r"]) + + @pytest.mark.unit + def test_hermite_spline_solve(self): + """Test that spline with double number of parameters is optimized.""" + eq = get("DSHAPE") + rho = np.linspace(0, 1.0, 20, endpoint=True) + eq.pressure = HermiteSplineProfile( + eq.pressure(rho), eq.pressure(rho, dr=1), rho + ) + eq.solve() + assert eq.is_nested() From c49f31c4c514cfe0a17ee9ce5a623d40da7e0e72 Mon Sep 17 00:00:00 2001 From: unalmis Date: Mon, 19 Aug 2024 17:21:28 -0400 Subject: [PATCH 9/9] Improve docstrings of spline profiles --- desc/compute/utils.py | 2 +- desc/profiles.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 92c41a000f..f6c7b12e68 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -216,7 +216,7 @@ def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None): Returns ------- - deps : list of str + deps : list[str] Names of quantities needed to compute key. """ diff --git a/desc/profiles.py b/desc/profiles.py index 69dff6ca56..064871e583 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -799,11 +799,11 @@ class SplineProfile(_Profile): Parameters ---------- values: array-like - 1-D Array containing values of the dependent variable. + 1-D array containing values of the dependent variable. knots : array-like 1-D array containing values of the independent variable. Must be real, finite, and in strictly increasing order in [0, 1]. - If not given, assumes ``values`` is uniformly spaced in [0, 1]. + If ``None``, assumes ``values`` is given on knots uniformly spaced in [0, 1]. method : str Method of interpolation. Default is cubic2. - `'nearest'`: nearest neighbor interpolation @@ -865,10 +865,10 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): Parameters ---------- grid : Grid - locations to compute values at. + Locations to compute values at. params : array-like - spline values to use. If not given, uses the - values given by the params attribute + Values of the function at ``self.knots``. + If not given, uses ``self.params``. dr, dt, dz : int derivative order in rho, theta, zeta @@ -897,13 +897,14 @@ class HermiteSplineProfile(_Profile): Parameters ---------- f: array-like - 1-D Array containing values of the dependent variable. + 1-D array containing values of the dependent variable. df: array-like - 1-D Array containing derivatives of the dependent variable. + 1-D array containing derivatives of the dependent variable. knots : array-like 1-D array containing values of the independent variable. Must be real, finite, and in strictly increasing order in [0, 1]. - If not given, assumes ``values`` is uniformly spaced in [0, 1]. + If ``None``, assumes ``f`` and ``df`` are given on knots uniformly + spaced in [0, 1]. name : str Optional name of the profile. @@ -959,10 +960,10 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0): Parameters ---------- grid : Grid - locations to compute values at. + Locations to compute values at. params : array-like First (second) half stores function (derivative) values at ``knots``. - Defaults to ``self.params``. + If not given, uses ``self.params``. dr, dt, dz : int derivative order in rho, theta, zeta