Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalObjective function to wrap external codes #1028

Draft
wants to merge 69 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
456a02b
initial commit
daniel-dudt May 17, 2024
6a557ec
get external objective working
daniel-dudt May 20, 2024
fc9ef77
test comparison to generic
daniel-dudt May 20, 2024
9a64f25
allow string kwargs in external fun
daniel-dudt May 21, 2024
ae84d5e
Merge branch 'master' into dd/external
ddudt May 21, 2024
11c1438
exclude ExternalObjective from tests
daniel-dudt May 21, 2024
bedcee1
Merge branch 'master' into dd/external
ddudt May 23, 2024
aff7d46
make external fun take eq as its argument
daniel-dudt May 23, 2024
7f3ff5b
Merge branch 'master' into dd/external
ddudt May 31, 2024
2bb9017
simplify wrapped fun to take params
daniel-dudt May 31, 2024
7b1cfaf
Merge branch 'master' into dd/external
ddudt Jun 4, 2024
debecad
numpifying to make vectorization work
daniel-dudt Jun 5, 2024
633fa5b
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jun 5, 2024
9ea37fb
Revert "numpifying to make vectorization work"
daniel-dudt Jun 5, 2024
87ab19f
vectorization working!
daniel-dudt Jun 7, 2024
5395611
allow vectorized to be an int
daniel-dudt Jun 11, 2024
52d58d0
fix numpy cond
daniel-dudt Jun 11, 2024
96bf929
Merge branch 'master' into dd/external
ddudt Jun 17, 2024
30aeea4
merging but no change?
daniel-dudt Jun 17, 2024
90296ea
update test with new UI
daniel-dudt Jun 17, 2024
d16e95d
remove unused pool code
daniel-dudt Jun 18, 2024
6b3f86d
Merge branch 'master' into dd/external
ddudt Jun 19, 2024
f9b7562
remove comment note
daniel-dudt Jul 17, 2024
fe5e95c
Merge branch 'master' into dd/external
ddudt Jul 17, 2024
f1f466b
fix black formatting from merge conflict
daniel-dudt Jul 18, 2024
ecc5b3b
repair test from merge conflict
daniel-dudt Jul 18, 2024
800b9bb
Merge branch 'master' into dd/external
ddudt Jul 18, 2024
bf62014
remove multiprocessing from ExternalObjective class
daniel-dudt Jul 18, 2024
0547bd7
jaxify as a util function
daniel-dudt Jul 19, 2024
e3057dd
Merge branch 'master' into dd/external
ddudt Jul 19, 2024
4323b8a
Merge branch 'master' into dd/external
ddudt Jul 22, 2024
03d0cb5
ExternalObjective no longer an ABC
daniel-dudt Jul 22, 2024
7723ecd
re-add print logic in backend
daniel-dudt Jul 22, 2024
4864b56
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
ef9711b
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
180c503
Merge branch 'yge/cpu' into dd/external
ddudt Jul 24, 2024
cea3f4a
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
09c02ec
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
0b2207f
exclude ExternalObjective from tests
daniel-dudt Jul 26, 2024
f6a395b
Merge branch 'master' into dd/external
ddudt Jul 26, 2024
aa570d4
scale FD derivatives by tangent norm
daniel-dudt Jul 26, 2024
d62d9ca
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jul 26, 2024
8c7bcb1
Merge branch 'master' into dd/external
ddudt Jul 30, 2024
7f1907b
Merge branch 'master' into dd/external
ddudt Aug 11, 2024
76b2a3c
Merge branch 'master' into dd/external
dpanici Aug 20, 2024
ef98142
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
16bb59b
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
a83a671
resolve merge conflict
daniel-dudt Aug 22, 2024
8beb2e6
Merge branch 'master' into dd/external
ddudt Aug 23, 2024
9e57ee1
Merge branch 'master' into dd/external
ddudt Aug 25, 2024
11521b2
fix formatting from merge conflict
daniel-dudt Aug 25, 2024
c004724
add static_attrs, update test
daniel-dudt Aug 26, 2024
37f9ee3
Merge branch 'master' into dd/external
ddudt Aug 27, 2024
aab0bdb
update with master
daniel-dudt Nov 7, 2024
829af5a
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
ba1a252
update depricated jax.pure_callback vmap arg
daniel-dudt Nov 12, 2024
bb8a535
update vmap_method
daniel-dudt Nov 12, 2024
56d6662
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
655fe06
Merge branch 'master' into dd/external
YigitElma Dec 4, 2024
44c25a2
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
795350d
remove duplicate line from merge conflict
daniel-dudt Dec 12, 2024
6fef120
fix test with block_until_ready
daniel-dudt Dec 12, 2024
021106d
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
0f35919
Merge branch 'master' into dd/external
ddudt Dec 17, 2024
24dd2f3
update documentation
daniel-dudt Dec 17, 2024
d3aa2dd
Merge branch 'master' into dd/external
ddudt Dec 18, 2024
ac1aa63
make vectorized a required arg
daniel-dudt Dec 18, 2024
4db5c9f
make ExternalObjective args keyword only
daniel-dudt Dec 19, 2024
96aec58
Merge branch 'master' into dd/external
ddudt Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
RadialForceBalance,
)
from ._free_boundary import BoundaryError, VacuumBoundaryError
from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser
from ._generic import (
ExternalObjective,
GenericObjective,
LinearObjectiveFromUser,
ObjectiveFromUser,
)
from ._geometry import (
AspectRatio,
BScaleLength,
Expand Down
244 changes: 237 additions & 7 deletions desc/objectives/_generic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Generic objectives that don't belong anywhere else."""

import functools
import inspect
import re

import numpy as np

from desc.backend import jnp
from desc.backend import jax, jnp
from desc.compute import data_index
from desc.compute.utils import _compute as compute_fun
from desc.compute.utils import get_profiles, get_transforms
Expand All @@ -15,6 +16,236 @@
from .objective_funs import _Objective


class ExternalObjective(_Objective):
"""Wrap an external code.

Similar to ``ObjectiveFromUser``, except derivatives of the objective function are
computed with finite differences instead of AD.
ddudt marked this conversation as resolved.
Show resolved Hide resolved

The user supplied function must take an Equilibrium as its only positional argument,
but can take additional keyword arguments.

Parameters
----------
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
fun : callable
Custom objective function.
dim_f : int
Dimension of the output of fun.
target : {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Must be broadcastable to Objective.dim_f. Defaults to ``target=0``.
bounds : tuple of {float, ndarray}, optional
Lower and upper bounds on the objective. Overrides target.
Both bounds must be broadcastable to to Objective.dim_f.
Defaults to ``target=0``.
weight : {float, ndarray}, optional
Weighting to apply to the Objective, relative to other Objectives.
Must be broadcastable to to Objective.dim_f
normalize : bool, optional
Whether to compute the error in physical units or non-dimensionalize.
Has no effect for this objective.
normalize_target : bool, optional
Whether target and bounds should be normalized before comparing to computed
values. If `normalize` is `True` and the target is in physical units,
this should also be set to True.
loss_function : {None, 'mean', 'min', 'max'}, optional
Loss function to apply to the objective values once computed. This loss function
is called on the raw compute value, before any shifting, scaling, or
normalization.
name : str, optional
Name of the objective function.

# TODO: add example

"""

_units = "(Unknown)"
_print_value_fmt = "External objective value: {:10.3e}"

def __init__(
self,
eq,
fun,
dim_f,
target=None,
bounds=None,
weight=1,
normalize=False,
normalize_target=False,
loss_function=None,
fd_step=1e-4, # TODO: generalize this to allow a vector of different scales
ddudt marked this conversation as resolved.
Show resolved Hide resolved
name="external",
**kwargs,
):
if target is None and bounds is None:
target = 0

Check warning on line 83 in desc/objectives/_generic.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_generic.py#L83

Added line #L83 was not covered by tests
self._eq = eq.copy()
self._fun = fun
self._dim_f = dim_f
self._fd_step = fd_step
self._kwargs = kwargs
super().__init__(
things=eq,
target=target,
bounds=bounds,
weight=weight,
normalize=normalize,
normalize_target=normalize_target,
loss_function=loss_function,
deriv_mode="fwd",
name=name,
)

def build(self, use_jit=True, verbose=1):
"""Build constant arrays.

Parameters
----------
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Level of output.

"""
self._scalar = self._dim_f == 1
self._constants = {"quad_weights": 1.0}

def fun_wrapped(
R_lmn,
ddudt marked this conversation as resolved.
Show resolved Hide resolved
Z_lmn,
L_lmn,
p_l,
i_l,
c_l,
Psi,
Te_l,
ne_l,
Ti_l,
Zeff_l,
a_lmn,
Ra_n,
Za_n,
Rb_lmn,
Zb_lmn,
I,
G,
Phi_mn,
):
"""Wrap external function with optimiazable params arguments."""
for param in self._eq.optimizable_params:
ddudt marked this conversation as resolved.
Show resolved Hide resolved
par = eval(param) # FIXME: how bad is it to use eval here?
if len(par):
setattr(self._eq, param, par)
return self._fun(self._eq, **self._kwargs)

# check to make sure fun_wrapped has the correct signature
# in case we ever update Equilibrium.optimizable_params
args = inspect.getfullargspec(fun_wrapped).args
assert args == self._eq.optimizable_params

# wrap external function to work with JAX
abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f)
self._fun_wrapped = self._jaxify(fun_wrapped, abstract_eval)

