Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalObjective function to wrap external codes #1028

Draft
wants to merge 69 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
456a02b
initial commit
daniel-dudt May 17, 2024
6a557ec
get external objective working
daniel-dudt May 20, 2024
fc9ef77
test comparison to generic
daniel-dudt May 20, 2024
9a64f25
allow string kwargs in external fun
daniel-dudt May 21, 2024
ae84d5e
Merge branch 'master' into dd/external
ddudt May 21, 2024
11c1438
exclude ExternalObjective from tests
daniel-dudt May 21, 2024
bedcee1
Merge branch 'master' into dd/external
ddudt May 23, 2024
aff7d46
make external fun take eq as its argument
daniel-dudt May 23, 2024
7f3ff5b
Merge branch 'master' into dd/external
ddudt May 31, 2024
2bb9017
simplify wrapped fun to take params
daniel-dudt May 31, 2024
7b1cfaf
Merge branch 'master' into dd/external
ddudt Jun 4, 2024
debecad
numpifying to make vectorization work
daniel-dudt Jun 5, 2024
633fa5b
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jun 5, 2024
9ea37fb
Revert "numpifying to make vectorization work"
daniel-dudt Jun 5, 2024
87ab19f
vectorization working!
daniel-dudt Jun 7, 2024
5395611
allow vectorized to be an int
daniel-dudt Jun 11, 2024
52d58d0
fix numpy cond
daniel-dudt Jun 11, 2024
96bf929
Merge branch 'master' into dd/external
ddudt Jun 17, 2024
30aeea4
merging but no change?
daniel-dudt Jun 17, 2024
90296ea
update test with new UI
daniel-dudt Jun 17, 2024
d16e95d
remove unused pool code
daniel-dudt Jun 18, 2024
6b3f86d
Merge branch 'master' into dd/external
ddudt Jun 19, 2024
f9b7562
remove comment note
daniel-dudt Jul 17, 2024
fe5e95c
Merge branch 'master' into dd/external
ddudt Jul 17, 2024
f1f466b
fix black formatting from merge conflict
daniel-dudt Jul 18, 2024
ecc5b3b
repair test from merge conflict
daniel-dudt Jul 18, 2024
800b9bb
Merge branch 'master' into dd/external
ddudt Jul 18, 2024
bf62014
remove multiprocessing from ExternalObjective class
daniel-dudt Jul 18, 2024
0547bd7
jaxify as a util function
daniel-dudt Jul 19, 2024
e3057dd
Merge branch 'master' into dd/external
ddudt Jul 19, 2024
4323b8a
Merge branch 'master' into dd/external
ddudt Jul 22, 2024
03d0cb5
ExternalObjective no longer an ABC
daniel-dudt Jul 22, 2024
7723ecd
re-add print logic in backend
daniel-dudt Jul 22, 2024
4864b56
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
ef9711b
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
180c503
Merge branch 'yge/cpu' into dd/external
ddudt Jul 24, 2024
cea3f4a
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
09c02ec
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
0b2207f
exclude ExternalObjective from tests
daniel-dudt Jul 26, 2024
f6a395b
Merge branch 'master' into dd/external
ddudt Jul 26, 2024
aa570d4
scale FD derivatives by tangent norm
daniel-dudt Jul 26, 2024
d62d9ca
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jul 26, 2024
8c7bcb1
Merge branch 'master' into dd/external
ddudt Jul 30, 2024
7f1907b
Merge branch 'master' into dd/external
ddudt Aug 11, 2024
76b2a3c
Merge branch 'master' into dd/external
dpanici Aug 20, 2024
ef98142
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
16bb59b
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
a83a671
resolve merge conflict
daniel-dudt Aug 22, 2024
8beb2e6
Merge branch 'master' into dd/external
ddudt Aug 23, 2024
9e57ee1
Merge branch 'master' into dd/external
ddudt Aug 25, 2024
11521b2
fix formatting from merge conflict
daniel-dudt Aug 25, 2024
c004724
add static_attrs, update test
daniel-dudt Aug 26, 2024
37f9ee3
Merge branch 'master' into dd/external
ddudt Aug 27, 2024
aab0bdb
update with master
daniel-dudt Nov 7, 2024
829af5a
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
ba1a252
update depricated jax.pure_callback vmap arg
daniel-dudt Nov 12, 2024
bb8a535
update vmap_method
daniel-dudt Nov 12, 2024
56d6662
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
655fe06
Merge branch 'master' into dd/external
YigitElma Dec 4, 2024
44c25a2
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
795350d
remove duplicate line from merge conflict
daniel-dudt Dec 12, 2024
6fef120
fix test with block_until_ready
daniel-dudt Dec 12, 2024
021106d
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
0f35919
Merge branch 'master' into dd/external
ddudt Dec 17, 2024
24dd2f3
update documentation
daniel-dudt Dec 17, 2024
d3aa2dd
Merge branch 'master' into dd/external
ddudt Dec 18, 2024
ac1aa63
make vectorized a required arg
daniel-dudt Dec 18, 2024
4db5c9f
make ExternalObjective args keyword only
daniel-dudt Dec 19, 2024
96aec58
Merge branch 'master' into dd/external
ddudt Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions desc/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Backend functions for DESC, with options for JAX or regular numpy."""

import functools
import multiprocessing as mp
import os
import warnings

Expand All @@ -11,15 +12,19 @@
from desc import config as desc_config
from desc import set_device

# only print details in the main process, not child processes spawned by multiprocessing
verbose = bool(mp.current_process().name == "MainProcess")

if os.environ.get("DESC_BACKEND") == "numpy":
jnp = np
use_jax = False
set_device(kind="cpu")
print(
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
if verbose:
print(

Check warning on line 23 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L22-L23

Added lines #L22 - L23 were not covered by tests
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
)
)
)
else:
if desc_config.get("device") is None:
set_device("cpu")
Expand All @@ -41,11 +46,12 @@
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = True
print(
f"DESC version {desc.__version__},"
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
if verbose:
print(
f"DESC version {desc.__version__}, "
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
del x, y
except ModuleNotFoundError:
jnp = np
Expand All @@ -59,11 +65,13 @@
desc.__version__, np.__version__, y.dtype
)
)
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")

if verbose:
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
)
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
jit = jax.jit
Expand Down Expand Up @@ -515,7 +523,7 @@
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, *operand):
def cond(pred, true_fun, false_fun, *operands):
"""Conditionally apply true_fun or false_fun.

