diff --git a/CHANGELOG.md b/CHANGELOG.md index d39af8d57..06f391829 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ New Feature - ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface. - Adds a new objective ``desc.objectives.MirrorRatio`` for targeting a particular mirror ratio on each flux surface, for either an ``Equilibrium`` or ``OmnigenousField``. - Adds the output quantities ``wb`` and ``wp`` to ``VMECIO.save``. +- Adds a new objective ``desc.objectives.ExternalObjective`` for wrapping external codes with finite differences. Bug Fixes diff --git a/desc/backend.py b/desc/backend.py index 3704cb0bb..1ae84c338 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -1,6 +1,7 @@ """Backend functions for DESC, with options for JAX or regular numpy.""" import functools +import multiprocessing as mp import os import warnings @@ -11,15 +12,19 @@ from desc import config as desc_config from desc import set_device +# only print details in the main process, not child processes spawned by multiprocessing +verbose = bool(mp.current_process().name == "MainProcess") + if os.environ.get("DESC_BACKEND") == "numpy": jnp = np use_jax = False set_device(kind="cpu") - print( - "DESC version {}, using numpy backend, version={}, dtype={}".format( - desc.__version__, np.__version__, np.linspace(0, 1).dtype + if verbose: + print( + "DESC version {}, using numpy backend, version={}, dtype={}".format( + desc.__version__, np.__version__, np.linspace(0, 1).dtype + ) ) - ) else: if desc_config.get("device") is None: set_device("cpu") @@ -41,11 +46,12 @@ x = jnp.linspace(0, 5) y = jnp.exp(x) use_jax = True - print( - f"DESC version {desc.__version__}," - + f"using JAX backend, jax version={jax.__version__}, " - + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" - ) + if verbose: + print( + f"DESC version {desc.__version__}, " + + f"using JAX backend, jax version={jax.__version__}, " + + f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}" + ) del x, y except ModuleNotFoundError: jnp = np @@ -59,11 +65,13 @@ desc.__version__, np.__version__, y.dtype ) ) -print( - "Using device: {}, with {:.2f} GB available memory".format( - desc_config.get("device"), desc_config.get("avail_mem") + +if verbose: + print( + "Using device: {}, with {:.2f} GB available memory".format( + desc_config.get("device"), desc_config.get("avail_mem") + ) ) -) if use_jax: # noqa: C901 from jax import custom_jvp, jit, vmap @@ -588,7 +596,7 @@ def fori_loop(lower, upper, body_fun, init_val): val = body_fun(i, val) return val - def cond(pred, true_fun, false_fun, *operand): + def cond(pred, true_fun, false_fun, *operands): """Conditionally apply true_fun or false_fun. This version is for the numpy backend, for jax backend see jax.lax.cond @@ -601,7 +609,7 @@ def cond(pred, true_fun, false_fun, *operand): Function (A -> B), to be applied if pred is True. false_fun: callable Function (A -> B), to be applied if pred is False. - operand: any + operands: any input to either branch depending on pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof. @@ -614,9 +622,9 @@ def cond(pred, true_fun, false_fun, *operand): """ if pred: - return true_fun(*operand) + return true_fun(*operands) else: - return false_fun(*operand) + return false_fun(*operands) def switch(index, branches, operand): """Apply exactly one of branches given by index. diff --git a/desc/integrals/surface_integral.py b/desc/integrals/surface_integral.py index 54cfdabe1..f2cb674f6 100644 --- a/desc/integrals/surface_integral.py +++ b/desc/integrals/surface_integral.py @@ -254,7 +254,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14) has_endpoint_dupe, lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]), lambda _: mask, - operand=None, + None, ) else: # If we don't have the idx attributes, we are forced to expand out. diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index 680b4f083..8cbd4b4d7 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -25,7 +25,12 @@ RadialForceBalance, ) from ._free_boundary import BoundaryError, VacuumBoundaryError -from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser +from ._generic import ( + ExternalObjective, + GenericObjective, + LinearObjectiveFromUser, + ObjectiveFromUser, +) from ._geometry import ( AspectRatio, BScaleLength, diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 8e16e61cf..5d5a25473 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,6 +2,7 @@ import inspect import re +from abc import ABC import numpy as np @@ -11,12 +12,160 @@ from desc.compute.utils import _parse_parameterization, get_profiles, get_transforms from desc.grid import QuadratureGrid from desc.optimizable import OptimizableCollection -from desc.utils import errorif, parse_argname_change +from desc.utils import errorif, jaxify, parse_argname_change from .linear_objectives import _FixedObjective from .objective_funs import _Objective, collect_docs +class ExternalObjective(_Objective, ABC): + """Wrap an external code. + + Similar to ``ObjectiveFromUser``, except derivatives of the objective function are + computed with finite differences instead of AD. The function does not need not be + JAX transformable. + + The user supplied function must take an Equilibrium or a list of Equilibria as its + only positional argument, but can take additional keyword arguments. + + Parameters + ---------- + eq : Equilibrium + Equilibrium that will be optimized to satisfy the Objective. + fun : callable + External objective function. It must take an Equilibrium or list of Equilibria + as its only positional argument, but can take additional kewyord arguments. + It does not need to be JAX transformable. + dim_f : int + Dimension of the output of ``fun``. + vectorized : bool + Set to False if ``fun`` takes a single Equilibrium as its positional argument. + Set to True if ``fun`` instead takes a list of Equilibria. + abs_step : float, optional + Absolute finite difference step size. Default = 1e-4. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + rel_step : float, optional + Relative finite difference step size. Default = 0. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + kwargs : any, optional + Keyword arguments that are passed as inputs to ``fun``. + + """ + + __doc__ = __doc__.rstrip() + collect_docs( + target_default="``target=0``.", bounds_default="``target=0``." + ) + + _units = "(Unknown)" + _print_value_fmt = "External objective value: " + _static_attrs = ["_fun_wrapped", "_kwargs"] + + def __init__( + self, + eq, + *, + fun, + dim_f, + vectorized, + abs_step=1e-4, + rel_step=0, + target=None, + bounds=None, + weight=1, + normalize=False, + normalize_target=False, + loss_function=None, + name="external", + **kwargs, + ): + if target is None and bounds is None: + target = 0 + self._eq = eq.copy() + self._fun = fun + self._dim_f = dim_f + self._vectorized = vectorized + self._abs_step = abs_step + self._rel_step = rel_step + self._kwargs = kwargs + super().__init__( + things=eq, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode="fwd", + name=name, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + self._scalar = self._dim_f == 1 + self._constants = {"quad_weights": 1.0} + + def fun_wrapped(params): + """Wrap external function with possibly vectorized params.""" + # number of equilibria for vectorized computations + param_shape = params["Psi"].shape + num_eq = param_shape[0] if len(param_shape) > 1 else 1 + + # convert params to list of equilibria + eqs = [self._eq.copy() for _ in range(num_eq)] + for k, eq in enumerate(eqs): + # update equilibria with new params + for param_key in self._eq.optimizable_params: + param_value = np.atleast_2d(params[param_key])[k, :] + if len(param_value): + setattr(eq, param_key, param_value) + + # call external function on equilibrium or list of equilibria + if not self._vectorized: + eqs = eqs[0] + return self._fun(eqs, **self._kwargs) + + # wrap external function to work with JAX + abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) + self._fun_wrapped = jaxify( + fun_wrapped, + abstract_eval, + vectorized=self._vectorized, + abs_step=self._abs_step, + rel_step=self._rel_step, + ) + + super().build(use_jit=use_jit, verbose=verbose) + + def compute(self, params, constants=None): + """Compute the quantity. + + Parameters + ---------- + params : list of dict + List of dictionaries of degrees of freedom, eg CoilSet.params_dict + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self.constants + + Returns + ------- + f : ndarray + Computed quantity. + + """ + f = self._fun_wrapped(params) + return f + + class GenericObjective(_Objective): """A generic objective that can compute any quantity from the `data_index`. @@ -295,10 +444,9 @@ class ObjectiveFromUser(_Objective): def myfun(grid, data): # This will compute the flux surface average of the function # R*B_T from the Grad-Shafranov equation - f = data['R']*data['B_phi'] + f = data['R'] * data['B_phi'] f_fsa = surface_averages(grid, f, sqrt_g=data['sqrt_g']) - # this has the FSA values on the full grid, but we just want - # the unique values: + # this is the FSA on the full grid, but we only want the unique values: return grid.compress(f_fsa) myobj = ObjectiveFromUser(fun=myfun, thing=eq) @@ -359,6 +507,8 @@ def build(self, use_jit=True, verbose=1): Level of output. """ + import jax + thing = self.things[0] if self._grid is None: errorif( @@ -389,7 +539,6 @@ def get_vars(fun): ).squeeze() self._fun_wrapped = lambda data: self._fun(grid, data) - import jax self._dim_f = jax.eval_shape(self._fun_wrapped, dummy_data).size self._scalar = self._dim_f == 1 diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 56e7d6e0b..73843d4a5 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -178,7 +178,7 @@ def lsqtr( # noqa: C901 f = fun(x, *args) nfev += 1 cost = 0.5 * jnp.dot(f, f) - J = jac(x, *args) + J = jac(x, *args).block_until_ready() # FIXME: block is needed for jaxify util njev += 1 g = jnp.dot(J.T, f) diff --git a/desc/utils.py b/desc/utils.py index 134108c9c..d276c75e4 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -9,7 +9,9 @@ from scipy.special import factorial from termcolor import colored -from desc.backend import flatnonzero, fori_loop, jit, jnp, take +from desc.backend import flatnonzero, fori_loop, jax, jit, jnp, take + +PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text class Timer: @@ -743,6 +745,97 @@ def atleast_nd(ndmin, ary): return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary +def jaxify(func, abstract_eval, vectorized=False, abs_step=1e-4, rel_step=0): + """Make an external (python) function work with JAX. + + Positional arguments to func can be differentiated, + use keyword args for static values and non-differentiable stuff. + + Note: Only forward mode differentiation is supported currently. + + Parameters + ---------- + func : callable + Function to wrap. Should be a "pure" function, in that it has no side + effects and doesn't maintain state. Does not need to be JAX transformable. + abstract_eval : callable + Auxilliary function that computes the output shape and dtype of func. + **Must be JAX transformable**. Should be of the form + + abstract_eval(*args, **kwargs) -> Pytree with same shape and dtype as + func(*args, **kwargs) + + For example, if func always returns a scalar: + + abstract_eval = lambda *args, **kwargs: jnp.array(1.) + + Or if func takes an array of shape(n) and returns a dict of arrays of + shape(n-2): + + abstract_eval = lambda arr, **kwargs: + {"out1": jnp.empty(arr.size-2), "out2": jnp.empty(arr.size-2)} + vectorized : bool, optional + Whether or not the wrapped function is vectorized. Default = False. + abs_step : float, optional + Absolute finite difference step size. Default = 1e-4. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + rel_step : float, optional + Relative finite difference step size. Default = 0. + Total step size is ``abs_step + rel_step * mean(abs(x))``. + + Returns + ------- + func : callable + New function that behaves as func but works with jit/vmap/jacfwd etc. + + """ + + def wrap_pure_callback(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result_shape_dtype = abstract_eval(*args, **kwargs) + return jax.pure_callback( + func, + result_shape_dtype, + *args, + vmap_method="expand_dims" if vectorized else "sequential", + **kwargs, + ) + + return wrapper + + def define_fd_jvp(func): + func = jax.custom_jvp(func) + + @func.defjvp + def func_jvp(primals, tangents): + primal_out = func(*primals) + + # flatten everything into 1D vectors for easier finite differences + y, unflaty = jax.flatten_util.ravel_pytree(primal_out) + x, unflatx = jax.flatten_util.ravel_pytree(primals) + v, _______ = jax.flatten_util.ravel_pytree(tangents) + + # finite difference step size + fd_step = abs_step + rel_step * jnp.mean(jnp.abs(x)) + + # scale tangents to unit norm if nonzero + normv = jnp.linalg.norm(v) + vh = jnp.where(normv == 0, v, v / normv) + + def f(x): + return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0] + + tangent_out = (f(x + fd_step * vh) - y) / fd_step * normv + tangent_out = unflaty(tangent_out) + + return primal_out, tangent_out + + return func + + return define_fd_jvp(wrap_pure_callback(func)) + + def atleast_3d_mid(ary): """Like np.atleast_3d but if adds dim at axis 1 for 2d arrays.""" ary = jnp.atleast_2d(ary) @@ -755,9 +848,6 @@ def atleast_2d_end(ary): return ary[:, jnp.newaxis] if ary.ndim == 1 else ary -PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text - - def dot(a, b, axis=-1): """Batched vector dot product. diff --git a/tests/test_examples.py b/tests/test_examples.py index b43245e28..955573e38 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -6,6 +6,7 @@ import numpy as np import pytest +from netCDF4 import Dataset from qic import Qic from qsc import Qsc from scipy.constants import mu_0 @@ -43,6 +44,7 @@ CoilSetMinDistance, CoilTorsion, CurrentDensity, + ExternalObjective, FixBoundaryR, FixBoundaryZ, FixCoilCurrent, @@ -1893,6 +1895,117 @@ def circle_constraint(params): ) +@pytest.mark.unit +@pytest.mark.slow +def test_external_vs_generic_objectives(tmpdir_factory): + """Test ExternalObjective compared to GenericObjective.""" + target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2] + + def data_from_vmec(eq, path="", surfs=8): + # write data + file = Dataset(path, mode="w", format="NETCDF3_64BIT_OFFSET") + NFP = eq.NFP + M = eq.M + N = eq.N + M_nyq = M + 4 + N_nyq = N + 2 if N > 0 else 0 + s_full = np.linspace(0, 1, surfs) + r_full = np.sqrt(s_full) + file.createDimension("radius", surfs) + grid_full = LinearGrid(M=M_nyq, N=N_nyq, NFP=NFP, rho=r_full) + data_full = eq.compute(["p"], grid=grid_full) + data_quad = eq.compute(["_vol", "_vol", "_vol"]) + betatotal = file.createVariable("betatotal", np.float64) + betatotal[:] = data_quad["_vol"] + betapol = file.createVariable("betapol", np.float64) + betapol[:] = data_quad["_vol"] + betator = file.createVariable("betator", np.float64) + betator[:] = data_quad["_vol"] + presf = file.createVariable("presf", np.float64, ("radius",)) + presf[:] = grid_full.compress(data_full["p"]) + file.close() + # read data + file = Dataset(path, mode="r") + betatot = float(file.variables["betatotal"][0]) + betapol = float(file.variables["betapol"][0]) + betator = float(file.variables["betator"][0]) + presf1 = float(file.variables["presf"][-1]) + file.close() + return np.atleast_1d([betatot, betapol, betator, presf1]) + + eq0 = get("SOLOVEV") + optimizer = Optimizer("lsq-exact") + + # generic + objective = ObjectiveFunction( + ( + GenericObjective("_vol", thing=eq0, target=target[0]), + GenericObjective("_vol", thing=eq0, target=target[1]), + GenericObjective("_vol", thing=eq0, target=target[2]), + GenericObjective( + "p", thing=eq0, target=0, grid=LinearGrid(rho=[1], M=0, N=0) + ), + ) + ) + constraints = FixParameters( + eq0, + { + "R_lmn": True, + "Z_lmn": True, + "L_lmn": True, + "p_l": np.arange(2, len(eq0.p_l)), + "i_l": True, + "Psi": True, + }, + ) + [eq_generic], _ = optimizer.optimize( + things=eq0, + objective=objective, + constraints=constraints, + copy=True, + ftol=0, + verbose=2, + ) + + # external + dir = tmpdir_factory.mktemp("results") + path = dir.join("wout_result.nc") + objective = ObjectiveFunction( + ExternalObjective( + eq=eq0, + fun=data_from_vmec, + dim_f=4, + vectorized=False, + target=target, + path=path, + surfs=8, + ) + ) + constraints = FixParameters( + eq0, + { + "R_lmn": True, + "Z_lmn": True, + "L_lmn": True, + "p_l": np.arange(2, len(eq0.p_l)), + "i_l": True, + "Psi": True, + }, + ) + [eq_external], _ = optimizer.optimize( + things=eq0, + objective=objective, + constraints=constraints, + copy=True, + ftol=0, + verbose=2, + ) + + np.testing.assert_allclose(eq_generic.p_l, eq_external.p_l) + np.testing.assert_allclose(eq_generic.p_l[:2], [2e2, -2e2], rtol=4e-2) + np.testing.assert_allclose(eq_external.p_l[:2], [2e2, -2e2], rtol=4e-2) + + @pytest.mark.unit @pytest.mark.optimize def test_coil_arclength_optimization(): diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 47a06db06..9111a60bb 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -53,6 +53,7 @@ EffectiveRipple, Elongation, Energy, + ExternalObjective, ForceBalance, ForceBalanceAnisotropic, FusionPower, @@ -2676,7 +2677,8 @@ class TestComputeScalarResolution: VacuumBoundaryError, # need to avoid blowup near the axis MercierStability, - # don't test these since they depend on what user wants + # we do not test these since they depend too much on what the user wants + ExternalObjective, LinearObjectiveFromUser, ObjectiveFromUser, ] @@ -3156,7 +3158,8 @@ class TestObjectiveNaNGrad: SurfaceQuadraticFlux, ToroidalFlux, VacuumBoundaryError, - # we don't test these since they depend too much on what exactly the user wants + # we do not test these since they depend too much on what the user wants + ExternalObjective, GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser, diff --git a/tests/test_utils.py b/tests/test_utils.py index 2812e8a01..ca8330f4c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,9 +5,9 @@ import numpy as np import pytest -from desc.backend import flatnonzero, jnp, tree_leaves, tree_structure +from desc.backend import flatnonzero, jax, jnp, tree_leaves, tree_structure from desc.grid import LinearGrid -from desc.utils import broadcast_tree, isalmostequal, islinspaced, take_mask +from desc.utils import broadcast_tree, isalmostequal, islinspaced, jaxify, take_mask @pytest.mark.unit @@ -201,6 +201,43 @@ def test_broadcast_tree(): np.testing.assert_allclose(leaf, leaf_correct) +@pytest.mark.unit +def test_jaxify(): + """Test that jaxify gives accurate finite difference derivatives.""" + + def fun(x): + """Function that is not JAX transformable.""" + return np.sin(x) + 2 * np.cos(x) + 3 * x**2 - 4 * x - 5 + + x = np.linspace(0, 2 * np.pi, 45) + f = fun(x) + df_true = np.cos(x) - 2 * np.sin(x) + 6 * x - 4 + + # finite differences with a range of abs and rel step sizes + abstract_eval = lambda *args, **kwargs: jnp.empty(f.size) + fun_jax_abs1 = jaxify(fun, abstract_eval, abs_step=1e-2) + fun_jax_abs2 = jaxify(fun, abstract_eval, abs_step=1e-3) + fun_jax_abs3 = jaxify(fun, abstract_eval, abs_step=1e-4) + fun_jax_rel1 = jaxify(fun, abstract_eval, rel_step=1e-3) + fun_jax_rel2 = jaxify(fun, abstract_eval, rel_step=1e-4) + fun_jax_rel3 = jaxify(fun, abstract_eval, rel_step=1e-5) + + df_abs1 = np.diagonal(jax.jacfwd(fun_jax_abs1)(x)) + df_abs2 = np.diagonal(jax.jacfwd(fun_jax_abs2)(x)) + df_abs3 = np.diagonal(jax.jacfwd(fun_jax_abs3)(x)) + df_rel1 = np.diagonal(jax.jacfwd(fun_jax_rel1)(x)) + df_rel2 = np.diagonal(jax.jacfwd(fun_jax_rel2)(x)) + df_rel3 = np.diagonal(jax.jacfwd(fun_jax_rel3)(x)) + + # convergence test: smaller step sizes should be more accurate + np.testing.assert_allclose(df_abs1, df_true, atol=5e-2) + np.testing.assert_allclose(df_abs2, df_true, atol=5e-3) + np.testing.assert_allclose(df_abs3, df_true, atol=5e-4) + np.testing.assert_allclose(df_rel1, df_true, rtol=2e-1) + np.testing.assert_allclose(df_rel2, df_true, rtol=2e-2) + np.testing.assert_allclose(df_rel3, df_true, rtol=3e-3) + + @partial(jnp.vectorize, signature="(m)->()") def _last_value(a): """Return the last non-nan value in ``a``."""