super().build(use_jit=use_jit, verbose=verbose)

def compute(self, params, constants=None):
"""Compute the quantity.

Parameters
----------
params : list of dict
List of dictionaries of degrees of freedom, eg CoilSet.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants

Returns
-------
f : ndarray
Computed quantity.

"""
# ensure positional args are passed in the correct order
args = [params[k] for k in self._eq.optimizable_params]
f = self._fun_wrapped(*args)
return f

def _jaxify(self, func, abstract_eval):
ddudt marked this conversation as resolved.
Show resolved Hide resolved
"""Make an external (python) function work with JAX.

Positional arguments to func can be differentiated,
use keyword args for static values and non-differentiable stuff.

Note: Only forward mode differentiation is supported currently.

Parameters
----------
func : callable
Function to wrap. Should be a "pure" function, in that it has no side
effects and doesn't maintain state. Does not need to be JAX transformable.
abstract_eval : callable
Auxilliary function that computes the output shape and dtype of func.
**Must be JAX transformable**. Should be of the form

abstract_eval(*args, **kwargs) -> Pytree with same shape and dtype as
func(*args, **kwargs)

For example, if func always returns a scalar:

abstract_eval = lambda *args, **kwargs: jnp.array(1.)

Or if func takes an array of shape(n) and returns a dict of arrays of
shape(n-2):

abstract_eval = lambda arr, **kwargs:
{"out1": jnp.empty(arr.size-2), "out2": jnp.empty(arr.size-2)}

Returns
-------
func : callable
New function that behaves as func but works with jit/vmap/jacfwd etc.

"""

