Skip to content

Commit

Permalink
Dd/optimizable (#956)
Browse files Browse the repository at this point in the history
Resolves #860 

- Generalizes `FixParameter` to work for an `OptimizableCollection` so
collections (like coil sets) can be constrained in optimizations.
- Replaces inheritance from `_FixedObjective` and `_FixProfile` to
inherit from `FixParameter` instead for many of the linear objectives,
to reduce redundant code.
- Other miscellaneous changes to make coil optimization work, like
adding `FixCurveShift` and `FixCurveRotation` to
`maybe_add_self_consistency`

To-Do:

- [x] Replace `_FixedObjective` and `_FixProfile` with `FixParameter`
- [x] Make sure tests pass
- [x] Add test to check order of residuals with specified modes input
  • Loading branch information
ddudt authored May 9, 2024
2 parents 082be4e + 0aff196 commit d9f4a8d
Show file tree
Hide file tree
Showing 17 changed files with 888 additions and 1,442 deletions.
5 changes: 5 additions & 0 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
tree_map,
tree_structure,
tree_unflatten,
treedef_is_leaf,
)

def put(arr, inds, vals):
Expand Down Expand Up @@ -412,6 +413,10 @@ def tree_leaves(*args, **kwargs):
"""Get leaves of pytree for numpy backend."""
raise NotImplementedError

def treedef_is_leaf(*args, **kwargs):
"""Check is leaf of pytree for numpy backend."""
raise NotImplementedError

def register_pytree_node(foo, *args):
"""Dummy decorator for non-jax pytrees."""
return foo
Expand Down
7 changes: 4 additions & 3 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ def get_idx(self, L=0, M=0, N=0, error=True):
N : int
Toroidal mode number.
error : bool
whether to raise exception if mode is not in basis, or return empty array
Whether to raise exception if the mode is not in the basis (default),
or to return an empty array.
Returns
-------
idx : ndarray of int
idx : int
Index of given mode numbers.
"""
Expand All @@ -130,7 +131,7 @@ def get_idx(self, L=0, M=0, N=0, error=True):
"mode ({}, {}, {}) is not in basis {}".format(L, M, N, str(self))
) from e
else:
return np.array([]).astype(int)
return np.array([], dtype=int)

@abstractmethod
def _get_modes(self):
Expand Down
3 changes: 2 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@
FixOmniBmax,
FixOmniMap,
FixOmniWell,
FixParameter,
FixParameters,
FixPressure,
FixPsi,
FixSheetCurrent,
FixSumModesLambda,
FixSumModesR,
FixSumModesZ,
Expand Down
3 changes: 1 addition & 2 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,11 @@ def build(self, use_jit=True, verbose=1):
if self._source_grid is None:
# for axisymmetry we still need to know about toroidal effects, so its
# cheapest to pretend there are extra field periods
source_NFP = eq.NFP if eq.N > 0 else 64
source_grid = LinearGrid(
rho=np.array([1.0]),
M=eq.M_grid,
N=eq.N_grid,
NFP=source_NFP,
NFP=eq.NFP if eq.N > 0 else 64,
sym=False,
)
else:
Expand Down
109 changes: 43 additions & 66 deletions desc/objectives/getters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Utilities for getting standard groups of objectives and constraints."""

import numpy as np

from desc.utils import is_any_instance
from desc.utils import flatten_list, is_any_instance, unique_list

from ._equilibrium import Energy, ForceBalance, HelicalForceBalance, RadialForceBalance
from .linear_objectives import (
Expand All @@ -24,9 +22,9 @@
FixIonTemperature,
FixIota,
FixLambdaGauge,
FixParameter,
FixPressure,
FixPsi,
FixSheetCurrent,
)
from .nae_utils import calc_zeroth_order_lambda, make_RZ_cons_1st_order
from .objective_funs import ObjectiveFunction
Expand Down Expand Up @@ -61,18 +59,15 @@ def get_equilibrium_objective(eq, mode="force", normalize=True):
-------
objective, ObjectiveFunction
An objective function with default force balance objectives.
"""
kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize}
if mode == "energy":
objectives = Energy(eq=eq, normalize=normalize, normalize_target=normalize)
objectives = Energy(**kwargs)
elif mode == "force":
objectives = ForceBalance(
eq=eq, normalize=normalize, normalize_target=normalize
)
objectives = ForceBalance(**kwargs)
elif mode == "forces":
objectives = (
RadialForceBalance(eq=eq, normalize=normalize, normalize_target=normalize),
HelicalForceBalance(eq=eq, normalize=normalize, normalize_target=normalize),
)
objectives = (RadialForceBalance(**kwargs), HelicalForceBalance(**kwargs))
else:
raise ValueError("got an unknown equilibrium objective type '{}'".format(mode))
return ObjectiveFunction(objectives)
Expand All @@ -96,20 +91,13 @@ def get_fixed_axis_constraints(eq, profiles=True, normalize=True):
A list of the linear constraints used in fixed-axis problems.
"""
constraints = (
FixAxisR(eq=eq, normalize=normalize, normalize_target=normalize),
FixAxisZ(eq=eq, normalize=normalize, normalize_target=normalize),
FixPsi(eq=eq, normalize=normalize, normalize_target=normalize),
)
kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize}
constraints = (FixAxisR(**kwargs), FixAxisZ(**kwargs), FixPsi(**kwargs))
if profiles:
for name, con in _PROFILE_CONSTRAINTS.items():
if getattr(eq, name) is not None:
constraints += (
con(eq=eq, normalize=normalize, normalize_target=normalize),
)
for param in ["I", "G", "Phi_mn"]:
if np.array(getattr(eq, param, [])).size:
constraints += (FixParameter(eq, param),)
constraints += (con(**kwargs),)
constraints += (FixSheetCurrent(**kwargs),)

