Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dd/optimizable #956

Merged
merged 69 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
6402ab0
create FixCollectionParameters objective
daniel-dudt Mar 25, 2024
9dc9a38
allow fixing of non-default params
daniel-dudt Mar 26, 2024
9a78600
bug fix: use np instead of jnp
daniel-dudt Mar 26, 2024
b0df4e4
update maybe_add_self_consistency
daniel-dudt Mar 26, 2024
4d73c2e
Merge branch 'master' into dd/optimizable
ddudt Mar 26, 2024
bd1eedc
fix NFP int bug
daniel-dudt Mar 26, 2024
b29ec75
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Mar 26, 2024
0bdc245
Merge branch 'master' into dd/optimizable
ddudt Mar 29, 2024
03dfe05
allow fixing params for only some things in collection
daniel-dudt Mar 29, 2024
19a8955
add test for second stage optimization
daniel-dudt Mar 29, 2024
ef63085
Merge branch 'master' into dd/optimizable
f0uriest Apr 2, 2024
f6aa6c1
Merge branch 'master' into dd/optimizable
ddudt Apr 2, 2024
672dd48
merge with master
daniel-dudt Apr 12, 2024
efb5429
Merge branch 'master' into dd/optimizable
ddudt Apr 12, 2024
bfb5f61
update tests
daniel-dudt Apr 12, 2024
d5f8fc2
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Apr 12, 2024
4f75638
Merge branch 'master' into dd/optimizable
ddudt Apr 18, 2024
8db2c3f
broadcast_tree util function
daniel-dudt Apr 19, 2024
135c0a9
broadcast_tree until function + tests
daniel-dudt Apr 20, 2024
fafa8ca
FixCollectionParameters working with custom params input
daniel-dudt Apr 20, 2024
566c810
tiny typos
daniel-dudt Apr 20, 2024
c39faa2
repair 2nd stage opt test
daniel-dudt Apr 20, 2024
80e7dbe
more test cases for broadcast_tree
daniel-dudt Apr 20, 2024
65f6f6e
combine FixCollectionParameters into FixParameter
daniel-dudt Apr 20, 2024
3477a96
fix params_leaves sorting issue
daniel-dudt Apr 20, 2024
b41e761
fix list vs array bug
daniel-dudt Apr 21, 2024
31c94c6
clean up a few lines
daniel-dudt Apr 21, 2024
d41a1fc
fix missing tree_leaves call
daniel-dudt Apr 21, 2024
f0dbd83
add assert statement to fix later
daniel-dudt Apr 22, 2024
be981f6
remove debugging print statement
daniel-dudt Apr 24, 2024
34cb43d
proximal projection hack
daniel-dudt Apr 24, 2024
2818ecb
cast indices to array in test
daniel-dudt Apr 29, 2024
414f296
refactor broadcast_tree
daniel-dudt Apr 29, 2024
1a5ae51
refactor FixParameter objective
daniel-dudt Apr 29, 2024
9115d8b
Merge branch 'master' into dd/optimizable
ddudt Apr 29, 2024
70fd03f
add dtype option to broadcast_tree
daniel-dudt Apr 29, 2024
b997d63
replace some FixedObjectives with FixParameter
daniel-dudt Apr 29, 2024
2d7c033
add FixSheetCurrent objective
daniel-dudt Apr 29, 2024
f88eb4b
use FixParameter for axis/boundary/etc. objectives
daniel-dudt Apr 29, 2024
9dc6a17
replace FixProfile with FixParameter
daniel-dudt Apr 29, 2024
eb6dc48
FixTheatSFL inherit from FixParameter
daniel-dudt Apr 29, 2024
42d7042
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Apr 29, 2024
e9d95a8
syntax error
daniel-dudt Apr 30, 2024
900012c
bug fixes
daniel-dudt Apr 30, 2024
5fb1614
remove outdated FIXME comment
daniel-dudt Apr 30, 2024
1193017
Merge branch 'master' into dd/optimizable
ddudt Apr 30, 2024
7a314b9
repairing tests
daniel-dudt Apr 30, 2024
259a16d
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Apr 30, 2024
6917eb6
copy and paste is dangerous
daniel-dudt Apr 30, 2024
32731e6
another copy/paste mistake
daniel-dudt Apr 30, 2024
3f3d3c2
syntax issue in freeb test
daniel-dudt Apr 30, 2024
a288df6
Merge branch 'master' into dd/optimizable
ddudt Apr 30, 2024
94532a6
syntax issue in freeb notebook
daniel-dudt Apr 30, 2024
13305b5
make Rory's suggested changes
daniel-dudt May 2, 2024
e0b746c
Merge branch 'master' into dd/optimizable
ddudt May 2, 2024
f524cc0
repair getter funs
daniel-dudt May 2, 2024
68de431
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt May 2, 2024
255f242
add test to check input order is preserved
daniel-dudt May 2, 2024
1752190
re-add normalize kwargs
daniel-dudt May 2, 2024
0615d56
update 2nd stage opt test with note
daniel-dudt May 2, 2024
c690222
update FixParameters example
daniel-dudt May 2, 2024
022861d
I hate debugging code
daniel-dudt May 2, 2024
847e355
Merge branch 'master' into dd/optimizable
ddudt May 3, 2024
8f9acfc
set vacuum=True for second_stage test
daniel-dudt May 3, 2024
fee6990
remove old comment
dpanici May 8, 2024
d025324
add note about True to FixParameters
dpanici May 8, 2024
4015aff
and note of False
dpanici May 8, 2024
7315d76
Merge branch 'master' into dd/optimizable
dpanici May 8, 2024
0aff196
Merge branch 'master' into dd/optimizable
ddudt May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
tree_map,
tree_structure,
tree_unflatten,
treedef_is_leaf,
)

