From 456a02bb80e99f8608baf4e48693854c010ab08f Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 17 May 2024 17:22:13 -0400 Subject: [PATCH 01/34] initial commit --- desc/objectives/__init__.py | 7 +- desc/objectives/_generic.py | 224 +++++++++++++++++++++++++++++- desc/objectives/objective_funs.py | 5 +- 3 files changed, 225 insertions(+), 11 deletions(-) diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index a99da8cfcf..222fb61e81 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -11,7 +11,12 @@ RadialForceBalance, ) from ._free_boundary import BoundaryError, VacuumBoundaryError -from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser +from ._generic import ( + ExternalObjective, + GenericObjective, + LinearObjectiveFromUser, + ObjectiveFromUser, +) from ._geometry import ( AspectRatio, BScaleLength, diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 23d4049876..1c666f219e 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -1,11 +1,12 @@ """Generic objectives that don't belong anywhere else.""" +import functools import inspect import re import numpy as np -from desc.backend import jnp +from desc.backend import jax, jnp from desc.compute import compute as compute_fun from desc.compute import data_index from desc.compute.utils import get_profiles, get_transforms @@ -14,6 +15,216 @@ from .linear_objectives import _FixedObjective from .objective_funs import _Objective +# Notes on deriv_mode: +# batched = Derivative class calls jax.jacfwd on ObjectiveFunction.compute +# looped = Derivative class calls jax.jvp on ObjectiveFunction.compute for each column +# blocked = block matrix of _Objective.jac for each objective in ObjectiveFunction + + +class ExternalObjective(_Objective): + """Wrap an external code. + + Similar to ``ObjectiveFromUser``, except derivatives of the objective function are + computed with finite differences instead of AD. + + The user supplied function should take one positional argument ``params``, which is + a list of the same length as `things` and corresponds to `thing.params_dict` for + each thing in things. + + Parameters + ---------- + fun : callable + Custom objective function. + things : Optimizable or tuple/list of Optimizable + Objects that will be optimized to satisfy the Objective. + target : {float, ndarray}, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. + bounds : tuple of {float, ndarray}, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f. + Defaults to ``target=0``. + weight : {float, ndarray}, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + Has no effect for this objective. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + loss_function : {None, 'mean', 'min', 'max'}, optional + Loss function to apply to the objective values once computed. This loss function + is called on the raw compute value, before any shifting, scaling, or + normalization. + deriv_mode : {"auto", "fwd", "rev"} # TODO: edit this + Specify how to compute jacobian matrix, either forward mode or reverse mode AD. + "auto" selects forward or reverse mode based on the size of the input and output + of the objective. Has no effect on self.grad or self.hess which always use + reverse mode and forward over reverse mode respectively. + name : str, optional + Name of the objective function. + + # TODO: add example + + """ + + _units = "(Unknown)" + _print_value_fmt = "External objective value: {:10.3e}" + + def __init__( + self, + fun, + things, + target=None, + bounds=None, + weight=1, + normalize=False, + normalize_target=False, + loss_function=None, + deriv_mode="auto", + fd_step=1e-4, # TODO: generalize this to allow a vector of different scales + name="external", # TODO: add kwargs to pass to external function + ): + if target is None and bounds is None: + target = 0 + self._fun = fun + self._fd_step = fd_step + super().__init__( + things=things, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + name=name, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + self._dim_f = 1 # FIXME: does this need to be a user input? + self._scalar = self._dim_f == 1 + self._constants = {"quad_weights": 1.0} + + self._args = self._fun.__code__.co_varnames[: self._fun.__code__.co_argcount] + + abstract_eval = lambda *args, **kwargs: jnp.array([1.0]) # FIXME: use dim_f? + self._fun_wrapped = self._jaxify(self._fun, abstract_eval) + + # TODO: test if jaxify is actually working + + super().build(use_jit=use_jit, verbose=verbose) + + def compute(self, params, constants=None): + """Compute the quantity. + + Parameters + ---------- + params : list of dict + List of dictionaries of degrees of freedom, eg CoilSet.params_dict + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self.constants + + Returns + ------- + f : ndarray + Computed quantity. + + """ + if constants is None: + constants = self.constants + + new_params = {k: params[k] for k in self._args} + f = self._fun_wrapped(**new_params) + return f + + def _jaxify(self, func, abstract_eval): + """Make an external (python) function work with JAX. + + Positional arguments to func can be differentiated, + use keyword args for static values and non-differentiable stuff. + + Note: Only forward mode differentiation is supported currently. + + Parameters + ---------- + func : callable + Function to wrap. Should be a "pure" function, in that it has no side + effects and doesn't maintain state. Does not need to be JAX transformable. + abstract_eval : callable + Auxilliary function that computes the output shape and dtype of func. + **Must be JAX transformable**. Should be of the form + + abstract_eval(*args, **kwargs) -> Pytree with same shape and dtype as + func(*args, **kwargs) + + For example, if func always returns a scalar: + + abstract_eval = lambda *args, **kwargs: jnp.array(1.) + + Or if func takes an array of shape(n) and returns a dict of arrays of + shape(n-2): + + abstract_eval = lambda arr, **kwargs: + {"out1": jnp.empty(arr.size-2), "out2": jnp.empty(arr.size-2)} + + Returns + ------- + func : callable + New function that behaves as func but works with jit/vmap/jacfwd etc. + + """ + + def wrap_pure_callback(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result_shape_dtype = abstract_eval(*args, **kwargs) + return jax.pure_callback(func, result_shape_dtype, *args, **kwargs) + + return wrapper + + def define_fd_jvp(func): + func = jax.custom_jvp(func) + + @func.defjvp + def func_jvp(primals, tangents): + primal_out = func(*primals) + + # flatten everything into 1d vectors for easier finite differences + y, unflaty = jax.flatten_util.ravel_pytree(primal_out) + v, unflatv = jax.flatten_util.ravel_pytree( + *tangents + ) # remember that primals/tangets are passed as tuples + x, unflatx = jax.flatten_util.ravel_pytree(*primals) + normv = jnp.linalg.norm(v) + # scale to unit norm if its nonzero + vh = jnp.where(normv == 0, v, v / normv) + + def f(x): + return jax.flatten_util.ravel_pytree(func(unflatx(x)))[0] + + tangent_out = (f(x + self._fd_step * vh) - y) / self._fd_step * normv + tangent_out = unflaty(tangent_out) + + return primal_out, tangent_out + + return func + + return define_fd_jvp(wrap_pure_callback(func)) + class GenericObjective(_Objective): """A generic objective that can compute any quantity from the `data_index`. @@ -58,7 +269,7 @@ class GenericObjective(_Objective): """ - _print_value_fmt = "GenericObjective value: {:10.3e} " + _print_value_fmt = "Generic objective value: {:10.3e} " def __init__( self, @@ -207,7 +418,7 @@ class LinearObjectiveFromUser(_FixedObjective): _linear = True _fixed = True _units = "(Unknown)" - _print_value_fmt = "Custom linear Objective value: {:10.3e}" + _print_value_fmt = "Custom linear objective value: {:10.3e}" def __init__( self, @@ -345,10 +556,9 @@ class ObjectiveFromUser(_Objective): def myfun(grid, data): # This will compute the flux surface average of the function # R*B_T from the Grad-Shafranov equation - f = data['R']*data['B_phi'] + f = data['R'] * data['B_phi'] f_fsa = surface_averages(grid, f, sqrt_g=data['sqrt_g']) - # this has the FSA values on the full grid, but we just want - # the unique values: + # this is the FSA on the full grid, but we only want the unique values: return grid.compress(f_fsa) myobj = ObjectiveFromUser(myfun) @@ -356,7 +566,7 @@ def myfun(grid, data): """ _units = "(Unknown)" - _print_value_fmt = "Custom Objective value: {:10.3e}" + _print_value_fmt = "Custom objective value: {:10.3e}" def __init__( self, diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 0a98bb0a28..1e5f627a09 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -58,6 +58,7 @@ def __init__( def _set_derivatives(self): """Set up derivatives of the objective functions.""" + # TODO: does deriv_mode have to be "blocked" if there is an ExternalObjective? if self._deriv_mode == "auto": if all((obj._deriv_mode == "fwd") for obj in self.objectives): self._deriv_mode = "batched" @@ -90,9 +91,7 @@ def jac_(op, x, constants=None): for obj, const in zip(self.objectives, constants): # get the xs that go to that objective xi = [x for x, t in zip(xs, self.things) if t in obj.things] - Ji_ = getattr(obj, op)( - *xi, constants=const - ) # jac wrt to just those things + Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt only xi Ji = [] # jac wrt all things for thing in self.things: if thing in obj.things: From 6a557ec0e444b28fe3149ddd916813b665e01637 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 20 May 2024 16:19:31 -0400 Subject: [PATCH 02/34] get external objective working --- desc/objectives/_generic.py | 72 +++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 1c666f219e..b386c67931 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -11,6 +11,7 @@ from desc.compute import data_index from desc.compute.utils import get_profiles, get_transforms from desc.grid import QuadratureGrid +from desc.utils import errorif from .linear_objectives import _FixedObjective from .objective_funs import _Objective @@ -27,16 +28,18 @@ class ExternalObjective(_Objective): Similar to ``ObjectiveFromUser``, except derivatives of the objective function are computed with finite differences instead of AD. - The user supplied function should take one positional argument ``params``, which is - a list of the same length as `things` and corresponds to `thing.params_dict` for - each thing in things. + The user supplied function can take several positional arguments that should + correspond to parameter names from ``thing.optimizable_params``, in additional to + other (static) keyword arguments. Parameters ---------- fun : callable Custom objective function. - things : Optimizable or tuple/list of Optimizable - Objects that will be optimized to satisfy the Objective. + dim_f : int + Dimension of the output of fun. + thing : Optimizable + Object that will be optimized to satisfy the Objective. target : {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. @@ -58,11 +61,6 @@ class ExternalObjective(_Objective): Loss function to apply to the objective values once computed. This loss function is called on the raw compute value, before any shifting, scaling, or normalization. - deriv_mode : {"auto", "fwd", "rev"} # TODO: edit this - Specify how to compute jacobian matrix, either forward mode or reverse mode AD. - "auto" selects forward or reverse mode based on the size of the input and output - of the objective. Has no effect on self.grad or self.hess which always use - reverse mode and forward over reverse mode respectively. name : str, optional Name of the objective function. @@ -76,30 +74,33 @@ class ExternalObjective(_Objective): def __init__( self, fun, - things, + dim_f, + thing, target=None, bounds=None, weight=1, normalize=False, normalize_target=False, loss_function=None, - deriv_mode="auto", fd_step=1e-4, # TODO: generalize this to allow a vector of different scales - name="external", # TODO: add kwargs to pass to external function + name="external", + **kwargs, ): if target is None and bounds is None: target = 0 self._fun = fun + self._dim_f = dim_f self._fd_step = fd_step + self._kwargs = kwargs super().__init__( - things=things, + things=thing, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, loss_function=loss_function, - deriv_mode=deriv_mode, + deriv_mode="fwd", name=name, ) @@ -114,17 +115,25 @@ def build(self, use_jit=True, verbose=1): Level of output. """ - self._dim_f = 1 # FIXME: does this need to be a user input? self._scalar = self._dim_f == 1 - self._constants = {"quad_weights": 1.0} - - self._args = self._fun.__code__.co_varnames[: self._fun.__code__.co_argcount] + self._constants = {"quad_weights": 1.0, "kwargs": self._kwargs} + + # positional arguments of the external function + kwargcount = len(self._fun.__defaults__) if self._fun.__defaults__ else 0 + self._args = self._fun.__code__.co_varnames[ + : self._fun.__code__.co_argcount - kwargcount + ] + errorif( + not set(self._args) <= set(self.things[0].optimizable_params), + ValueError, + "Positional arguments of `fun` must be a subset of " + + "`thing.optimizable_params`.", + ) - abstract_eval = lambda *args, **kwargs: jnp.array([1.0]) # FIXME: use dim_f? + # wrap external function to work with JAX + abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) self._fun_wrapped = self._jaxify(self._fun, abstract_eval) - # TODO: test if jaxify is actually working - super().build(use_jit=use_jit, verbose=verbose) def compute(self, params, constants=None): @@ -146,9 +155,12 @@ def compute(self, params, constants=None): """ if constants is None: constants = self.constants + kwargs = constants["kwargs"] + + # sort params into positional args for external function + args = [params[k] for k in self._args] - new_params = {k: params[k] for k in self._args} - f = self._fun_wrapped(**new_params) + f = self._fun_wrapped(*args, **kwargs) return f def _jaxify(self, func, abstract_eval): @@ -203,18 +215,16 @@ def define_fd_jvp(func): def func_jvp(primals, tangents): primal_out = func(*primals) - # flatten everything into 1d vectors for easier finite differences + # flatten everything into 1D vectors for easier finite differences y, unflaty = jax.flatten_util.ravel_pytree(primal_out) - v, unflatv = jax.flatten_util.ravel_pytree( - *tangents - ) # remember that primals/tangets are passed as tuples - x, unflatx = jax.flatten_util.ravel_pytree(*primals) + x, unflatx = jax.flatten_util.ravel_pytree(primals) + v, _______ = jax.flatten_util.ravel_pytree(tangents) + # scale to unit norm if nonzero normv = jnp.linalg.norm(v) - # scale to unit norm if its nonzero vh = jnp.where(normv == 0, v, v / normv) def f(x): - return jax.flatten_util.ravel_pytree(func(unflatx(x)))[0] + return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0] tangent_out = (f(x + self._fd_step * vh) - y) / self._fd_step * normv tangent_out = unflaty(tangent_out) From fc9ef77db40e011b0528d221de9f533084fd7d9f Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 20 May 2024 17:15:53 -0400 Subject: [PATCH 03/34] test comparison to generic --- desc/objectives/_generic.py | 6 +--- tests/test_examples.py | 63 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index b386c67931..f1679a4421 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -16,12 +16,8 @@ from .linear_objectives import _FixedObjective from .objective_funs import _Objective -# Notes on deriv_mode: -# batched = Derivative class calls jax.jacfwd on ObjectiveFunction.compute -# looped = Derivative class calls jax.jvp on ObjectiveFunction.compute for each column -# blocked = block matrix of _Objective.jac for each objective in ObjectiveFunction - +# TODO: add SPSA option class ExternalObjective(_Objective): """Wrap an external code. diff --git a/tests/test_examples.py b/tests/test_examples.py index adad7d7104..0e7b1d39dd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -30,6 +30,7 @@ CoilLength, CoilTorsion, CurrentDensity, + ExternalObjective, FixBoundaryR, FixBoundaryZ, FixCurrent, @@ -1294,3 +1295,65 @@ def test_second_stage_optimization(): np.testing.assert_allclose(field[0].R0, 3.5) # this value was fixed np.testing.assert_allclose(field[0].B0, 1) # toroidal field (no change) np.testing.assert_allclose(field[1].B0, 0, atol=1e-12) # vertical field (vanishes) + + +@pytest.mark.unit +def test_external_vs_generic_objectives(): + """Test ExternalObjective compared to GenericObjective.""" + fname = "R" + target = [4.2, 5.3] + rho = np.array([0, 1]) + + # TODO: need to add profiles and spectral_indexing + + def extfun(Psi, R_lmn, Z_lmn, L_lmn, L=1, M=1, N=0, NFP=1, sym=False): + eq = Equilibrium( + Psi=float(Psi[0]), + R_lmn=R_lmn, + Z_lmn=Z_lmn, + L_lmn=L_lmn, + L=L, + M=M, + N=N, + NFP=NFP, + sym=sym, + spectral_indexing="fringe", + ) + grid = LinearGrid(rho=rho) + data = eq.compute(fname, grid=grid) + return np.atleast_1d(data[fname]) + + eq0 = get("SOLOVEV") + optimizer = Optimizer("lsq-exact") + grid = LinearGrid(rho=rho) + grid._weights = np.ones_like(grid.weights) + + # generic + objective = ObjectiveFunction( + GenericObjective("R", eq=eq0, target=target, grid=grid) + ) + constraints = FixParameters( + eq0, {"Psi": True, "Z_lmn": True, "L_lmn": True, "p_l": True, "c_l": True} + ) + [eq_generic], _ = optimizer.optimize( + things=eq0, + objective=objective, + constraints=constraints, + copy=True, + ftol=0, + verbose=2, + ) + + # external + kwargs = {"L": eq0.L, "M": eq0.M, "N": eq0.N, "NFP": eq0.NFP, "sym": eq0.sym} + objective = ObjectiveFunction( + ExternalObjective(extfun, len(target), eq0, target=target, **kwargs) + ) + constraints = FixParameters( + eq0, {"Psi": True, "Z_lmn": True, "L_lmn": True, "p_l": True, "c_l": True} + ) + [eq_external], _ = optimizer.optimize( + things=eq0, objective=objective, constraints=constraints, copy=True, verbose=2 + ) + + np.testing.assert_allclose(eq_generic.R_lmn, eq_external.R_lmn) From 9a64f253c7194188feb6fc7a5aec2c9ce7de0099 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 21 May 2024 14:34:45 -0400 Subject: [PATCH 04/34] allow string kwargs in external fun --- desc/objectives/_generic.py | 27 ++++++++++++++------------- tests/test_examples.py | 26 +++++++++++++++++++++----- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index f1679a4421..5815742661 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -24,9 +24,10 @@ class ExternalObjective(_Objective): Similar to ``ObjectiveFromUser``, except derivatives of the objective function are computed with finite differences instead of AD. - The user supplied function can take several positional arguments that should - correspond to parameter names from ``thing.optimizable_params``, in additional to - other (static) keyword arguments. + The user supplied function can take positional arguments that must correspond to + parameter names from ``thing.optimizable_params`` and expect JAX arrays as types. + The function can also take additional keyword arguments, which can be of any names + and data types since they are treated as static and not differentiable. Parameters ---------- @@ -112,7 +113,7 @@ def build(self, use_jit=True, verbose=1): """ self._scalar = self._dim_f == 1 - self._constants = {"quad_weights": 1.0, "kwargs": self._kwargs} + self._constants = {"quad_weights": 1.0} # positional arguments of the external function kwargcount = len(self._fun.__defaults__) if self._fun.__defaults__ else 0 @@ -126,9 +127,15 @@ def build(self, use_jit=True, verbose=1): + "`thing.optimizable_params`.", ) + # wrap keyword arguments, which may not be JAX compatable data types + no_kwargs_fun = lambda *args, **kwargs: self._fun(*args, **self._kwargs) + sig = inspect.signature(self._fun) + params = [p for p in sig.parameters.values() if p.default == p.empty] + no_kwargs_fun.__signature__ = sig.replace(parameters=params) + # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) - self._fun_wrapped = self._jaxify(self._fun, abstract_eval) + self._fun_wrapped = self._jaxify(no_kwargs_fun, abstract_eval) super().build(use_jit=use_jit, verbose=verbose) @@ -149,14 +156,8 @@ def compute(self, params, constants=None): Computed quantity. """ - if constants is None: - constants = self.constants - kwargs = constants["kwargs"] - - # sort params into positional args for external function - args = [params[k] for k in self._args] - - f = self._fun_wrapped(*args, **kwargs) + args = [params[k] for k in self._args] # sort positional args to external order + f = self._fun_wrapped(*args) return f def _jaxify(self, func, abstract_eval): diff --git a/tests/test_examples.py b/tests/test_examples.py index 0e7b1d39dd..f32ecf844f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1304,9 +1304,18 @@ def test_external_vs_generic_objectives(): target = [4.2, 5.3] rho = np.array([0, 1]) - # TODO: need to add profiles and spectral_indexing - - def extfun(Psi, R_lmn, Z_lmn, L_lmn, L=1, M=1, N=0, NFP=1, sym=False): + def extfun( + Psi, + R_lmn, + Z_lmn, + L_lmn, + L=1, + M=1, + N=0, + NFP=1, + sym=False, + spectral_indexing="ansi", + ): eq = Equilibrium( Psi=float(Psi[0]), R_lmn=R_lmn, @@ -1317,7 +1326,7 @@ def extfun(Psi, R_lmn, Z_lmn, L_lmn, L=1, M=1, N=0, NFP=1, sym=False): N=N, NFP=NFP, sym=sym, - spectral_indexing="fringe", + spectral_indexing=spectral_indexing, ) grid = LinearGrid(rho=rho) data = eq.compute(fname, grid=grid) @@ -1345,7 +1354,14 @@ def extfun(Psi, R_lmn, Z_lmn, L_lmn, L=1, M=1, N=0, NFP=1, sym=False): ) # external - kwargs = {"L": eq0.L, "M": eq0.M, "N": eq0.N, "NFP": eq0.NFP, "sym": eq0.sym} + kwargs = { + "L": eq0.L, + "M": eq0.M, + "N": eq0.N, + "NFP": eq0.NFP, + "sym": eq0.sym, + "spectral_indexing": eq0.spectral_indexing, + } objective = ObjectiveFunction( ExternalObjective(extfun, len(target), eq0, target=target, **kwargs) ) From 11c143823a9cab9aea19e79f8d2281e4f8100cae Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 21 May 2024 15:00:49 -0400 Subject: [PATCH 05/34] exclude ExternalObjective from tests --- tests/test_objective_funs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 9dd4ec5fee..1fcbc95291 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -39,6 +39,7 @@ CoilTorsion, Elongation, Energy, + ExternalObjective, ForceBalance, ForceBalanceAnisotropic, GenericObjective, @@ -1754,6 +1755,7 @@ class TestComputeScalarResolution: # need to avoid blowup near the axis MercierStability, # don't test these since they depend on what user wants + ExternalObjective, LinearObjectiveFromUser, ObjectiveFromUser, ] @@ -2075,6 +2077,7 @@ class TestObjectiveNaNGrad: QuadraticFlux, ToroidalFlux, # we don't test these since they depend too much on what exactly the user wants + ExternalObjective, GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser, From aff7d46a9323c001521d111a61e0d05185b7cef3 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 23 May 2024 18:09:04 -0400 Subject: [PATCH 06/34] make external fun take eq as its argument --- desc/objectives/_generic.py | 72 +++++++++++++++------------ tests/test_examples.py | 98 +++++++++++++++++++------------------ 2 files changed, 92 insertions(+), 78 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 9bcbe85fa0..f39521a204 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -7,37 +7,32 @@ import numpy as np from desc.backend import jax, jnp -from desc.compute import compute as compute_fun from desc.compute import data_index from desc.compute.utils import _compute as compute_fun from desc.compute.utils import get_profiles, get_transforms from desc.grid import QuadratureGrid -from desc.utils import errorif from .linear_objectives import _FixedObjective from .objective_funs import _Objective -# TODO: add SPSA option class ExternalObjective(_Objective): """Wrap an external code. Similar to ``ObjectiveFromUser``, except derivatives of the objective function are computed with finite differences instead of AD. - The user supplied function can take positional arguments that must correspond to - parameter names from ``thing.optimizable_params`` and expect JAX arrays as types. - The function can also take additional keyword arguments, which can be of any names - and data types since they are treated as static and not differentiable. + The user supplied function must take an Equilibrium as its only positional argument, + but can take additional keyword arguments. Parameters ---------- + eq : Equilibrium + Equilibrium that will be optimized to satisfy the Objective. fun : callable Custom objective function. dim_f : int Dimension of the output of fun. - thing : Optimizable - Object that will be optimized to satisfy the Objective. target : {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. @@ -71,9 +66,9 @@ class ExternalObjective(_Objective): def __init__( self, + eq, fun, dim_f, - thing, target=None, bounds=None, weight=1, @@ -86,12 +81,13 @@ def __init__( ): if target is None and bounds is None: target = 0 + self._eq = eq.copy() self._fun = fun self._dim_f = dim_f self._fd_step = fd_step self._kwargs = kwargs super().__init__( - things=thing, + things=eq, target=target, bounds=bounds, weight=weight, @@ -116,27 +112,42 @@ def build(self, use_jit=True, verbose=1): self._scalar = self._dim_f == 1 self._constants = {"quad_weights": 1.0} - # positional arguments of the external function - kwargcount = len(self._fun.__defaults__) if self._fun.__defaults__ else 0 - self._args = self._fun.__code__.co_varnames[ - : self._fun.__code__.co_argcount - kwargcount - ] - errorif( - not set(self._args) <= set(self.things[0].optimizable_params), - ValueError, - "Positional arguments of `fun` must be a subset of " - + "`thing.optimizable_params`.", - ) - - # wrap keyword arguments, which may not be JAX compatable data types - no_kwargs_fun = lambda *args, **kwargs: self._fun(*args, **self._kwargs) - sig = inspect.signature(self._fun) - params = [p for p in sig.parameters.values() if p.default == p.empty] - no_kwargs_fun.__signature__ = sig.replace(parameters=params) + def fun_wrapped( + R_lmn, + Z_lmn, + L_lmn, + p_l, + i_l, + c_l, + Psi, + Te_l, + ne_l, + Ti_l, + Zeff_l, + a_lmn, + Ra_n, + Za_n, + Rb_lmn, + Zb_lmn, + I, + G, + Phi_mn, + ): + """Wrap external function with optimiazable params arguments.""" + for param in self._eq.optimizable_params: + par = eval(param) # FIXME: how bad is it to use eval here? + if len(par): + setattr(self._eq, param, par) + return self._fun(self._eq, **self._kwargs) + + # check to make sure fun_wrapped has the correct signature + # in case we ever update Equilibrium.optimizable_params + args = inspect.getfullargspec(fun_wrapped).args + assert args == self._eq.optimizable_params # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) - self._fun_wrapped = self._jaxify(no_kwargs_fun, abstract_eval) + self._fun_wrapped = self._jaxify(fun_wrapped, abstract_eval) super().build(use_jit=use_jit, verbose=verbose) @@ -157,7 +168,8 @@ def compute(self, params, constants=None): Computed quantity. """ - args = [params[k] for k in self._args] # sort positional args to external order + # ensure positional args are passed in the correct order + args = [params[k] for k in self._eq.optimizable_params] f = self._fun_wrapped(*args) return f diff --git a/tests/test_examples.py b/tests/test_examples.py index 6d5eee4ecf..5ee219d2fb 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -6,6 +6,7 @@ import numpy as np import pytest +from netCDF4 import Dataset from qic import Qic from qsc import Qsc @@ -60,6 +61,7 @@ ) from desc.optimize import Optimizer from desc.profiles import FourierZernikeProfile, PowerSeriesProfile +from desc.vmec import VMECIO from desc.vmec_utils import vmec_boundary_subspace from .utils import area_difference_desc, area_difference_vmec @@ -1317,51 +1319,42 @@ def test_optimize_with_fourier_planar_coil(): @pytest.mark.unit -def test_external_vs_generic_objectives(): +def test_external_vs_generic_objectives(tmpdir_factory): """Test ExternalObjective compared to GenericObjective.""" - fname = "R" - target = [4.2, 5.3] - rho = np.array([0, 1]) - - def extfun( - Psi, - R_lmn, - Z_lmn, - L_lmn, - L=1, - M=1, - N=0, - NFP=1, - sym=False, - spectral_indexing="ansi", - ): - eq = Equilibrium( - Psi=float(Psi[0]), - R_lmn=R_lmn, - Z_lmn=Z_lmn, - L_lmn=L_lmn, - L=L, - M=M, - N=N, - NFP=NFP, - sym=sym, - spectral_indexing=spectral_indexing, - ) - grid = LinearGrid(rho=rho) - data = eq.compute(fname, grid=grid) - return np.atleast_1d(data[fname]) + target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2] + + def data_from_vmec(eq, path=""): + VMECIO.save(eq, path, surfs=8, verbose=0) + file = Dataset(path, mode="r") + betatot = float(file.variables["betatotal"][0]) + betapol = float(file.variables["betapol"][0]) + betator = float(file.variables["betator"][0]) + presf1 = float(file.variables["presf"][-1]) + file.close() + return np.atleast_1d([betatot, betapol, betator, presf1]) eq0 = get("SOLOVEV") optimizer = Optimizer("lsq-exact") - grid = LinearGrid(rho=rho) - grid._weights = np.ones_like(grid.weights) # generic objective = ObjectiveFunction( - GenericObjective("R", eq=eq0, target=target, grid=grid) + ( + GenericObjective("_vol", eq=eq0, target=target[0]), + GenericObjective("_vol", eq=eq0, target=target[1]), + GenericObjective("_vol", eq=eq0, target=target[2]), + GenericObjective("p", eq=eq0, target=0, grid=LinearGrid(rho=[1], M=0, N=0)), + ) ) constraints = FixParameters( - eq0, {"Psi": True, "Z_lmn": True, "L_lmn": True, "p_l": True, "c_l": True} + eq0, + { + "R_lmn": True, + "Z_lmn": True, + "L_lmn": True, + "p_l": np.arange(2, len(eq0.p_l)), + "i_l": True, + "Psi": True, + }, ) [eq_generic], _ = optimizer.optimize( things=eq0, @@ -1373,22 +1366,31 @@ def extfun( ) # external - kwargs = { - "L": eq0.L, - "M": eq0.M, - "N": eq0.N, - "NFP": eq0.NFP, - "sym": eq0.sym, - "spectral_indexing": eq0.spectral_indexing, - } + dir = tmpdir_factory.mktemp("results") + path = dir.join("wout_result.nc") objective = ObjectiveFunction( - ExternalObjective(extfun, len(target), eq0, target=target, **kwargs) + ExternalObjective(eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path) ) constraints = FixParameters( - eq0, {"Psi": True, "Z_lmn": True, "L_lmn": True, "p_l": True, "c_l": True} + eq0, + { + "R_lmn": True, + "Z_lmn": True, + "L_lmn": True, + "p_l": np.arange(2, len(eq0.p_l)), + "i_l": True, + "Psi": True, + }, ) [eq_external], _ = optimizer.optimize( - things=eq0, objective=objective, constraints=constraints, copy=True, verbose=2 + things=eq0, + objective=objective, + constraints=constraints, + copy=True, + ftol=0, + verbose=2, ) - np.testing.assert_allclose(eq_generic.R_lmn, eq_external.R_lmn) + np.testing.assert_allclose(eq_generic.p_l, eq_external.p_l) + np.testing.assert_allclose(eq_generic.p_l[:2], [2e2, -2e2], rtol=4e-2) + np.testing.assert_allclose(eq_external.p_l[:2], [2e2, -2e2], rtol=4e-2) From 2bb90173c52f9093386c647a49b9aff7b7f14e72 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 31 May 2024 15:43:17 -0400 Subject: [PATCH 07/34] simplify wrapped fun to take params --- desc/objectives/_generic.py | 41 +++++++------------------------------ 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index f39521a204..00159e1ebc 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -112,39 +112,14 @@ def build(self, use_jit=True, verbose=1): self._scalar = self._dim_f == 1 self._constants = {"quad_weights": 1.0} - def fun_wrapped( - R_lmn, - Z_lmn, - L_lmn, - p_l, - i_l, - c_l, - Psi, - Te_l, - ne_l, - Ti_l, - Zeff_l, - a_lmn, - Ra_n, - Za_n, - Rb_lmn, - Zb_lmn, - I, - G, - Phi_mn, - ): - """Wrap external function with optimiazable params arguments.""" - for param in self._eq.optimizable_params: - par = eval(param) # FIXME: how bad is it to use eval here? - if len(par): - setattr(self._eq, param, par) + def fun_wrapped(params): + """Wrap external function with optimizable params arguments.""" + for param_key in self._eq.optimizable_params: + param_value = params[param_key] + if len(param_value): + setattr(self._eq, param_key, param_value) return self._fun(self._eq, **self._kwargs) - # check to make sure fun_wrapped has the correct signature - # in case we ever update Equilibrium.optimizable_params - args = inspect.getfullargspec(fun_wrapped).args - assert args == self._eq.optimizable_params - # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) self._fun_wrapped = self._jaxify(fun_wrapped, abstract_eval) @@ -168,9 +143,7 @@ def compute(self, params, constants=None): Computed quantity. """ - # ensure positional args are passed in the correct order - args = [params[k] for k in self._eq.optimizable_params] - f = self._fun_wrapped(*args) + f = self._fun_wrapped(params) return f def _jaxify(self, func, abstract_eval): From debecad3624a0f8e64edd763e7218bfc4b3323c5 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 4 Jun 2024 18:02:48 -0600 Subject: [PATCH 08/34] numpifying to make vectorization work --- desc/backend.py | 5 ++-- desc/basis.py | 10 ++++---- desc/equilibrium/equilibrium.py | 21 ++++++++++------- desc/geometry/core.py | 17 +++++++------- desc/geometry/curve.py | 36 ++++++++++++++--------------- desc/geometry/surface.py | 8 +++---- desc/objectives/_generic.py | 41 ++++++++++++++++++++++++++++----- desc/profiles.py | 24 +++++++++---------- tests/test_examples.py | 1 + 9 files changed, 99 insertions(+), 64 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 77cf6f090d..240db14afd 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -126,8 +126,9 @@ def sign(x): 1 where x>=0, -1 where x<0 """ - x = jnp.asarray(x) - y = jnp.where(x == 0, 1, jnp.sign(x)) + # FIXME: when this is jnp, Basis with sym is a JAX object for some reason + x = np.asarray(x) + y = np.where(x == 0, 1, np.sign(x)) return y @jit diff --git a/desc/basis.py b/desc/basis.py index 9ba700f499..2c8099e5b2 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -290,7 +290,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -404,7 +404,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[0] != 0) or (derivatives[1] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -535,7 +535,7 @@ def evaluate( if modes is None: modes = self.modes if derivatives[0] != 0: - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -742,7 +742,7 @@ def evaluate( if modes is None: modes = self.modes if derivatives[2] != 0: - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -1241,7 +1241,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index b32ae3b0fa..c213d0115d 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -10,7 +10,6 @@ from scipy.constants import mu_0 from termcolor import colored -from desc.backend import jnp from desc.basis import FourierZernikeBasis, fourier, zernike_radial from desc.compat import ensure_positive_jacobian from desc.compute import compute as compute_fun @@ -376,7 +375,7 @@ def __init__( assert ("R_lmn" in kwargs) and ("Z_lmn" in kwargs), "Must give both R and Z" self.R_lmn = kwargs.pop("R_lmn") self.Z_lmn = kwargs.pop("Z_lmn") - self.L_lmn = kwargs.pop("L_lmn", jnp.zeros(self.L_basis.num_modes)) + self.L_lmn = kwargs.pop("L_lmn", np.zeros(self.L_basis.num_modes)) else: self.set_initial_guess(ensure_nested=ensure_nested) if check_orientation: @@ -600,9 +599,15 @@ def change_resolution( ) self.axis.change_resolution(self.N, NFP=self.NFP, sym=self.sym) - self._R_lmn = copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes) - self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) - self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) + self._R_lmn = np.asarray( + copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes) + ) + self._Z_lmn = np.asarray( + copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) + ) + self._L_lmn = np.asarray( + copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) + ) def get_surface_at(self, rho=None, theta=None, zeta=None): """Return a representation for a given coordinate surface. @@ -1265,7 +1270,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, R_lmn): - R_lmn = jnp.atleast_1d(jnp.asarray(R_lmn)) + R_lmn = np.atleast_1d(np.asarray(R_lmn)) errorif( R_lmn.size != self._R_lmn.size, ValueError, @@ -1282,7 +1287,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, Z_lmn): - Z_lmn = jnp.atleast_1d(jnp.asarray(Z_lmn)) + Z_lmn = np.atleast_1d(np.asarray(Z_lmn)) errorif( Z_lmn.size != self._Z_lmn.size, ValueError, @@ -1299,7 +1304,7 @@ def L_lmn(self): @L_lmn.setter def L_lmn(self, L_lmn): - L_lmn = jnp.atleast_1d(jnp.asarray(L_lmn)) + L_lmn = np.atleast_1d(np.asarray(L_lmn)) errorif( L_lmn.size != self._L_lmn.size, ValueError, diff --git a/desc/geometry/core.py b/desc/geometry/core.py index a637ff529e..e2cb8417c3 100644 --- a/desc/geometry/core.py +++ b/desc/geometry/core.py @@ -5,7 +5,6 @@ import numpy as np -from desc.backend import jnp from desc.compute import compute as compute_fun from desc.compute import data_index from desc.compute.geom_utils import reflection_matrix, rotation_matrix @@ -26,8 +25,8 @@ class Curve(IOAble, Optimizable, ABC): _io_attrs_ = ["_name", "_shift", "_rotmat"] def __init__(self, name=""): - self._shift = jnp.array([0, 0, 0], dtype=float) - self._rotmat = jnp.eye(3, dtype=float).flatten() + self._shift = np.array([0, 0, 0], dtype=float) + self._rotmat = np.eye(3, dtype=float).flatten() self._name = name def _set_up(self): @@ -43,12 +42,12 @@ def _set_up(self): @property def shift(self): """Displacement of curve in X, Y, Z.""" - return self.__dict__.setdefault("_shift", jnp.array([0, 0, 0], dtype=float)) + return self.__dict__.setdefault("_shift", np.array([0, 0, 0], dtype=float)) @shift.setter def shift(self, new): if len(new) == 3: - self._shift = jnp.asarray(new) + self._shift = np.asarray(new) else: raise ValueError("shift should be a 3 element vector, got {}".format(new)) @@ -56,14 +55,14 @@ def shift(self, new): @property def rotmat(self): """Rotation matrix of curve in X, Y, Z.""" - return self.__dict__.setdefault("_rotmat", jnp.eye(3, dtype=float).flatten()) + return self.__dict__.setdefault("_rotmat", np.eye(3, dtype=float).flatten()) @rotmat.setter def rotmat(self, new): if len(new) == 9: - self._rotmat = jnp.asarray(new) + self._rotmat = np.asarray(new) else: - self._rotmat = jnp.asarray(new.flatten()) + self._rotmat = np.asarray(new.flatten()) @property def name(self): @@ -179,7 +178,7 @@ def compute( def translate(self, displacement=[0, 0, 0]): """Translate the curve by a rigid displacement in X,Y,Z coordinates.""" - self.shift = self.shift + jnp.asarray(displacement) + self.shift = self.shift + np.asarray(displacement) def rotate(self, axis=[0, 0, 1], angle=0): """Rotate the curve by a fixed angle about axis in X,Y,Z coordinates.""" diff --git a/desc/geometry/curve.py b/desc/geometry/curve.py index 3f670656e7..bd087dda61 100644 --- a/desc/geometry/curve.py +++ b/desc/geometry/curve.py @@ -2,7 +2,7 @@ import numpy as np -from desc.backend import jnp, put +from desc.backend import put from desc.basis import FourierSeries from desc.compute import rpz2xyz, xyz2rpz from desc.grid import LinearGrid @@ -86,11 +86,11 @@ def __init__( NZ = np.max(abs(modes_Z)) N = max(NR, NZ) self._NFP = check_posint(NFP, "NFP", False) - self._R_basis = FourierSeries(N, int(NFP), sym="cos" if sym else False) - self._Z_basis = FourierSeries(N, int(NFP), sym="sin" if sym else False) + self._R_basis = FourierSeries(int(N), int(NFP), sym="cos" if sym else False) + self._Z_basis = FourierSeries(int(N), int(NFP), sym="sin" if sym else False) - self._R_n = copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2]) - self._Z_n = copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2]) + self._R_n = np.array(copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2])) + self._Z_n = np.array(copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2])) @property def sym(self): @@ -138,8 +138,8 @@ def change_resolution(self, N=None, NFP=None, sym=None): self.Z_basis.change_resolution( N=N, NFP=self.NFP, sym="sin" if self.sym else self.sym ) - self.R_n = copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes) - self.Z_n = copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes) + self.R_n = np.array(copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes)) + self.Z_n = np.array(copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes)) def get_coeffs(self, n): """Get Fourier coefficients for given mode number(s).""" @@ -176,7 +176,7 @@ def R_n(self): @R_n.setter def R_n(self, new): if len(new) == self.R_basis.num_modes: - self._R_n = jnp.asarray(new) + self._R_n = np.asarray(new) else: raise ValueError( f"R_n should have the same size as the basis, got {len(new)} for " @@ -192,7 +192,7 @@ def Z_n(self): @Z_n.setter def Z_n(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_n = jnp.asarray(new) + self._Z_n = np.asarray(new) else: raise ValueError( f"Z_n should have the same size as the basis, got {len(new)} for " @@ -439,7 +439,7 @@ def X_n(self): @X_n.setter def X_n(self, new): if len(new) == self.X_basis.num_modes: - self._X_n = jnp.asarray(new) + self._X_n = np.asarray(new) else: raise ValueError( f"X_n should have the same size as the basis, got {len(new)} for " @@ -455,7 +455,7 @@ def Y_n(self): @Y_n.setter def Y_n(self, new): if len(new) == self.Y_basis.num_modes: - self._Y_n = jnp.asarray(new) + self._Y_n = np.asarray(new) else: raise ValueError( f"Y_n should have the same size as the basis, got {len(new)} for " @@ -471,7 +471,7 @@ def Z_n(self): @Z_n.setter def Z_n(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_n = jnp.asarray(new) + self._Z_n = np.asarray(new) else: raise ValueError( f"Z_n should have the same size as the basis, got {len(new)} for " @@ -663,7 +663,7 @@ def r_n(self): @r_n.setter def r_n(self, new): if len(np.asarray(new)) == self.r_basis.num_modes: - self._r_n = jnp.asarray(new) + self._r_n = np.asarray(new) else: raise ValueError( f"r_n should have the same size as the basis, got {len(new)} for " @@ -829,7 +829,7 @@ def X(self): @X.setter def X(self, new): if len(new) == len(self.knots): - self._X = jnp.asarray(new) + self._X = np.asarray(new) else: raise ValueError( "X should have the same size as the knots, " @@ -845,7 +845,7 @@ def Y(self): @Y.setter def Y(self, new): if len(new) == len(self.knots): - self._Y = jnp.asarray(new) + self._Y = np.asarray(new) else: raise ValueError( "Y should have the same size as the knots, " @@ -861,7 +861,7 @@ def Z(self): @Z.setter def Z(self, new): if len(new) == len(self.knots): - self._Z = jnp.asarray(new) + self._Z = np.asarray(new) else: raise ValueError( "Z should have the same size as the knots, " @@ -876,7 +876,7 @@ def knots(self): @knots.setter def knots(self, new): if len(new) == len(self.knots): - knots = jnp.atleast_1d(jnp.asarray(new)) + knots = np.atleast_1d(np.asarray(new)) errorif( not np.all(np.diff(knots) > 0), ValueError, @@ -884,7 +884,7 @@ def knots(self, new): ) errorif(knots[0] < 0, ValueError, "knots must lie in [0, 2pi]") errorif(knots[-1] > 2 * np.pi, ValueError, "knots must lie in [0, 2pi]") - self._knots = jnp.asarray(knots) + self._knots = np.asarray(knots) else: raise ValueError( "new knots should have the same size as the current knots, " diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 3a4cf0c136..0bbd750a93 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -225,7 +225,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, new): if len(new) == self.R_basis.num_modes: - self._R_lmn = jnp.asarray(new) + self._R_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"R_lmn should have the same size as the basis, got {len(new)} for " @@ -241,7 +241,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_lmn = jnp.asarray(new) + self._Z_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"Z_lmn should have the same size as the basis, got {len(new)} for " @@ -963,7 +963,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, new): if len(new) == self.R_basis.num_modes: - self._R_lmn = jnp.asarray(new) + self._R_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"R_lmn should have the same size as the basis, got {len(new)} for " @@ -979,7 +979,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_lmn = jnp.asarray(new) + self._Z_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"Z_lmn should have the same size as the basis, got {len(new)} for " diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 00159e1ebc..f140781148 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,7 +2,9 @@ import functools import inspect +import multiprocessing import re +import warnings import numpy as np @@ -76,6 +78,7 @@ def __init__( normalize_target=False, loss_function=None, fd_step=1e-4, # TODO: generalize this to allow a vector of different scales + vectorized=False, name="external", **kwargs, ): @@ -85,6 +88,7 @@ def __init__( self._fun = fun self._dim_f = dim_f self._fd_step = fd_step + self._vectorized = vectorized self._kwargs = kwargs super().__init__( things=eq, @@ -114,11 +118,30 @@ def build(self, use_jit=True, verbose=1): def fun_wrapped(params): """Wrap external function with optimizable params arguments.""" - for param_key in self._eq.optimizable_params: - param_value = params[param_key] - if len(param_value): - setattr(self._eq, param_key, param_value) - return self._fun(self._eq, **self._kwargs) + param_shape = params["Psi"].shape + num_eq = param_shape[0] if len(param_shape) > 1 else 1 + if self._vectorized and num_eq > 1: + # convert params to list of Equilibria + eqs = [self._eq.copy() for _ in range(num_eq)] + for k, eq in enumerate(eqs): + for param_key in self._eq.optimizable_params: + param_value = np.array(params[param_key][k, :]) + if len(param_value): + setattr(eq, param_key, param_value) + # parallelize calls to external function + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with multiprocessing.Pool() as pool: + results = pool.map( + functools.partial(self._fun, **self._kwargs), eqs + ) + return jnp.vstack(results) + else: + # update Equilibrium with params + for param_key in self._eq.optimizable_params: + param_value = params[param_key] + if len(param_value): + setattr(self._eq, param_key, param_value) + return self._fun(self._eq, **self._kwargs) # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) @@ -187,7 +210,13 @@ def wrap_pure_callback(func): @functools.wraps(func) def wrapper(*args, **kwargs): result_shape_dtype = abstract_eval(*args, **kwargs) - return jax.pure_callback(func, result_shape_dtype, *args, **kwargs) + return jax.pure_callback( + func, + result_shape_dtype, + *args, + vectorized=self._vectorized, + **kwargs, + ) return wrapper diff --git a/desc/profiles.py b/desc/profiles.py index 2190359501..4223f8283f 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -277,7 +277,7 @@ def __init__(self, scale, profile, **kwargs): @property def params(self): """ndarray: Parameters for computation [scale, profile.params].""" - return jnp.concatenate([jnp.atleast_1d(self._scale), self._profile.params]) + return np.concatenate([np.atleast_1d(self._scale), self._profile.params]) @params.setter def params(self, x): @@ -364,7 +364,7 @@ def __init__(self, *profiles, **kwargs): @property def params(self): """ndarray: Concatenated array of parameters for computation.""" - return jnp.concatenate([profile.params for profile in self._profiles]) + return np.concatenate([profile.params for profile in self._profiles]) @params.setter def params(self, x): @@ -451,7 +451,7 @@ def __init__(self, *profiles, **kwargs): @property def params(self): """ndarray: Concatenated array of parameters for computation.""" - return jnp.concatenate([profile.params for profile in self._profiles]) + return np.concatenate([profile.params for profile in self._profiles]) @params.setter def params(self, x): @@ -591,9 +591,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size == self._basis.num_modes: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( "params should have the same size as the basis, " @@ -745,9 +745,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size == 3: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError(f"params should be an array of size 3, got {len(new)}.") @@ -849,7 +849,7 @@ def params(self): @params.setter def params(self, new): if len(new) == len(self._knots): - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( "params should have the same size as the knots, " @@ -932,9 +932,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size >= 5: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( "params should have at least 5 elements [ped, offset, sym, width," @@ -1199,9 +1199,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size == self._basis.num_modes: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( f"params should have the same size as the basis, got {new.size} " diff --git a/tests/test_examples.py b/tests/test_examples.py index 5ee219d2fb..23b5c7e899 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1319,6 +1319,7 @@ def test_optimize_with_fourier_planar_coil(): @pytest.mark.unit +@pytest.mark.slow def test_external_vs_generic_objectives(tmpdir_factory): """Test ExternalObjective compared to GenericObjective.""" target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2] From 9ea37fbe5c47e0b88e14d854fca836b7f591f010 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 4 Jun 2024 22:52:48 -0600 Subject: [PATCH 09/34] Revert "numpifying to make vectorization work" This reverts commit debecad3624a0f8e64edd763e7218bfc4b3323c5. --- desc/backend.py | 5 ++-- desc/basis.py | 10 ++++---- desc/equilibrium/equilibrium.py | 21 +++++++---------- desc/geometry/core.py | 17 +++++++------- desc/geometry/curve.py | 36 ++++++++++++++--------------- desc/geometry/surface.py | 8 +++---- desc/objectives/_generic.py | 41 +++++---------------------------- desc/profiles.py | 24 +++++++++---------- tests/test_examples.py | 1 - 9 files changed, 64 insertions(+), 99 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 240db14afd..77cf6f090d 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -126,9 +126,8 @@ def sign(x): 1 where x>=0, -1 where x<0 """ - # FIXME: when this is jnp, Basis with sym is a JAX object for some reason - x = np.asarray(x) - y = np.where(x == 0, 1, np.sign(x)) + x = jnp.asarray(x) + y = jnp.where(x == 0, 1, jnp.sign(x)) return y @jit diff --git a/desc/basis.py b/desc/basis.py index 2c8099e5b2..9ba700f499 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -290,7 +290,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[1] != 0) or (derivatives[2] != 0): - return np.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -404,7 +404,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[0] != 0) or (derivatives[1] != 0): - return np.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -535,7 +535,7 @@ def evaluate( if modes is None: modes = self.modes if derivatives[0] != 0: - return np.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -742,7 +742,7 @@ def evaluate( if modes is None: modes = self.modes if derivatives[2] != 0: - return np.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -1241,7 +1241,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[1] != 0) or (derivatives[2] != 0): - return np.zeros((nodes.shape[0], modes.shape[0])) + return jnp.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index c213d0115d..b32ae3b0fa 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -10,6 +10,7 @@ from scipy.constants import mu_0 from termcolor import colored +from desc.backend import jnp from desc.basis import FourierZernikeBasis, fourier, zernike_radial from desc.compat import ensure_positive_jacobian from desc.compute import compute as compute_fun @@ -375,7 +376,7 @@ def __init__( assert ("R_lmn" in kwargs) and ("Z_lmn" in kwargs), "Must give both R and Z" self.R_lmn = kwargs.pop("R_lmn") self.Z_lmn = kwargs.pop("Z_lmn") - self.L_lmn = kwargs.pop("L_lmn", np.zeros(self.L_basis.num_modes)) + self.L_lmn = kwargs.pop("L_lmn", jnp.zeros(self.L_basis.num_modes)) else: self.set_initial_guess(ensure_nested=ensure_nested) if check_orientation: @@ -599,15 +600,9 @@ def change_resolution( ) self.axis.change_resolution(self.N, NFP=self.NFP, sym=self.sym) - self._R_lmn = np.asarray( - copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes) - ) - self._Z_lmn = np.asarray( - copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) - ) - self._L_lmn = np.asarray( - copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) - ) + self._R_lmn = copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes) + self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) + self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) def get_surface_at(self, rho=None, theta=None, zeta=None): """Return a representation for a given coordinate surface. @@ -1270,7 +1265,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, R_lmn): - R_lmn = np.atleast_1d(np.asarray(R_lmn)) + R_lmn = jnp.atleast_1d(jnp.asarray(R_lmn)) errorif( R_lmn.size != self._R_lmn.size, ValueError, @@ -1287,7 +1282,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, Z_lmn): - Z_lmn = np.atleast_1d(np.asarray(Z_lmn)) + Z_lmn = jnp.atleast_1d(jnp.asarray(Z_lmn)) errorif( Z_lmn.size != self._Z_lmn.size, ValueError, @@ -1304,7 +1299,7 @@ def L_lmn(self): @L_lmn.setter def L_lmn(self, L_lmn): - L_lmn = np.atleast_1d(np.asarray(L_lmn)) + L_lmn = jnp.atleast_1d(jnp.asarray(L_lmn)) errorif( L_lmn.size != self._L_lmn.size, ValueError, diff --git a/desc/geometry/core.py b/desc/geometry/core.py index e2cb8417c3..a637ff529e 100644 --- a/desc/geometry/core.py +++ b/desc/geometry/core.py @@ -5,6 +5,7 @@ import numpy as np +from desc.backend import jnp from desc.compute import compute as compute_fun from desc.compute import data_index from desc.compute.geom_utils import reflection_matrix, rotation_matrix @@ -25,8 +26,8 @@ class Curve(IOAble, Optimizable, ABC): _io_attrs_ = ["_name", "_shift", "_rotmat"] def __init__(self, name=""): - self._shift = np.array([0, 0, 0], dtype=float) - self._rotmat = np.eye(3, dtype=float).flatten() + self._shift = jnp.array([0, 0, 0], dtype=float) + self._rotmat = jnp.eye(3, dtype=float).flatten() self._name = name def _set_up(self): @@ -42,12 +43,12 @@ def _set_up(self): @property def shift(self): """Displacement of curve in X, Y, Z.""" - return self.__dict__.setdefault("_shift", np.array([0, 0, 0], dtype=float)) + return self.__dict__.setdefault("_shift", jnp.array([0, 0, 0], dtype=float)) @shift.setter def shift(self, new): if len(new) == 3: - self._shift = np.asarray(new) + self._shift = jnp.asarray(new) else: raise ValueError("shift should be a 3 element vector, got {}".format(new)) @@ -55,14 +56,14 @@ def shift(self, new): @property def rotmat(self): """Rotation matrix of curve in X, Y, Z.""" - return self.__dict__.setdefault("_rotmat", np.eye(3, dtype=float).flatten()) + return self.__dict__.setdefault("_rotmat", jnp.eye(3, dtype=float).flatten()) @rotmat.setter def rotmat(self, new): if len(new) == 9: - self._rotmat = np.asarray(new) + self._rotmat = jnp.asarray(new) else: - self._rotmat = np.asarray(new.flatten()) + self._rotmat = jnp.asarray(new.flatten()) @property def name(self): @@ -178,7 +179,7 @@ def compute( def translate(self, displacement=[0, 0, 0]): """Translate the curve by a rigid displacement in X,Y,Z coordinates.""" - self.shift = self.shift + np.asarray(displacement) + self.shift = self.shift + jnp.asarray(displacement) def rotate(self, axis=[0, 0, 1], angle=0): """Rotate the curve by a fixed angle about axis in X,Y,Z coordinates.""" diff --git a/desc/geometry/curve.py b/desc/geometry/curve.py index bd087dda61..3f670656e7 100644 --- a/desc/geometry/curve.py +++ b/desc/geometry/curve.py @@ -2,7 +2,7 @@ import numpy as np -from desc.backend import put +from desc.backend import jnp, put from desc.basis import FourierSeries from desc.compute import rpz2xyz, xyz2rpz from desc.grid import LinearGrid @@ -86,11 +86,11 @@ def __init__( NZ = np.max(abs(modes_Z)) N = max(NR, NZ) self._NFP = check_posint(NFP, "NFP", False) - self._R_basis = FourierSeries(int(N), int(NFP), sym="cos" if sym else False) - self._Z_basis = FourierSeries(int(N), int(NFP), sym="sin" if sym else False) + self._R_basis = FourierSeries(N, int(NFP), sym="cos" if sym else False) + self._Z_basis = FourierSeries(N, int(NFP), sym="sin" if sym else False) - self._R_n = np.array(copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2])) - self._Z_n = np.array(copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2])) + self._R_n = copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2]) + self._Z_n = copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2]) @property def sym(self): @@ -138,8 +138,8 @@ def change_resolution(self, N=None, NFP=None, sym=None): self.Z_basis.change_resolution( N=N, NFP=self.NFP, sym="sin" if self.sym else self.sym ) - self.R_n = np.array(copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes)) - self.Z_n = np.array(copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes)) + self.R_n = copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes) + self.Z_n = copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes) def get_coeffs(self, n): """Get Fourier coefficients for given mode number(s).""" @@ -176,7 +176,7 @@ def R_n(self): @R_n.setter def R_n(self, new): if len(new) == self.R_basis.num_modes: - self._R_n = np.asarray(new) + self._R_n = jnp.asarray(new) else: raise ValueError( f"R_n should have the same size as the basis, got {len(new)} for " @@ -192,7 +192,7 @@ def Z_n(self): @Z_n.setter def Z_n(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_n = np.asarray(new) + self._Z_n = jnp.asarray(new) else: raise ValueError( f"Z_n should have the same size as the basis, got {len(new)} for " @@ -439,7 +439,7 @@ def X_n(self): @X_n.setter def X_n(self, new): if len(new) == self.X_basis.num_modes: - self._X_n = np.asarray(new) + self._X_n = jnp.asarray(new) else: raise ValueError( f"X_n should have the same size as the basis, got {len(new)} for " @@ -455,7 +455,7 @@ def Y_n(self): @Y_n.setter def Y_n(self, new): if len(new) == self.Y_basis.num_modes: - self._Y_n = np.asarray(new) + self._Y_n = jnp.asarray(new) else: raise ValueError( f"Y_n should have the same size as the basis, got {len(new)} for " @@ -471,7 +471,7 @@ def Z_n(self): @Z_n.setter def Z_n(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_n = np.asarray(new) + self._Z_n = jnp.asarray(new) else: raise ValueError( f"Z_n should have the same size as the basis, got {len(new)} for " @@ -663,7 +663,7 @@ def r_n(self): @r_n.setter def r_n(self, new): if len(np.asarray(new)) == self.r_basis.num_modes: - self._r_n = np.asarray(new) + self._r_n = jnp.asarray(new) else: raise ValueError( f"r_n should have the same size as the basis, got {len(new)} for " @@ -829,7 +829,7 @@ def X(self): @X.setter def X(self, new): if len(new) == len(self.knots): - self._X = np.asarray(new) + self._X = jnp.asarray(new) else: raise ValueError( "X should have the same size as the knots, " @@ -845,7 +845,7 @@ def Y(self): @Y.setter def Y(self, new): if len(new) == len(self.knots): - self._Y = np.asarray(new) + self._Y = jnp.asarray(new) else: raise ValueError( "Y should have the same size as the knots, " @@ -861,7 +861,7 @@ def Z(self): @Z.setter def Z(self, new): if len(new) == len(self.knots): - self._Z = np.asarray(new) + self._Z = jnp.asarray(new) else: raise ValueError( "Z should have the same size as the knots, " @@ -876,7 +876,7 @@ def knots(self): @knots.setter def knots(self, new): if len(new) == len(self.knots): - knots = np.atleast_1d(np.asarray(new)) + knots = jnp.atleast_1d(jnp.asarray(new)) errorif( not np.all(np.diff(knots) > 0), ValueError, @@ -884,7 +884,7 @@ def knots(self, new): ) errorif(knots[0] < 0, ValueError, "knots must lie in [0, 2pi]") errorif(knots[-1] > 2 * np.pi, ValueError, "knots must lie in [0, 2pi]") - self._knots = np.asarray(knots) + self._knots = jnp.asarray(knots) else: raise ValueError( "new knots should have the same size as the current knots, " diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 0bbd750a93..3a4cf0c136 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -225,7 +225,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, new): if len(new) == self.R_basis.num_modes: - self._R_lmn = np.atleast_1d(np.asarray(new)) + self._R_lmn = jnp.asarray(new) else: raise ValueError( f"R_lmn should have the same size as the basis, got {len(new)} for " @@ -241,7 +241,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_lmn = np.atleast_1d(np.asarray(new)) + self._Z_lmn = jnp.asarray(new) else: raise ValueError( f"Z_lmn should have the same size as the basis, got {len(new)} for " @@ -963,7 +963,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, new): if len(new) == self.R_basis.num_modes: - self._R_lmn = np.atleast_1d(np.asarray(new)) + self._R_lmn = jnp.asarray(new) else: raise ValueError( f"R_lmn should have the same size as the basis, got {len(new)} for " @@ -979,7 +979,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_lmn = np.atleast_1d(np.asarray(new)) + self._Z_lmn = jnp.asarray(new) else: raise ValueError( f"Z_lmn should have the same size as the basis, got {len(new)} for " diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index f140781148..00159e1ebc 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,9 +2,7 @@ import functools import inspect -import multiprocessing import re -import warnings import numpy as np @@ -78,7 +76,6 @@ def __init__( normalize_target=False, loss_function=None, fd_step=1e-4, # TODO: generalize this to allow a vector of different scales - vectorized=False, name="external", **kwargs, ): @@ -88,7 +85,6 @@ def __init__( self._fun = fun self._dim_f = dim_f self._fd_step = fd_step - self._vectorized = vectorized self._kwargs = kwargs super().__init__( things=eq, @@ -118,30 +114,11 @@ def build(self, use_jit=True, verbose=1): def fun_wrapped(params): """Wrap external function with optimizable params arguments.""" - param_shape = params["Psi"].shape - num_eq = param_shape[0] if len(param_shape) > 1 else 1 - if self._vectorized and num_eq > 1: - # convert params to list of Equilibria - eqs = [self._eq.copy() for _ in range(num_eq)] - for k, eq in enumerate(eqs): - for param_key in self._eq.optimizable_params: - param_value = np.array(params[param_key][k, :]) - if len(param_value): - setattr(eq, param_key, param_value) - # parallelize calls to external function - with warnings.catch_warnings(action="ignore", category=RuntimeWarning): - with multiprocessing.Pool() as pool: - results = pool.map( - functools.partial(self._fun, **self._kwargs), eqs - ) - return jnp.vstack(results) - else: - # update Equilibrium with params - for param_key in self._eq.optimizable_params: - param_value = params[param_key] - if len(param_value): - setattr(self._eq, param_key, param_value) - return self._fun(self._eq, **self._kwargs) + for param_key in self._eq.optimizable_params: + param_value = params[param_key] + if len(param_value): + setattr(self._eq, param_key, param_value) + return self._fun(self._eq, **self._kwargs) # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) @@ -210,13 +187,7 @@ def wrap_pure_callback(func): @functools.wraps(func) def wrapper(*args, **kwargs): result_shape_dtype = abstract_eval(*args, **kwargs) - return jax.pure_callback( - func, - result_shape_dtype, - *args, - vectorized=self._vectorized, - **kwargs, - ) + return jax.pure_callback(func, result_shape_dtype, *args, **kwargs) return wrapper diff --git a/desc/profiles.py b/desc/profiles.py index 1aaf1aae91..359ffefb42 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -277,7 +277,7 @@ def __init__(self, scale, profile, **kwargs): @property def params(self): """ndarray: Parameters for computation [scale, profile.params].""" - return np.concatenate([np.atleast_1d(self._scale), self._profile.params]) + return jnp.concatenate([jnp.atleast_1d(self._scale), self._profile.params]) @params.setter def params(self, x): @@ -364,7 +364,7 @@ def __init__(self, *profiles, **kwargs): @property def params(self): """ndarray: Concatenated array of parameters for computation.""" - return np.concatenate([profile.params for profile in self._profiles]) + return jnp.concatenate([profile.params for profile in self._profiles]) @params.setter def params(self, x): @@ -451,7 +451,7 @@ def __init__(self, *profiles, **kwargs): @property def params(self): """ndarray: Concatenated array of parameters for computation.""" - return np.concatenate([profile.params for profile in self._profiles]) + return jnp.concatenate([profile.params for profile in self._profiles]) @params.setter def params(self, x): @@ -591,9 +591,9 @@ def params(self): @params.setter def params(self, new): - new = np.atleast_1d(np.asarray(new)) + new = jnp.atleast_1d(jnp.asarray(new)) if new.size == self._basis.num_modes: - self._params = np.asarray(new) + self._params = jnp.asarray(new) else: raise ValueError( "params should have the same size as the basis, " @@ -745,9 +745,9 @@ def params(self): @params.setter def params(self, new): - new = np.atleast_1d(np.asarray(new)) + new = jnp.atleast_1d(jnp.asarray(new)) if new.size == 3: - self._params = np.asarray(new) + self._params = jnp.asarray(new) else: raise ValueError(f"params should be an array of size 3, got {len(new)}.") @@ -849,7 +849,7 @@ def params(self): @params.setter def params(self, new): if len(new) == len(self._knots): - self._params = np.asarray(new) + self._params = jnp.asarray(new) else: raise ValueError( "params should have the same size as the knots, " @@ -932,9 +932,9 @@ def params(self): @params.setter def params(self, new): - new = np.atleast_1d(np.asarray(new)) + new = jnp.atleast_1d(jnp.asarray(new)) if new.size >= 5: - self._params = np.asarray(new) + self._params = jnp.asarray(new) else: raise ValueError( "params should have at least 5 elements [ped, offset, sym, width," @@ -1199,9 +1199,9 @@ def params(self): @params.setter def params(self, new): - new = np.atleast_1d(np.asarray(new)) + new = jnp.atleast_1d(jnp.asarray(new)) if new.size == self._basis.num_modes: - self._params = np.asarray(new) + self._params = jnp.asarray(new) else: raise ValueError( f"params should have the same size as the basis, got {new.size} " diff --git a/tests/test_examples.py b/tests/test_examples.py index 23b5c7e899..5ee219d2fb 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1319,7 +1319,6 @@ def test_optimize_with_fourier_planar_coil(): @pytest.mark.unit -@pytest.mark.slow def test_external_vs_generic_objectives(tmpdir_factory): """Test ExternalObjective compared to GenericObjective.""" target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2] From 87ab19fb05b99c08f4cbe59128852b474f55efab Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 7 Jun 2024 11:28:05 -0600 Subject: [PATCH 10/34] vectorization working! --- desc/backend.py | 38 +++++++++++++++-------- desc/objectives/__init__.py | 7 +---- desc/objectives/_generic.py | 60 ++++++++++++++++++++++++++++++------- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 77cf6f090d..901ab28253 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -1,5 +1,6 @@ """Backend functions for DESC, with options for JAX or regular numpy.""" +import multiprocessing import os import warnings @@ -10,15 +11,23 @@ from desc import config as desc_config from desc import set_device +verbose = True + +# set child processes to use numpy backend and suppress print statements +if not multiprocessing.current_process().name == "MainProcess": + os.environ["DESC_BACKEND"] = "numpy" + verbose = False + if os.environ.get("DESC_BACKEND") == "numpy": jnp = np use_jax = False set_device(kind="cpu") - print( - "DESC version {}, using numpy backend, version={}, dtype={}".format( - desc.__version__, np.__version__, np.linspace(0, 1).dtype + if verbose: + print( + "DESC version {}, using numpy backend, version={}, dtype={}".format( + desc.__version__, np.__version__, np.linspace(0, 1).dtype + ) ) - ) else: if desc_config.get("device") is None: set_device("cpu") @@ -40,11 +49,12 @@ x = jnp.linspace(0, 5) y = jnp.exp(x) use_jax = True - print( - f"DESC version {desc.__version__}," - + f"using JAX backend, jax version={jax.__version__}, " - + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" - ) + if verbose: + print( + f"DESC version {desc.__version__}, " + + f"using JAX backend, jax version={jax.__version__}, " + + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" + ) del x, y except ModuleNotFoundError: jnp = np @@ -58,11 +68,13 @@ desc.__version__, np.__version__, y.dtype ) ) -print( - "Using device: {}, with {:.2f} GB available memory".format( - desc_config.get("device"), desc_config.get("avail_mem") + +if verbose: + print( + "Using device: {}, with {:.2f} GB available memory".format( + desc_config.get("device"), desc_config.get("avail_mem") + ) ) -) if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign? jit = jax.jit diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index ce35e1dcfc..29d875b7ed 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -18,12 +18,7 @@ RadialForceBalance, ) from ._free_boundary import BoundaryError, VacuumBoundaryError -from ._generic import ( - ExternalObjective, - GenericObjective, - LinearObjectiveFromUser, - ObjectiveFromUser, -) +from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser from ._geometry import ( AspectRatio, BScaleLength, diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 00159e1ebc..0114a0d0b8 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,11 +2,13 @@ import functools import inspect +import multiprocessing import re +from abc import ABC import numpy as np -from desc.backend import jax, jnp +from desc.backend import jnp from desc.compute import data_index from desc.compute.utils import _compute as compute_fun from desc.compute.utils import get_profiles, get_transforms @@ -16,7 +18,7 @@ from .objective_funs import _Objective -class ExternalObjective(_Objective): +class _ExternalObjective(_Objective, ABC): """Wrap an external code. Similar to ``ObjectiveFromUser``, except derivatives of the objective function are @@ -25,6 +27,8 @@ class ExternalObjective(_Objective): The user supplied function must take an Equilibrium as its only positional argument, but can take additional keyword arguments. + # TODO: add Parameters documentation + Parameters ---------- eq : Equilibrium @@ -76,6 +80,7 @@ def __init__( normalize_target=False, loss_function=None, fd_step=1e-4, # TODO: generalize this to allow a vector of different scales + vectorized=False, name="external", **kwargs, ): @@ -85,7 +90,13 @@ def __init__( self._fun = fun self._dim_f = dim_f self._fd_step = fd_step + self._vectorized = vectorized self._kwargs = kwargs + if self._vectorized: + try: # spawn a new environment so the backend can be set to numpy + multiprocessing.set_start_method("spawn") + except RuntimeError: # context can only be set once + pass super().__init__( things=eq, target=target, @@ -113,12 +124,33 @@ def build(self, use_jit=True, verbose=1): self._constants = {"quad_weights": 1.0} def fun_wrapped(params): - """Wrap external function with optimizable params arguments.""" - for param_key in self._eq.optimizable_params: - param_value = params[param_key] - if len(param_value): - setattr(self._eq, param_key, param_value) - return self._fun(self._eq, **self._kwargs) + """Wrap external function with possibly vectorized params.""" + # number of equilibria for vectorized computations + param_shape = params["Psi"].shape + num_eq = param_shape[0] if len(param_shape) > 1 else 1 + + if self._vectorized and num_eq > 1: + # convert params to list of equilibria + eqs = [self._eq.copy() for _ in range(num_eq)] + for k, eq in enumerate(eqs): + # update equilibria with new params + for param_key in self._eq.optimizable_params: + param_value = np.array(params[param_key][k, :]) + if len(param_value): + setattr(eq, param_key, param_value) + # parallelize calls to external function + with multiprocessing.Pool(processes=num_eq) as pool: + results = pool.map( + functools.partial(self._fun, **self._kwargs), eqs + ) + return jnp.vstack(results, dtype=float) + else: # no vectorization + # update equilibrium with new params + for param_key in self._eq.optimizable_params: + param_value = params[param_key] + if len(param_value): + setattr(self._eq, param_key, param_value) + return self._fun(self._eq, **self._kwargs) # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) @@ -182,12 +214,19 @@ def _jaxify(self, func, abstract_eval): New function that behaves as func but works with jit/vmap/jacfwd etc. """ + import jax def wrap_pure_callback(func): @functools.wraps(func) def wrapper(*args, **kwargs): result_shape_dtype = abstract_eval(*args, **kwargs) - return jax.pure_callback(func, result_shape_dtype, *args, **kwargs) + return jax.pure_callback( + func, + result_shape_dtype, + *args, + vectorized=self._vectorized, + **kwargs, + ) return wrapper @@ -602,6 +641,8 @@ def build(self, use_jit=True, verbose=1): Level of output. """ + import jax + eq = self.things[0] if self._grid is None: grid = QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid, eq.NFP) @@ -628,7 +669,6 @@ def get_vars(fun): ).squeeze() self._fun_wrapped = lambda data: self._fun(grid, data) - import jax self._dim_f = jax.eval_shape(self._fun_wrapped, dummy_data).size self._scalar = self._dim_f == 1 From 539561182246cde61dabc94bbd9a285b43aae299 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 11 Jun 2024 14:05:52 -0600 Subject: [PATCH 11/34] allow vectorized to be an int --- desc/objectives/_generic.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 0114a0d0b8..937cd2d88d 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,7 +2,8 @@ import functools import inspect -import multiprocessing +import multiprocessing as mp +import os import re from abc import ABC @@ -80,10 +81,11 @@ def __init__( normalize_target=False, loss_function=None, fd_step=1e-4, # TODO: generalize this to allow a vector of different scales - vectorized=False, + vectorized=False, # False or int name="external", **kwargs, ): + assert isinstance(vectorized, bool) or isinstance(vectorized, int) if target is None and bounds is None: target = 0 self._eq = eq.copy() @@ -94,7 +96,7 @@ def __init__( self._kwargs = kwargs if self._vectorized: try: # spawn a new environment so the backend can be set to numpy - multiprocessing.set_start_method("spawn") + mp.set_start_method("spawn") except RuntimeError: # context can only be set once pass super().__init__( @@ -139,10 +141,17 @@ def fun_wrapped(params): if len(param_value): setattr(eq, param_key, param_value) # parallelize calls to external function - with multiprocessing.Pool(processes=num_eq) as pool: + max_processes = ( + self._vectorized + if isinstance(self._vectorized, int) + else os.cpu_count() + ) + with mp.Pool(processes=min(max_processes, num_eq)) as pool: results = pool.map( functools.partial(self._fun, **self._kwargs), eqs ) + pool.close() + pool.join() return jnp.vstack(results, dtype=float) else: # no vectorization # update equilibrium with new params @@ -224,7 +233,7 @@ def wrapper(*args, **kwargs): func, result_shape_dtype, *args, - vectorized=self._vectorized, + vectorized=bool(self._vectorized), **kwargs, ) From 52d58d05e51e133094bee9bbd9a39f22ac64c713 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 11 Jun 2024 17:01:30 -0600 Subject: [PATCH 12/34] fix numpy cond --- desc/backend.py | 8 ++++---- desc/compute/utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 901ab28253..1ad2a90904 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -501,7 +501,7 @@ def fori_loop(lower, upper, body_fun, init_val): val = body_fun(i, val) return val - def cond(pred, true_fun, false_fun, *operand): + def cond(pred, true_fun, false_fun, *operands): """Conditionally apply true_fun or false_fun. This version is for the numpy backend, for jax backend see jax.lax.cond @@ -514,7 +514,7 @@ def cond(pred, true_fun, false_fun, *operand): Function (A -> B), to be applied if pred is True. false_fun: callable Function (A -> B), to be applied if pred is False. - operand: any + operands: any input to either branch depending on pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof. @@ -527,9 +527,9 @@ def cond(pred, true_fun, false_fun, *operand): """ if pred: - return true_fun(*operand) + return true_fun(*operands) else: - return false_fun(*operand) + return false_fun(*operands) def switch(index, branches, operand): """Apply exactly one of branches given by index. diff --git a/desc/compute/utils.py b/desc/compute/utils.py index b65b9365a9..6cd6e41a0d 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -927,7 +927,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14) has_endpoint_dupe, lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]), lambda _: mask, - operand=None, + None, ) else: expand_out = False From 30aeea4602b9df3a3fac4c63d1f0b55147212023 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 17 Jun 2024 18:04:10 -0400 Subject: [PATCH 13/34] merging but no change? --- desc/objectives/_generic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 937cd2d88d..40280c685b 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -125,6 +125,14 @@ def build(self, use_jit=True, verbose=1): self._scalar = self._dim_f == 1 self._constants = {"quad_weights": 1.0} + if self._vectorized: + max_processes = ( + self._vectorized + if isinstance(self._vectorized, int) + else os.cpu_count() + ) + self._pool = mp.Pool(processes=max_processes) + def fun_wrapped(params): """Wrap external function with possibly vectorized params.""" # number of equilibria for vectorized computations From 90296ea83fd4a8fbefba8badfd175fd8c4157b88 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 17 Jun 2024 18:20:49 -0400 Subject: [PATCH 14/34] update test with new UI --- desc/objectives/__init__.py | 7 ++++++- tests/test_examples.py | 35 +++++++++++++++++++++++++++++++++-- tests/test_objective_funs.py | 3 --- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index 29d875b7ed..05a44a63d4 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -18,7 +18,12 @@ RadialForceBalance, ) from ._free_boundary import BoundaryError, VacuumBoundaryError -from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser +from ._generic import ( + GenericObjective, + LinearObjectiveFromUser, + ObjectiveFromUser, + _ExternalObjective, +) from ._geometry import ( AspectRatio, BScaleLength, diff --git a/tests/test_examples.py b/tests/test_examples.py index 5ee219d2fb..b523250282 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -31,7 +31,6 @@ CoilLength, CoilTorsion, CurrentDensity, - ExternalObjective, FixBoundaryR, FixBoundaryZ, FixCurrent, @@ -56,6 +55,7 @@ QuasisymmetryTwoTerm, VacuumBoundaryError, Volume, + _ExternalObjective, get_fixed_boundary_constraints, get_NAE_constraints, ) @@ -1319,6 +1319,7 @@ def test_optimize_with_fourier_planar_coil(): @pytest.mark.unit +@pytest.mark.slow def test_external_vs_generic_objectives(tmpdir_factory): """Test ExternalObjective compared to GenericObjective.""" target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2] @@ -1333,6 +1334,36 @@ def data_from_vmec(eq, path=""): file.close() return np.atleast_1d([betatot, betapol, betator, presf1]) + class TestExternalObjective(_ExternalObjective): + + def __init__( + self, + eq, + target=None, + bounds=None, + weight=1, + normalize=False, + normalize_target=False, + loss_function=None, + path="", + name="external", + ): + super().__init__( + eq=eq, + fun=data_from_vmec, + dim_f=4, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + fd_step=1e-4, + vectorized=False, + name=name, + path=path, + ) + eq0 = get("SOLOVEV") optimizer = Optimizer("lsq-exact") @@ -1369,7 +1400,7 @@ def data_from_vmec(eq, path=""): dir = tmpdir_factory.mktemp("results") path = dir.join("wout_result.nc") objective = ObjectiveFunction( - ExternalObjective(eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path) + TestExternalObjective(eq=eq0, target=target, path=path) ) constraints = FixParameters( eq0, diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 1fcbc95291..9dd4ec5fee 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -39,7 +39,6 @@ CoilTorsion, Elongation, Energy, - ExternalObjective, ForceBalance, ForceBalanceAnisotropic, GenericObjective, @@ -1755,7 +1754,6 @@ class TestComputeScalarResolution: # need to avoid blowup near the axis MercierStability, # don't test these since they depend on what user wants - ExternalObjective, LinearObjectiveFromUser, ObjectiveFromUser, ] @@ -2077,7 +2075,6 @@ class TestObjectiveNaNGrad: QuadraticFlux, ToroidalFlux, # we don't test these since they depend too much on what exactly the user wants - ExternalObjective, GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser, From d16e95da3ad0be57ed9aa3c86df79d2f6bb24ccf Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 18 Jun 2024 16:57:04 -0400 Subject: [PATCH 15/34] remove unused pool code --- desc/objectives/_generic.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 40280c685b..937cd2d88d 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -125,14 +125,6 @@ def build(self, use_jit=True, verbose=1): self._scalar = self._dim_f == 1 self._constants = {"quad_weights": 1.0} - if self._vectorized: - max_processes = ( - self._vectorized - if isinstance(self._vectorized, int) - else os.cpu_count() - ) - self._pool = mp.Pool(processes=max_processes) - def fun_wrapped(params): """Wrap external function with possibly vectorized params.""" # number of equilibria for vectorized computations From f9b7562febe29694556d111538557b212b9d0afc Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Wed, 17 Jul 2024 12:42:33 -0600 Subject: [PATCH 16/34] remove comment note --- desc/objectives/objective_funs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 1e5f627a09..848dec17bf 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -58,7 +58,6 @@ def __init__( def _set_derivatives(self): """Set up derivatives of the objective functions.""" - # TODO: does deriv_mode have to be "blocked" if there is an ExternalObjective? if self._deriv_mode == "auto": if all((obj._deriv_mode == "fwd") for obj in self.objectives): self._deriv_mode = "batched" From f1f466b0600eabee8add2dc99bf0ca4b35c5ddea Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 18 Jul 2024 11:02:26 -0600 Subject: [PATCH 17/34] fix black formatting from merge conflict --- tests/test_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index 4e07438e6b..f6fe24c7f4 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -72,7 +72,6 @@ from desc.optimize import Optimizer from desc.profiles import FourierZernikeProfile, PowerSeriesProfile from desc.vmec import VMECIO -from desc.vmec_utils import vmec_boundary_subspace from .utils import area_difference_desc, area_difference_vmec @@ -1551,6 +1550,7 @@ def circle_constraint(params): rtol=2e-2, ) + @pytest.mark.unit @pytest.mark.slow def test_external_vs_generic_objectives(tmpdir_factory): From ecc5b3bed7e164bfedb387be62449d4e6d06dde7 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 18 Jul 2024 11:04:57 -0600 Subject: [PATCH 18/34] repair test from merge conflict --- tests/test_examples.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index f6fe24c7f4..c719699c69 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1603,10 +1603,12 @@ def __init__( # generic objective = ObjectiveFunction( ( - GenericObjective("_vol", eq=eq0, target=target[0]), - GenericObjective("_vol", eq=eq0, target=target[1]), - GenericObjective("_vol", eq=eq0, target=target[2]), - GenericObjective("p", eq=eq0, target=0, grid=LinearGrid(rho=[1], M=0, N=0)), + GenericObjective("_vol", thing=eq0, target=target[0]), + GenericObjective("_vol", thing=eq0, target=target[1]), + GenericObjective("_vol", thing=eq0, target=target[2]), + GenericObjective( + "p", thing=eq0, target=0, grid=LinearGrid(rho=[1], M=0, N=0) + ), ) ) constraints = FixParameters( From bf62014c4fa54b318d62fe751f7f1ae87a0a7718 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 18 Jul 2024 13:03:24 -0600 Subject: [PATCH 19/34] remove multiprocessing from ExternalObjective class --- desc/backend.py | 37 +++++++++++------------------- desc/objectives/_generic.py | 45 +++++++++---------------------------- 2 files changed, 24 insertions(+), 58 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 1ad2a90904..1ac07a729c 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -1,6 +1,5 @@ """Backend functions for DESC, with options for JAX or regular numpy.""" -import multiprocessing import os import warnings @@ -11,23 +10,15 @@ from desc import config as desc_config from desc import set_device -verbose = True - -# set child processes to use numpy backend and suppress print statements -if not multiprocessing.current_process().name == "MainProcess": - os.environ["DESC_BACKEND"] = "numpy" - verbose = False - if os.environ.get("DESC_BACKEND") == "numpy": jnp = np use_jax = False set_device(kind="cpu") - if verbose: - print( - "DESC version {}, using numpy backend, version={}, dtype={}".format( - desc.__version__, np.__version__, np.linspace(0, 1).dtype - ) + print( + "DESC version {}, using numpy backend, version={}, dtype={}".format( + desc.__version__, np.__version__, np.linspace(0, 1).dtype ) + ) else: if desc_config.get("device") is None: set_device("cpu") @@ -49,12 +40,11 @@ x = jnp.linspace(0, 5) y = jnp.exp(x) use_jax = True - if verbose: - print( - f"DESC version {desc.__version__}, " - + f"using JAX backend, jax version={jax.__version__}, " - + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" - ) + print( + f"DESC version {desc.__version__}, " + + f"using JAX backend, jax version={jax.__version__}, " + + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" + ) del x, y except ModuleNotFoundError: jnp = np @@ -69,12 +59,11 @@ ) ) -if verbose: - print( - "Using device: {}, with {:.2f} GB available memory".format( - desc_config.get("device"), desc_config.get("avail_mem") - ) +print( + "Using device: {}, with {:.2f} GB available memory".format( + desc_config.get("device"), desc_config.get("avail_mem") ) +) if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign? jit = jax.jit diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 1cc16fef0c..814d0f3631 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,8 +2,6 @@ import functools import inspect -import multiprocessing as mp -import os import re from abc import ABC @@ -96,11 +94,6 @@ def __init__( self._fd_step = fd_step self._vectorized = vectorized self._kwargs = kwargs - if self._vectorized: - try: # spawn a new environment so the backend can be set to numpy - mp.set_start_method("spawn") - except RuntimeError: # context can only be set once - pass super().__init__( things=eq, target=target, @@ -133,35 +126,19 @@ def fun_wrapped(params): param_shape = params["Psi"].shape num_eq = param_shape[0] if len(param_shape) > 1 else 1 - if self._vectorized and num_eq > 1: - # convert params to list of equilibria - eqs = [self._eq.copy() for _ in range(num_eq)] - for k, eq in enumerate(eqs): - # update equilibria with new params - for param_key in self._eq.optimizable_params: - param_value = np.array(params[param_key][k, :]) - if len(param_value): - setattr(eq, param_key, param_value) - # parallelize calls to external function - max_processes = ( - self._vectorized - if isinstance(self._vectorized, int) - else os.cpu_count() - ) - with mp.Pool(processes=min(max_processes, num_eq)) as pool: - results = pool.map( - functools.partial(self._fun, **self._kwargs), eqs - ) - pool.close() - pool.join() - return jnp.vstack(results, dtype=float) - else: # no vectorization - # update equilibrium with new params + # convert params to list of equilibria + eqs = [self._eq.copy() for _ in range(num_eq)] + for k, eq in enumerate(eqs): + # update equilibria with new params for param_key in self._eq.optimizable_params: - param_value = params[param_key] + param_value = np.atleast_2d(params[param_key])[k, :] if len(param_value): - setattr(self._eq, param_key, param_value) - return self._fun(self._eq, **self._kwargs) + setattr(eq, param_key, param_value) + + # call external function on equilibrium or list of equilibria + if not self._vectorized: + eqs = eqs[0] + return self._fun(eqs, **self._kwargs) # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) From 0547bd74245e74d2f31eef21f4fb64f453177be9 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 19 Jul 2024 12:49:13 -0600 Subject: [PATCH 20/34] jaxify as a util function --- desc/objectives/_generic.py | 117 +++++++++--------------------------- desc/utils.py | 94 ++++++++++++++++++++++++++++- tests/test_examples.py | 27 +-------- tests/test_utils.py | 41 ++++++++++++- 4 files changed, 161 insertions(+), 118 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 814d0f3631..36275c35f2 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -1,6 +1,5 @@ """Generic objectives that don't belong anywhere else.""" -import functools import inspect import re from abc import ABC @@ -13,7 +12,7 @@ from desc.compute.utils import _parse_parameterization, get_profiles, get_transforms from desc.grid import QuadratureGrid from desc.optimizable import OptimizableCollection -from desc.utils import errorif, parse_argname_change +from desc.utils import errorif, jaxify, parse_argname_change from .linear_objectives import _FixedObjective from .objective_funs import _Objective @@ -23,21 +22,22 @@ class _ExternalObjective(_Objective, ABC): """Wrap an external code. Similar to ``ObjectiveFromUser``, except derivatives of the objective function are - computed with finite differences instead of AD. + computed with finite differences instead of AD. The function does not need not be + JAX transformable. The user supplied function must take an Equilibrium as its only positional argument, but can take additional keyword arguments. - # TODO: add Parameters documentation - Parameters ---------- eq : Equilibrium Equilibrium that will be optimized to satisfy the Objective. fun : callable - Custom objective function. + External objective function. It must take an Equilibrium as its only positional + argument, but can take additional kewyord arguments. It does not need to be JAX + transformable. dim_f : int - Dimension of the output of fun. + Dimension of the output of ``fun``. target : {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. @@ -59,6 +59,14 @@ class _ExternalObjective(_Objective, ABC): Loss function to apply to the objective values once computed. This loss function is called on the raw compute value, before any shifting, scaling, or normalization. + vectorized : bool, optional + Whether or not ``fun`` is vectorized. Default = False. + abs_step : float, optional + Absolute finite difference step size. Default = 1e-4. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + rel_step : float, optional + Relative finite difference step size. Default = 0. + Total step size is ``abs_step + rel_step * mean(abs(x))``. name : str, optional Name of the objective function. @@ -80,8 +88,9 @@ def __init__( normalize=False, normalize_target=False, loss_function=None, - fd_step=1e-4, # TODO: generalize this to allow a vector of different scales - vectorized=False, # False or int + vectorized=False, + abs_step=1e-4, + rel_step=0, name="external", **kwargs, ): @@ -91,8 +100,9 @@ def __init__( self._eq = eq.copy() self._fun = fun self._dim_f = dim_f - self._fd_step = fd_step self._vectorized = vectorized + self._abs_step = abs_step + self._rel_step = rel_step self._kwargs = kwargs super().__init__( things=eq, @@ -142,7 +152,13 @@ def fun_wrapped(params): # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) - self._fun_wrapped = self._jaxify(fun_wrapped, abstract_eval) + self._fun_wrapped = jaxify( + fun_wrapped, + abstract_eval, + vectorized=self._vectorized, + abs_step=self._abs_step, + rel_step=self._rel_step, + ) super().build(use_jit=use_jit, verbose=verbose) @@ -166,85 +182,6 @@ def compute(self, params, constants=None): f = self._fun_wrapped(params) return f - def _jaxify(self, func, abstract_eval): - """Make an external (python) function work with JAX. - - Positional arguments to func can be differentiated, - use keyword args for static values and non-differentiable stuff. - - Note: Only forward mode differentiation is supported currently. - - Parameters - ---------- - func : callable - Function to wrap. Should be a "pure" function, in that it has no side - effects and doesn't maintain state. Does not need to be JAX transformable. - abstract_eval : callable - Auxilliary function that computes the output shape and dtype of func. - **Must be JAX transformable**. Should be of the form - - abstract_eval(*args, **kwargs) -> Pytree with same shape and dtype as - func(*args, **kwargs) - - For example, if func always returns a scalar: - - abstract_eval = lambda *args, **kwargs: jnp.array(1.) - - Or if func takes an array of shape(n) and returns a dict of arrays of - shape(n-2): - - abstract_eval = lambda arr, **kwargs: - {"out1": jnp.empty(arr.size-2), "out2": jnp.empty(arr.size-2)} - - Returns - ------- - func : callable - New function that behaves as func but works with jit/vmap/jacfwd etc. - - """ - import jax - - def wrap_pure_callback(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - result_shape_dtype = abstract_eval(*args, **kwargs) - return jax.pure_callback( - func, - result_shape_dtype, - *args, - vectorized=bool(self._vectorized), - **kwargs, - ) - - return wrapper - - def define_fd_jvp(func): - func = jax.custom_jvp(func) - - @func.defjvp - def func_jvp(primals, tangents): - primal_out = func(*primals) - - # flatten everything into 1D vectors for easier finite differences - y, unflaty = jax.flatten_util.ravel_pytree(primal_out) - x, unflatx = jax.flatten_util.ravel_pytree(primals) - v, _______ = jax.flatten_util.ravel_pytree(tangents) - # scale to unit norm if nonzero - normv = jnp.linalg.norm(v) - vh = jnp.where(normv == 0, v, v / normv) - - def f(x): - return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0] - - tangent_out = (f(x + self._fd_step * vh) - y) / self._fd_step * normv - tangent_out = unflaty(tangent_out) - - return primal_out, tangent_out - - return func - - return define_fd_jvp(wrap_pure_callback(func)) - class GenericObjective(_Objective): """A generic objective that can compute any quantity from the `data_index`. diff --git a/desc/utils.py b/desc/utils.py index 6427d3f300..48dc4587c4 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -1,5 +1,6 @@ """Utility functions, independent of the rest of DESC.""" +import functools import operator import warnings from itertools import combinations_with_replacement, permutations @@ -8,7 +9,7 @@ from scipy.special import factorial from termcolor import colored -from desc.backend import fori_loop, jit, jnp +from desc.backend import fori_loop, jax, jit, jnp class Timer: @@ -676,3 +677,94 @@ def broadcast_tree(tree_in, tree_out, dtype=int): # invalid tree structure else: raise ValueError("trees must be nested lists of dicts") + + +def jaxify(func, abstract_eval, vectorized=False, abs_step=1e-4, rel_step=0): + """Make an external (python) function work with JAX. + + Positional arguments to func can be differentiated, + use keyword args for static values and non-differentiable stuff. + + Note: Only forward mode differentiation is supported currently. + + Parameters + ---------- + func : callable + Function to wrap. Should be a "pure" function, in that it has no side + effects and doesn't maintain state. Does not need to be JAX transformable. + abstract_eval : callable + Auxilliary function that computes the output shape and dtype of func. + **Must be JAX transformable**. Should be of the form + + abstract_eval(*args, **kwargs) -> Pytree with same shape and dtype as + func(*args, **kwargs) + + For example, if func always returns a scalar: + + abstract_eval = lambda *args, **kwargs: jnp.array(1.) + + Or if func takes an array of shape(n) and returns a dict of arrays of + shape(n-2): + + abstract_eval = lambda arr, **kwargs: + {"out1": jnp.empty(arr.size-2), "out2": jnp.empty(arr.size-2)} + vectorized : bool, optional + Whether or not the wrapped function is vectorized. Default = False. + abs_step : float, optional + Absolute finite difference step size. Default = 1e-4. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + rel_step : float, optional + Relative finite difference step size. Default = 0. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + + Returns + ------- + func : callable + New function that behaves as func but works with jit/vmap/jacfwd etc. + + """ + + def wrap_pure_callback(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result_shape_dtype = abstract_eval(*args, **kwargs) + return jax.pure_callback( + func, + result_shape_dtype, + *args, + vectorized=vectorized, + **kwargs, + ) + + return wrapper + + def define_fd_jvp(func): + func = jax.custom_jvp(func) + + @func.defjvp + def func_jvp(primals, tangents): + primal_out = func(*primals) + + # flatten everything into 1D vectors for easier finite differences + y, unflaty = jax.flatten_util.ravel_pytree(primal_out) + x, unflatx = jax.flatten_util.ravel_pytree(primals) + v, _______ = jax.flatten_util.ravel_pytree(tangents) + + # finite difference step size + fd_step = abs_step + rel_step * jnp.mean(jnp.abs(x)) + + # scale tangents to unit norm if nonzero + normv = jnp.linalg.norm(v) + vh = jnp.where(normv == 0, v, v / normv) + + def f(x): + return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0] + + tangent_out = (f(x + fd_step * vh) - y) / fd_step + tangent_out = unflaty(tangent_out) + + return primal_out, tangent_out + + return func + + return define_fd_jvp(wrap_pure_callback(func)) diff --git a/tests/test_examples.py b/tests/test_examples.py index c719699c69..efbfcd1e6a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1569,32 +1569,9 @@ def data_from_vmec(eq, path=""): class TestExternalObjective(_ExternalObjective): - def __init__( - self, - eq, - target=None, - bounds=None, - weight=1, - normalize=False, - normalize_target=False, - loss_function=None, - path="", - name="external", - ): + def __init__(self, eq, target=None, path=""): super().__init__( - eq=eq, - fun=data_from_vmec, - dim_f=4, - target=target, - bounds=bounds, - weight=weight, - normalize=normalize, - normalize_target=normalize_target, - loss_function=loss_function, - fd_step=1e-4, - vectorized=False, - name=name, - path=path, + eq=eq, fun=data_from_vmec, dim_f=4, target=target, path=path ) eq0 = get("SOLOVEV") diff --git a/tests/test_utils.py b/tests/test_utils.py index 6bfadb4008..4fd128afae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,9 +3,9 @@ import numpy as np import pytest -from desc.backend import tree_leaves, tree_structure +from desc.backend import jax, jnp, tree_leaves, tree_structure from desc.grid import LinearGrid -from desc.utils import broadcast_tree, isalmostequal, islinspaced +from desc.utils import broadcast_tree, isalmostequal, islinspaced, jaxify @pytest.mark.unit @@ -197,3 +197,40 @@ def test_broadcast_tree(): ] for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): np.testing.assert_allclose(leaf, leaf_correct) + + +@pytest.mark.unit +def test_jaxify(): + """Test that jaxify gives accurate finite difference derivatives.""" + + def fun(x): + """Function that is not JAX transformable.""" + return np.sin(x) + 2 * np.cos(x) + 3 * x**2 - 4 * x - 5 + + x = np.linspace(0, 2 * np.pi, 45) + f = fun(x) + df_true = np.cos(x) - 2 * np.sin(x) + 6 * x - 4 + + # finite differences with a range of abs and rel step sizes + abstract_eval = lambda *args, **kwargs: jnp.empty(f.size) + fun_jax_abs1 = jaxify(fun, abstract_eval, abs_step=1e-2) + fun_jax_abs2 = jaxify(fun, abstract_eval, abs_step=1e-3) + fun_jax_abs3 = jaxify(fun, abstract_eval, abs_step=1e-4) + fun_jax_rel1 = jaxify(fun, abstract_eval, rel_step=1e-3) + fun_jax_rel2 = jaxify(fun, abstract_eval, rel_step=1e-4) + fun_jax_rel3 = jaxify(fun, abstract_eval, rel_step=1e-5) + + df_abs1 = np.diagonal(jax.jacfwd(fun_jax_abs1)(x)) + df_abs2 = np.diagonal(jax.jacfwd(fun_jax_abs2)(x)) + df_abs3 = np.diagonal(jax.jacfwd(fun_jax_abs3)(x)) + df_rel1 = np.diagonal(jax.jacfwd(fun_jax_rel1)(x)) + df_rel2 = np.diagonal(jax.jacfwd(fun_jax_rel2)(x)) + df_rel3 = np.diagonal(jax.jacfwd(fun_jax_rel3)(x)) + + # convergence test: smaller step sizes should be more accurate + np.testing.assert_allclose(df_abs1, df_true, atol=5e-2) + np.testing.assert_allclose(df_abs2, df_true, atol=5e-3) + np.testing.assert_allclose(df_abs3, df_true, atol=5e-4) + np.testing.assert_allclose(df_rel1, df_true, rtol=2e-1) + np.testing.assert_allclose(df_rel2, df_true, rtol=2e-2) + np.testing.assert_allclose(df_rel3, df_true, rtol=3e-3) From 03d0cb593b0191863c106b8b410376faefff6d1e Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 22 Jul 2024 10:23:55 -0600 Subject: [PATCH 21/34] ExternalObjective no longer an ABC --- desc/objectives/__init__.py | 2 +- desc/objectives/_generic.py | 5 +++-- tests/test_examples.py | 11 ++--------- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index 6a2584a1ce..dad325909a 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -21,10 +21,10 @@ ) from ._free_boundary import BoundaryError, VacuumBoundaryError from ._generic import ( + ExternalObjective, GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser, - _ExternalObjective, ) from ._geometry import ( AspectRatio, diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 36275c35f2..b52b8e3fa2 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -18,7 +18,7 @@ from .objective_funs import _Objective -class _ExternalObjective(_Objective, ABC): +class ExternalObjective(_Objective, ABC): """Wrap an external code. Similar to ``ObjectiveFromUser``, except derivatives of the objective function are @@ -69,6 +69,8 @@ class _ExternalObjective(_Objective, ABC): Total step size is ``abs_step + rel_step * mean(abs(x))``. name : str, optional Name of the objective function. + kwargs : any, optional + Keyword arguments that are passed as inputs to ``fun``. # TODO: add example @@ -94,7 +96,6 @@ def __init__( name="external", **kwargs, ): - assert isinstance(vectorized, bool) or isinstance(vectorized, int) if target is None and bounds is None: target = 0 self._eq = eq.copy() diff --git a/tests/test_examples.py b/tests/test_examples.py index efbfcd1e6a..f97570a435 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -39,6 +39,7 @@ CoilSetMinDistance, CoilTorsion, CurrentDensity, + ExternalObjective, FixBoundaryR, FixBoundaryZ, FixCoilCurrent, @@ -65,7 +66,6 @@ QuasisymmetryTwoTerm, VacuumBoundaryError, Volume, - _ExternalObjective, get_fixed_boundary_constraints, get_NAE_constraints, ) @@ -1567,13 +1567,6 @@ def data_from_vmec(eq, path=""): file.close() return np.atleast_1d([betatot, betapol, betator, presf1]) - class TestExternalObjective(_ExternalObjective): - - def __init__(self, eq, target=None, path=""): - super().__init__( - eq=eq, fun=data_from_vmec, dim_f=4, target=target, path=path - ) - eq0 = get("SOLOVEV") optimizer = Optimizer("lsq-exact") @@ -1612,7 +1605,7 @@ def __init__(self, eq, target=None, path=""): dir = tmpdir_factory.mktemp("results") path = dir.join("wout_result.nc") objective = ObjectiveFunction( - TestExternalObjective(eq=eq0, target=target, path=path) + ExternalObjective(eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path) ) constraints = FixParameters( eq0, From 7723ecd65598e6e678fd13ddbf657375c7c4744b Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 22 Jul 2024 11:01:28 -0600 Subject: [PATCH 22/34] re-add print logic in backend --- desc/backend.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 208abafc88..9940c5f234 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -1,5 +1,6 @@ """Backend functions for DESC, with options for JAX or regular numpy.""" +import multiprocessing as mp import os import warnings @@ -10,15 +11,19 @@ from desc import config as desc_config from desc import set_device +# only print details in the main process, not child processes spawned by multiprocessing +verbose = bool(mp.current_process().name == "MainProcess") + if os.environ.get("DESC_BACKEND") == "numpy": jnp = np use_jax = False set_device(kind="cpu") - print( - "DESC version {}, using numpy backend, version={}, dtype={}".format( - desc.__version__, np.__version__, np.linspace(0, 1).dtype + if verbose: + print( + "DESC version {}, using numpy backend, version={}, dtype={}".format( + desc.__version__, np.__version__, np.linspace(0, 1).dtype + ) ) - ) else: if desc_config.get("device") is None: set_device("cpu") @@ -40,11 +45,12 @@ x = jnp.linspace(0, 5) y = jnp.exp(x) use_jax = True - print( - f"DESC version {desc.__version__}, " - + f"using JAX backend, jax version={jax.__version__}, " - + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" - ) + if verbose: + print( + f"DESC version {desc.__version__}, " + + f"using JAX backend, jax version={jax.__version__}, " + + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" + ) del x, y except ModuleNotFoundError: jnp = np @@ -59,11 +65,12 @@ ) ) -print( - "Using device: {}, with {:.2f} GB available memory".format( - desc_config.get("device"), desc_config.get("avail_mem") +if verbose: + print( + "Using device: {}, with {:.2f} GB available memory".format( + desc_config.get("device"), desc_config.get("avail_mem") + ) ) -) if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign? jit = jax.jit From 0b2207f552e53cf5456a22cef33ec6a61165f3eb Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 26 Jul 2024 13:15:18 -0600 Subject: [PATCH 23/34] exclude ExternalObjective from tests --- tests/test_objective_funs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 9fcc12634b..dd60edb271 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -46,6 +46,7 @@ CoilTorsion, Elongation, Energy, + ExternalObjective, ForceBalance, ForceBalanceAnisotropic, GenericObjective, @@ -1954,7 +1955,8 @@ class TestComputeScalarResolution: VacuumBoundaryError, # need to avoid blowup near the axis MercierStability, - # don't test these since they depend on what user wants + # we do not test these since they depend too much on what the user wants + ExternalObjective, LinearObjectiveFromUser, ObjectiveFromUser, ] @@ -2274,17 +2276,17 @@ class TestObjectiveNaNGrad: CoilSetMinDistance, CoilTorsion, ForceBalanceAnisotropic, + Omnigenity, PlasmaCoilSetMinDistance, PlasmaVesselDistance, QuadraticFlux, ToroidalFlux, VacuumBoundaryError, - # we don't test these since they depend too much on what exactly the user wants + # we do not test these since they depend too much on what the user wants + ExternalObjective, GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser, - # TODO: add Omnigenity objective (see GH issue #943) - Omnigenity, ] other_objectives = list(set(objectives) - set(specials)) From aa570d4cb5608072be5a83af558c27521f6f106f Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Fri, 26 Jul 2024 13:18:02 -0600 Subject: [PATCH 24/34] scale FD derivatives by tangent norm --- desc/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 96fe1ea7b3..c1f191f74f 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -766,7 +766,7 @@ def func_jvp(primals, tangents): def f(x): return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0] - tangent_out = (f(x + fd_step * vh) - y) / fd_step + tangent_out = (f(x + fd_step * vh) - y) / fd_step * normv tangent_out = unflaty(tangent_out) return primal_out, tangent_out From a83a6719fa51f652c6383d3710aaeb169db3515e Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 22 Aug 2024 13:24:56 -0600 Subject: [PATCH 25/34] resolve merge conflict --- desc/integrals/surface_integral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/integrals/surface_integral.py b/desc/integrals/surface_integral.py index acc1e6c1b9..c0f603fe97 100644 --- a/desc/integrals/surface_integral.py +++ b/desc/integrals/surface_integral.py @@ -256,7 +256,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14) has_endpoint_dupe, lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]), lambda _: mask, - operand=None, + None, ) else: # If we don't have the idx attributes, we are forced to expand out. From 11521b2f77465a2f2cfb63abfd9a222f28ef5c89 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Sun, 25 Aug 2024 17:39:13 -0600 Subject: [PATCH 26/34] fix formatting from merge conflict --- desc/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 71f39f7489..53f95b753a 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -11,7 +11,6 @@ from desc.backend import fori_loop, jax, jit, jnp - PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text From c004724cc051b10f0ddf44b12d7a5e323cf8844d Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Mon, 26 Aug 2024 17:21:13 -0600 Subject: [PATCH 27/34] add static_attrs, update test --- desc/objectives/_generic.py | 3 ++- tests/test_examples.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index ec4cb76a0e..477b683c1d 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -77,7 +77,8 @@ class ExternalObjective(_Objective, ABC): """ _units = "(Unknown)" - _print_value_fmt = "External objective value: {:10.3e}" + _print_value_fmt = "External objective value: " + _static_attrs = ["_fun_wrapped", "_kwargs"] def __init__( self, diff --git a/tests/test_examples.py b/tests/test_examples.py index ef3080dfa5..e97441ffc0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -71,7 +71,6 @@ ) from desc.optimize import Optimizer from desc.profiles import FourierZernikeProfile, PowerSeriesProfile -from desc.vmec import VMECIO from .utils import area_difference_desc, area_difference_vmec @@ -1568,8 +1567,30 @@ def test_external_vs_generic_objectives(tmpdir_factory): """Test ExternalObjective compared to GenericObjective.""" target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2] - def data_from_vmec(eq, path=""): - VMECIO.save(eq, path, surfs=8, verbose=0) + def data_from_vmec(eq, path="", surfs=8): + # write data + file = Dataset(path, mode="w", format="NETCDF3_64BIT_OFFSET") + NFP = eq.NFP + M = eq.M + N = eq.N + M_nyq = M + 4 + N_nyq = N + 2 if N > 0 else 0 + s_full = np.linspace(0, 1, surfs) + r_full = np.sqrt(s_full) + file.createDimension("radius", surfs) + grid_full = LinearGrid(M=M_nyq, N=N_nyq, NFP=NFP, rho=r_full) + data_full = eq.compute(["p"], grid=grid_full) + data_quad = eq.compute(["_vol", "_vol", "_vol"]) + betatotal = file.createVariable("betatotal", np.float64) + betatotal[:] = data_quad["_vol"] + betapol = file.createVariable("betapol", np.float64) + betapol[:] = data_quad["_vol"] + betator = file.createVariable("betator", np.float64) + betator[:] = data_quad["_vol"] + presf = file.createVariable("presf", np.float64, ("radius",)) + presf[:] = grid_full.compress(data_full["p"]) + file.close() + # read data file = Dataset(path, mode="r") betatot = float(file.variables["betatotal"][0]) betapol = float(file.variables["betapol"][0]) @@ -1616,7 +1637,9 @@ def data_from_vmec(eq, path=""): dir = tmpdir_factory.mktemp("results") path = dir.join("wout_result.nc") objective = ObjectiveFunction( - ExternalObjective(eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path) + ExternalObjective( + eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path, surfs=8 + ) ) constraints = FixParameters( eq0, From ba1a252e27cfde6ba24baebc27d46b62f4c402d4 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 12 Nov 2024 12:56:47 -0700 Subject: [PATCH 28/34] update depricated jax.pure_callback vmap arg --- desc/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 5b6ce947a4..95089d26b5 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -799,7 +799,7 @@ def wrapper(*args, **kwargs): func, result_shape_dtype, *args, - vectorized=vectorized, + vmap_method="legacy_vectorized", # TODO: use "expand_dims" instead? **kwargs, ) From bb8a53521e9474e47716b7a6311ddf69c4072d02 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 12 Nov 2024 17:27:08 -0500 Subject: [PATCH 29/34] update vmap_method --- desc/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 95089d26b5..ff9f08dc95 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -799,7 +799,7 @@ def wrapper(*args, **kwargs): func, result_shape_dtype, *args, - vmap_method="legacy_vectorized", # TODO: use "expand_dims" instead? + vmap_method="expand_dims" if vectorized else "sequential", **kwargs, ) From 795350db3c18b56a13837dad993574ad36e65f85 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 12 Dec 2024 09:56:20 -0700 Subject: [PATCH 30/34] remove duplicate line from merge conflict --- desc/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/desc/utils.py b/desc/utils.py index 5f212f79ad..d276c75e4e 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -848,9 +848,6 @@ def atleast_2d_end(ary): return ary[:, jnp.newaxis] if ary.ndim == 1 else ary -PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text - - def dot(a, b, axis=-1): """Batched vector dot product. From 6fef120c7f89f4eed61ce4ee346550b2ef75bacd Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 12 Dec 2024 13:28:08 -0700 Subject: [PATCH 31/34] fix test with block_until_ready --- desc/optimize/least_squares.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 56e7d6e0ba..73843d4a59 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -178,7 +178,7 @@ def lsqtr( # noqa: C901 f = fun(x, *args) nfev += 1 cost = 0.5 * jnp.dot(f, f) - J = jac(x, *args) + J = jac(x, *args).block_until_ready() # FIXME: block is needed for jaxify util njev += 1 g = jnp.dot(J.T, f) From 24dd2f39819e47ea2911601d1b2f793b0c230748 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 17 Dec 2024 11:13:14 -0700 Subject: [PATCH 32/34] update documentation --- CHANGELOG.md | 1 + desc/objectives/_generic.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 348e970cd9..4d4321bb0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ New Feature - Adds an option ``scaled_termination`` (defaults to True) to all of the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``. - ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface. - Adds a new objective ``desc.objectives.MirrorRatio`` for targeting a particular mirror ratio on each flux surface, for either an ``Equilibrium`` or ``OmnigenousField``. +- Adds a new objective ``desc.objectives.ExternalObjective`` for wrapping external codes with finite differences. Bug Fixes diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index cf5f66c311..8570c23893 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -25,17 +25,17 @@ class ExternalObjective(_Objective, ABC): computed with finite differences instead of AD. The function does not need not be JAX transformable. - The user supplied function must take an Equilibrium as its only positional argument, - but can take additional keyword arguments. + The user supplied function must take an Equilibrium or a list of Equilibria as its + only positional argument, but can take additional keyword arguments. Parameters ---------- eq : Equilibrium Equilibrium that will be optimized to satisfy the Objective. fun : callable - External objective function. It must take an Equilibrium as its only positional - argument, but can take additional kewyord arguments. It does not need to be JAX - transformable. + External objective function. It must take an Equilibrium or list of Equilibria + as its only positional argument, but can take additional kewyord arguments. + It does not need to be JAX transformable. dim_f : int Dimension of the output of ``fun``. target : {float, ndarray}, optional @@ -60,7 +60,8 @@ class ExternalObjective(_Objective, ABC): is called on the raw compute value, before any shifting, scaling, or normalization. vectorized : bool, optional - Whether or not ``fun`` is vectorized. Default = False. + Set to False if ``fun`` takes a single Equilibrium as its positional argument. + Set to True if ``fun`` instead takes a list of Equilibria. Default = False. abs_step : float, optional Absolute finite difference step size. Default = 1e-4. Total step size is ``abs_step + rel_step * mean(abs(x))``. @@ -72,8 +73,6 @@ class ExternalObjective(_Objective, ABC): kwargs : any, optional Keyword arguments that are passed as inputs to ``fun``. - # TODO: add example - """ _units = "(Unknown)" From ac1aa637a2907bd84192a60efc08ace88c8a59ed Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Wed, 18 Dec 2024 13:18:37 -0700 Subject: [PATCH 33/34] make vectorized a required arg --- desc/objectives/_generic.py | 8 ++++---- tests/test_examples.py | 8 +++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 8570c23893..737d684379 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -38,6 +38,9 @@ class ExternalObjective(_Objective, ABC): It does not need to be JAX transformable. dim_f : int Dimension of the output of ``fun``. + vectorized : bool + Set to False if ``fun`` takes a single Equilibrium as its positional argument. + Set to True if ``fun`` instead takes a list of Equilibria. target : {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. @@ -59,9 +62,6 @@ class ExternalObjective(_Objective, ABC): Loss function to apply to the objective values once computed. This loss function is called on the raw compute value, before any shifting, scaling, or normalization. - vectorized : bool, optional - Set to False if ``fun`` takes a single Equilibrium as its positional argument. - Set to True if ``fun`` instead takes a list of Equilibria. Default = False. abs_step : float, optional Absolute finite difference step size. Default = 1e-4. Total step size is ``abs_step + rel_step * mean(abs(x))``. @@ -84,13 +84,13 @@ def __init__( eq, fun, dim_f, + vectorized, target=None, bounds=None, weight=1, normalize=False, normalize_target=False, loss_function=None, - vectorized=False, abs_step=1e-4, rel_step=0, name="external", diff --git a/tests/test_examples.py b/tests/test_examples.py index 932bea3741..955573e388 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1972,7 +1972,13 @@ def data_from_vmec(eq, path="", surfs=8): path = dir.join("wout_result.nc") objective = ObjectiveFunction( ExternalObjective( - eq=eq0, fun=data_from_vmec, dim_f=4, target=target, path=path, surfs=8 + eq=eq0, + fun=data_from_vmec, + dim_f=4, + vectorized=False, + target=target, + path=path, + surfs=8, ) ) constraints = FixParameters( From 4db5c9f73ee894d1d267eb3b33fdf099c888e44c Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Thu, 19 Dec 2024 14:17:58 -0700 Subject: [PATCH 34/34] make ExternalObjective args keyword only --- desc/objectives/_generic.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 737d684379..5d5a25473b 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -41,40 +41,21 @@ class ExternalObjective(_Objective, ABC): vectorized : bool Set to False if ``fun`` takes a single Equilibrium as its positional argument. Set to True if ``fun`` instead takes a list of Equilibria. - target : {float, ndarray}, optional - Target value(s) of the objective. Only used if bounds is None. - Must be broadcastable to Objective.dim_f. Defaults to ``target=0``. - bounds : tuple of {float, ndarray}, optional - Lower and upper bounds on the objective. Overrides target. - Both bounds must be broadcastable to to Objective.dim_f. - Defaults to ``target=0``. - weight : {float, ndarray}, optional - Weighting to apply to the Objective, relative to other Objectives. - Must be broadcastable to to Objective.dim_f - normalize : bool, optional - Whether to compute the error in physical units or non-dimensionalize. - Has no effect for this objective. - normalize_target : bool, optional - Whether target and bounds should be normalized before comparing to computed - values. If `normalize` is `True` and the target is in physical units, - this should also be set to True. - loss_function : {None, 'mean', 'min', 'max'}, optional - Loss function to apply to the objective values once computed. This loss function - is called on the raw compute value, before any shifting, scaling, or - normalization. abs_step : float, optional Absolute finite difference step size. Default = 1e-4. Total step size is ``abs_step + rel_step * mean(abs(x))``. rel_step : float, optional Relative finite difference step size. Default = 0. Total step size is ``abs_step + rel_step * mean(abs(x))``. - name : str, optional - Name of the objective function. kwargs : any, optional Keyword arguments that are passed as inputs to ``fun``. """ + __doc__ = __doc__.rstrip() + collect_docs( + target_default="``target=0``.", bounds_default="``target=0``." + ) + _units = "(Unknown)" _print_value_fmt = "External objective value: " _static_attrs = ["_fun_wrapped", "_kwargs"] @@ -82,17 +63,18 @@ class ExternalObjective(_Objective, ABC): def __init__( self, eq, + *, fun, dim_f, vectorized, + abs_step=1e-4, + rel_step=0, target=None, bounds=None, weight=1, normalize=False, normalize_target=False, loss_function=None, - abs_step=1e-4, - rel_step=0, name="external", **kwargs, ):