Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add field to things for boundary error objectives #935

Merged
merged 9 commits into from
Mar 19, 2024
35 changes: 24 additions & 11 deletions desc/io/equilibrium_io.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these changes necessary for this PR?

Copy link
Member Author

@f0uriest f0uriest Mar 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SplineMagneticField requires some custom behavior to work correctly as a pytree (required for it to be part of things), and the previous method of defining a custom flattening method didn't consider all the extra attributes that get added with the Optimizable base class. This refactors things a bit so that each class can use the default flattening method but also define specific attributes as static or dynamic without having to re-define the entire flattening/unflattening method (and then having to account for every attribute individually)

Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@


def _make_hashable(x):
# turn unhashable ndarray of ints nto a hashable tuple
# turn unhashable ndarray of ints into a hashable tuple
if hasattr(x, "shape"):
return ("ndarray", x.shape, tuple(x.flatten()))
return x
Expand All @@ -114,16 +114,29 @@
# use subclass method
return obj.tree_flatten()

children = {
key: val for key, val in obj.__dict__.items() if not _unjittable(val)
}
aux_data = tuple(
[
(key, _make_hashable(val))
for key, val in obj.__dict__.items()
if _unjittable(val)
]
)
# in jax parlance, "children" of a pytree are things like arrays etc
# that get traced and can change. "aux_data" is metadata that is assumed
# static and must be hashable. By default we assume floating point arrays
# are children, and int/bool arrays are metadata that should be static
children = {}
aux_data = []

# this allows classes to override the default static/dynamic stuff
# if they need certain floats to be static or ints to by dynamic etc.
static_attrs = getattr(obj, "_static_attrs", [])
dynamic_attrs = getattr(obj, "_dynamic_attrs", [])
assert set(static_attrs).isdisjoint(set(dynamic_attrs))

for key, val in obj.__dict__.items():
if key in static_attrs:
aux_data += [(key, _make_hashable(val))]
elif key in dynamic_attrs:
children[key] = val

Check warning on line 134 in desc/io/equilibrium_io.py

View check run for this annotation

Codecov / codecov/patch

desc/io/equilibrium_io.py#L134

Added line #L134 was not covered by tests
elif _unjittable(val):
aux_data += [(key, _make_hashable(val))]
else:
children[key] = val

return ((children,), aux_data)

def _generic_tree_unflatten(aux_data, children):
Expand Down
28 changes: 4 additions & 24 deletions desc/magnetic_fields/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,9 @@ class SplineMagneticField(_MagneticField, Optimizable):
"_currents",
"_NFP",
]
# by default floats are considered dynamic but for this to work with jit these
# need to be static
_static_attrs = ["_extrap", "_period"]

