diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 5a11d39245..554cdac070 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -1,6 +1,7 @@ """Functions and methods for saving and loading equilibria and other objects.""" import copy +import functools import os import pickle import pydoc @@ -86,7 +87,9 @@ def _unjittable(x): return any([_unjittable(y) for y in x.values()]) if hasattr(x, "dtype") and np.ndim(x) == 0: return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_) - return isinstance(x, (str, types.FunctionType, bool, int, np.int_)) + return isinstance( + x, (str, types.FunctionType, functools.partial, bool, int, np.int_) + ) def _make_hashable(x): diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index cde5ec78e3..5905802f89 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -39,8 +39,8 @@ def update_target(self, thing): assert len(new_target) == len(self.target) self.target = new_target self._target_from_user = self.target # in case the Objective is re-built - if self._use_jit: - self.jit() + if not self._use_jit: + self._unjit() def _parse_target_from_user( self, target_from_user, default_target, default_bounds, idx @@ -232,8 +232,8 @@ def update_target(self, thing): """ self.target = self.compute(thing.params_dict) - if self._use_jit: - self.jit() + if not self._use_jit: + self._unjit() class BoundaryRSelfConsistency(_Objective): @@ -3184,6 +3184,7 @@ class FixNearAxisR(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "R_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" @@ -3320,6 +3321,7 @@ class FixNearAxisZ(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "Z_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" @@ -3462,6 +3464,7 @@ class FixNearAxisLambda(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "L_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(dimensionless)" diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 8f7843a2cf..59734aa261 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1,7 +1,7 @@ """Base classes for objectives.""" +import functools from abc import ABC, abstractmethod -from functools import partial import numpy as np @@ -9,7 +9,14 @@ from desc.derivatives import Derivative from desc.io import IOAble from desc.optimizable import Optimizable -from desc.utils import Timer, flatten_list, is_broadcastable, setdefault, unique_list +from desc.utils import ( + Timer, + errorif, + flatten_list, + is_broadcastable, + setdefault, + unique_list, +) class ObjectiveFunction(IOAble): @@ -57,68 +64,15 @@ def __init__( self._name = name def _set_derivatives(self): - """Set up derivatives of the objective functions.""" + """Choose derivative mode based on mode of sub-objectives.""" if self._deriv_mode == "auto": if all((obj._deriv_mode == "fwd") for obj in self.objectives): self._deriv_mode = "batched" else: self._deriv_mode = "blocked" - if self._deriv_mode in {"batched", "looped", "blocked"}: - self._grad = Derivative(self.compute_scalar, mode="grad") - self._hess = Derivative(self.compute_scalar, mode="hess") - if self._deriv_mode == "batched": - self._jac_scaled = Derivative(self.compute_scaled, mode="fwd") - self._jac_scaled_error = Derivative(self.compute_scaled_error, mode="fwd") - self._jac_unscaled = Derivative(self.compute_unscaled, mode="fwd") - if self._deriv_mode == "looped": - self._jac_scaled = Derivative(self.compute_scaled, mode="looped") - self._jac_scaled_error = Derivative( - self.compute_scaled_error, mode="looped" - ) - self._jac_unscaled = Derivative(self.compute_unscaled, mode="looped") - if self._deriv_mode == "blocked": - # could also do something similar for grad and hess, but probably not - # worth it. grad is already super cheap to eval all at once, and blocked - # hess would only be block diag which may miss important interactions. - - def jac_(op, x, constants=None): - if constants is None: - constants = self.constants - xs_splits = np.cumsum([t.dim_x for t in self.things]) - xs = jnp.split(x, xs_splits) - J = [] - for obj, const in zip(self.objectives, constants): - # get the xs that go to that objective - xi = [x for x, t in zip(xs, self.things) if t in obj.things] - Ji_ = getattr(obj, op)( - *xi, constants=const - ) # jac wrt to just those things - Ji = [] # jac wrt all things - for thing in self.things: - if thing in obj.things: - i = obj.things.index(thing) - Ji += [Ji_[i]] - else: - Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] - Ji = jnp.hstack(Ji) - J += [Ji] - return jnp.vstack(J) - - self._jac_scaled = partial(jac_, "jac_scaled") - self._jac_scaled_error = partial(jac_, "jac_scaled_error") - self._jac_unscaled = partial(jac_, "jac_unscaled") - - def jit(self): # noqa: C901 - """Apply JIT to compute methods, or re-apply after updating self.""" - # can't loop here because del doesn't work on getattr - # main idea is that when jitting a method, jax replaces that method - # with a CompiledFunction object, with self compiled in. To re-jit - # (ie, after updating attributes of self), we just need to delete the jax - # CompiledFunction object, which will then leave the raw method in its place, - # and then jit the raw method with the new self - - self._use_jit = True + def _unjit(self): + """Remove jit compiled methods.""" methods = [ "compute_scaled", "compute_scaled_error", @@ -136,17 +90,13 @@ def jit(self): # noqa: C901 "vjp_scaled_error", "vjp_unscaled", ] - for method in methods: try: - delattr(self, method) + setattr( + self, method, functools.partial(getattr(self, method)._fun, self) + ) except AttributeError: pass - setattr(self, method, jit(getattr(self, method))) - - for obj in self._objectives: - if obj._use_jit: - obj.jit() @execute_on_cpu def build(self, use_jit=None, verbose=1): @@ -179,8 +129,8 @@ def build(self, use_jit=None, verbose=1): self._scalar = False self._set_derivatives() - if self.use_jit: - self.jit() + if not self.use_jit: + self._unjit() self._set_things() @@ -219,23 +169,18 @@ def _set_things(self, things=None): ) unique_, inds_ = unique_list(flat_) - def unflatten(unique): - assert len(unique) == len(unique_) - flat = [unique[i] for i in inds_] - return tree_unflatten(treedef_, flat) - - def flatten(things): - flat, treedef = tree_flatten( - things, is_leaf=lambda x: isinstance(x, Optimizable) + # this is needed to know which "thing" goes with which sub-objective, + # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] + self._things_per_objective_idx = [] + for obj in self.objectives: + self._things_per_objective_idx.append( + [i for i, t in enumerate(unique_) if t in obj.things] ) - assert treedef == treedef_ - assert len(flat) == len(flat_) - unique, _ = unique_list(flat) - return unique - self._unflatten = unflatten - self._flatten = flatten + self._unflatten = _ThingUnflattener(len(unique_), inds_, treedef_) + self._flatten = _ThingFlattener(len(flat_), treedef_) + @jit def compute_unscaled(self, x, constants=None): """Compute the raw value of the objective function. @@ -255,6 +200,7 @@ def compute_unscaled(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_unscaled(*par, constants=const) @@ -263,6 +209,7 @@ def compute_unscaled(self, x, constants=None): ) return f + @jit def compute_scaled(self, x, constants=None): """Compute the objective function and apply weighting and normalization. @@ -282,6 +229,7 @@ def compute_scaled(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_scaled(*par, constants=const) @@ -290,6 +238,7 @@ def compute_scaled(self, x, constants=None): ) return f + @jit def compute_scaled_error(self, x, constants=None): """Compute and apply the target/bounds, weighting, and normalization. @@ -309,6 +258,7 @@ def compute_scaled_error(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_scaled_error(*par, constants=const) @@ -317,6 +267,7 @@ def compute_scaled_error(self, x, constants=None): ) return f + @jit def compute_scalar(self, x, constants=None): """Compute the sum of squares error. @@ -355,6 +306,7 @@ def print_value(self, x, constants=None): f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2 print("Total (sum of squares): {:10.3e}, ".format(f)) params = self.unpack_state(x) + assert len(params) == len(constants) == len(self.objectives) for par, obj, const in zip(params, self.objectives, constants): obj.print_value(*par, constants=const) return None @@ -390,14 +342,17 @@ def unpack_state(self, x, per_objective=True): xs_splits = np.cumsum([t.dim_x for t in self.things]) xs = jnp.split(x, xs_splits) + xs = xs[: len(self.things)] # jnp.split returns an empty array at the end + assert len(xs) == len(self.things) params = [t.unpack_params(xi) for t, xi in zip(self.things, xs)] if per_objective: # params is a list of lists of dicts, for each thing and for each objective params = self._unflatten(params) # this filters out the params of things that are unused by each objective + assert len(params) == len(self._things_per_objective_idx) params = [ - [par for par, thing in zip(param, self.things) if thing in obj.things] - for param, obj in zip(params, self.objectives) + [param[i] for i in idx] + for param, idx in zip(params, self._things_per_objective_idx) ] return params @@ -405,39 +360,116 @@ def x(self, *things): """Return the full state vector from the Optimizable objects things.""" # TODO: also check resolution etc? things = things or self.things - assert all([type(t1) is type(t2) for t1, t2 in zip(things, self.things)]) + errorif( + len(things) != len(self.things), + ValueError, + "Got the wrong number of things, " + f"expected {len(self.things)} got {len(things)}", + ) + for t1, t2 in zip(things, self.things): + errorif( + not isinstance(t1, type(t2)), + TypeError, + f"got incompatible types between things {type(t1)} " + f"and self.things {type(t2)}", + ) xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) + @jit def grad(self, x, constants=None): """Compute gradient vector of self.compute_scalar wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_1d(self._grad(x, constants).squeeze()) + return jnp.atleast_1d( + Derivative(self.compute_scalar, mode="grad")(x, constants).squeeze() + ) + @jit def hess(self, x, constants=None): """Compute Hessian matrix of self.compute_scalar wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._hess(x, constants).squeeze()) + return jnp.atleast_2d( + Derivative(self.compute_scalar, mode="hess")(x, constants).squeeze() + ) + + def _jac_blocked(self, op, x, constants=None): + # could also do something similar for grad and hess, but probably not + # worth it. grad is already super cheap to eval all at once, and blocked + # hess would only be block diag which may miss important interactions. + if constants is None: + constants = self.constants + xs_splits = np.cumsum([t.dim_x for t in self.things]) + xs = jnp.split(x, xs_splits) + J = [] + assert len(self.objectives) == len(self.constants) + # basic idea is we compute the jacobian of each objective wrt each thing + # one by one, and assemble into big block matrix + # if objective doesn't depend on a given thing, that part is set to 0. + for k, (obj, const) in enumerate(zip(self.objectives, constants)): + # get the xs that go to that objective + thing_idx = self._things_per_objective_idx[k] + xi = [xs[i] for i in thing_idx] + Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things + Ji = [] # jac wrt all things + for i, thing in enumerate(self.things): + if i in thing_idx: # dfi/dxj != 0 + Ji += [Ji_[thing_idx.index(i)]] + else: # dfi/dxj == 0 + Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] + Ji = jnp.hstack(Ji) # something like [df1/dx1, df1/dx2, 0] + J += [Ji] + # something like [df1/dx1, df1/dx2, 0] + # [df2/dx1, 0, df2/dx3] # noqa:E800 + J = jnp.vstack(J) + return J + + @jit def jac_scaled(self, x, constants=None): """Compute Jacobian matrix of self.compute_scaled wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_scaled(x, constants).squeeze()) + if self._deriv_mode == "batched": + J = Derivative(self.compute_scaled, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_scaled, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac_blocked("jac_scaled", x, constants) + + return jnp.atleast_2d(J.squeeze()) + + @jit def jac_scaled_error(self, x, constants=None): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_scaled_error(x, constants).squeeze()) + if self._deriv_mode == "batched": + J = Derivative(self.compute_scaled_error, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_scaled_error, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac_blocked("jac_scaled_error", x, constants) + + return jnp.atleast_2d(J.squeeze()) + + @jit def jac_unscaled(self, x, constants=None): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_unscaled(x, constants).squeeze()) + + if self._deriv_mode == "batched": + J = Derivative(self.compute_unscaled, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_unscaled, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac_blocked("jac_unscaled", x, constants) + + return jnp.atleast_2d(J.squeeze()) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) @@ -457,6 +489,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"): else: raise NotImplementedError("Cannot compute JVP higher than 3rd order.") + @jit def jvp_scaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled. @@ -473,6 +506,7 @@ def jvp_scaled(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled") + @jit def jvp_scaled_error(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled_error. @@ -489,6 +523,7 @@ def jvp_scaled_error(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled_error") + @jit def jvp_unscaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_unscaled. @@ -509,6 +544,7 @@ def _vjp(self, v, x, constants=None, op="compute_scaled"): fun = lambda x: getattr(self, op)(x, constants) return Derivative.compute_vjp(fun, 0, v, x) + @jit def vjp_scaled(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_scaled. @@ -524,6 +560,7 @@ def vjp_scaled(self, v, x, constants=None): """ return self._vjp(v, x, constants, "compute_scaled") + @jit def vjp_scaled_error(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_scaled_error. @@ -539,6 +576,7 @@ def vjp_scaled_error(self, v, x, constants=None): """ return self._vjp(v, x, constants, "compute_scaled_error") + @jit def vjp_unscaled(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_unscaled. @@ -803,7 +841,7 @@ def __init__( self._normalization = 1 self._deriv_mode = deriv_mode self._name = name - self._use_jit = None + self._use_jit = True self._built = False self._loss_function = { "mean": jnp.mean, @@ -815,11 +853,7 @@ def __init__( self._things = flatten_list([things], True) def _set_derivatives(self): - """Set up derivatives of the objective wrt each argument.""" - argnums = tuple(range(len(self.things))) - # derivatives return tuple, one for each thing - self._grad = Derivative(self.compute_scalar, argnums, mode="grad") - self._hess = Derivative(self.compute_scalar, argnums, mode="hess") + """Choose derivative mode based on size of inputs/outputs.""" if self._deriv_mode == "auto": # choose based on shape of jacobian. fwd mode is more memory efficient # so we prefer that unless the jacobian is really wide @@ -828,20 +862,9 @@ def _set_derivatives(self): if self.dim_f >= 0.5 * sum(t.dim_x for t in self.things) else "rev" ) - self._jac_scaled = Derivative( - self.compute_scaled, argnums, mode=self._deriv_mode - ) - self._jac_scaled_error = Derivative( - self.compute_scaled_error, argnums, mode=self._deriv_mode - ) - self._jac_unscaled = Derivative( - self.compute_unscaled, argnums, mode=self._deriv_mode - ) - - def jit(self): # noqa: C901 - """Apply JIT to compute methods, or re-apply after updating self.""" - self._use_jit = True + def _unjit(self): + """Remove jit compiled methods.""" methods = [ "compute_scaled", "compute_scaled_error", @@ -850,16 +873,19 @@ def jit(self): # noqa: C901 "jac_scaled", "jac_scaled_error", "jac_unscaled", + "jvp_scaled", + "jvp_scaled_error", + "jvp_unscaled", "hess", "grad", ] - for method in methods: try: - delattr(self, method) + setattr( + self, method, functools.partial(getattr(self, method)._fun, self) + ) except AttributeError: pass - setattr(self, method, jit(getattr(self, method))) def _check_dimensions(self): """Check that len(target) = len(bounds) = len(weight) = dim_f.""" @@ -913,8 +939,8 @@ def build(self, use_jit=True, verbose=1): if use_jit is not None: self._use_jit = use_jit - if self._use_jit: - self.jit() + if not self._use_jit: + self._unjit() self._built = True @@ -924,6 +950,7 @@ def compute(self, *args, **kwargs): def _maybe_array_to_params(self, *args): argsout = tuple() + assert len(args) == len(self.things) for arg, thing in zip(args, self.things): if isinstance(arg, (np.ndarray, jnp.ndarray)): argsout += (thing.unpack_params(arg),) @@ -931,6 +958,7 @@ def _maybe_array_to_params(self, *args): argsout += (arg,) return argsout + @jit def compute_unscaled(self, *args, **kwargs): """Compute the raw value of the objective.""" args = self._maybe_array_to_params(*args) @@ -939,6 +967,7 @@ def compute_unscaled(self, *args, **kwargs): f = self._loss_function(f) return jnp.atleast_1d(f) + @jit def compute_scaled(self, *args, **kwargs): """Compute and apply weighting and normalization.""" args = self._maybe_array_to_params(*args) @@ -947,6 +976,7 @@ def compute_scaled(self, *args, **kwargs): f = self._loss_function(f) return jnp.atleast_1d(self._scale(f, **kwargs)) + @jit def compute_scaled_error(self, *args, **kwargs): """Compute and apply the target/bounds, weighting, and normalization.""" args = self._maybe_array_to_params(*args) @@ -989,6 +1019,7 @@ def _scale(self, f, *args, **kwargs): f_norm = jnp.atleast_1d(f) / self.normalization # normalization return f_norm * w * self.weight + @jit def compute_scalar(self, *args, **kwargs): """Compute the scalar form of the objective.""" if self.scalar: @@ -997,25 +1028,41 @@ def compute_scalar(self, *args, **kwargs): f = jnp.sum(self.compute_scaled_error(*args, **kwargs) ** 2) / 2 return f.squeeze() + @jit def grad(self, *args, **kwargs): """Compute gradient vector of self.compute_scalar wrt x.""" - return self._grad(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scalar, argnums, mode="grad")(*args, **kwargs) + @jit def hess(self, *args, **kwargs): """Compute Hessian matrix of self.compute_scalar wrt x.""" - return self._hess(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scalar, argnums, mode="hess")(*args, **kwargs) + @jit def jac_scaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled wrt x.""" - return self._jac_scaled(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scaled, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) + @jit def jac_scaled_error(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" - return self._jac_scaled_error(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scaled_error, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) + @jit def jac_unscaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" - return self._jac_unscaled(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_unscaled, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) @@ -1027,6 +1074,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"): sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) + @jit def jvp_scaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled. @@ -1042,6 +1090,7 @@ def jvp_scaled(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled") + @jit def jvp_scaled_error(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled_error. @@ -1057,6 +1106,7 @@ def jvp_scaled_error(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled_error") + @jit def jvp_unscaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_unscaled. @@ -1148,6 +1198,19 @@ def print_value(self, *args, **kwargs): def xs(self, *things): """Return a tuple of args required by this objective from optimizable things.""" things = things or self.things + errorif( + len(things) != len(self.things), + ValueError, + "Got the wrong number of things, " + f"expected {len(self.things)} got {len(things)}", + ) + for t1, t2 in zip(things, self.things): + errorif( + not isinstance(t1, type(t2)), + TypeError, + f"got incompatible types between things {type(t1)} " + f"and self.things {type(t2)}", + ) return tuple([t.params_dict for t in things]) @property @@ -1241,7 +1304,45 @@ def things(self, new): if not isinstance(new, (tuple, list)): new = [new] assert all(isinstance(x, Optimizable) for x in new) + assert len(new) == len(self.things) assert all(type(a) is type(b) for a, b in zip(new, self.things)) self._things = list(new) # can maybe improve this later to not rebuild if resolution is the same self._built = False + + +# local functions assigned as attributes aren't hashable so they cause stuff to +# recompile, so instead we define a hashable class to do the same thing. + + +class _ThingUnflattener(IOAble): + + _static_attrs = ["length", "inds", "treedef"] + + def __init__(self, length, inds, treedef): + self.length = length + self.inds = inds + self.treedef = treedef + + def __call__(self, unique): + assert len(unique) == self.length + flat = [unique[i] for i in self.inds] + return tree_unflatten(self.treedef, flat) + + +class _ThingFlattener(IOAble): + + _static_attrs = ["length", "treedef"] + + def __init__(self, length, treedef): + self.length = length + self.treedef = treedef + + def __call__(self, things): + flat, treedef = tree_flatten( + things, is_leaf=lambda x: isinstance(x, Optimizable) + ) + assert treedef == self.treedef + assert len(flat) == self.length + unique, _ = unique_list(flat) + return unique diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 5167573f51..ffaa507d3b 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -6,6 +6,7 @@ import numpy as np from desc.backend import cond, jit, jnp, logsumexp, put +from desc.io import IOAble from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif @@ -168,18 +169,8 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa Z = jnp.asarray(Z) D = jnp.asarray(D) - @jit - def project(x_full): - """Project a full state vector into the reduced optimization vector.""" - x_reduced = Z.T @ ((1 / D) * x_full - xp)[unfixed_idx] - return jnp.atleast_1d(jnp.squeeze(x_reduced)) - - @jit - def recover(x_reduced): - """Recover the full state vector from the reduced optimization vector.""" - dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced) - x_full = D * (xp + dx) - return jnp.atleast_1d(jnp.squeeze(x_full)) + project = _Project(Z, D, xp, unfixed_idx) + recover = _Recover(Z, D, xp, unfixed_idx, objective.dim_x) # check that all constraints are actually satisfiable params = objective.unpack_state(D * xp, False) @@ -227,6 +218,41 @@ def recover(x_reduced): return xp, A, b, Z, D, unfixed_idx, project, recover +class _Project(IOAble): + _io_attrs_ = ["Z", "D", "xp", "unfixed_idx"] + + def __init__(self, Z, D, xp, unfixed_idx): + self.Z = Z + self.D = D + self.xp = xp + self.unfixed_idx = unfixed_idx + + @jit + def __call__(self, x_full): + """Project a full state vector into the reduced optimization vector.""" + x_reduced = self.Z.T @ ((1 / self.D) * x_full - self.xp)[self.unfixed_idx] + return jnp.atleast_1d(jnp.squeeze(x_reduced)) + + +class _Recover(IOAble): + _io_attrs_ = ["Z", "D", "xp", "unfixed_idx", "dim_x"] + _static_attrs = ["dim_x"] + + def __init__(self, Z, D, xp, unfixed_idx, dim_x): + self.Z = Z + self.D = D + self.xp = xp + self.unfixed_idx = unfixed_idx + self.dim_x = dim_x + + @jit + def __call__(self, x_reduced): + """Recover the full state vector from the reduced optimization vector.""" + dx = put(jnp.zeros(self.dim_x), self.unfixed_idx, self.Z @ x_reduced) + x_full = self.D * (self.xp + dx) + return jnp.atleast_1d(jnp.squeeze(x_full)) + + def softmax(arr, alpha): """JAX softmax implementation. diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 8d4e480161..2ee6645742 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1020,21 +1020,6 @@ def jvp_unscaled(self, v, x, constants=None): jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="unscaled") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) - @functools.partial(jit, static_argnames=("self", "op")) - def _jvp_f(self, xf, dc, constants, op): - Fx = getattr(self._constraint, "jac_" + op)(xf, constants) - # TODO: replace with self._unfixed_idx_mat? - Fx_reduced = Fx @ jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z - Fc = Fx @ (self._dxdc @ dc) - Fxh = Fx_reduced - cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) - uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) - sf += sf[-1] # add a tiny bit of regularization - sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) - Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) - return Fxh_inv @ Fc - - @functools.partial(jit, static_argnames=("self", "op")) def _jvp(self, v, xf, xg, constants, op): # we're replacing stuff like this with jvps # Fx_reduced = Fx[:, unfixed_idx] @ Z # noqa: E800 @@ -1047,7 +1032,17 @@ def _jvp(self, v, xf, xg, constants, op): # want jvp_f to only get parts from equilibrium, not other things vs = jnp.split(v, np.cumsum(self._dimc_per_thing)) # this is Fx_reduced_inv @ Fc - dfdc = self._jvp_f(xf, vs[self._eq_idx], constants[1], op) + dfdc = _proximal_jvp_f_pure( + self._constraint, + xf, + constants[1], + vs[self._eq_idx], + self._unfixed_idx, + self._Z, + self._D, + self._dxdc, + op, + ) # broadcasting against multiple things dfdcs = [jnp.zeros(dim) for dim in self._dimc_per_thing] dfdcs[self._eq_idx] = dfdc @@ -1069,27 +1064,7 @@ def _jvp(self, v, xf, xg, constants, op): else: # deriv_mode == "blocked" vgs = jnp.split(tangent, np.cumsum(self._dimx_per_thing)) xgs = jnp.split(xg, np.cumsum(self._dimx_per_thing)) - out = [] - for obj, const in zip( - self._objective.objectives, self._objective.constants - ): - xi = [x for x, t in zip(xgs, self._objective.things) if t in obj.things] - vi = [v for v, t in zip(vgs, self._objective.things) if t in obj.things] - if obj._deriv_mode == "rev": - # obj might now allow fwd mode, so compute full rev mode jacobian - # and do matmul manually. This is slightly inefficient, but usually - # when rev mode is used, dim_f <<< dim_x, so its not too bad. - Ji = getattr(obj, "jac_" + op)(*xi, const) - outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum( - axis=0 - ) - out.append(outi) - else: - outi = getattr(obj, "jvp_" + op)( - [_vi for _vi in vi], xi, constants=const - ).T - out.append(outi) - out = jnp.concatenate(out) + out = _proximal_jvp_blocked_pure(self._objective, vgs, xgs, op) return -out @property @@ -1100,3 +1075,48 @@ def constants(self): def __getattr__(self, name): """For other attributes we defer to the base objective.""" return getattr(self._objective, name) + + +# in ProximalProjection we have an explicit state that we keep track of (and add +# to as we go) meaning if we jit anything with self static it doesn't update +# correctly, while if we leave self unstatic then it recompiles every time because +# the pytree structure of ProximalProjection is changing. To get around that we +# define these helper functions that are stateless so we can safely jit them + + +@functools.partial(jit, static_argnames=["op"]) +def _proximal_jvp_f_pure(constraint, xf, constants, dc, unfixed_idx, Z, D, dxdc, op): + Fx = getattr(constraint, "jac_" + op)(xf, constants) + Fx_reduced = Fx @ jnp.diag(D)[:, unfixed_idx] @ Z + Fc = Fx @ (dxdc @ dc) + Fxh = Fx_reduced + cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) + uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) + sf += sf[-1] # add a tiny bit of regularization + sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) + Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) + return Fxh_inv @ Fc + + +@functools.partial(jit, static_argnames=["op"]) +def _proximal_jvp_blocked_pure(objective, vgs, xgs, op): + out = [] + for k, (obj, const) in enumerate(zip(objective.objectives, objective.constants)): + thing_idx = objective._things_per_objective_idx[k] + xi = [xgs[i] for i in thing_idx] + vi = [vgs[i] for i in thing_idx] + assert len(xi) > 0 + assert len(vi) > 0 + assert len(xi) == len(vi) + if obj._deriv_mode == "rev": + # obj might not allow fwd mode, so compute full rev mode jacobian + # and do matmul manually. This is slightly inefficient, but usually + # when rev mode is used, dim_f <<< dim_x, so its not too bad. + Ji = getattr(obj, "jac_" + op)(*xi, constants=const) + outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum(axis=0) + out.append(outi) + else: + outi = getattr(obj, "jvp_" + op)([_vi for _vi in vi], xi, constants=const).T + out.append(outi) + out = jnp.concatenate(out) + return out diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index a95980599a..ff566c6416 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -640,7 +640,7 @@ def test_plasma_vessel_distance(self): surface_fixed=True, ) obj.build() - d = obj.compute_unscaled(*obj.xs(eq, surface)) + d = obj.compute_unscaled(*obj.xs(eq)) assert d.size == obj.dim_f assert abs(d.min() - (a_s - a_p)) < 1e-14 assert abs(d.max() - (a_s - a_p)) < surf_grid.spacing[0, 1] * a_p @@ -1535,7 +1535,7 @@ def test_boundary_error_print(capsys): obj = VacuumBoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 2 f1 = f[:n] f2 = f[n:] @@ -1610,7 +1610,7 @@ def test_boundary_error_print(capsys): obj = BoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 2 f1 = f[:n] f2 = f[n:] @@ -1686,7 +1686,7 @@ def test_boundary_error_print(capsys): obj = BoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 3 f1 = f[:n] f2 = f[n : 2 * n]