def put(arr, inds, vals):
Expand Down Expand Up @@ -412,6 +413,10 @@ def tree_leaves(*args, **kwargs):
"""Get leaves of pytree for numpy backend."""
raise NotImplementedError

def treedef_is_leaf(*args, **kwargs):
"""Check is leaf of pytree for numpy backend."""
raise NotImplementedError

def register_pytree_node(foo, *args):
"""Dummy decorator for non-jax pytrees."""
return foo
Expand Down
7 changes: 4 additions & 3 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ def get_idx(self, L=0, M=0, N=0, error=True):
N : int
Toroidal mode number.
error : bool
whether to raise exception if mode is not in basis, or return empty array
Whether to raise exception if the mode is not in the basis (default),
or to return an empty array.
Returns
-------
idx : ndarray of int
idx : int
Index of given mode numbers.
"""
Expand All @@ -130,7 +131,7 @@ def get_idx(self, L=0, M=0, N=0, error=True):
"mode ({}, {}, {}) is not in basis {}".format(L, M, N, str(self))
) from e
else:
return np.array([]).astype(int)
return np.array([], dtype=int)

@abstractmethod
def _get_modes(self):
Expand Down
3 changes: 2 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@
FixOmniBmax,
FixOmniMap,
FixOmniWell,
FixParameter,
FixParameters,
FixPressure,
FixPsi,
FixSheetCurrent,
FixSumModesLambda,
FixSumModesR,
FixSumModesZ,
Expand Down
3 changes: 1 addition & 2 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,11 @@ def build(self, use_jit=True, verbose=1):
if self._source_grid is None:
# for axisymmetry we still need to know about toroidal effects, so its
# cheapest to pretend there are extra field periods
source_NFP = eq.NFP if eq.N > 0 else 64
source_grid = LinearGrid(
rho=np.array([1.0]),
M=eq.M_grid,
N=eq.N_grid,
NFP=source_NFP,
NFP=eq.NFP if eq.N > 0 else 64,
sym=False,
)
else:
Expand Down
109 changes: 43 additions & 66 deletions desc/objectives/getters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Utilities for getting standard groups of objectives and constraints."""