def wrap_pure_callback(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result_shape_dtype = abstract_eval(*args, **kwargs)
return jax.pure_callback(func, result_shape_dtype, *args, **kwargs)

return wrapper

def define_fd_jvp(func):
func = jax.custom_jvp(func)

@func.defjvp
def func_jvp(primals, tangents):
primal_out = func(*primals)

# flatten everything into 1D vectors for easier finite differences
y, unflaty = jax.flatten_util.ravel_pytree(primal_out)
x, unflatx = jax.flatten_util.ravel_pytree(primals)
v, _______ = jax.flatten_util.ravel_pytree(tangents)
# scale to unit norm if nonzero
normv = jnp.linalg.norm(v)
vh = jnp.where(normv == 0, v, v / normv)

def f(x):
return jax.flatten_util.ravel_pytree(func(*unflatx(x)))[0]

tangent_out = (f(x + self._fd_step * vh) - y) / self._fd_step * normv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may want to also scale fd_step by norm(x) to ensure that dx << x

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize vh to be same norm as |x|

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and undo at the end as well considering the scales)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x here is really params, which includes variables at very different scales like boundary coefficients vs current profile coefficients. We are using the same finite difference step size for everything, which means that step could be huge for some variables and tiny for others. Then when generating wout files and running external codes for those perturbed equilibria, the results might not be physically meaningful.

tangent_out = unflaty(tangent_out)

return primal_out, tangent_out

return func

return define_fd_jvp(wrap_pure_callback(func))


class GenericObjective(_Objective):
"""A generic objective that can compute any quantity from the `data_index`.

Expand Down Expand Up @@ -58,7 +289,7 @@

"""

_print_value_fmt = "GenericObjective value: {:10.3e} "
_print_value_fmt = "Generic objective value: {:10.3e} "

def __init__(
self,
Expand Down Expand Up @@ -207,7 +438,7 @@
_linear = True
_fixed = True
_units = "(Unknown)"
_print_value_fmt = "Custom linear Objective value: {:10.3e}"
_print_value_fmt = "Custom linear objective value: {:10.3e}"

def __init__(
self,
Expand Down Expand Up @@ -345,18 +576,17 @@
def myfun(grid, data):
# This will compute the flux surface average of the function
# R*B_T from the Grad-Shafranov equation
f = data['R']*data['B_phi']
f = data['R'] * data['B_phi']
f_fsa = surface_averages(grid, f, sqrt_g=data['sqrt_g'])
# this has the FSA values on the full grid, but we just want
# the unique values:
# this is the FSA on the full grid, but we only want the unique values:
return grid.compress(f_fsa)

myobj = ObjectiveFromUser(myfun)

"""

_units = "(Unknown)"
_print_value_fmt = "Custom Objective value: {:10.3e}"
_print_value_fmt = "Custom objective value: {:10.3e}"

def __init__(
self,
Expand Down
5 changes: 2 additions & 3 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(

def _set_derivatives(self):
"""Set up derivatives of the objective functions."""
# TODO: does deriv_mode have to be "blocked" if there is an ExternalObjective?
ddudt marked this conversation as resolved.
Show resolved Hide resolved
if self._deriv_mode == "auto":
if all((obj._deriv_mode == "fwd") for obj in self.objectives):
self._deriv_mode = "batched"
Expand Down Expand Up @@ -90,9 +91,7 @@ def jac_(op, x, constants=None):
for obj, const in zip(self.objectives, constants):
# get the xs that go to that objective
xi = [x for x, t in zip(xs, self.things) if t in obj.things]
Ji_ = getattr(obj, op)(
*xi, constants=const
) # jac wrt to just those things
Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt only xi
Ji = [] # jac wrt all things
for thing in self.things:
if thing in obj.things:
Expand Down
Loading
Loading