This version is for the numpy backend, for jax backend see jax.lax.cond
Expand All @@ -528,7 +536,7 @@
Function (A -> B), to be applied if pred is True.
false_fun: callable
Function (A -> B), to be applied if pred is False.
operand: any
operands: any
input to either branch depending on pred. The type can be a scalar, array,
or any pytree (nested Python tuple/list/dict) thereof.

Expand All @@ -541,9 +549,9 @@

"""
if pred:
return true_fun(*operand)
return true_fun(*operands)
else:
return false_fun(*operand)
return false_fun(*operands)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14)
has_endpoint_dupe,
lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]),
lambda _: mask,
operand=None,
None,
)
else:
# If we don't have the idx attributes, we are forced to expand out.
Expand Down
7 changes: 6 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
RadialForceBalance,
)
from ._free_boundary import BoundaryError, VacuumBoundaryError
from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser
from ._generic import (
ExternalObjective,
GenericObjective,
LinearObjectiveFromUser,
ObjectiveFromUser,
)
from ._geometry import (
AspectRatio,
BScaleLength,
Expand Down
177 changes: 172 additions & 5 deletions desc/objectives/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import re
from abc import ABC

import numpy as np

Expand All @@ -11,12 +12,178 @@
from desc.compute.utils import _parse_parameterization, get_profiles, get_transforms
from desc.grid import QuadratureGrid
from desc.optimizable import OptimizableCollection
from desc.utils import errorif, parse_argname_change
from desc.utils import errorif, jaxify, parse_argname_change

from .linear_objectives import _FixedObjective
from .objective_funs import _Objective


class ExternalObjective(_Objective, ABC):
"""Wrap an external code.

Similar to ``ObjectiveFromUser``, except derivatives of the objective function are
computed with finite differences instead of AD. The function does not need not be
JAX transformable.

The user supplied function must take an Equilibrium as its only positional argument,
but can take additional keyword arguments.