def __init__(
self, R, phi, Z, BR, Bphi, BZ, currents=1.0, NFP=1, method="cubic", extrap=False
Expand Down Expand Up @@ -1201,7 +1204,7 @@ def from_mgrid(cls, mgrid_file, extcur=None, method="cubic", extrap=False):
else: # "raw"
extcur = 1 # coil current scaling factor
nextcur = int(mgrid["nextcur"][()]) # number of coils
extcur = np.broadcast_to(extcur, nextcur)
extcur = np.broadcast_to(extcur, nextcur).astype(float)

# compute grid knots in cylindrical coordinates
ir = int(mgrid["ir"][()]) # number of grid points in the R coordinate
Expand Down Expand Up @@ -1275,29 +1278,6 @@ def from_field(
extrap=extrap,
)

def tree_flatten(self):
"""Convert DESC objects to JAX pytrees."""
# the default flattening method in the IOAble base class assumes all floats
# are non-static, but for the periodic BC to work we need the period to be
# a static value, so we override the default tree flatten/unflatten method
# so that we can pass a SplineMagneticField into a jitted function such as
# an objective.
static = ["_method", "_extrap", "_period", "_axisym"]
children = {key: val for key, val in self.__dict__.items() if key not in static}
aux_data = tuple(
[(key, val) for key, val in self.__dict__.items() if key in static]
)
return ((children,), aux_data)

@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreate a DESC object from JAX pytree."""
obj = cls.__new__(cls)
obj.__dict__.update(children[0])
for kv in aux_data:
setattr(obj, kv[0], kv[1])
return obj


class ScalarPotentialField(_MagneticField):
"""Magnetic field due to a scalar magnetic potential in cylindrical coordinates.
Expand Down
79 changes: 51 additions & 28 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class VacuumBoundaryError(_Objective):
----------
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
ext_field : MagneticField
External field produced by coils.
field : MagneticField
External field produced by coils or other sources outside the plasma.
target : float, ndarray, optional
Target value(s) of the objective. Only used if bounds is None.
len(target) must be equal to Objective.dim_f
Expand Down Expand Up @@ -69,7 +69,10 @@ class VacuumBoundaryError(_Objective):
grid : Grid, optional
Collocation grid containing the nodes to evaluate error at. Should be at rho=1.
field_grid : Grid, optional
Grid used to discretize ext_field.
Grid used to discretize field.
field_fixed : bool
Whether to assume the field is fixed. For free boundary solve, should
be fixed. For single stage optimization, should be False (default).
name : str
Name of the objective function.

Expand All @@ -84,7 +87,7 @@ class VacuumBoundaryError(_Objective):
def __init__(
self,
eq,
ext_field,
field,
target=None,
bounds=None,
weight=1,
Expand All @@ -94,15 +97,20 @@ def __init__(
deriv_mode="auto",
grid=None,
field_grid=None,
field_fixed=False,
name="Vacuum boundary error",
):
if target is None and bounds is None:
target = 0
self._grid = grid
self._ext_field = ext_field
self._eq = eq
self._field = field
self._field_grid = field_grid
things = [eq]

self._field_fixed = field_fixed
if field_fixed:
things = [eq]
else:
things = [eq, field]
super().__init__(
things=things,
target=target,
Expand Down Expand Up @@ -178,7 +186,7 @@ def build(self, use_jit=True, verbose=1):
self._constants = {
"transforms": transforms,
"profiles": profiles,
"ext_field": self._ext_field,
"field": self._field,
"quad_weights": np.sqrt(np.tile(transforms["grid"].weights, 2)),
}

Expand All @@ -198,13 +206,15 @@ def build(self, use_jit=True, verbose=1):

super().build(use_jit=use_jit, verbose=verbose)

def compute(self, eq_params, constants=None):
def compute(self, eq_params, field_params=None, constants=None):
"""Compute boundary force error.

Parameters
----------
eq_params : dict
Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict
field_params : dict
Dictionary of field parameters, if field is not fixed.
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants
Expand All @@ -226,8 +236,10 @@ def compute(self, eq_params, constants=None):
profiles=constants["profiles"],
)
x = jnp.array([data["R"], data["phi"], data["Z"]]).T
Bext = constants["ext_field"].compute_magnetic_field(
x, source_grid=self._field_grid, basis="rpz"
# can always pass in field params. If they're None, it just uses the
# defaults for the given field.
Bext = constants["field"].compute_magnetic_field(
x, source_grid=self._field_grid, basis="rpz", params=field_params
)
Bex_total = Bext
Bin_total = data["B"]
Expand Down Expand Up @@ -336,7 +348,7 @@ class BoundaryError(_Objective):
----------
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
ext_field : MagneticField
field : MagneticField
External field produced by coils.
target : float, ndarray, optional
Target value(s) of the objective. Only used if bounds is None.
Expand Down Expand Up @@ -371,7 +383,10 @@ class BoundaryError(_Objective):
Savart integral and where to evaluate errors. source_grid should not be
stellarator symmetric, and both should be at rho=1.
field_grid : Grid, optional
Grid used to discretize ext_field.
Grid used to discretize field.
field_fixed : bool
Whether to assume the field is fixed. For free boundary solve, should
be fixed. For single stage optimization, should be False (default).
loop : bool
If True, evaluate integral using loops, as opposed to vmap. Slower, but uses
less memory.
Expand All @@ -388,11 +403,11 @@ class BoundaryError(_Objective):
from desc.magnetic_fields import FourierCurrentPotentialField
# turn the regular FourierRZToroidalSurface into a current potential on the
# last closed flux surface
eq.surface = FourierCurrentPotentialField.from_suface(eq.surface,
eq.surface = FourierCurrentPotentialField.from_surface(eq.surface,
M_Phi=eq.M,
N_Phi=eq.N,
)
objective = BoundaryError(eq, ext_field)
objective = BoundaryError(eq, field)

"""

Expand All @@ -406,7 +421,7 @@ class BoundaryError(_Objective):
def __init__(
self,
eq,
ext_field,
field,
target=None,
bounds=None,
weight=1,
Expand All @@ -419,6 +434,7 @@ def __init__(
source_grid=None,
eval_grid=None,
field_grid=None,
field_fixed=False,
loop=True,
name="Boundary error",
):
Expand All @@ -428,11 +444,14 @@ def __init__(
self._eval_grid = eval_grid
self._s = s
self._q = q
self._ext_field = ext_field
self._field = field
self._field_grid = field_grid
self._loop = loop
self._sheet_current = hasattr(eq.surface, "Phi_mn")
things = [eq]
if field_fixed:
things = [eq]
else:
things = [eq, field]

super().__init__(
things=things,
Expand Down Expand Up @@ -554,7 +573,7 @@ def build(self, use_jit=True, verbose=1):
"source_transforms": source_transforms,
"source_profiles": source_profiles,
"interpolator": interpolator,
"ext_field": self._ext_field,
"field": self._field,
"quad_weights": np.sqrt(np.tile(eval_transforms["grid"].weights, neq)),
}

Expand Down Expand Up @@ -592,13 +611,15 @@ def build(self, use_jit=True, verbose=1):

super().build(use_jit=use_jit, verbose=verbose)

def compute(self, eq_params, constants=None):
def compute(self, eq_params, field_params=None, constants=None):
"""Compute boundary force error.

Parameters
----------
eq_params : dict
Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict
field_params : dict
Dictionary of field parameters, if field is not fixed.
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants
Expand Down Expand Up @@ -660,8 +681,10 @@ def compute(self, eq_params, constants=None):
# need extra factor of B/2 bc we're evaluating on plasma surface
Bplasma = Bplasma + eval_data["B"] / 2
x = jnp.array([eval_data["R"], eval_data["phi"], eval_data["Z"]]).T
Bext = constants["ext_field"].compute_magnetic_field(
x, source_grid=self._field_grid, basis="rpz"
# can always pass in field params. If they're None, it just uses the
# defaults for the given field.
Bext = constants["field"].compute_magnetic_field(
x, source_grid=self._field_grid, basis="rpz", params=field_params
)
Bex_total = Bext + Bplasma
Bin_total = eval_data["B"]
Expand Down Expand Up @@ -774,7 +797,7 @@ class BoundaryErrorNESTOR(_Objective):
----------
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
ext_field : MagneticField
field : MagneticField
External field produced by coils.
target : float, ndarray, optional
Target value(s) of the objective. Only used if bounds is None.
Expand All @@ -790,7 +813,7 @@ class BoundaryErrorNESTOR(_Objective):
ntheta, nzeta : int
number of grid points in poloidal, toroidal directions to use in NESTOR.
field_grid : Grid, optional
Grid used to discretize ext_field.
Grid used to discretize field.
normalize : bool
Whether to compute the error in physical units or non-dimensionalize.
normalize_target : bool
Expand Down Expand Up @@ -820,7 +843,7 @@ class BoundaryErrorNESTOR(_Objective):
def __init__(
self,
eq,
ext_field,
field,
target=None,
bounds=None,
weight=1,
Expand All @@ -841,7 +864,7 @@ def __init__(
self.nf = nf
self.ntheta = ntheta
self.nzeta = nzeta
self.ext_field = ext_field
self.field = field
self.field_grid = field_grid
super().__init__(
things=eq,
Expand Down Expand Up @@ -874,7 +897,7 @@ def build(self, use_jit=True, verbose=1):

nest = Nestor(
eq,
self.ext_field,
self.field,
self.mf,
self.nf,
self.ntheta,
Expand All @@ -900,7 +923,7 @@ def build(self, use_jit=True, verbose=1):
self._constants = {
"profiles": profiles,
"transforms": transforms,
"ext_field": self.ext_field,
"field": self.field,
"nestor": nest,
"quad_weights": np.sqrt(transforms["grid"].weights),
}
Expand Down
5 changes: 3 additions & 2 deletions docs/notebooks/tutorials/free_boundary_equilibrium.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@
},
"outputs": [],
"source": [
"objective = ObjectiveFunction(BoundaryError(eq=eq2, ext_field=ext_field))"
"# For a standard free boundary solve, we set field_fixed=True. For single stage optimization, we would set to False\n",
"objective = ObjectiveFunction(BoundaryError(eq=eq2, field=ext_field, field_fixed=True))"
]
},
{
Expand Down Expand Up @@ -689,7 +690,7 @@
},
"outputs": [],
"source": [
"objective = ObjectiveFunction(VacuumBoundaryError(eq=eq2, ext_field=ext_field))"
"objective = ObjectiveFunction(VacuumBoundaryError(eq=eq2, field=ext_field, field_fixed=True))"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def test_proximal_freeb_compute(benchmark):
eq = desc.examples.get("ESTELL")
eq.change_resolution(6, 6, 6, 12, 12, 12)
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, ext_field=field))
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
obj = LinearConstraintProjection(
Expand All @@ -393,7 +393,7 @@ def test_proximal_freeb_jac(benchmark):
eq = desc.examples.get("ESTELL")
eq.change_resolution(6, 6, 6, 12, 12, 12)
field = ToroidalMagneticField(1.0, 1.0) # just a dummy field for benchmarking
objective = ObjectiveFunction(BoundaryError(eq, ext_field=field))
objective = ObjectiveFunction(BoundaryError(eq, field=field))
constraint = ObjectiveFunction(ForceBalance(eq))
prox = ProximalProjection(objective, constraint, eq)
obj = LinearConstraintProjection(
Expand Down
Loading
Loading