return constraints

Expand All @@ -132,20 +120,13 @@ def get_fixed_boundary_constraints(eq, profiles=True, normalize=True):
A list of the linear constraints used in fixed-boundary problems.
"""
constraints = (
FixBoundaryR(eq=eq, normalize=normalize, normalize_target=normalize),
FixBoundaryZ(eq=eq, normalize=normalize, normalize_target=normalize),
FixPsi(eq=eq, normalize=normalize, normalize_target=normalize),
)
kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize}
constraints = (FixBoundaryR(**kwargs), FixBoundaryZ(**kwargs), FixPsi(**kwargs))
if profiles:
for name, con in _PROFILE_CONSTRAINTS.items():
if getattr(eq, name) is not None:
constraints += (
con(eq=eq, normalize=normalize, normalize_target=normalize),
)
for param in ["I", "G", "Phi_mn"]:
if np.array(getattr(eq, param, [])).size:
constraints += (FixParameter(eq, param),)
constraints += (con(**kwargs),)
constraints += (FixSheetCurrent(**kwargs),)

return constraints

Expand Down Expand Up @@ -186,24 +167,18 @@ def get_NAE_constraints(
-------
constraints, tuple of _Objectives
A list of the linear constraints used in fixed-axis problems.
"""
kwargs = {"eq": desc_eq, "normalize": normalize, "normalize_target": normalize}
if not isinstance(fix_lambda, bool):
fix_lambda = int(fix_lambda)
constraints = (
FixAxisR(eq=desc_eq, normalize=normalize, normalize_target=normalize),
FixAxisZ(eq=desc_eq, normalize=normalize, normalize_target=normalize),
FixPsi(eq=desc_eq, normalize=normalize, normalize_target=normalize),
)
constraints = (FixAxisR(**kwargs), FixAxisZ(**kwargs), FixPsi(**kwargs))

if profiles:
for name, con in _PROFILE_CONSTRAINTS.items():
if getattr(desc_eq, name) is not None:
constraints += (
con(eq=desc_eq, normalize=normalize, normalize_target=normalize),
)
for param in ["I", "G", "Phi_mn"]:
if np.array(getattr(desc_eq, param, [])).size:
constraints += (FixParameter(desc_eq, param),)
constraints += (con(**kwargs),)
constraints += (FixSheetCurrent(**kwargs),)

if fix_lambda or (fix_lambda >= 0 and type(fix_lambda) is int):
L_axis_constraints, _, _ = calc_zeroth_order_lambda(
Expand All @@ -222,30 +197,32 @@ def get_NAE_constraints(

def maybe_add_self_consistency(thing, constraints):
"""Add self consistency constraints if needed."""
params = set(unique_list(flatten_list(thing.optimizable_params))[0])

# Equilibrium
if (
hasattr(thing, "Ra_n")
and hasattr(thing, "Za_n")
and hasattr(thing, "Rb_lmn")
and hasattr(thing, "Zb_lmn")
and hasattr(thing, "L_lmn")
if {"R_lmn", "Rb_lmn"} <= params and not is_any_instance(
constraints, BoundaryRSelfConsistency
):
constraints += (BoundaryRSelfConsistency(eq=thing),)
if {"Z_lmn", "Zb_lmn"} <= params and not is_any_instance(
constraints, BoundaryZSelfConsistency
):
constraints += (BoundaryZSelfConsistency(eq=thing),)
if {"L_lmn"} <= params and not is_any_instance(constraints, FixLambdaGauge):
constraints += (FixLambdaGauge(eq=thing),)
if {"R_lmn", "Ra_n"} <= params and not is_any_instance(
constraints, AxisRSelfConsistency
):
constraints += (AxisRSelfConsistency(eq=thing),)
if {"Z_lmn", "Za_n"} <= params and not is_any_instance(
constraints, AxisZSelfConsistency
):
if not is_any_instance(constraints, BoundaryRSelfConsistency):
constraints += (BoundaryRSelfConsistency(eq=thing),)
if not is_any_instance(constraints, BoundaryZSelfConsistency):
constraints += (BoundaryZSelfConsistency(eq=thing),)
if not is_any_instance(constraints, FixLambdaGauge):
constraints += (FixLambdaGauge(eq=thing),)
if not is_any_instance(constraints, AxisRSelfConsistency):
constraints += (AxisRSelfConsistency(eq=thing),)
if not is_any_instance(constraints, AxisZSelfConsistency):
constraints += (AxisZSelfConsistency(eq=thing),)
constraints += (AxisZSelfConsistency(eq=thing),)

# Curve
elif hasattr(thing, "shift") and hasattr(thing, "rotmat"):
if not is_any_instance(constraints, FixCurveShift):
constraints += (FixCurveShift(curve=thing),)
if not is_any_instance(constraints, FixCurveRotation):
constraints += (FixCurveRotation(curve=thing),)
if {"shift"} <= params and not is_any_instance(constraints, FixCurveShift):
constraints += (FixCurveShift(curve=thing),)
if {"rotmat"} <= params and not is_any_instance(constraints, FixCurveRotation):
constraints += (FixCurveRotation(curve=thing),)

return constraints
Loading

0 comments on commit d9f4a8d

Please sign in to comment.