import numpy as np

from desc.utils import is_any_instance
from desc.utils import flatten_list, is_any_instance, unique_list

from ._equilibrium import Energy, ForceBalance, HelicalForceBalance, RadialForceBalance
from .linear_objectives import (
Expand All @@ -24,9 +22,9 @@
FixIonTemperature,
FixIota,
FixLambdaGauge,
FixParameter,
FixPressure,
FixPsi,
FixSheetCurrent,
)
from .nae_utils import calc_zeroth_order_lambda, make_RZ_cons_1st_order
from .objective_funs import ObjectiveFunction
Expand Down Expand Up @@ -61,18 +59,15 @@ def get_equilibrium_objective(eq, mode="force", normalize=True):
-------
objective, ObjectiveFunction
An objective function with default force balance objectives.
"""
kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize}
if mode == "energy":
objectives = Energy(eq=eq, normalize=normalize, normalize_target=normalize)
objectives = Energy(**kwargs)
elif mode == "force":
objectives = ForceBalance(
eq=eq, normalize=normalize, normalize_target=normalize
)
objectives = ForceBalance(**kwargs)
elif mode == "forces":
objectives = (
RadialForceBalance(eq=eq, normalize=normalize, normalize_target=normalize),
HelicalForceBalance(eq=eq, normalize=normalize, normalize_target=normalize),
)
objectives = (RadialForceBalance(**kwargs), HelicalForceBalance(**kwargs))
else:
raise ValueError("got an unknown equilibrium objective type '{}'".format(mode))
return ObjectiveFunction(objectives)
Expand All @@ -96,20 +91,13 @@ def get_fixed_axis_constraints(eq, profiles=True, normalize=True):
A list of the linear constraints used in fixed-axis problems.
"""
constraints = (
FixAxisR(eq=eq, normalize=normalize, normalize_target=normalize),
FixAxisZ(eq=eq, normalize=normalize, normalize_target=normalize),
FixPsi(eq=eq, normalize=normalize, normalize_target=normalize),
)
kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize}
constraints = (FixAxisR(**kwargs), FixAxisZ(**kwargs), FixPsi(**kwargs))
if profiles:
for name, con in _PROFILE_CONSTRAINTS.items():
if getattr(eq, name) is not None:
constraints += (
con(eq=eq, normalize=normalize, normalize_target=normalize),
)
for param in ["I", "G", "Phi_mn"]:
if np.array(getattr(eq, param, [])).size:
constraints += (FixParameter(eq, param),)
constraints += (con(**kwargs),)
constraints += (FixSheetCurrent(**kwargs),)

return constraints

Expand All @@ -132,20 +120,13 @@ def get_fixed_boundary_constraints(eq, profiles=True, normalize=True):
A list of the linear constraints used in fixed-boundary problems.
"""
constraints = (
FixBoundaryR(eq=eq, normalize=normalize, normalize_target=normalize),
FixBoundaryZ(eq=eq, normalize=normalize, normalize_target=normalize),
FixPsi(eq=eq, normalize=normalize, normalize_target=normalize),
)
kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize}
constraints = (FixBoundaryR(**kwargs), FixBoundaryZ(**kwargs), FixPsi(**kwargs))
if profiles:
for name, con in _PROFILE_CONSTRAINTS.items():
if getattr(eq, name) is not None:
constraints += (
con(eq=eq, normalize=normalize, normalize_target=normalize),
)
for param in ["I", "G", "Phi_mn"]:
if np.array(getattr(eq, param, [])).size:
constraints += (FixParameter(eq, param),)
constraints += (con(**kwargs),)
constraints += (FixSheetCurrent(**kwargs),)
ddudt marked this conversation as resolved.
Show resolved Hide resolved

return constraints

Expand Down Expand Up @@ -186,24 +167,18 @@ def get_NAE_constraints(
-------
constraints, tuple of _Objectives
A list of the linear constraints used in fixed-axis problems.
"""
kwargs = {"eq": desc_eq, "normalize": normalize, "normalize_target": normalize}
if not isinstance(fix_lambda, bool):
fix_lambda = int(fix_lambda)
constraints = (
FixAxisR(eq=desc_eq, normalize=normalize, normalize_target=normalize),
FixAxisZ(eq=desc_eq, normalize=normalize, normalize_target=normalize),
FixPsi(eq=desc_eq, normalize=normalize, normalize_target=normalize),
)
constraints = (FixAxisR(**kwargs), FixAxisZ(**kwargs), FixPsi(**kwargs))