Parameters
----------
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
fun : callable
External objective function. It must take an Equilibrium as its only positional
argument, but can take additional kewyord arguments. It does not need to be JAX
transformable.
dim_f : int
Dimension of the output of ``fun``.
target : {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Must be broadcastable to Objective.dim_f. Defaults to ``target=0``.
bounds : tuple of {float, ndarray}, optional
Lower and upper bounds on the objective. Overrides target.
Both bounds must be broadcastable to to Objective.dim_f.
Defaults to ``target=0``.
weight : {float, ndarray}, optional
Weighting to apply to the Objective, relative to other Objectives.
Must be broadcastable to to Objective.dim_f
normalize : bool, optional
Whether to compute the error in physical units or non-dimensionalize.
Has no effect for this objective.
normalize_target : bool, optional
Whether target and bounds should be normalized before comparing to computed
values. If `normalize` is `True` and the target is in physical units,
this should also be set to True.
loss_function : {None, 'mean', 'min', 'max'}, optional
Loss function to apply to the objective values once computed. This loss function
is called on the raw compute value, before any shifting, scaling, or
normalization.
vectorized : bool, optional
Whether or not ``fun`` is vectorized. Default = False.
abs_step : float, optional
Absolute finite difference step size. Default = 1e-4.
Total step size is ``abs_step + rel_step * mean(abs(x))``.
rel_step : float, optional
Relative finite difference step size. Default = 0.
Total step size is ``abs_step + rel_step * mean(abs(x))``.
name : str, optional
Name of the objective function.
kwargs : any, optional
Keyword arguments that are passed as inputs to ``fun``.

# TODO: add example

"""

_units = "(Unknown)"
_print_value_fmt = "External objective value: {:10.3e}"

def __init__(
self,
eq,
fun,
dim_f,
target=None,
bounds=None,
weight=1,
normalize=False,
normalize_target=False,
loss_function=None,
vectorized=False,
abs_step=1e-4,
rel_step=0,
name="external",
**kwargs,
):
if target is None and bounds is None:
target = 0

Check warning on line 100 in desc/objectives/_generic.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_generic.py#L100

Added line #L100 was not covered by tests
self._eq = eq.copy()
self._fun = fun
self._dim_f = dim_f
self._vectorized = vectorized
self._abs_step = abs_step
self._rel_step = rel_step
self._kwargs = kwargs
super().__init__(
things=eq,
target=target,
bounds=bounds,
weight=weight,
normalize=normalize,
normalize_target=normalize_target,
loss_function=loss_function,
deriv_mode="fwd",
name=name,
)

def build(self, use_jit=True, verbose=1):
"""Build constant arrays.

Parameters
----------
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Level of output.

"""
self._scalar = self._dim_f == 1
self._constants = {"quad_weights": 1.0}

def fun_wrapped(params):
"""Wrap external function with possibly vectorized params."""
# number of equilibria for vectorized computations
param_shape = params["Psi"].shape
num_eq = param_shape[0] if len(param_shape) > 1 else 1

# convert params to list of equilibria
eqs = [self._eq.copy() for _ in range(num_eq)]
for k, eq in enumerate(eqs):
# update equilibria with new params
for param_key in self._eq.optimizable_params:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this just be eq.params_dict = params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because when vectorized params is a dict of 2D arrays with each row the params for a different equilibrium

param_value = np.atleast_2d(params[param_key])[k, :]
if len(param_value):
setattr(eq, param_key, param_value)

# call external function on equilibrium or list of equilibria
if not self._vectorized:
eqs = eqs[0]
return self._fun(eqs, **self._kwargs)

# wrap external function to work with JAX
abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f)
self._fun_wrapped = jaxify(
fun_wrapped,
abstract_eval,
vectorized=self._vectorized,
abs_step=self._abs_step,
rel_step=self._rel_step,
)

super().build(use_jit=use_jit, verbose=verbose)

def compute(self, params, constants=None):
"""Compute the quantity.

Parameters
----------
params : list of dict
List of dictionaries of degrees of freedom, eg CoilSet.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants

Returns
-------
f : ndarray
Computed quantity.

"""
f = self._fun_wrapped(params)
return f


class GenericObjective(_Objective):
"""A generic objective that can compute any quantity from the `data_index`.

Expand Down Expand Up @@ -352,10 +519,9 @@
def myfun(grid, data):
# This will compute the flux surface average of the function
# R*B_T from the Grad-Shafranov equation
f = data['R']*data['B_phi']
f = data['R'] * data['B_phi']
f_fsa = surface_averages(grid, f, sqrt_g=data['sqrt_g'])
# this has the FSA values on the full grid, but we just want
# the unique values:
# this is the FSA on the full grid, but we only want the unique values:
return grid.compress(f_fsa)

myobj = ObjectiveFromUser(fun=myfun, thing=eq)
Expand Down Expand Up @@ -414,6 +580,8 @@
Level of output.

"""
import jax

thing = self.things[0]
if self._grid is None:
errorif(
Expand Down Expand Up @@ -444,7 +612,6 @@
).squeeze()

self._fun_wrapped = lambda data: self._fun(grid, data)
import jax

self._dim_f = jax.eval_shape(self._fun_wrapped, dummy_data).size
self._scalar = self._dim_f == 1
Expand Down
4 changes: 1 addition & 3 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def jac_(op, x, constants=None):
for obj, const in zip(self.objectives, constants):
# get the xs that go to that objective
xi = [x for x, t in zip(xs, self.things) if t in obj.things]
Ji_ = getattr(obj, op)(
*xi, constants=const
) # jac wrt to just those things
Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt only xi
Ji = [] # jac wrt all things
for thing in self.things:
if thing in obj.things:
Expand Down
Loading
Loading