From d8fe0bb242d043ef5b0139f2904cfc3ca6e706b3 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 4 Jun 2024 16:14:37 -0400 Subject: [PATCH 01/24] Remove jit method of objective, directly compile methods --- desc/compute/utils.py | 2 + desc/objectives/linear_objectives.py | 8 +- desc/objectives/objective_funs.py | 315 +++++++++++++++------------ 3 files changed, 185 insertions(+), 140 deletions(-) diff --git a/desc/compute/utils.py b/desc/compute/utils.py index b65b9365a9..c1cc580671 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -344,6 +344,8 @@ def get_transforms(keys, obj, grid, jitable=False, **kwargs): basis = getattr(obj, c + "_basis") # first check if we already have a transform with a compatible basis for transform in transforms.values(): + if jitable: # re-using transforms doesn't work under jit, so skip + continue if basis.equiv(getattr(transform, "basis", None)): ders = np.unique( np.vstack([derivs[c], transform.derivatives]), axis=0 diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index b20317e3aa..ea91b4bcb6 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -38,8 +38,8 @@ def update_target(self, thing): assert len(new_target) == len(self.target) self.target = new_target self._target_from_user = self.target # in case the Objective is re-built - if self._use_jit: - self.jit() + # if self._use_jit: + # self.jit() def _parse_target_from_user( self, target_from_user, default_target, default_bounds, idx @@ -231,8 +231,8 @@ def update_target(self, thing): """ self.target = self.compute(thing.params_dict) - if self._use_jit: - self.jit() + # if self._use_jit: + # self.jit() class BoundaryRSelfConsistency(_Objective): diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 0a98bb0a28..6211a7cc18 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1,7 +1,6 @@ """Base classes for objectives.""" from abc import ABC, abstractmethod -from functools import partial import numpy as np @@ -63,90 +62,47 @@ def _set_derivatives(self): self._deriv_mode = "batched" else: self._deriv_mode = "blocked" - if self._deriv_mode in {"batched", "looped", "blocked"}: - self._grad = Derivative(self.compute_scalar, mode="grad") - self._hess = Derivative(self.compute_scalar, mode="hess") - if self._deriv_mode == "batched": - self._jac_scaled = Derivative(self.compute_scaled, mode="fwd") - self._jac_scaled_error = Derivative(self.compute_scaled_error, mode="fwd") - self._jac_unscaled = Derivative(self.compute_unscaled, mode="fwd") - if self._deriv_mode == "looped": - self._jac_scaled = Derivative(self.compute_scaled, mode="looped") - self._jac_scaled_error = Derivative( - self.compute_scaled_error, mode="looped" - ) - self._jac_unscaled = Derivative(self.compute_unscaled, mode="looped") - if self._deriv_mode == "blocked": - # could also do something similar for grad and hess, but probably not - # worth it. grad is already super cheap to eval all at once, and blocked - # hess would only be block diag which may miss important interactions. - - def jac_(op, x, constants=None): - if constants is None: - constants = self.constants - xs_splits = np.cumsum([t.dim_x for t in self.things]) - xs = jnp.split(x, xs_splits) - J = [] - for obj, const in zip(self.objectives, constants): - # get the xs that go to that objective - xi = [x for x, t in zip(xs, self.things) if t in obj.things] - Ji_ = getattr(obj, op)( - *xi, constants=const - ) # jac wrt to just those things - Ji = [] # jac wrt all things - for thing in self.things: - if thing in obj.things: - i = obj.things.index(thing) - Ji += [Ji_[i]] - else: - Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] - Ji = jnp.hstack(Ji) - J += [Ji] - return jnp.vstack(J) - - self._jac_scaled = partial(jac_, "jac_scaled") - self._jac_scaled_error = partial(jac_, "jac_scaled_error") - self._jac_unscaled = partial(jac_, "jac_unscaled") - - def jit(self): # noqa: C901 - """Apply JIT to compute methods, or re-apply after updating self.""" - # can't loop here because del doesn't work on getattr - # main idea is that when jitting a method, jax replaces that method - # with a CompiledFunction object, with self compiled in. To re-jit - # (ie, after updating attributes of self), we just need to delete the jax - # CompiledFunction object, which will then leave the raw method in its place, - # and then jit the raw method with the new self - - self._use_jit = True - - methods = [ - "compute_scaled", - "compute_scaled_error", - "compute_unscaled", - "compute_scalar", - "jac_scaled", - "jac_scaled_error", - "jac_unscaled", - "hess", - "grad", - "jvp_scaled", - "jvp_scaled_error", - "jvp_unscaled", - "vjp_scaled", - "vjp_scaled_error", - "vjp_unscaled", - ] - - for method in methods: - try: - delattr(self, method) - except AttributeError: - pass - setattr(self, method, jit(getattr(self, method))) - - for obj in self._objectives: - if obj._use_jit: - obj.jit() + + # TODO: figure out how to not use jit if user requests + # def jit(self): # noqa: C901 + # """Apply JIT to compute methods, or re-apply after updating self.""" + # # can't loop here because del doesn't work on getattr + # # main idea is that when jitting a method, jax replaces that method + # # with a CompiledFunction object, with self compiled in. To re-jit + # # (ie, after updating attributes of self), we just need to delete the jax + # # CompiledFunction object, which will then leave the raw method in its place, + # # and then jit the raw method with the new self + + # self._use_jit = True + + # methods = [ + # "compute_scaled", + # "compute_scaled_error", + # "compute_unscaled", + # "compute_scalar", + # "jac_scaled", + # "jac_scaled_error", + # "jac_unscaled", + # "hess", + # "grad", + # "jvp_scaled", + # "jvp_scaled_error", + # "jvp_unscaled", + # "vjp_scaled", + # "vjp_scaled_error", + # "vjp_unscaled", + # ] + + # for method in methods: + # try: + # delattr(self, method) + # except AttributeError: + # pass + # setattr(self, method, jit(getattr(self, method))) + + # for obj in self._objectives: + # if obj._use_jit: + # obj.jit() def build(self, use_jit=None, verbose=1): """Build the objective. @@ -178,11 +134,19 @@ def build(self, use_jit=None, verbose=1): self._scalar = False self._set_derivatives() - if self.use_jit: - self.jit() + # if self.use_jit: + # self.jit() self._set_things() + # this is needed to know which "thing" goes with which sub-objective, + # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] + self._things_per_objective_idx = [] + for obj in self.objectives: + self._things_per_objective_idx.append( + [i for i, t in enumerate(self.things) if t in obj.things] + ) + self._built = True timer.stop("Objective build") if verbose > 1: @@ -235,6 +199,7 @@ def flatten(things): self._unflatten = unflatten self._flatten = flatten + @jit def compute_unscaled(self, x, constants=None): """Compute the raw value of the objective function. @@ -262,6 +227,7 @@ def compute_unscaled(self, x, constants=None): ) return f + @jit def compute_scaled(self, x, constants=None): """Compute the objective function and apply weighting and normalization. @@ -289,6 +255,7 @@ def compute_scaled(self, x, constants=None): ) return f + @jit def compute_scaled_error(self, x, constants=None): """Compute and apply the target/bounds, weighting, and normalization. @@ -316,6 +283,7 @@ def compute_scaled_error(self, x, constants=None): ) return f + @jit def compute_scalar(self, x, constants=None): """Compute the sum of squares error. @@ -395,8 +363,8 @@ def unpack_state(self, x, per_objective=True): params = self._unflatten(params) # this filters out the params of things that are unused by each objective params = [ - [par for par, thing in zip(param, self.things) if thing in obj.things] - for param, obj in zip(params, self.objectives) + [param[i] for i in idx] + for param, idx in zip(params, self._things_per_objective_idx) ] return params @@ -408,35 +376,94 @@ def x(self, *things): xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) + @jit def grad(self, x, constants=None): """Compute gradient vector of self.compute_scalar wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_1d(self._grad(x, constants).squeeze()) + return jnp.atleast_1d( + Derivative(self.compute_scalar, mode="grad")(x, constants).squeeze() + ) + @jit def hess(self, x, constants=None): """Compute Hessian matrix of self.compute_scalar wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._hess(x, constants).squeeze()) + return jnp.atleast_2d( + Derivative(self.compute_scalar, mode="hess")(x, constants).squeeze() + ) + + def _jac(self, op, x, constants=None): + # could also do something similar for grad and hess, but probably not + # worth it. grad is already super cheap to eval all at once, and blocked + # hess would only be block diag which may miss important interactions. + if constants is None: + constants = self.constants + xs_splits = np.cumsum([t.dim_x for t in self.things]) + xs = jnp.split(x, xs_splits) + J = [] + for k, (obj, const) in enumerate(zip(self.objectives, constants)): + # get the xs that go to that objective + xi = [xs[i] for i in self._things_per_objective_idx[k]] + Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things + Ji = [] # jac wrt all things + for i, (thing, idx) in enumerate( + zip(self.things, self._things_per_objective_idx) + ): + if i in idx: + Ji += [Ji_[idx.index(i)]] + else: + Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] + Ji = jnp.hstack(Ji) + J += [Ji] + return jnp.vstack(J) + + @jit def jac_scaled(self, x, constants=None): """Compute Jacobian matrix of self.compute_scaled wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_scaled(x, constants).squeeze()) + if self._deriv_mode == "batched": + J = Derivative(self.compute_scaled, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_scaled, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac("jac_scaled", x, constants) + + return jnp.atleast_2d(J.squeeze()) + + @jit def jac_scaled_error(self, x, constants=None): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_scaled_error(x, constants).squeeze()) + if self._deriv_mode == "batched": + J = Derivative(self.compute_scaled_error, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_scaled_error, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac("jac_scaled_error", x, constants) + + return jnp.atleast_2d(J.squeeze()) + + @jit def jac_unscaled(self, x, constants=None): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" if constants is None: constants = self.constants - return jnp.atleast_2d(self._jac_unscaled(x, constants).squeeze()) + + if self._deriv_mode == "batched": + J = Derivative(self.compute_unscaled, mode="fwd")(x, constants) + if self._deriv_mode == "looped": + J = Derivative(self.compute_unscaled, mode="looped")(x, constants) + if self._deriv_mode == "blocked": + J = self._jac("jac_unscaled", x, constants) + + return jnp.atleast_2d(J.squeeze()) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) @@ -456,6 +483,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"): else: raise NotImplementedError("Cannot compute JVP higher than 3rd order.") + @jit def jvp_scaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled. @@ -472,6 +500,7 @@ def jvp_scaled(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled") + @jit def jvp_scaled_error(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled_error. @@ -488,6 +517,7 @@ def jvp_scaled_error(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled_error") + @jit def jvp_unscaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_unscaled. @@ -508,6 +538,7 @@ def _vjp(self, v, x, constants=None, op="compute_scaled"): fun = lambda x: getattr(self, op)(x, constants) return Derivative.compute_vjp(fun, 0, v, x) + @jit def vjp_scaled(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_scaled. @@ -523,6 +554,7 @@ def vjp_scaled(self, v, x, constants=None): """ return self._vjp(v, x, constants, "compute_scaled") + @jit def vjp_scaled_error(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_scaled_error. @@ -538,6 +570,7 @@ def vjp_scaled_error(self, v, x, constants=None): """ return self._vjp(v, x, constants, "compute_scaled_error") + @jit def vjp_unscaled(self, v, x, constants=None): """Compute vector-Jacobian product of self.compute_unscaled. @@ -815,10 +848,6 @@ def __init__( def _set_derivatives(self): """Set up derivatives of the objective wrt each argument.""" - argnums = tuple(range(len(self.things))) - # derivatives return tuple, one for each thing - self._grad = Derivative(self.compute_scalar, argnums, mode="grad") - self._hess = Derivative(self.compute_scalar, argnums, mode="hess") if self._deriv_mode == "auto": # choose based on shape of jacobian. fwd mode is more memory efficient # so we prefer that unless the jacobian is really wide @@ -827,38 +856,29 @@ def _set_derivatives(self): if self.dim_f >= 0.5 * sum(t.dim_x for t in self.things) else "rev" ) - self._jac_scaled = Derivative( - self.compute_scaled, argnums, mode=self._deriv_mode - ) - self._jac_scaled_error = Derivative( - self.compute_scaled_error, argnums, mode=self._deriv_mode - ) - self._jac_unscaled = Derivative( - self.compute_unscaled, argnums, mode=self._deriv_mode - ) - def jit(self): # noqa: C901 - """Apply JIT to compute methods, or re-apply after updating self.""" - self._use_jit = True - - methods = [ - "compute_scaled", - "compute_scaled_error", - "compute_unscaled", - "compute_scalar", - "jac_scaled", - "jac_scaled_error", - "jac_unscaled", - "hess", - "grad", - ] - - for method in methods: - try: - delattr(self, method) - except AttributeError: - pass - setattr(self, method, jit(getattr(self, method))) + # def jit(self): # noqa: C901 + # """Apply JIT to compute methods, or re-apply after updating self.""" + # self._use_jit = True + + # methods = [ + # "compute_scaled", + # "compute_scaled_error", + # "compute_unscaled", + # "compute_scalar", + # "jac_scaled", + # "jac_scaled_error", + # "jac_unscaled", + # "hess", + # "grad", + # ] + + # for method in methods: + # try: + # delattr(self, method) + # except AttributeError: + # pass + # setattr(self, method, jit(getattr(self, method))) def _check_dimensions(self): """Check that len(target) = len(bounds) = len(weight) = dim_f.""" @@ -912,8 +932,8 @@ def build(self, use_jit=True, verbose=1): if use_jit is not None: self._use_jit = use_jit - if self._use_jit: - self.jit() + # if self._use_jit: + # self.jit() self._built = True @@ -930,6 +950,7 @@ def _maybe_array_to_params(self, *args): argsout += (arg,) return argsout + @jit def compute_unscaled(self, *args, **kwargs): """Compute the raw value of the objective.""" args = self._maybe_array_to_params(*args) @@ -938,6 +959,7 @@ def compute_unscaled(self, *args, **kwargs): f = self._loss_function(f) return jnp.atleast_1d(f) + @jit def compute_scaled(self, *args, **kwargs): """Compute and apply weighting and normalization.""" args = self._maybe_array_to_params(*args) @@ -946,6 +968,7 @@ def compute_scaled(self, *args, **kwargs): f = self._loss_function(f) return jnp.atleast_1d(self._scale(f, **kwargs)) + @jit def compute_scaled_error(self, *args, **kwargs): """Compute and apply the target/bounds, weighting, and normalization.""" args = self._maybe_array_to_params(*args) @@ -988,6 +1011,7 @@ def _scale(self, f, *args, **kwargs): f_norm = jnp.atleast_1d(f) / self.normalization # normalization return f_norm * w * self.weight + @jit def compute_scalar(self, *args, **kwargs): """Compute the scalar form of the objective.""" if self.scalar: @@ -996,25 +1020,41 @@ def compute_scalar(self, *args, **kwargs): f = jnp.sum(self.compute_scaled_error(*args, **kwargs) ** 2) / 2 return f.squeeze() + @jit def grad(self, *args, **kwargs): """Compute gradient vector of self.compute_scalar wrt x.""" - return self._grad(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scalar, argnums, mode="grad")(*args, **kwargs) + @jit def hess(self, *args, **kwargs): """Compute Hessian matrix of self.compute_scalar wrt x.""" - return self._hess(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scalar, argnums, mode="hess")(*args, **kwargs) + @jit def jac_scaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled wrt x.""" - return self._jac_scaled(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scaled, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) + @jit def jac_scaled_error(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" - return self._jac_scaled_error(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_scaled_error, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) + @jit def jac_unscaled(self, *args, **kwargs): """Compute Jacobian matrix of self.compute_unscaled wrt x.""" - return self._jac_unscaled(*args, **kwargs) + argnums = tuple(range(len(self.things))) + return Derivative(self.compute_unscaled, argnums, mode=self._deriv_mode)( + *args, **kwargs + ) def _jvp(self, v, x, constants=None, op="compute_scaled"): v = v if isinstance(v, (tuple, list)) else (v,) @@ -1026,6 +1066,7 @@ def _jvp(self, v, x, constants=None, op="compute_scaled"): sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" return jnp.vectorize(jvpfun, signature=sig)(*v) + @jit def jvp_scaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled. @@ -1041,6 +1082,7 @@ def jvp_scaled(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled") + @jit def jvp_scaled_error(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_scaled_error. @@ -1056,6 +1098,7 @@ def jvp_scaled_error(self, v, x, constants=None): """ return self._jvp(v, x, constants, "compute_scaled_error") + @jit def jvp_unscaled(self, v, x, constants=None): """Compute Jacobian-vector product of self.compute_unscaled. From f99c3d182d68cdcf3bb00e5fef3adb3c4f39358a Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 4 Jun 2024 17:52:34 -0400 Subject: [PATCH 02/24] Fix incorrect indexing --- desc/objectives/objective_funs.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 6211a7cc18..dedb019687 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -406,14 +406,13 @@ def _jac(self, op, x, constants=None): J = [] for k, (obj, const) in enumerate(zip(self.objectives, constants)): # get the xs that go to that objective - xi = [xs[i] for i in self._things_per_objective_idx[k]] + thing_idx = self._things_per_objective_idx[k] + xi = [xs[i] for i in thing_idx] Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things Ji = [] # jac wrt all things - for i, (thing, idx) in enumerate( - zip(self.things, self._things_per_objective_idx) - ): - if i in idx: - Ji += [Ji_[idx.index(i)]] + for i, thing in enumerate(self.things): + if i in thing_idx: + Ji += [Ji_[thing_idx.index(i)]] else: Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] Ji = jnp.hstack(Ji) From e37eb35c8a71c5de089888f6d440fa8ddf0fecaf Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 25 Jun 2024 19:36:46 -0400 Subject: [PATCH 03/24] Add unjit method --- desc/objectives/linear_objectives.py | 8 +- desc/objectives/objective_funs.py | 120 ++++++++++++--------------- 2 files changed, 58 insertions(+), 70 deletions(-) diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index ea91b4bcb6..50479542cd 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -38,8 +38,8 @@ def update_target(self, thing): assert len(new_target) == len(self.target) self.target = new_target self._target_from_user = self.target # in case the Objective is re-built - # if self._use_jit: - # self.jit() + if not self._use_jit: + self._unjit() def _parse_target_from_user( self, target_from_user, default_target, default_bounds, idx @@ -231,8 +231,8 @@ def update_target(self, thing): """ self.target = self.compute(thing.params_dict) - # if self._use_jit: - # self.jit() + if not self._use_jit: + self._unjit() class BoundaryRSelfConsistency(_Objective): diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 930b8e000d..68e71731dc 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -1,5 +1,6 @@ """Base classes for objectives.""" +import functools from abc import ABC, abstractmethod import numpy as np @@ -63,46 +64,32 @@ def _set_derivatives(self): else: self._deriv_mode = "blocked" - # TODO: figure out how to not use jit if user requests - # def jit(self): # noqa: C901 - # """Apply JIT to compute methods, or re-apply after updating self.""" - # # can't loop here because del doesn't work on getattr - # # main idea is that when jitting a method, jax replaces that method - # # with a CompiledFunction object, with self compiled in. To re-jit - # # (ie, after updating attributes of self), we just need to delete the jax - # # CompiledFunction object, which will then leave the raw method in its place, - # # and then jit the raw method with the new self - - # self._use_jit = True - - # methods = [ - # "compute_scaled", - # "compute_scaled_error", - # "compute_unscaled", - # "compute_scalar", - # "jac_scaled", - # "jac_scaled_error", - # "jac_unscaled", - # "hess", - # "grad", - # "jvp_scaled", - # "jvp_scaled_error", - # "jvp_unscaled", - # "vjp_scaled", - # "vjp_scaled_error", - # "vjp_unscaled", - # ] - - # for method in methods: - # try: - # delattr(self, method) - # except AttributeError: - # pass - # setattr(self, method, jit(getattr(self, method))) - - # for obj in self._objectives: - # if obj._use_jit: - # obj.jit() + def _unjit(self): + """Remove jit compiled methods.""" + methods = [ + "compute_scaled", + "compute_scaled_error", + "compute_unscaled", + "compute_scalar", + "jac_scaled", + "jac_scaled_error", + "jac_unscaled", + "hess", + "grad", + "jvp_scaled", + "jvp_scaled_error", + "jvp_unscaled", + "vjp_scaled", + "vjp_scaled_error", + "vjp_unscaled", + ] + for method in methods: + try: + setattr( + self, method, functools.partial(getattr(self, method)._fun, self) + ) + except AttributeError: + pass def build(self, use_jit=None, verbose=1): """Build the objective. @@ -134,8 +121,8 @@ def build(self, use_jit=None, verbose=1): self._scalar = False self._set_derivatives() - # if self.use_jit: - # self.jit() + if not self.use_jit: + self._unjit() self._set_things() @@ -856,28 +843,29 @@ def _set_derivatives(self): else "rev" ) - # def jit(self): # noqa: C901 - # """Apply JIT to compute methods, or re-apply after updating self.""" - # self._use_jit = True - - # methods = [ - # "compute_scaled", - # "compute_scaled_error", - # "compute_unscaled", - # "compute_scalar", - # "jac_scaled", - # "jac_scaled_error", - # "jac_unscaled", - # "hess", - # "grad", - # ] - - # for method in methods: - # try: - # delattr(self, method) - # except AttributeError: - # pass - # setattr(self, method, jit(getattr(self, method))) + def _unjit(self): + """Remove jit compiled methods.""" + methods = [ + "compute_scaled", + "compute_scaled_error", + "compute_unscaled", + "compute_scalar", + "jac_scaled", + "jac_scaled_error", + "jac_unscaled", + "jvp_scaled", + "jvp_scaled_error", + "jvp_unscaled", + "hess", + "grad", + ] + for method in methods: + try: + setattr( + self, method, functools.partial(getattr(self, method)._fun, self) + ) + except AttributeError: + pass def _check_dimensions(self): """Check that len(target) = len(bounds) = len(weight) = dim_f.""" @@ -931,8 +919,8 @@ def build(self, use_jit=True, verbose=1): if use_jit is not None: self._use_jit = use_jit - # if self._use_jit: - # self.jit() + if not self._use_jit: + self._unjit() self._built = True From 8980654899a7c8a5907d6aab9f90442899a75cbb Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 25 Jun 2024 19:37:54 -0400 Subject: [PATCH 04/24] Add some asserts to catch weird edge cases --- desc/objectives/objective_funs.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 68e71731dc..b8bcd00a56 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -126,14 +126,6 @@ def build(self, use_jit=None, verbose=1): self._set_things() - # this is needed to know which "thing" goes with which sub-objective, - # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] - self._things_per_objective_idx = [] - for obj in self.objectives: - self._things_per_objective_idx.append( - [i for i, t in enumerate(self.things) if t in obj.things] - ) - self._built = True timer.stop("Objective build") if verbose > 1: @@ -169,6 +161,14 @@ def _set_things(self, things=None): ) unique_, inds_ = unique_list(flat_) + # this is needed to know which "thing" goes with which sub-objective, + # ie objectives[i].things == [things[k] for k in things_per_objective_idx[i]] + self._things_per_objective_idx = [] + for obj in self.objectives: + self._things_per_objective_idx.append( + [i for i, t in enumerate(unique_) if t in obj.things] + ) + def unflatten(unique): assert len(unique) == len(unique_) flat = [unique[i] for i in inds_] @@ -206,6 +206,7 @@ def compute_unscaled(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_unscaled(*par, constants=const) @@ -234,6 +235,7 @@ def compute_scaled(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_scaled(*par, constants=const) @@ -262,6 +264,7 @@ def compute_scaled_error(self, x, constants=None): params = self.unpack_state(x) if constants is None: constants = self.constants + assert len(params) == len(constants) == len(self.objectives) f = jnp.concatenate( [ obj.compute_scaled_error(*par, constants=const) @@ -309,6 +312,7 @@ def print_value(self, x, constants=None): f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2 print("Total (sum of squares): {:10.3e}, ".format(f)) params = self.unpack_state(x) + assert len(params) == len(constants) == len(self.objectives) for par, obj, const in zip(params, self.objectives, constants): obj.print_value(*par, constants=const) return None @@ -344,11 +348,14 @@ def unpack_state(self, x, per_objective=True): xs_splits = np.cumsum([t.dim_x for t in self.things]) xs = jnp.split(x, xs_splits) + xs = xs[: len(self.things)] # jnp.split returns an empty array at the end + assert len(xs) == len(self.things) params = [t.unpack_params(xi) for t, xi in zip(self.things, xs)] if per_objective: # params is a list of lists of dicts, for each thing and for each objective params = self._unflatten(params) # this filters out the params of things that are unused by each objective + assert len(params) == len(self._things_per_objective_idx) params = [ [param[i] for i in idx] for param, idx in zip(params, self._things_per_objective_idx) @@ -359,6 +366,7 @@ def x(self, *things): """Return the full state vector from the Optimizable objects things.""" # TODO: also check resolution etc? things = things or self.things + assert len(things) == len(self.things) assert all([type(t1) is type(t2) for t1, t2 in zip(things, self.things)]) xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) @@ -391,6 +399,7 @@ def _jac(self, op, x, constants=None): xs_splits = np.cumsum([t.dim_x for t in self.things]) xs = jnp.split(x, xs_splits) J = [] + assert len(self.objectives) == len(self.constants) for k, (obj, const) in enumerate(zip(self.objectives, constants)): # get the xs that go to that objective thing_idx = self._things_per_objective_idx[k] @@ -930,6 +939,7 @@ def compute(self, *args, **kwargs): def _maybe_array_to_params(self, *args): argsout = tuple() + assert len(args) == len(self.things) for arg, thing in zip(args, self.things): if isinstance(arg, (np.ndarray, jnp.ndarray)): argsout += (thing.unpack_params(arg),) @@ -1270,6 +1280,7 @@ def things(self, new): if not isinstance(new, (tuple, list)): new = [new] assert all(isinstance(x, Optimizable) for x in new) + assert len(new) == len(self.things) assert all(type(a) is type(b) for a, b in zip(new, self.things)) self._things = list(new) # can maybe improve this later to not rebuild if resolution is the same From 780b27f0bda85cbc8dc23dc4bf7bf9b7817cf8e7 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 6 Aug 2024 00:52:24 -0400 Subject: [PATCH 05/24] take p_newton out, add update_qr function --- desc/optimize/aug_lagrangian_ls.py | 14 +++++++-- desc/optimize/least_squares.py | 14 +++++++-- desc/optimize/tr_subproblems.py | 25 +++++----------- desc/optimize/utils.py | 48 +++++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 22 deletions(-) diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index 65c7de50a0..7caa1896f9 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -2,7 +2,7 @@ from scipy.optimize import NonlinearConstraint, OptimizeResult -from desc.backend import jnp +from desc.backend import jnp, qr from desc.utils import errorif, setdefault from .bound_utils import ( @@ -25,6 +25,7 @@ inequality_to_bounds, print_header_nonlinear, print_iteration_nonlinear, + solve_triangular_regularized, ) @@ -368,6 +369,15 @@ def lagjac(z, y, mu, *args): U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False) elif tr_method == "cho": B_h = jnp.dot(J_a.T, J_a) + elif tr_method == "qr": + # try full newton step + tall = J_a.shape[0] >= J_a.shape[1] + if tall: + Q, R = qr(J_a, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ L_a) + else: + Q, R = qr(J_a.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, L_a, lower=True) actual_reduction = -1 Lactual_reduction = -1 @@ -390,7 +400,7 @@ def lagjac(z, y, mu, *args): ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - L_a, J_a, trust_radius, alpha + Q, R, p_newton, L_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 0ed35a506a..ebd09af617 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -2,7 +2,7 @@ from scipy.optimize import OptimizeResult -from desc.backend import jnp +from desc.backend import jnp, qr from desc.utils import errorif, setdefault from .bound_utils import ( @@ -24,6 +24,7 @@ compute_jac_scale, print_header_nonlinear, print_iteration_nonlinear, + solve_triangular_regularized, ) @@ -268,6 +269,15 @@ def lsqtr( # noqa: C901 - FIXME: simplify this U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False) elif tr_method == "cho": B_h = jnp.dot(J_a.T, J_a) + elif tr_method == "qr": + # try full newton step + tall = J_a.shape[0] >= J_a.shape[1] + if tall: + Q, R = qr(J_a, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f_a) + else: + Q, R = qr(J_a.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, f_a, lower=True) actual_reduction = -1 @@ -289,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - f_a, J_a, trust_radius, alpha + Q, R, p_newton, f_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/tr_subproblems.py b/desc/optimize/tr_subproblems.py index c9149d2bb3..fe50fd4a34 100644 --- a/desc/optimize/tr_subproblems.py +++ b/desc/optimize/tr_subproblems.py @@ -14,7 +14,7 @@ ) from desc.utils import setdefault -from .utils import chol, solve_triangular_regularized +from .utils import chol, solve_triangular_regularized, update_qr_jax @jit @@ -378,7 +378,7 @@ def loop_body(state): @jit def trust_region_step_exact_qr( - f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 + Q, R, p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 ): """Solve a trust-region problem using a semi-exact method. @@ -414,14 +414,6 @@ def trust_region_step_exact_qr( Sometimes called Levenberg-Marquardt parameter. """ - # try full newton step - tall = J.shape[0] >= J.shape[1] - if tall: - Q, R = qr(J, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ f) - else: - Q, R = qr(J.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, f, lower=True) def truefun(*_): return p_newton, False, 0.0 @@ -450,10 +442,9 @@ def loop_body(state): alpha, ) - Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) - # Ji is always tall since its padded by alpha*I - Q, R = qr(Ji, mode="economic") - p = solve_triangular_regularized(R, -Q.T @ fp) + Q1, R1 = update_qr_jax(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) + R1 = R1[: R1.shape[1], : R1.shape[1]] + p = solve_triangular_regularized(R1, -Q1.T @ fp) p_norm = jnp.linalg.norm(p) phi = p_norm - trust_radius alpha_upper = jnp.where(phi < 0, alpha, alpha_upper) @@ -474,9 +465,9 @@ def loop_body(state): alpha, *_ = while_loop( loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) ) - Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) - Q, R = qr(Ji, mode="economic") - p = solve_triangular(R, -Q.T @ fp) + Q1, R1 = update_qr_jax(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) + R1 = R1[: R1.shape[1], : R1.shape[1]] + p = solve_triangular_regularized(R1, -Q1.T @ fp) # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index bbee5a22f7..cf079c3195 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import cond, jit, jnp, put, solve_triangular +from desc.backend import cond, fori_loop, jit, jnp, put, solve_triangular from desc.utils import Index @@ -551,3 +551,49 @@ def solve_triangular_regularized(R, b, lower=False): Rs = R * dri[:, None] b = dri * b return solve_triangular(Rs, b, unit_diagonal=True, lower=lower) + + +# TODO: add references to the docstrings +def _givens_jax(a, b): + """Compute Givens rotation matrix. + + Compute the Givens rotation matrix G2 that zeros out the second element + of a 2-vector. + G2*[a; b] = [r; 0] + where r = sqrt(a^2 + b^2) + G2 = [[c, -s], [s, c]] + """ + r = jnp.sqrt(a**2 + b**2) + c = a / r + s = -b / r + + G2 = jnp.array([[c, -s], [s, c]]) + return G2.astype(float) + + +@jit +def update_qr_jax(A, w, q, r): + """Update QR factorization with a diagonal matrix w at the bottom.""" + m, n = A.shape + Q = jnp.eye(m + n) + Q = Q.at[:m, :m].set(q) + + R = jnp.vstack([r, w]) + + def body_inner(i, jQR): + j, Q, R = jQR + i = m + j - i + a, b = R[i - 1, j], R[i, j] + G2 = _givens_jax(a, b) + R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) + Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T) + return j, Q, R + + def body(j, QR): + Q, R = QR + j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) + return Q, R + + Q, R = fori_loop(0, n, body, (Q, R)) + + return Q, R From 4c72ab52888cbd0b0665120d559e1f61326279ba Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 6 Aug 2024 12:42:38 -0400 Subject: [PATCH 06/24] add update_qr for economic --- desc/optimize/tr_subproblems.py | 16 ++++++++-------- desc/optimize/utils.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/desc/optimize/tr_subproblems.py b/desc/optimize/tr_subproblems.py index fe50fd4a34..20d874644a 100644 --- a/desc/optimize/tr_subproblems.py +++ b/desc/optimize/tr_subproblems.py @@ -14,7 +14,7 @@ ) from desc.utils import setdefault -from .utils import chol, solve_triangular_regularized, update_qr_jax +from .utils import chol, solve_triangular_regularized, update_qr_jax_eco @jit @@ -441,16 +441,15 @@ def loop_body(state): jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), alpha, ) + Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - Q1, R1 = update_qr_jax(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - R1 = R1[: R1.shape[1], : R1.shape[1]] - p = solve_triangular_regularized(R1, -Q1.T @ fp) + p = solve_triangular_regularized(R2, -Q2.T @ fp) p_norm = jnp.linalg.norm(p) phi = p_norm - trust_radius alpha_upper = jnp.where(phi < 0, alpha, alpha_upper) alpha_lower = jnp.where(phi > 0, alpha, alpha_lower) - q = solve_triangular_regularized(R.T, p, lower=True) + q = solve_triangular_regularized(R2.T, p, lower=True) q_norm = jnp.linalg.norm(q) alpha += (p_norm / q_norm) ** 2 * phi / trust_radius @@ -465,9 +464,10 @@ def loop_body(state): alpha, *_ = while_loop( loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) ) - Q1, R1 = update_qr_jax(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - R1 = R1[: R1.shape[1], : R1.shape[1]] - p = solve_triangular_regularized(R1, -Q1.T @ fp) + + Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) + + p = solve_triangular(R2, -Q2.T @ fp) # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index cf079c3195..f64cf169d5 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -597,3 +597,33 @@ def body(j, QR): Q, R = fori_loop(0, n, body, (Q, R)) return Q, R + + +@jit +def update_qr_jax_eco(A, w, q, r): + """Update QR factorization with a diagonal matrix w at the bottom.""" + m, n = A.shape + Q = jnp.eye(m + n) + Q = Q.at[:m, :m].set(q) + + R = jnp.vstack([r, w]) + + def body_inner(i, jQR): + j, Q, R = jQR + i = m + j - i + a, b = R[i - 1, j], R[i, j] + G2 = _givens_jax(a, b) + R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) + Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T) + return j, Q, R + + def body(j, QR): + Q, R = QR + j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) + return Q, R + + Q, R = fori_loop(0, n, body, (Q, R)) + Re = R.at[: R.shape[1], : R.shape[1]].get() + Qe = Q.at[:, : R.shape[1]].get() + + return Qe, Re From bdc048c437c4d34f9e4f16a3eadb916b60b74d14 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 6 Aug 2024 12:55:20 -0400 Subject: [PATCH 07/24] change initial qr to return full array for compatibility with update --- desc/optimize/least_squares.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index ebd09af617..91d453450f 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -273,11 +273,15 @@ def lsqtr( # noqa: C901 - FIXME: simplify this # try full newton step tall = J_a.shape[0] >= J_a.shape[1] if tall: - Q, R = qr(J_a, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ f_a) + Q, R = qr(J_a) + p_newton = solve_triangular_regularized( + R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ f_a + ) else: - Q, R = qr(J_a.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, f_a, lower=True) + Q, R = qr(J_a.T) + p_newton = Q @ solve_triangular_regularized( + R[: R.shape[1], : R.shape[1]].T, f_a, lower=True + ) actual_reduction = -1 From 49ce7b226856ee76b1997dfff1e2a5f1df4d510f Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 6 Aug 2024 14:14:41 -0400 Subject: [PATCH 08/24] use economic qr both in jax and scipy versions --- desc/optimize/least_squares.py | 12 ++++-------- desc/optimize/utils.py | 10 ++++++---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 91d453450f..ebd09af617 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -273,15 +273,11 @@ def lsqtr( # noqa: C901 - FIXME: simplify this # try full newton step tall = J_a.shape[0] >= J_a.shape[1] if tall: - Q, R = qr(J_a) - p_newton = solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ f_a - ) + Q, R = qr(J_a, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f_a) else: - Q, R = qr(J_a.T) - p_newton = Q @ solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]].T, f_a, lower=True - ) + Q, R = qr(J_a.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, f_a, lower=True) actual_reduction = -1 diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index f64cf169d5..cd3d824415 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -603,14 +603,16 @@ def body(j, QR): def update_qr_jax_eco(A, w, q, r): """Update QR factorization with a diagonal matrix w at the bottom.""" m, n = A.shape - Q = jnp.eye(m + n) - Q = Q.at[:m, :m].set(q) + mr, nr = r.shape + Q = jnp.zeros([m + n, 2 * n]) + Q = Q.at[:m, :n].set(q) + Q = Q.at[-n:, -n:].set(jnp.eye(n)) R = jnp.vstack([r, w]) def body_inner(i, jQR): j, Q, R = jQR - i = m + j - i + i = n + j - i a, b = R[i - 1, j], R[i, j] G2 = _givens_jax(a, b) R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) @@ -619,7 +621,7 @@ def body_inner(i, jQR): def body(j, QR): Q, R = QR - j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) + j, Q, R = fori_loop(0, n, body_inner, (j, Q, R)) return Q, R Q, R = fori_loop(0, n, body, (Q, R)) From 28e391e1ec3adf20f0d1679a8f8e0b2668559eb2 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 6 Aug 2024 15:03:44 -0400 Subject: [PATCH 09/24] back to full QR --- desc/optimize/aug_lagrangian_ls.py | 13 +++++++++---- desc/optimize/least_squares.py | 13 +++++++++---- desc/optimize/utils.py | 10 ++++------ 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index 7caa1896f9..fb222d776c 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -373,11 +373,16 @@ def lagjac(z, y, mu, *args): # try full newton step tall = J_a.shape[0] >= J_a.shape[1] if tall: - Q, R = qr(J_a, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ L_a) + Q, R = qr(J_a) + p_newton = solve_triangular_regularized( + R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ L_a + ) else: - Q, R = qr(J_a.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, L_a, lower=True) + Q, R = qr(J_a.T) + p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized( + R[: R.shape[1], : R.shape[1]].T, L_a, lower=True + ) + Q, R = qr(J_a) actual_reduction = -1 Lactual_reduction = -1 diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index ebd09af617..a459e7f760 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -273,11 +273,16 @@ def lsqtr( # noqa: C901 - FIXME: simplify this # try full newton step tall = J_a.shape[0] >= J_a.shape[1] if tall: - Q, R = qr(J_a, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ f_a) + Q, R = qr(J_a) + p_newton = solve_triangular_regularized( + R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ f_a + ) else: - Q, R = qr(J_a.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, f_a, lower=True) + Q, R = qr(J_a.T) + p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized( + R[: R.shape[1], : R.shape[1]].T, f_a, lower=True + ) + Q, R = qr(J_a) actual_reduction = -1 diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index cd3d824415..f64cf169d5 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -603,16 +603,14 @@ def body(j, QR): def update_qr_jax_eco(A, w, q, r): """Update QR factorization with a diagonal matrix w at the bottom.""" m, n = A.shape - mr, nr = r.shape - Q = jnp.zeros([m + n, 2 * n]) - Q = Q.at[:m, :n].set(q) - Q = Q.at[-n:, -n:].set(jnp.eye(n)) + Q = jnp.eye(m + n) + Q = Q.at[:m, :m].set(q) R = jnp.vstack([r, w]) def body_inner(i, jQR): j, Q, R = jQR - i = n + j - i + i = m + j - i a, b = R[i - 1, j], R[i, j] G2 = _givens_jax(a, b) R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) @@ -621,7 +619,7 @@ def body_inner(i, jQR): def body(j, QR): Q, R = QR - j, Q, R = fori_loop(0, n, body_inner, (j, Q, R)) + j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) return Q, R Q, R = fori_loop(0, n, body, (Q, R)) From c6c42c444e6159aea14cefe7bfeb0cb77d6fe2c1 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 8 Aug 2024 11:56:57 -0400 Subject: [PATCH 10/24] enforce very small numbers to be 0, replace custom Givens with Jax implementation of it --- desc/backend.py | 5 +++++ desc/optimize/utils.py | 18 ++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index c26213b045..22c5c60c6f 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -76,6 +76,7 @@ repeat = jnp.repeat take = jnp.take scan = jax.lax.scan + rsqrt = jax.lax.rsqrt from jax import custom_jvp from jax.experimental.ode import odeint from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular @@ -811,3 +812,7 @@ def take( else: out = np.take(a, indices, axis, out, mode) return out + + def rsqrt(x): + """Reciprocal square root.""" + return 1 / np.sqrt(x) diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index f64cf169d5..8c31cdc1a2 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import cond, fori_loop, jit, jnp, put, solve_triangular +from desc.backend import cond, fori_loop, jit, jnp, put, rsqrt, solve_triangular from desc.utils import Index @@ -563,11 +563,14 @@ def _givens_jax(a, b): where r = sqrt(a^2 + b^2) G2 = [[c, -s], [s, c]] """ - r = jnp.sqrt(a**2 + b**2) - c = a / r - s = -b / r - - G2 = jnp.array([[c, -s], [s, c]]) + # Taken from jax._src.scipy.sparse.linalg._givens_rotation + b_zero = abs(b) == 0 + a_lt_b = abs(a) < abs(b) + t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a) + r = rsqrt(1 + abs(t) ** 2).astype(t.dtype) + cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r)) + sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t)) + G2 = jnp.array([[cs, -sn], [sn, cs]]) return G2.astype(float) @@ -595,6 +598,7 @@ def body(j, QR): return Q, R Q, R = fori_loop(0, n, body, (Q, R)) + R = jnp.where(jnp.abs(R) < 1e-10, 0, R) return Q, R @@ -623,6 +627,8 @@ def body(j, QR): return Q, R Q, R = fori_loop(0, n, body, (Q, R)) + R = jnp.where(jnp.abs(R) < 1e-10, 0, R) + Re = R.at[: R.shape[1], : R.shape[1]].get() Qe = Q.at[:, : R.shape[1]].get() From cfcb3937c1dfa3be4424fa31fe48a90b8dce39e7 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:23:26 -0400 Subject: [PATCH 11/24] Fix hashing of partial objects --- desc/io/optimizable_io.py | 5 ++++- desc/optimize/_constraint_wrappers.py | 22 +++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 5a11d39245..554cdac070 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -1,6 +1,7 @@ """Functions and methods for saving and loading equilibria and other objects.""" import copy +import functools import os import pickle import pydoc @@ -86,7 +87,9 @@ def _unjittable(x): return any([_unjittable(y) for y in x.values()]) if hasattr(x, "dtype") and np.ndim(x) == 0: return np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.int_) - return isinstance(x, (str, types.FunctionType, bool, int, np.int_)) + return isinstance( + x, (str, types.FunctionType, functools.partial, bool, int, np.int_) + ) def _make_hashable(x): diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 93bf5385ae..31384f4987 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1,7 +1,5 @@ """Wrappers for doing STELLOPT/SIMSOPT like optimization.""" -import functools - import numpy as np from desc.backend import jit, jnp @@ -972,7 +970,7 @@ def jvp_scaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled") + jvpfun = lambda u: self._jvp_scaled(u, xf, xg, constants) return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_scaled_error(self, v, x, constants=None): @@ -992,7 +990,7 @@ def jvp_scaled_error(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled_error") + jvpfun = lambda u: self._jvp_scaled_error(u, xf, xg, constants) return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_unscaled(self, v, x, constants=None): @@ -1012,10 +1010,9 @@ def jvp_unscaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="unscaled") + jvpfun = lambda u: self._jvp_unscaled(u, xf, xg, constants) return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) - @functools.partial(jit, static_argnames=("self", "op")) def _jvp_f(self, xf, dc, constants, op): Fx = getattr(self._constraint, "jac_" + op)(xf, constants) Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z @@ -1028,7 +1025,6 @@ def _jvp_f(self, xf, dc, constants, op): Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) return Fxh_inv @ Fc - @functools.partial(jit, static_argnames=("self", "op")) def _jvp(self, v, xf, xg, constants, op): # we're replacing stuff like this with jvps # Fx_reduced = Fx[:, unfixed_idx] @ Z # noqa: E800 @@ -1086,6 +1082,18 @@ def _jvp(self, v, xf, xg, constants, op): out = jnp.concatenate(out) return -out + @jit + def _jvp_scaled(self, v, xf, xg, constants): + return self._jvp(v, xf, xg, constants, "scaled") + + @jit + def _jvp_scaled_error(self, v, xf, xg, constants): + return self._jvp(v, xf, xg, constants, "scaled_error") + + @jit + def _jvp_unscaled(self, v, xf, xg, constants): + return self._jvp(v, xf, xg, constants, "unscaled") + @property def constants(self): """list: constant parameters for each sub-objective.""" From 2b04c096d1d09e8bb376ebde1e23bc4d10b90875 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:26:21 -0400 Subject: [PATCH 12/24] Better error message when using wrong things to get state vector --- desc/objectives/objective_funs.py | 37 ++++++++++++++++++++++++++++--- tests/test_objective_funs.py | 8 +++---- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index b42dc0040c..8e04112bb1 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -9,7 +9,14 @@ from desc.derivatives import Derivative from desc.io import IOAble from desc.optimizable import Optimizable -from desc.utils import Timer, flatten_list, is_broadcastable, setdefault, unique_list +from desc.utils import ( + Timer, + errorif, + flatten_list, + is_broadcastable, + setdefault, + unique_list, +) class ObjectiveFunction(IOAble): @@ -367,8 +374,19 @@ def x(self, *things): """Return the full state vector from the Optimizable objects things.""" # TODO: also check resolution etc? things = things or self.things - assert len(things) == len(self.things) - assert all([type(t1) is type(t2) for t1, t2 in zip(things, self.things)]) + errorif( + len(things) != len(self.things), + ValueError, + "Got the wrong number of things, " + f"expected {len(self.things)} got {len(things)}", + ) + for t1, t2 in zip(things, self.things): + errorif( + not isinstance(t1, type(t2)), + TypeError, + f"got incompatible types between things {type(t1)} " + f"and self.things {type(t2)}", + ) xs = [t.pack_params(t.params_dict) for t in things] return jnp.concatenate(xs) @@ -1188,6 +1206,19 @@ def print_value(self, *args, **kwargs): def xs(self, *things): """Return a tuple of args required by this objective from optimizable things.""" things = things or self.things + errorif( + len(things) != len(self.things), + ValueError, + "Got the wrong number of things, " + f"expected {len(self.things)} got {len(things)}", + ) + for t1, t2 in zip(things, self.things): + errorif( + not isinstance(t1, type(t2)), + TypeError, + f"got incompatible types between things {type(t1)} " + f"and self.things {type(t2)}", + ) return tuple([t.params_dict for t in things]) @property diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index d6f67a5ffc..d664ac356a 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -639,7 +639,7 @@ def test_plasma_vessel_distance(self): surface_fixed=True, ) obj.build() - d = obj.compute_unscaled(*obj.xs(eq, surface)) + d = obj.compute_unscaled(*obj.xs(eq)) assert d.size == obj.dim_f assert abs(d.min() - (a_s - a_p)) < 1e-14 assert abs(d.max() - (a_s - a_p)) < surf_grid.spacing[0, 1] * a_p @@ -1534,7 +1534,7 @@ def test_boundary_error_print(capsys): obj = VacuumBoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 2 f1 = f[:n] f2 = f[n:] @@ -1609,7 +1609,7 @@ def test_boundary_error_print(capsys): obj = BoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 2 f1 = f[:n] f2 = f[n:] @@ -1685,7 +1685,7 @@ def test_boundary_error_print(capsys): obj = BoundaryError(eq, coilset, field_grid=coil_grid) obj.build() - f = np.abs(obj.compute_unscaled(*obj.xs(eq))) + f = np.abs(obj.compute_unscaled(*obj.xs(eq, coilset))) n = len(f) // 3 f1 = f[:n] f2 = f[n : 2 * n] From 7b1553fe36ce226e9c426031d3d6f1003dd99a4f Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:27:07 -0400 Subject: [PATCH 13/24] Don't use object identity comparison in blocked jacobian --- desc/optimize/_constraint_wrappers.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 31384f4987..f141485845 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1060,16 +1060,20 @@ def _jvp(self, v, xf, xg, constants, op): vgs = jnp.split(tangent, np.cumsum(self._dimx_per_thing)) xgs = jnp.split(xg, np.cumsum(self._dimx_per_thing)) out = [] - for obj, const in zip( - self._objective.objectives, self._objective.constants + for k, (obj, const) in enumerate( + zip(self._objective.objectives, self._objective.constants) ): - xi = [x for x, t in zip(xgs, self._objective.things) if t in obj.things] - vi = [v for v, t in zip(vgs, self._objective.things) if t in obj.things] + thing_idx = self._objective._things_per_objective_idx[k] + xi = [xgs[i] for i in thing_idx] + vi = [vgs[i] for i in thing_idx] + assert len(xi) > 0 + assert len(vi) > 0 + assert len(xi) == len(vi) if obj._deriv_mode == "rev": - # obj might now allow fwd mode, so compute full rev mode jacobian + # obj might not allow fwd mode, so compute full rev mode jacobian # and do matmul manually. This is slightly inefficient, but usually # when rev mode is used, dim_f <<< dim_x, so its not too bad. - Ji = getattr(obj, "jac_" + op)(*xi, const) + Ji = getattr(obj, "jac_" + op)(*xi, constants=const) outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum( axis=0 ) From b2762464a667b70584d4d05af1bd98d5a30cc788 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Mon, 12 Aug 2024 23:37:12 -0400 Subject: [PATCH 14/24] Make sure objectives use jit by default --- desc/objectives/objective_funs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 8e04112bb1..eeaccc2c25 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -849,7 +849,7 @@ def __init__( self._normalization = 1 self._deriv_mode = deriv_mode self._name = name - self._use_jit = None + self._use_jit = True self._built = False self._loss_function = { "mean": jnp.mean, From e690715c7652a5437d298688c9f9f83130acd965 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 14 Aug 2024 17:36:30 -0400 Subject: [PATCH 15/24] Mark unhashable attributes as static --- desc/objectives/linear_objectives.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index d855a025a6..5905802f89 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -3184,6 +3184,7 @@ class FixNearAxisR(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "R_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" @@ -3320,6 +3321,7 @@ class FixNearAxisZ(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "Z_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(m)" @@ -3462,6 +3464,7 @@ class FixNearAxisLambda(_FixedObjective): """ + _static_attrs = ["_nae_eq"] _target_arg = "L_lmn" _fixed = False # not "diagonal", since its fixing a sum _units = "(dimensionless)" From de22dd88b5e6c116d5f8dabda2dea32fbf5b2f77 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 14 Aug 2024 17:36:59 -0400 Subject: [PATCH 16/24] Use hashable callable classes instead of local functions to avoid recompilation --- desc/objectives/objective_funs.py | 55 ++++++++++++++++++++++--------- desc/objectives/utils.py | 46 +++++++++++++++++++------- 2 files changed, 74 insertions(+), 27 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index eeaccc2c25..43ce4e7348 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -177,22 +177,8 @@ def _set_things(self, things=None): [i for i, t in enumerate(unique_) if t in obj.things] ) - def unflatten(unique): - assert len(unique) == len(unique_) - flat = [unique[i] for i in inds_] - return tree_unflatten(treedef_, flat) - - def flatten(things): - flat, treedef = tree_flatten( - things, is_leaf=lambda x: isinstance(x, Optimizable) - ) - assert treedef == treedef_ - assert len(flat) == len(flat_) - unique, _ = unique_list(flat) - return unique - - self._unflatten = unflatten - self._flatten = flatten + self._unflatten = _ThingUnflattener(len(unique_), inds_, treedef_) + self._flatten = _ThingFlattener(len(flat_), treedef_) @jit def compute_unscaled(self, x, constants=None): @@ -1317,3 +1303,40 @@ def things(self, new): self._things = list(new) # can maybe improve this later to not rebuild if resolution is the same self._built = False + + +# local functions assigned as attributes aren't hashable so they cause stuff to +# recompile, so instead we define a hashable class to do the same thing. + + +class _ThingUnflattener(IOAble): + + _static_attrs = ["length", "inds", "treedef"] + + def __init__(self, length, inds, treedef): + self.length = length + self.inds = inds + self.treedef = treedef + + def __call__(self, unique): + assert len(unique) == self.length + flat = [unique[i] for i in self.inds] + return tree_unflatten(self.treedef, flat) + + +class _ThingFlattener(IOAble): + + _static_attrs = ["length", "treedef"] + + def __init__(self, length, treedef): + self.length = length + self.treedef = treedef + + def __call__(self, things): + flat, treedef = tree_flatten( + things, is_leaf=lambda x: isinstance(x, Optimizable) + ) + assert treedef == self.treedef + assert len(flat) == self.length + unique, _ = unique_list(flat) + return unique diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 27a8d30b37..f34d7d2d0b 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -6,6 +6,7 @@ import numpy as np from desc.backend import cond, jit, jnp, logsumexp, put +from desc.io import IOAble from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif @@ -142,17 +143,8 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901 xp = put(xp, unfixed_idx, Ainv_full @ b) xp = jnp.asarray(xp) - @jit - def project(x): - """Project a full state vector into the reduced optimization vector.""" - x_reduced = Z.T @ ((x - xp)[unfixed_idx]) - return jnp.atleast_1d(jnp.squeeze(x_reduced)) - - @jit - def recover(x_reduced): - """Recover the full state vector from the reduced optimization vector.""" - dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced) - return jnp.atleast_1d(jnp.squeeze(xp + dx)) + project = _Project(Z, xp, unfixed_idx) + recover = _Recover(Z, xp, unfixed_idx, objective.dim_x) # check that all constraints are actually satisfiable params = objective.unpack_state(xp, False) @@ -200,6 +192,38 @@ def recover(x_reduced): return xp, A, b, Z, unfixed_idx, project, recover +class _Project(IOAble): + _io_attrs_ = ["Z", "xp", "unfixed_idx"] + + def __init__(self, Z, xp, unfixed_idx): + self.Z = Z + self.xp = xp + self.unfixed_idx = unfixed_idx + + @jit + def __call__(self, x): + """Project a full state vector into the reduced optimization vector.""" + x_reduced = self.Z.T @ ((x - self.xp)[self.unfixed_idx]) + return jnp.atleast_1d(jnp.squeeze(x_reduced)) + + +class _Recover(IOAble): + _io_attrs_ = ["Z", "xp", "unfixed_idx", "dim_x"] + _static_attrs = ["dim_x"] + + def __init__(self, Z, xp, unfixed_idx, dim_x): + self.Z = Z + self.xp = xp + self.unfixed_idx = unfixed_idx + self.dim_x = dim_x + + @jit + def __call__(self, x_reduced): + """Recover the full state vector from the reduced optimization vector.""" + dx = put(jnp.zeros(self.dim_x), self.unfixed_idx, self.Z @ x_reduced) + return jnp.atleast_1d(jnp.squeeze(self.xp + dx)) + + def softmax(arr, alpha): """JAX softmax implementation. From 330a04d88ae8b74b8bc16420f78eec8d26aa4dcd Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 14 Aug 2024 17:37:39 -0400 Subject: [PATCH 17/24] Move some proximal logic to pure jax functions to avoid recompilation --- desc/optimize/_constraint_wrappers.py | 114 ++++++++++++++------------ 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index f141485845..4249b36a00 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -1,5 +1,7 @@ """Wrappers for doing STELLOPT/SIMSOPT like optimization.""" +import functools + import numpy as np from desc.backend import jit, jnp @@ -970,7 +972,7 @@ def jvp_scaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp_scaled(u, xf, xg, constants) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_scaled_error(self, v, x, constants=None): @@ -990,7 +992,7 @@ def jvp_scaled_error(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp_scaled_error(u, xf, xg, constants) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled_error") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) def jvp_unscaled(self, v, x, constants=None): @@ -1010,21 +1012,9 @@ def jvp_unscaled(self, v, x, constants=None): v = v[0] if isinstance(v, (tuple, list)) else v constants = setdefault(constants, self.constants) xg, xf = self._update_equilibrium(x, store=True) - jvpfun = lambda u: self._jvp_unscaled(u, xf, xg, constants) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="unscaled") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) - def _jvp_f(self, xf, dc, constants, op): - Fx = getattr(self._constraint, "jac_" + op)(xf, constants) - Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z - Fc = Fx @ (self._dxdc @ dc) - Fxh = Fx_reduced - cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) - uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) - sf += sf[-1] # add a tiny bit of regularization - sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) - Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) - return Fxh_inv @ Fc - def _jvp(self, v, xf, xg, constants, op): # we're replacing stuff like this with jvps # Fx_reduced = Fx[:, unfixed_idx] @ Z # noqa: E800 @@ -1037,7 +1027,16 @@ def _jvp(self, v, xf, xg, constants, op): # want jvp_f to only get parts from equilibrium, not other things vs = jnp.split(v, np.cumsum(self._dimc_per_thing)) # this is Fx_reduced_inv @ Fc - dfdc = self._jvp_f(xf, vs[self._eq_idx], constants[1], op) + dfdc = _proximal_jvp_f_pure( + self._constraint, + xf, + constants[1], + vs[self._eq_idx], + self._unfixed_idx, + self._Z, + self._dxdc, + op, + ) # broadcasting against multiple things dfdcs = [jnp.zeros(dim) for dim in self._dimc_per_thing] dfdcs[self._eq_idx] = dfdc @@ -1059,45 +1058,9 @@ def _jvp(self, v, xf, xg, constants, op): else: # deriv_mode == "blocked" vgs = jnp.split(tangent, np.cumsum(self._dimx_per_thing)) xgs = jnp.split(xg, np.cumsum(self._dimx_per_thing)) - out = [] - for k, (obj, const) in enumerate( - zip(self._objective.objectives, self._objective.constants) - ): - thing_idx = self._objective._things_per_objective_idx[k] - xi = [xgs[i] for i in thing_idx] - vi = [vgs[i] for i in thing_idx] - assert len(xi) > 0 - assert len(vi) > 0 - assert len(xi) == len(vi) - if obj._deriv_mode == "rev": - # obj might not allow fwd mode, so compute full rev mode jacobian - # and do matmul manually. This is slightly inefficient, but usually - # when rev mode is used, dim_f <<< dim_x, so its not too bad. - Ji = getattr(obj, "jac_" + op)(*xi, constants=const) - outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum( - axis=0 - ) - out.append(outi) - else: - outi = getattr(obj, "jvp_" + op)( - [_vi for _vi in vi], xi, constants=const - ).T - out.append(outi) - out = jnp.concatenate(out) + out = _proximal_jvp_blocked_pure(self._objective, vgs, xgs, op) return -out - @jit - def _jvp_scaled(self, v, xf, xg, constants): - return self._jvp(v, xf, xg, constants, "scaled") - - @jit - def _jvp_scaled_error(self, v, xf, xg, constants): - return self._jvp(v, xf, xg, constants, "scaled_error") - - @jit - def _jvp_unscaled(self, v, xf, xg, constants): - return self._jvp(v, xf, xg, constants, "unscaled") - @property def constants(self): """list: constant parameters for each sub-objective.""" @@ -1106,3 +1069,48 @@ def constants(self): def __getattr__(self, name): """For other attributes we defer to the base objective.""" return getattr(self._objective, name) + + +# in ProximalProjection we have an explicit state that we keep track of (and add +# to as we go) meaning if we jit anything with self static it doesn't update +# correctly, while if we leave self unstatic then it recompiles every time because +# the pytree structure of ProximalProjection is changing. To get around that we +# define these helper functions that are stateless so we can safely jit them + + +@functools.partial(jit, static_argnames=["op"]) +def _proximal_jvp_f_pure(constraint, xf, constants, dc, unfixed_idx, Z, dxdc, op): + Fx = getattr(constraint, "jac_" + op)(xf, constants) + Fx_reduced = Fx[:, unfixed_idx] @ Z + Fc = Fx @ (dxdc @ dc) + Fxh = Fx_reduced + cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape) + uf, sf, vtf = jnp.linalg.svd(Fxh, full_matrices=False) + sf += sf[-1] # add a tiny bit of regularization + sfi = jnp.where(sf < cutoff * sf[0], 0, 1 / sf) + Fxh_inv = vtf.T @ (sfi[..., None] * uf.T) + return Fxh_inv @ Fc + + +@functools.partial(jit, static_argnames=["op"]) +def _proximal_jvp_blocked_pure(objective, vgs, xgs, op): + out = [] + for k, (obj, const) in enumerate(zip(objective.objectives, objective.constants)): + thing_idx = objective._things_per_objective_idx[k] + xi = [xgs[i] for i in thing_idx] + vi = [vgs[i] for i in thing_idx] + assert len(xi) > 0 + assert len(vi) > 0 + assert len(xi) == len(vi) + if obj._deriv_mode == "rev": + # obj might not allow fwd mode, so compute full rev mode jacobian + # and do matmul manually. This is slightly inefficient, but usually + # when rev mode is used, dim_f <<< dim_x, so its not too bad. + Ji = getattr(obj, "jac_" + op)(*xi, constants=const) + outi = jnp.array([Jii @ vii.T for Jii, vii in zip(Ji, vi)]).sum(axis=0) + out.append(outi) + else: + outi = getattr(obj, "jvp_" + op)([_vi for _vi in vi], xi, constants=const).T + out.append(outi) + out = jnp.concatenate(out) + return out From 5e30fdc1962bc2f33ffe49bda1d8084bd84eedd6 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 15 Aug 2024 01:40:23 -0400 Subject: [PATCH 18/24] remove qr_update stuff, only take p_newton out --- desc/backend.py | 5 -- desc/optimize/aug_lagrangian_ls.py | 17 +++--- desc/optimize/least_squares.py | 17 +++--- desc/optimize/tr_subproblems.py | 19 ++++--- desc/optimize/utils.py | 84 +----------------------------- 5 files changed, 24 insertions(+), 118 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 22c5c60c6f..c26213b045 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -76,7 +76,6 @@ repeat = jnp.repeat take = jnp.take scan = jax.lax.scan - rsqrt = jax.lax.rsqrt from jax import custom_jvp from jax.experimental.ode import odeint from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular @@ -812,7 +811,3 @@ def take( else: out = np.take(a, indices, axis, out, mode) return out - - def rsqrt(x): - """Reciprocal square root.""" - return 1 / np.sqrt(x) diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index fb222d776c..a5d75416b2 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -371,18 +371,13 @@ def lagjac(z, y, mu, *args): B_h = jnp.dot(J_a.T, J_a) elif tr_method == "qr": # try full newton step - tall = J_a.shape[0] >= J_a.shape[1] + tall = J.shape[0] >= J.shape[1] if tall: - Q, R = qr(J_a) - p_newton = solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ L_a - ) + Q, R = qr(J, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f) else: - Q, R = qr(J_a.T) - p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]].T, L_a, lower=True - ) - Q, R = qr(J_a) + Q, R = qr(J.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) actual_reduction = -1 Lactual_reduction = -1 @@ -405,7 +400,7 @@ def lagjac(z, y, mu, *args): ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - Q, R, p_newton, L_a, J_a, trust_radius, alpha + p_newton, L_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index a459e7f760..d373642995 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -271,18 +271,13 @@ def lsqtr( # noqa: C901 - FIXME: simplify this B_h = jnp.dot(J_a.T, J_a) elif tr_method == "qr": # try full newton step - tall = J_a.shape[0] >= J_a.shape[1] + tall = J.shape[0] >= J.shape[1] if tall: - Q, R = qr(J_a) - p_newton = solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ f_a - ) + Q, R = qr(J, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f) else: - Q, R = qr(J_a.T) - p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]].T, f_a, lower=True - ) - Q, R = qr(J_a) + Q, R = qr(J.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) actual_reduction = -1 @@ -304,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - Q, R, p_newton, f_a, J_a, trust_radius, alpha + p_newton, f_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/tr_subproblems.py b/desc/optimize/tr_subproblems.py index 20d874644a..8c39e82295 100644 --- a/desc/optimize/tr_subproblems.py +++ b/desc/optimize/tr_subproblems.py @@ -14,7 +14,7 @@ ) from desc.utils import setdefault -from .utils import chol, solve_triangular_regularized, update_qr_jax_eco +from .utils import chol, solve_triangular_regularized @jit @@ -378,7 +378,7 @@ def loop_body(state): @jit def trust_region_step_exact_qr( - Q, R, p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 + p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 ): """Solve a trust-region problem using a semi-exact method. @@ -441,15 +441,18 @@ def loop_body(state): jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), alpha, ) - Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - p = solve_triangular_regularized(R2, -Q2.T @ fp) + Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) + # Ji is always tall since its padded by alpha*I + Q, R = qr(Ji, mode="economic") + + p = solve_triangular_regularized(R, -Q.T @ fp) p_norm = jnp.linalg.norm(p) phi = p_norm - trust_radius alpha_upper = jnp.where(phi < 0, alpha, alpha_upper) alpha_lower = jnp.where(phi > 0, alpha, alpha_lower) - q = solve_triangular_regularized(R2.T, p, lower=True) + q = solve_triangular_regularized(R.T, p, lower=True) q_norm = jnp.linalg.norm(q) alpha += (p_norm / q_norm) ** 2 * phi / trust_radius @@ -465,9 +468,9 @@ def loop_body(state): loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) ) - Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - - p = solve_triangular(R2, -Q2.T @ fp) + Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) + Q, R = qr(Ji, mode="economic") + p = solve_triangular(R, -Q.T @ fp) # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index 8c31cdc1a2..bbee5a22f7 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import cond, fori_loop, jit, jnp, put, rsqrt, solve_triangular +from desc.backend import cond, jit, jnp, put, solve_triangular from desc.utils import Index @@ -551,85 +551,3 @@ def solve_triangular_regularized(R, b, lower=False): Rs = R * dri[:, None] b = dri * b return solve_triangular(Rs, b, unit_diagonal=True, lower=lower) - - -# TODO: add references to the docstrings -def _givens_jax(a, b): - """Compute Givens rotation matrix. - - Compute the Givens rotation matrix G2 that zeros out the second element - of a 2-vector. - G2*[a; b] = [r; 0] - where r = sqrt(a^2 + b^2) - G2 = [[c, -s], [s, c]] - """ - # Taken from jax._src.scipy.sparse.linalg._givens_rotation - b_zero = abs(b) == 0 - a_lt_b = abs(a) < abs(b) - t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a) - r = rsqrt(1 + abs(t) ** 2).astype(t.dtype) - cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r)) - sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t)) - G2 = jnp.array([[cs, -sn], [sn, cs]]) - return G2.astype(float) - - -@jit -def update_qr_jax(A, w, q, r): - """Update QR factorization with a diagonal matrix w at the bottom.""" - m, n = A.shape - Q = jnp.eye(m + n) - Q = Q.at[:m, :m].set(q) - - R = jnp.vstack([r, w]) - - def body_inner(i, jQR): - j, Q, R = jQR - i = m + j - i - a, b = R[i - 1, j], R[i, j] - G2 = _givens_jax(a, b) - R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) - Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T) - return j, Q, R - - def body(j, QR): - Q, R = QR - j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) - return Q, R - - Q, R = fori_loop(0, n, body, (Q, R)) - R = jnp.where(jnp.abs(R) < 1e-10, 0, R) - - return Q, R - - -@jit -def update_qr_jax_eco(A, w, q, r): - """Update QR factorization with a diagonal matrix w at the bottom.""" - m, n = A.shape - Q = jnp.eye(m + n) - Q = Q.at[:m, :m].set(q) - - R = jnp.vstack([r, w]) - - def body_inner(i, jQR): - j, Q, R = jQR - i = m + j - i - a, b = R[i - 1, j], R[i, j] - G2 = _givens_jax(a, b) - R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) - Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T) - return j, Q, R - - def body(j, QR): - Q, R = QR - j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) - return Q, R - - Q, R = fori_loop(0, n, body, (Q, R)) - R = jnp.where(jnp.abs(R) < 1e-10, 0, R) - - Re = R.at[: R.shape[1], : R.shape[1]].get() - Qe = Q.at[:, : R.shape[1]].get() - - return Qe, Re From dc1c6aaa448452701a69c05d2d112b43f9257521 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 15 Aug 2024 11:21:32 -0400 Subject: [PATCH 19/24] fix type J must be J_a --- desc/optimize/aug_lagrangian_ls.py | 6 +++--- desc/optimize/least_squares.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index a5d75416b2..507f500995 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -371,12 +371,12 @@ def lagjac(z, y, mu, *args): B_h = jnp.dot(J_a.T, J_a) elif tr_method == "qr": # try full newton step - tall = J.shape[0] >= J.shape[1] + tall = J_a.shape[0] >= J_a.shape[1] if tall: - Q, R = qr(J, mode="economic") + Q, R = qr(J_a, mode="economic") p_newton = solve_triangular_regularized(R, -Q.T @ f) else: - Q, R = qr(J.T, mode="economic") + Q, R = qr(J_a.T, mode="economic") p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) actual_reduction = -1 diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index d373642995..878f674390 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -271,12 +271,12 @@ def lsqtr( # noqa: C901 - FIXME: simplify this B_h = jnp.dot(J_a.T, J_a) elif tr_method == "qr": # try full newton step - tall = J.shape[0] >= J.shape[1] + tall = J_a.shape[0] >= J_a.shape[1] if tall: - Q, R = qr(J, mode="economic") + Q, R = qr(J_a, mode="economic") p_newton = solve_triangular_regularized(R, -Q.T @ f) else: - Q, R = qr(J.T, mode="economic") + Q, R = qr(J_a.T, mode="economic") p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) actual_reduction = -1 From b379412041c8110c737f02c7b7bcdf5ccf0b38a1 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 15 Aug 2024 13:15:24 -0400 Subject: [PATCH 20/24] fix typo f->f_a and f->L_a --- desc/optimize/aug_lagrangian_ls.py | 4 ++-- desc/optimize/least_squares.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index 507f500995..27a1d9f8ba 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -374,10 +374,10 @@ def lagjac(z, y, mu, *args): tall = J_a.shape[0] >= J_a.shape[1] if tall: Q, R = qr(J_a, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ f) + p_newton = solve_triangular_regularized(R, -Q.T @ L_a) else: Q, R = qr(J_a.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) + p_newton = Q @ solve_triangular_regularized(R.T, -L_a, lower=True) actual_reduction = -1 Lactual_reduction = -1 diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 878f674390..227cd93f70 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -274,10 +274,10 @@ def lsqtr( # noqa: C901 - FIXME: simplify this tall = J_a.shape[0] >= J_a.shape[1] if tall: Q, R = qr(J_a, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ f) + p_newton = solve_triangular_regularized(R, -Q.T @ f_a) else: Q, R = qr(J_a.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) + p_newton = Q @ solve_triangular_regularized(R.T, -f_a, lower=True) actual_reduction = -1 From 46926fbd27ac125c722b5d1a89c2cbc36dd43ded Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Sat, 17 Aug 2024 16:23:56 -0400 Subject: [PATCH 21/24] Add comments to explain some stuff better --- desc/objectives/objective_funs.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 43ce4e7348..59734aa261 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -64,7 +64,7 @@ def __init__( self._name = name def _set_derivatives(self): - """Set up derivatives of the objective functions.""" + """Choose derivative mode based on mode of sub-objectives.""" if self._deriv_mode == "auto": if all((obj._deriv_mode == "fwd") for obj in self.objectives): self._deriv_mode = "batched" @@ -394,7 +394,7 @@ def hess(self, x, constants=None): Derivative(self.compute_scalar, mode="hess")(x, constants).squeeze() ) - def _jac(self, op, x, constants=None): + def _jac_blocked(self, op, x, constants=None): # could also do something similar for grad and hess, but probably not # worth it. grad is already super cheap to eval all at once, and blocked # hess would only be block diag which may miss important interactions. @@ -405,6 +405,9 @@ def _jac(self, op, x, constants=None): xs = jnp.split(x, xs_splits) J = [] assert len(self.objectives) == len(self.constants) + # basic idea is we compute the jacobian of each objective wrt each thing + # one by one, and assemble into big block matrix + # if objective doesn't depend on a given thing, that part is set to 0. for k, (obj, const) in enumerate(zip(self.objectives, constants)): # get the xs that go to that objective thing_idx = self._things_per_objective_idx[k] @@ -412,13 +415,16 @@ def _jac(self, op, x, constants=None): Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt to just those things Ji = [] # jac wrt all things for i, thing in enumerate(self.things): - if i in thing_idx: + if i in thing_idx: # dfi/dxj != 0 Ji += [Ji_[thing_idx.index(i)]] - else: + else: # dfi/dxj == 0 Ji += [jnp.zeros((obj.dim_f, thing.dim_x))] - Ji = jnp.hstack(Ji) + Ji = jnp.hstack(Ji) # something like [df1/dx1, df1/dx2, 0] J += [Ji] - return jnp.vstack(J) + # something like [df1/dx1, df1/dx2, 0] + # [df2/dx1, 0, df2/dx3] # noqa:E800 + J = jnp.vstack(J) + return J @jit def jac_scaled(self, x, constants=None): @@ -431,7 +437,7 @@ def jac_scaled(self, x, constants=None): if self._deriv_mode == "looped": J = Derivative(self.compute_scaled, mode="looped")(x, constants) if self._deriv_mode == "blocked": - J = self._jac("jac_scaled", x, constants) + J = self._jac_blocked("jac_scaled", x, constants) return jnp.atleast_2d(J.squeeze()) @@ -446,7 +452,7 @@ def jac_scaled_error(self, x, constants=None): if self._deriv_mode == "looped": J = Derivative(self.compute_scaled_error, mode="looped")(x, constants) if self._deriv_mode == "blocked": - J = self._jac("jac_scaled_error", x, constants) + J = self._jac_blocked("jac_scaled_error", x, constants) return jnp.atleast_2d(J.squeeze()) @@ -461,7 +467,7 @@ def jac_unscaled(self, x, constants=None): if self._deriv_mode == "looped": J = Derivative(self.compute_unscaled, mode="looped")(x, constants) if self._deriv_mode == "blocked": - J = self._jac("jac_unscaled", x, constants) + J = self._jac_blocked("jac_unscaled", x, constants) return jnp.atleast_2d(J.squeeze()) @@ -847,7 +853,7 @@ def __init__( self._things = flatten_list([things], True) def _set_derivatives(self): - """Set up derivatives of the objective wrt each argument.""" + """Choose derivative mode based on size of inputs/outputs.""" if self._deriv_mode == "auto": # choose based on shape of jacobian. fwd mode is more memory efficient # so we prefer that unless the jacobian is really wide From 0cccaf3d7e635af986dcf9c48c8b26135d33482b Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 21 Aug 2024 14:53:35 -0400 Subject: [PATCH 22/24] Reorder nodes from create_meshgrid to match other sorting --- desc/grid.py | 35 +++++++++++++++++++++-------------- tests/test_grid.py | 25 +++++++++++++------------ 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/desc/grid.py b/desc/grid.py index 06ce1329ae..a37b70a6e4 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -216,7 +216,7 @@ def is_meshgrid(self): Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal coordinate value. The is_meshgrid flag denotes whether any coordinate can be iterated over along the relevant axis of the reshaped grid: - nodes.reshape(num_radial, num_poloidal, num_toroidal, 3). + nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F"). """ return self.__dict__.setdefault("_is_meshgrid", False) @@ -632,7 +632,7 @@ class Grid(_Grid): Let the tuple (r, p, t) ∈ R³ denote a radial, poloidal, and toroidal coordinate value. The is_meshgrid flag denotes whether any coordinate can be iterated over along the relevant axis of the reshaped grid: - nodes.reshape(num_radial, num_poloidal, num_toroidal, 3). + nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F"). jitable : bool Whether to skip certain checks and conditionals that don't work under jit. Allows grid to be created on the fly with custom nodes, but weights, symmetry @@ -762,11 +762,16 @@ def create_meshgrid( dc = _periodic_spacing(c, period[2])[1] * NFP else: da, db, dc = spacing + + bb, aa, cc = jnp.meshgrid(b, a, c, indexing="ij") + nodes = jnp.column_stack( - list(map(jnp.ravel, jnp.meshgrid(a, b, c, indexing="ij"))) + [aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")] ) + bb, aa, cc = jnp.meshgrid(db, da, dc, indexing="ij") + spacing = jnp.column_stack( - list(map(jnp.ravel, jnp.meshgrid(da, db, dc, indexing="ij"))) + [aa.flatten(order="F"), bb.flatten(order="F"), cc.flatten(order="F")] ) weights = ( spacing.prod(axis=1) @@ -776,19 +781,18 @@ def create_meshgrid( else None ) - unique_a_idx = jnp.arange(a.size) * b.size * c.size - unique_b_idx = jnp.arange(b.size) * c.size - unique_c_idx = jnp.arange(c.size) - inverse_a_idx = repeat( - unique_a_idx // (b.size * c.size), - b.size * c.size, - total_repeat_length=a.size * b.size * c.size, + unique_a_idx = jnp.arange(a.size) * b.size + unique_b_idx = jnp.arange(b.size) + unique_c_idx = jnp.arange(c.size) * a.size * b.size + inverse_a_idx = jnp.tile( + repeat(unique_a_idx // b.size, b.size, total_repeat_length=a.size * b.size), + c.size, ) inverse_b_idx = jnp.tile( - repeat(unique_b_idx // c.size, c.size, total_repeat_length=b.size * c.size), - a.size, + unique_b_idx, + a.size * c.size, ) - inverse_c_idx = jnp.tile(unique_c_idx, a.size * b.size) + inverse_c_idx = repeat(unique_c_idx // (a.size * b.size), (a.size * b.size)) return Grid( nodes=nodes, spacing=spacing, @@ -908,6 +912,7 @@ def __init__( self._toroidal_endpoint = False self._node_pattern = "linear" self._coordinates = "rtz" + self._is_meshgrid = True self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) self._nodes, self._spacing = self._create_nodes( L=L, @@ -1200,6 +1205,7 @@ def __init__(self, L, M, N, NFP=1): self._sym = False self._node_pattern = "quad" self._coordinates = "rtz" + self._is_meshgrid = True self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) self._nodes, self._spacing = self._create_nodes(L=L, M=M, N=N, NFP=NFP) # symmetry is never enforced for Quadrature Grid @@ -1341,6 +1347,7 @@ def __init__(self, L, M, N, NFP=1, sym=False, axis=False, node_pattern="jacobi") self._sym = sym self._node_pattern = node_pattern self._coordinates = "rtz" + self._is_meshgrid = False self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) self._nodes, self._spacing = self._create_nodes( L=L, M=M, N=N, NFP=NFP, axis=axis, node_pattern=node_pattern diff --git a/tests/test_grid.py b/tests/test_grid.py index 9d298b8c85..a47961dda7 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -738,20 +738,21 @@ def test_meshgrid(self): """Test meshgrid constructor.""" R = np.linspace(0, 1, 4) A = np.linspace(0, 2 * np.pi, 2) - Z = np.linspace(0, 10 * np.pi, 3) + Z = np.linspace(0, 2 * np.pi, 3) grid = Grid.create_meshgrid( - [R, A, Z], coordinates="raz", period=(np.inf, 2 * np.pi, np.inf) + [R, A, Z], coordinates="raz", period=(np.inf, 2 * np.pi, 2 * np.pi) ) - r, a, z = grid.nodes.T - _, unique, inverse = np.unique(r, return_index=True, return_inverse=True) - np.testing.assert_allclose(grid.unique_rho_idx, unique) - np.testing.assert_allclose(grid.inverse_rho_idx, inverse) - _, unique, inverse = np.unique(a, return_index=True, return_inverse=True) - np.testing.assert_allclose(grid.unique_alpha_idx, unique) - np.testing.assert_allclose(grid.inverse_alpha_idx, inverse) - _, unique, inverse = np.unique(z, return_index=True, return_inverse=True) - np.testing.assert_allclose(grid.unique_zeta_idx, unique) - np.testing.assert_allclose(grid.inverse_zeta_idx, inverse) + # treating theta == alpha just for grid construction + grid1 = LinearGrid(rho=R, theta=A, zeta=Z) + # atol=1e-12 bc Grid by default shifts points away from the axis a tiny bit + np.testing.assert_allclose(grid1.nodes, grid.nodes, atol=1e-12) + # want radial/poloidal/toroidal nodes sorted in the same order for both + np.testing.assert_allclose(grid1.unique_rho_idx, grid.unique_rho_idx) + np.testing.assert_allclose(grid1.unique_theta_idx, grid.unique_alpha_idx) + np.testing.assert_allclose(grid1.unique_zeta_idx, grid.unique_zeta_idx) + np.testing.assert_allclose(grid1.inverse_rho_idx, grid.inverse_rho_idx) + np.testing.assert_allclose(grid1.inverse_theta_idx, grid.inverse_alpha_idx) + np.testing.assert_allclose(grid1.inverse_zeta_idx, grid.inverse_zeta_idx) @pytest.mark.unit From 5052e026dcec533f15f07d316163c3aff3ad27f8 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 21 Aug 2024 17:07:49 -0400 Subject: [PATCH 23/24] Add utility for reshaping meshgrids --- desc/grid.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_grid.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/desc/grid.py b/desc/grid.py index a37b70a6e4..2b0914bb8d 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -598,6 +598,52 @@ def replace_at_axis(self, x, y, copy=False, **kwargs): ) return x + def meshgrid_reshape(self, x, order): + """Reshape data to match grid coordinates. + + Given flattened data on a tensor product grid, reshape the data such that + the axes of the array correspond to coordinate values on the grid. + + Parameters + ---------- + x : ndarray, shape(N,) or shape(N,3) + Data to reshape. + order : str + Desired order of axes for returned data. Should be a permutation of + ``grid.coordinates``, eg ``order="rtz"`` has the first axis of the returned + data correspond to different rho coordinates, the second axis to different + theta, etc. ``order="trz"`` would have the first axis correspond to theta, + and so on. + + Returns + ------- + x : ndarray + Data reshaped to align with grid nodes. + """ + errorif( + not self.is_meshgrid, + ValueError, + "grid is not a tensor product grid, so meshgrid_reshape doesn't " + "make any sense", + ) + errorif( + sorted(order) != sorted(self.coordinates), + ValueError, + f"order should be a permutation of {self.coordinates}, got {order}", + ) + shape = (self.num_poloidal, self.num_rho, self.num_zeta) + vec = False + if x.ndim > 1: + vec = True + shape += (-1,) + x = x.reshape(shape, order="F") + x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc + newax = tuple(list(self.coordinates).index(c) for c in order) + if vec: + newax += (3,) + x = jnp.transpose(x, newax) + return x + class Grid(_Grid): """Collocation grid with custom node placement. diff --git a/tests/test_grid.py b/tests/test_grid.py index a47961dda7..b475410625 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -754,6 +754,46 @@ def test_meshgrid(self): np.testing.assert_allclose(grid1.inverse_theta_idx, grid.inverse_alpha_idx) np.testing.assert_allclose(grid1.inverse_zeta_idx, grid.inverse_zeta_idx) + @pytest.mark.unit + def test_meshgrid_reshape(self): + """Test that reshaping meshgrids works correctly.""" + grid = LinearGrid(2, 3, 4) + + r = grid.nodes[grid.unique_rho_idx, 0] + t = grid.nodes[grid.unique_theta_idx, 1] + z = grid.nodes[grid.unique_zeta_idx, 2] + + # reshaping rtz should have rho along first axis + np.testing.assert_allclose( + grid.meshgrid_reshape(grid.nodes[:, 0], "rtz")[0], r[0] + ) + np.testing.assert_allclose( + grid.meshgrid_reshape(grid.nodes[:, 0], "rtz")[2], r[2] + ) + # reshaping rzt should have theta along last axis + np.testing.assert_allclose( + grid.meshgrid_reshape(grid.nodes[:, 1], "rzt")[:, :, 0], t[0] + ) + np.testing.assert_allclose( + grid.meshgrid_reshape(grid.nodes[:, 1], "rzt")[:, :, 3], t[3] + ) + # reshaping tzr should have zeta along 2nd axis + np.testing.assert_allclose( + grid.meshgrid_reshape(grid.nodes, "tzr")[:, 0, :, 2], z[0] + ) + np.testing.assert_allclose( + grid.meshgrid_reshape(grid.nodes, "tzr")[:, 3, :, 2], z[3] + ) + + # coordinates are rtz, not raz + with pytest.raises(ValueError): + grid.meshgrid_reshape(grid.nodes[:, 0], "raz") + + # not a meshgrid + grid = ConcentricGrid(2, 3, 4) + with pytest.raises(ValueError): + grid.meshgrid_reshape(grid.nodes[:, 0], "rtz") + @pytest.mark.unit def test_find_most_rational_surfaces(): From 3bd14bf80225861b5c36bbdcda550f255aaa269b Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Wed, 21 Aug 2024 20:35:36 -0400 Subject: [PATCH 24/24] Requested changes --- desc/grid.py | 2 +- tests/test_grid.py | 50 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/desc/grid.py b/desc/grid.py index 2b0914bb8d..ee471e5d1b 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -638,7 +638,7 @@ def meshgrid_reshape(self, x, order): shape += (-1,) x = x.reshape(shape, order="F") x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc - newax = tuple(list(self.coordinates).index(c) for c in order) + newax = tuple(self.coordinates.index(c) for c in order) if vec: newax += (3,) x = jnp.transpose(x, newax) diff --git a/tests/test_grid.py b/tests/test_grid.py index b475410625..160c6aac9c 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -763,26 +763,18 @@ def test_meshgrid_reshape(self): t = grid.nodes[grid.unique_theta_idx, 1] z = grid.nodes[grid.unique_zeta_idx, 2] + # user regular allclose for broadcasting to work correctly # reshaping rtz should have rho along first axis - np.testing.assert_allclose( - grid.meshgrid_reshape(grid.nodes[:, 0], "rtz")[0], r[0] - ) - np.testing.assert_allclose( - grid.meshgrid_reshape(grid.nodes[:, 0], "rtz")[2], r[2] + assert np.allclose( + grid.meshgrid_reshape(grid.nodes[:, 0], "rtz"), r[:, None, None] ) # reshaping rzt should have theta along last axis - np.testing.assert_allclose( - grid.meshgrid_reshape(grid.nodes[:, 1], "rzt")[:, :, 0], t[0] - ) - np.testing.assert_allclose( - grid.meshgrid_reshape(grid.nodes[:, 1], "rzt")[:, :, 3], t[3] + assert np.allclose( + grid.meshgrid_reshape(grid.nodes[:, 1], "rzt"), t[None, None, :] ) # reshaping tzr should have zeta along 2nd axis - np.testing.assert_allclose( - grid.meshgrid_reshape(grid.nodes, "tzr")[:, 0, :, 2], z[0] - ) - np.testing.assert_allclose( - grid.meshgrid_reshape(grid.nodes, "tzr")[:, 3, :, 2], z[3] + assert np.allclose( + grid.meshgrid_reshape(grid.nodes, "tzr")[:, :, :, 2], z[None, :, None] ) # coordinates are rtz, not raz @@ -794,6 +786,34 @@ def test_meshgrid_reshape(self): with pytest.raises(ValueError): grid.meshgrid_reshape(grid.nodes[:, 0], "rtz") + rho = np.linspace(0, 1, 3) + alpha = np.linspace(0, 2 * np.pi, 4) + zeta = np.linspace(0, 6 * np.pi, 5) + grid = Grid.create_meshgrid([rho, alpha, zeta], coordinates="raz") + r, a, z = grid.nodes.T + r = grid.meshgrid_reshape(r, "raz") + a = grid.meshgrid_reshape(a, "raz") + z = grid.meshgrid_reshape(z, "raz") + # functions of zeta should separate along first two axes + # since those are contiguous, this should work + f = z.reshape(-1, zeta.size) + for i in range(1, f.shape[0]): + np.testing.assert_allclose(f[i - 1], f[i]) + # likewise for rho + f = r.reshape(rho.size, -1) + for i in range(1, f.shape[-1]): + np.testing.assert_allclose(f[:, i - 1], f[:, i]) + # test reshaping result won't mix data + f = (a**2 + z).reshape(rho.size, alpha.size, zeta.size) + for i in range(1, f.shape[0]): + np.testing.assert_allclose(f[i - 1], f[i]) + f = (r**2 + z).reshape(rho.size, alpha.size, zeta.size) + for i in range(1, f.shape[1]): + np.testing.assert_allclose(f[:, i - 1], f[:, i]) + f = (r**2 + a).reshape(rho.size, alpha.size, zeta.size) + for i in range(1, f.shape[-1]): + np.testing.assert_allclose(f[..., i - 1], f[..., i]) + @pytest.mark.unit def test_find_most_rational_surfaces():