if profiles:
for name, con in _PROFILE_CONSTRAINTS.items():
if getattr(desc_eq, name) is not None:
constraints += (
con(eq=desc_eq, normalize=normalize, normalize_target=normalize),
)
for param in ["I", "G", "Phi_mn"]:
if np.array(getattr(desc_eq, param, [])).size:
constraints += (FixParameter(desc_eq, param),)
constraints += (con(**kwargs),)
constraints += (FixSheetCurrent(**kwargs),)

if fix_lambda or (fix_lambda >= 0 and type(fix_lambda) is int):
L_axis_constraints, _, _ = calc_zeroth_order_lambda(
Expand All @@ -222,30 +197,32 @@ def get_NAE_constraints(

def maybe_add_self_consistency(thing, constraints):
"""Add self consistency constraints if needed."""
params = set(unique_list(flatten_list(thing.optimizable_params))[0])

# Equilibrium
if (
hasattr(thing, "Ra_n")
and hasattr(thing, "Za_n")
and hasattr(thing, "Rb_lmn")
and hasattr(thing, "Zb_lmn")
and hasattr(thing, "L_lmn")
if {"R_lmn", "Rb_lmn"} <= params and not is_any_instance(
constraints, BoundaryRSelfConsistency
):
constraints += (BoundaryRSelfConsistency(eq=thing),)
if {"Z_lmn", "Zb_lmn"} <= params and not is_any_instance(
constraints, BoundaryZSelfConsistency
):
constraints += (BoundaryZSelfConsistency(eq=thing),)
if {"L_lmn"} <= params and not is_any_instance(constraints, FixLambdaGauge):
constraints += (FixLambdaGauge(eq=thing),)
if {"R_lmn", "Ra_n"} <= params and not is_any_instance(
constraints, AxisRSelfConsistency
):
constraints += (AxisRSelfConsistency(eq=thing),)
if {"Z_lmn", "Za_n"} <= params and not is_any_instance(
constraints, AxisZSelfConsistency
):
if not is_any_instance(constraints, BoundaryRSelfConsistency):
constraints += (BoundaryRSelfConsistency(eq=thing),)
if not is_any_instance(constraints, BoundaryZSelfConsistency):
constraints += (BoundaryZSelfConsistency(eq=thing),)
if not is_any_instance(constraints, FixLambdaGauge):
constraints += (FixLambdaGauge(eq=thing),)
if not is_any_instance(constraints, AxisRSelfConsistency):
constraints += (AxisRSelfConsistency(eq=thing),)
if not is_any_instance(constraints, AxisZSelfConsistency):
constraints += (AxisZSelfConsistency(eq=thing),)
constraints += (AxisZSelfConsistency(eq=thing),)

# Curve
elif hasattr(thing, "shift") and hasattr(thing, "rotmat"):
if not is_any_instance(constraints, FixCurveShift):
constraints += (FixCurveShift(curve=thing),)
if not is_any_instance(constraints, FixCurveRotation):
constraints += (FixCurveRotation(curve=thing),)
if {"shift"} <= params and not is_any_instance(constraints, FixCurveShift):
constraints += (FixCurveShift(curve=thing),)
if {"rotmat"} <= params and not is_any_instance(constraints, FixCurveRotation):
constraints += (FixCurveRotation(curve=thing),)

return constraints
Loading
Loading