Skip to content

Commit

Permalink
Allow GenericObjective and ObjectiveFromUser to work with generic…
Browse files Browse the repository at this point in the history
… "things" (#1061)

Resolves #841
  • Loading branch information
ddudt authored Jun 21, 2024
2 parents caed4ce + d2e9a95 commit f738381
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 80 deletions.
116 changes: 67 additions & 49 deletions desc/objectives/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from desc.backend import jnp
from desc.compute import data_index
from desc.compute.utils import _compute as compute_fun
from desc.compute.utils import get_profiles, get_transforms
from desc.compute.utils import _parse_parameterization, get_profiles, get_transforms
from desc.grid import QuadratureGrid
from desc.optimizable import OptimizableCollection
from desc.utils import errorif, parse_argname_change

from .linear_objectives import _FixedObjective
from .objective_funs import _Objective
Expand All @@ -22,8 +24,8 @@ class GenericObjective(_Objective):
----------
f : str
Name of the quantity to compute.
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
thing : Optimizable
Object that will be optimized to satisfy the Objective.
target : {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Must be broadcastable to Objective.dim_f. Defaults to ``target=0``.
Expand Down Expand Up @@ -51,19 +53,19 @@ class GenericObjective(_Objective):
of the objective. Has no effect on self.grad or self.hess which always use
reverse mode and forward over reverse mode respectively.
grid : Grid, optional
Collocation grid containing the nodes to evaluate at.
Defaults to ``QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid)``.
Collocation grid containing the nodes to evaluate at. Defaults to
``QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid)`` if thing is an Equilibrium.
name : str, optional
Name of the objective function.
"""

_print_value_fmt = "GenericObjective value: {:10.3e} "
_print_value_fmt = "Generic objective value: {:10.3e} "

def __init__(
self,
f,
eq,
thing,
target=None,
bounds=None,
weight=1,
Expand All @@ -73,13 +75,20 @@ def __init__(
deriv_mode="auto",
grid=None,
name="generic",
**kwargs,
):
errorif(
isinstance(thing, OptimizableCollection),
NotImplementedError,
"thing must be of type Optimizable and not OptimizableCollection.",
)
thing = parse_argname_change(thing, kwargs, "eq", "thing")
if target is None and bounds is None:
target = 0
self.f = f
self._grid = grid
super().__init__(
things=eq,
things=thing,
target=target,
bounds=bounds,
weight=weight,
Expand All @@ -89,17 +98,10 @@ def __init__(
deriv_mode=deriv_mode,
name=name,
)
self._scalar = not bool(
data_index["desc.equilibrium.equilibrium.Equilibrium"][self.f]["dim"]
)
self._coordinates = data_index["desc.equilibrium.equilibrium.Equilibrium"][
self.f
]["coordinates"]
self._units = (
"("
+ data_index["desc.equilibrium.equilibrium.Equilibrium"][self.f]["units"]
+ ")"
)
self._p = _parse_parameterization(thing)
self._scalar = not bool(data_index[self._p][self.f]["dim"])
self._coordinates = data_index[self._p][self.f]["coordinates"]
self._units = "(" + data_index[self._p][self.f]["units"] + ")"

def build(self, use_jit=True, verbose=1):
"""Build constant arrays.
Expand All @@ -112,21 +114,25 @@ def build(self, use_jit=True, verbose=1):
Level of output.
"""
eq = self.things[0]
thing = self.things[0]
if self._grid is None:
grid = QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid, eq.NFP)
errorif(
self._p != "desc.equilibrium.equilibrium.Equilibrium",
ValueError,
"grid must be supplied for things besides an Equilibrium.",
)
grid = QuadratureGrid(thing.L_grid, thing.M_grid, thing.N_grid, thing.NFP)
else:
grid = self._grid

p = "desc.equilibrium.equilibrium.Equilibrium"
if data_index[p][self.f]["dim"] == 0:
if data_index[self._p][self.f]["dim"] == 0:
self._dim_f = 1
elif data_index[p][self.f]["coordinates"] == "r":
elif data_index[self._p][self.f]["coordinates"] == "r":
self._dim_f = grid.num_rho
else:
self._dim_f = grid.num_nodes * np.prod(data_index[p][self.f]["dim"])
profiles = get_profiles(self.f, obj=eq, grid=grid)
transforms = get_transforms(self.f, obj=eq, grid=grid)
self._dim_f = grid.num_nodes * np.prod(data_index[self._p][self.f]["dim"])
profiles = get_profiles(self.f, obj=thing, grid=grid)
transforms = get_transforms(self.f, obj=thing, grid=grid)
self._constants = {
"transforms": transforms,
"profiles": profiles,
Expand All @@ -141,8 +147,8 @@ def compute(self, params, constants=None):
params : dict
Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants
Dictionary of constant data, eg transforms, profiles etc.
Defaults to self.constants
Returns
-------
Expand All @@ -153,7 +159,7 @@ def compute(self, params, constants=None):
if constants is None:
constants = self.constants
data = compute_fun(
"desc.equilibrium.equilibrium.Equilibrium",
self._p,
self.f,
params=params,
transforms=constants["transforms"],
Expand Down Expand Up @@ -207,7 +213,7 @@ class LinearObjectiveFromUser(_FixedObjective):
_linear = True
_fixed = True
_units = "(Unknown)"
_print_value_fmt = "Custom linear Objective value: {:10.3e}"
_print_value_fmt = "Custom linear objective value: {:10.3e}"

def __init__(
self,
Expand Down Expand Up @@ -302,8 +308,8 @@ class ObjectiveFromUser(_Objective):
----------
fun : callable
Custom objective function.
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
thing : Optimizable
Object that will be optimized to satisfy the Objective.
target : {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Must be broadcastable to Objective.dim_f. Defaults to ``target=0``.
Expand Down Expand Up @@ -331,17 +337,17 @@ class ObjectiveFromUser(_Objective):
of the objective. Has no effect on self.grad or self.hess which always use
reverse mode and forward over reverse mode respectively.
grid : Grid, optional
Collocation grid containing the nodes to evaluate at.
Defaults to ``QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid)``.
Collocation grid containing the nodes to evaluate at. Defaults to
``QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid)`` if thing is an Equilibrium.
name : str, optional
Name of the objective function.
Examples
--------
.. code-block:: python
from desc.compute.utils import surface_averages
def myfun(grid, data):
# This will compute the flux surface average of the function
# R*B_T from the Grad-Shafranov equation
Expand All @@ -351,17 +357,17 @@ def myfun(grid, data):
# the unique values:
return grid.compress(f_fsa)
myobj = ObjectiveFromUser(myfun)
myobj = ObjectiveFromUser(fun=myfun, thing=eq)
"""

_units = "(Unknown)"
_print_value_fmt = "Custom Objective value: {:10.3e}"
_print_value_fmt = "Custom objective value: {:10.3e}"

def __init__(
self,
fun,
eq,
thing,
target=None,
bounds=None,
weight=1,
Expand All @@ -371,13 +377,20 @@ def __init__(
deriv_mode="auto",
grid=None,
name="custom",
**kwargs,
):
errorif(
isinstance(thing, OptimizableCollection),
NotImplementedError,
"thing must be of type Optimizable and not OptimizableCollection.",
)
thing = parse_argname_change(thing, kwargs, "eq", "thing")
if target is None and bounds is None:
target = 0
self._fun = fun
self._grid = grid
super().__init__(
things=eq,
things=thing,
target=target,
bounds=bounds,
weight=weight,
Expand All @@ -387,6 +400,7 @@ def __init__(
deriv_mode=deriv_mode,
name=name,
)
self._p = _parse_parameterization(thing)

def build(self, use_jit=True, verbose=1):
"""Build constant arrays.
Expand All @@ -399,9 +413,14 @@ def build(self, use_jit=True, verbose=1):
Level of output.
"""
eq = self.things[0]
thing = self.things[0]
if self._grid is None:
grid = QuadratureGrid(eq.L_grid, eq.M_grid, eq.N_grid, eq.NFP)
errorif(
self._p != "desc.equilibrium.equilibrium.Equilibrium",
ValueError,
"grid must be supplied for things besides an Equilibrium.",
)
grid = QuadratureGrid(thing.L_grid, thing.M_grid, thing.N_grid, thing.NFP)
else:
grid = self._grid

Expand All @@ -414,23 +433,22 @@ def get_vars(fun):

self._data_keys = get_vars(self._fun)
dummy_data = {}
p = "desc.equilibrium.equilibrium.Equilibrium"
for key in self._data_keys:
assert key in data_index[p], f"Don't know how to compute {key}"
if data_index[p][key]["dim"] == 0:
assert key in data_index[self._p], f"Don't know how to compute {key}."
if data_index[self._p][key]["dim"] == 0:
dummy_data[key] = jnp.array(0.0)
else:
dummy_data[key] = jnp.empty(
(grid.num_nodes, data_index[p][key]["dim"])
(grid.num_nodes, data_index[self._p][key]["dim"])
).squeeze()

self._fun_wrapped = lambda data: self._fun(grid, data)
import jax

self._dim_f = jax.eval_shape(self._fun_wrapped, dummy_data).size
self._scalar = self._dim_f == 1
profiles = get_profiles(self._data_keys, obj=eq, grid=grid)
transforms = get_transforms(self._data_keys, obj=eq, grid=grid)
profiles = get_profiles(self._data_keys, obj=thing, grid=grid)
transforms = get_transforms(self._data_keys, obj=thing, grid=grid)
self._constants = {
"transforms": transforms,
"profiles": profiles,
Expand Down Expand Up @@ -459,7 +477,7 @@ def compute(self, params, constants=None):
if constants is None:
constants = self.constants
data = compute_fun(
"desc.equilibrium.equilibrium.Equilibrium",
self._p,
self._data_keys,
params=params,
transforms=constants["transforms"],
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/tutorials/omnigenity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@
"objective = ObjectiveFunction(\n",
" (\n",
" # target major radius of R0=1 m\n",
" GenericObjective(\"R0\", eq=eq, target=1.0, name=\"major radius\"),\n",
" GenericObjective(\"R0\", thing=eq, target=1.0, name=\"major radius\"),\n",
" # target aspect ratio R0/a<=10\n",
" AspectRatio(eq=eq, bounds=(0, 10)),\n",
" # omnigenity on the rho=0.5 surface\n",
Expand Down Expand Up @@ -1113,7 +1113,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def mirrorRatio(params):

objective = ObjectiveFunction(
(
GenericObjective("R0", eq=eq, target=1.0, name="major radius"),
GenericObjective("R0", thing=eq, target=1.0, name="major radius"),
AspectRatio(eq=eq, bounds=(0, 10)),
Omnigenity(
eq=eq,
Expand Down Expand Up @@ -896,7 +896,7 @@ def test_omnigenity_proximal():
# first, test optimizing the equilibrium with the field fixed
objective = ObjectiveFunction(
(
GenericObjective("R0", eq=eq, target=1.0, name="major radius"),
GenericObjective("R0", thing=eq, target=1.0, name="major radius"),
AspectRatio(eq=eq, bounds=(0, 10)),
Omnigenity(eq=eq, field=field, field_fixed=True), # field is fixed
)
Expand All @@ -908,12 +908,12 @@ def test_omnigenity_proximal():
FixPsi(eq=eq),
)
optimizer = Optimizer("proximal-lsq-exact")
eq, _ = optimizer.optimize(eq, objective, constraints, maxiter=2, verbose=3)
[eq], _ = optimizer.optimize(eq, objective, constraints, maxiter=2, verbose=3)

# second, test optimizing both the equilibrium and the field simultaneously
objective = ObjectiveFunction(
(
GenericObjective("R0", eq=eq, target=1.0, name="major radius"),
GenericObjective("R0", thing=eq, target=1.0, name="major radius"),
AspectRatio(eq=eq, bounds=(0, 10)),
Omnigenity(eq=eq, field=field), # field is not fixed
)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_correct_indexing_passed_modes():
objective = ObjectiveFunction(
(
# just need dummy objective for factorizing constraints
GenericObjective("0", eq=eq),
GenericObjective("0", thing=eq),
),
use_jit=False,
)
Expand Down Expand Up @@ -467,7 +467,7 @@ def test_correct_indexing_passed_modes_and_passed_target():
eq.surface = eq.get_surface_at(1.0)

objective = ObjectiveFunction(
(GenericObjective("0", eq=eq),),
(GenericObjective("0", thing=eq),),
use_jit=False,
)
objective.build()
Expand Down Expand Up @@ -530,7 +530,7 @@ def test_correct_indexing_passed_modes_axis():
eq.axis = eq.get_axis()

objective = ObjectiveFunction(
(GenericObjective("0", eq=eq),),
(GenericObjective("0", thing=eq),),
use_jit=False,
)
objective.build()
Expand Down Expand Up @@ -590,7 +590,7 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
eq.axis = eq.get_axis()

objective = ObjectiveFunction(
(GenericObjective("0", eq=eq),),
(GenericObjective("0", thing=eq),),
use_jit=False,
)
objective.build()
Expand Down
Loading

0 comments on commit f738381

Please sign in to comment.