Skip to content

Commit

Permalink
Merge branch 'master' into dp/quad-grid
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Aug 20, 2024
2 parents 890cb62 + dd3f472 commit e140509
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 72 deletions.
36 changes: 22 additions & 14 deletions desc/objectives/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,32 @@ def get_lowest_mode(basis, coeffs):

scales["R0"] = R00
scales["a"] = np.sqrt(np.abs(R10 * Z10))
scales["Psi"] = abs(thing.Psi)
scales["A"] = np.pi * scales["a"] ** 2
scales["V"] = 2 * np.pi * scales["R0"] * scales["A"]
scales["B_T"] = abs(thing.Psi) / scales["A"]
iota_avg = np.mean(np.abs(thing.get_profile("iota")(np.linspace(0, 1, 20))))
if np.isclose(iota_avg, 0):
scales["B_P"] = scales["B_T"]
else:
scales["B_P"] = scales["B_T"] * iota_avg
scales["B"] = np.sqrt(scales["B_T"] ** 2 + scales["B_P"] ** 2)
scales["I"] = scales["B_P"] * 2 * np.pi / mu_0
scales["p"] = scales["B"] ** 2 / (2 * mu_0)
scales["W"] = scales["p"] * scales["V"]
scales["B"] = scales["Psi"] / scales["A"] * 1.25
B_pressure = scales["B"] ** 2 / (2 * mu_0)
scales["I"] = scales["B"] * 2 * np.pi / mu_0
scales["W"] = B_pressure * scales["V"]
scales["J"] = scales["B"] / scales["a"] / mu_0
scales["F"] = scales["p"] / scales["a"]
scales["F"] = B_pressure / scales["a"]
scales["f"] = scales["F"] * scales["V"]
scales["Psi"] = abs(thing.Psi)
scales["n"] = 1e19
scales["T"] = scales["p"] / (scales["n"] * elementary_charge)

if thing.pressure is not None:
p0 = float(thing.pressure(0)[0])
else:
scales["n"] = float(
((thing.atomic_number(0) + 1) / 2 * thing.electron_density(0))[0]
)
scales["T"] = np.mean(
[thing.electron_temperature(0), thing.ion_temperature(0)]
)
p0 = elementary_charge * 2 * scales["n"] * scales["T"]
if p0 < 1: # vacuum
scales["p"] = B_pressure
else:
scales["p"] = p0
scales["W_p"] = scales["p"] * scales["V"] / 2

elif isinstance(thing, FourierRZToroidalSurface):
R00 = thing.R_lmn[thing.R_basis.get_idx(M=0, N=0)]
Expand Down
51 changes: 39 additions & 12 deletions desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif


def factorize_linear_constraints(objective, constraint): # noqa: C901
def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa: C901
"""Compute and factorize A to get pseudoinverse and nullspace.
Given constraints of the form Ax=b, factorize A to find a particular solution xp
Expand All @@ -22,6 +22,10 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
Objective function to optimize.
constraint : ObjectiveFunction
Objective function of linear constraints to enforce.
x_scale : array_like or ``'auto'``, optional
Characteristic scale of each variable. Setting ``x_scale`` is equivalent
to reformulating the problem in scaled variables ``xs = x / x_scale``.
If set to ``'auto'``, the scale is determined from the initial state vector.
Returns
-------
Expand All @@ -33,6 +37,8 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
Combined RHS vector.
Z : ndarray
Null space operator for full combined A such that A @ Z == 0.
D : ndarray
Scale of the full state vector x, as set by the parameter ``x_scale``.
unfixed_idx : ndarray
Indices of x that correspond to non-fixed values.
project, recover : function
Expand Down Expand Up @@ -130,32 +136,53 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
)
A = A[unfixed_rows][:, unfixed_idx]
b = b[unfixed_rows]

unfixed_idx = indices_idx
fixed_idx = np.delete(np.arange(xp.size), unfixed_idx)

# compute x_scale if not provided
if x_scale == "auto":
x_scale = objective.x(*objective.things)
errorif(
x_scale.shape != xp.shape,
ValueError,
"x_scale must be the same size as the full state vector. "
+ f"Got size {x_scale.size} for state vector of size {xp.size}.",
)
D = np.where(np.abs(x_scale) < 1e2, 1, np.abs(x_scale))

# null space & particular solution
A = A * D[None, unfixed_idx]
if A.size:
Ainv_full, Z = svd_inv_null(A)
A_inv, Z = svd_inv_null(A)
else:
Ainv_full = A.T
A_inv = A.T
Z = np.eye(A.shape[1])
Ainv_full = jnp.asarray(Ainv_full)
Z = jnp.asarray(Z)
b = jnp.asarray(b)
xp = put(xp, unfixed_idx, Ainv_full @ b)
xp = put(xp, unfixed_idx, A_inv @ b)
xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx])

# cast to jnp arrays
xp = jnp.asarray(xp)
A = jnp.asarray(A)
b = jnp.asarray(b)
Z = jnp.asarray(Z)
D = jnp.asarray(D)

@jit
def project(x):
def project(x_full):
"""Project a full state vector into the reduced optimization vector."""
x_reduced = Z.T @ ((x - xp)[unfixed_idx])
x_reduced = Z.T @ ((1 / D) * x_full - xp)[unfixed_idx]
return jnp.atleast_1d(jnp.squeeze(x_reduced))

@jit
def recover(x_reduced):
"""Recover the full state vector from the reduced optimization vector."""
dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced)
return jnp.atleast_1d(jnp.squeeze(xp + dx))
x_full = D * (xp + dx)
return jnp.atleast_1d(jnp.squeeze(x_full))

# check that all constraints are actually satisfiable
params = objective.unpack_state(xp, False)
params = objective.unpack_state(D * xp, False)
for con in constraint.objectives:
xpi = [params[i] for i, t in enumerate(objective.things) if t in con.things]
y1 = con.compute_unscaled(*xpi)
Expand Down Expand Up @@ -197,7 +224,7 @@ def recover(x_reduced):
"or be due to floating point error.",
)

return xp, A, b, Z, unfixed_idx, project, recover
return xp, A, b, Z, D, unfixed_idx, project, recover


def softmax(arr, alpha):
Expand Down
36 changes: 21 additions & 15 deletions desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def build(self, use_jit=None, verbose=1):
self._A,
self._b,
self._Z,
self._D,
self._unfixed_idx,
self._project,
self._recover,
Expand All @@ -113,10 +114,8 @@ def build(self, use_jit=None, verbose=1):
self._dim_x = self._objective.dim_x
self._dim_x_reduced = self._Z.shape[1]

# equivalent matrix for A[unfixed_idx]@Z == A@unfixed_idx_mat
self._unfixed_idx_mat = (
jnp.eye(self._objective.dim_x)[:, self._unfixed_idx] @ self._Z
)
# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ unfixed_idx_mat
self._unfixed_idx_mat = jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z

self._built = True
timer.stop("Linear constraint projection build")
Expand Down Expand Up @@ -261,7 +260,7 @@ def grad(self, x_reduced, constants=None):
"""
x = self.recover(x_reduced)
df = self._objective.grad(x, constants)
return df[self._unfixed_idx] @ self._Z
return df[self._unfixed_idx] @ (self._Z * self._D[self._unfixed_idx, None])

def hess(self, x_reduced, constants=None):
"""Compute Hessian of self.compute_scalar.
Expand All @@ -281,13 +280,19 @@ def hess(self, x_reduced, constants=None):
"""
x = self.recover(x_reduced)
df = self._objective.hess(x, constants)
return self._Z.T @ df[self._unfixed_idx, :][:, self._unfixed_idx] @ self._Z
return (
(self._Z.T * (1 / self._D)[None, self._unfixed_idx])
@ df[self._unfixed_idx, :][:, self._unfixed_idx]
@ (self._Z * self._D[self._unfixed_idx, None])
)

def _jac(self, x_reduced, constants=None, op="scaled"):
x = self.recover(x_reduced)
if self._objective._deriv_mode == "blocked":
fun = getattr(self._objective, "jac_" + op)
return fun(x, constants)[:, self._unfixed_idx] @ self._Z
return fun(x, constants)[:, self._unfixed_idx] @ (
self._Z * self._D[self._unfixed_idx, None]
)

v = self._unfixed_idx_mat
df = getattr(self._objective, "jvp_" + op)(v.T, x, constants)
Expand Down Expand Up @@ -401,7 +406,7 @@ def jvp_unscaled(self, v, x_reduced, constants=None):
def _vjp(self, v, x_reduced, constants=None, op="vjp_scaled"):
x = self.recover(x_reduced)
df = getattr(self._objective, op)(v, x, constants)
return df[self._unfixed_idx] @ self._Z
return df[self._unfixed_idx] @ (self._Z * self._D[self._unfixed_idx, None])

def vjp_scaled(self, v, x_reduced, constants=None):
"""Compute vector-Jacobian product of self.compute_scaled.
Expand Down Expand Up @@ -533,8 +538,8 @@ def _set_eq_state_vector(self):
self._args.remove(arg)
linear_constraint = ObjectiveFunction(self._linear_constraints)
linear_constraint.build()
_, A, _, self._Z, self._unfixed_idx, _, _ = factorize_linear_constraints(
self._constraint, linear_constraint
_, _, _, self._Z, self._D, self._unfixed_idx, _, _ = (
factorize_linear_constraints(self._constraint, linear_constraint)
)

# dx/dc - goes from the full state to optimization variables for eq
Expand Down Expand Up @@ -618,14 +623,14 @@ def build(self, use_jit=None, verbose=1): # noqa: C901
)
self._dimx_per_thing = [t.dim_x for t in self.things]

# equivalent matrix for A[unfixed_idx]@Z == A@unfixed_idx_mat
# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ unfixed_idx_mat
self._unfixed_idx_mat = jnp.eye(self._objective.dim_x)
self._unfixed_idx_mat = jnp.split(
self._unfixed_idx_mat, np.cumsum([t.dim_x for t in self.things]), axis=-1
)
self._unfixed_idx_mat[self._eq_idx] = (
self._unfixed_idx_mat[self._eq_idx][:, self._unfixed_idx] @ self._Z
)
self._unfixed_idx_mat[self._eq_idx] = self._unfixed_idx_mat[self._eq_idx][
:, self._unfixed_idx
] @ (self._Z * self._D[self._unfixed_idx, None])
self._unfixed_idx_mat = np.concatenate(
[np.atleast_2d(foo) for foo in self._unfixed_idx_mat], axis=-1
)
Expand Down Expand Up @@ -1018,7 +1023,8 @@ def jvp_unscaled(self, v, x, constants=None):
@functools.partial(jit, static_argnames=("self", "op"))
def _jvp_f(self, xf, dc, constants, op):
Fx = getattr(self._constraint, "jac_" + op)(xf, constants)
Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z
# TODO: replace with self._unfixed_idx_mat?
Fx_reduced = Fx @ jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z
Fc = Fx @ (self._dxdc @ dc)
Fxh = Fx_reduced
cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape)
Expand Down
20 changes: 8 additions & 12 deletions desc/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
if verbose > 0:
print("Factorizing linear constraints")
timer.start("linear constraint factorize")
xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
xp, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)
timer.stop("linear constraint factorize")
Expand Down Expand Up @@ -291,7 +291,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
print("Computing df")
timer.start("df computation")
Jx = objective.jac_scaled_error(x)
Jx_reduced = Jx[:, unfixed_idx] @ Z @ scale
Jx_reduced = Jx @ jnp.diag(D)[:, unfixed_idx] @ Z @ scale
RHS1 = objective.jvp_scaled(tangents, x)
if include_f:
f = objective.compute_scaled_error(x)
Expand Down Expand Up @@ -388,9 +388,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)
_, _, _, _, _, _, _, recover = factorize_linear_constraints(objective, constraint)

# update other attributes
dx_reduced = dx1_reduced + dx2_reduced + dx3_reduced
Expand Down Expand Up @@ -547,7 +545,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)

_, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
_, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint
)

Expand All @@ -564,7 +562,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
dx2_reduced = 0

# dx/dx_reduced
dxdx_reduced = jnp.eye(eq.dim_x)[:, unfixed_idx] @ Z
dxdx_reduced = jnp.diag(D)[:, unfixed_idx] @ Z

# dx/dc
dxdc = []
Expand Down Expand Up @@ -612,8 +610,8 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
timer.disp("dg computation")

# projections onto optimization space
Fx_reduced = Fx[:, unfixed_idx] @ Z
Gx_reduced = Gx[:, unfixed_idx] @ Z
Fx_reduced = Fx @ jnp.diag(D)[:, unfixed_idx] @ Z
Gx_reduced = Gx @ jnp.diag(D)[:, unfixed_idx] @ Z
Fc = Fx @ dxdc
Gc = Gx @ dxdc

Expand Down Expand Up @@ -752,9 +750,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
_, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint
)
_, _, _, _, _, _, _, recover = factorize_linear_constraints(objective_f, constraint)

# update other attributes
dx_reduced = dx1_reduced + dx2_reduced
Expand Down
7 changes: 6 additions & 1 deletion desc/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from desc.backend import sign
from desc.basis import fourier, zernike_radial_poly
from desc.coils import CoilSet
from desc.coils import CoilSet, _Coil
from desc.compute import data_index, get_transforms
from desc.compute.utils import _parse_parameterization, surface_averages_map
from desc.equilibrium.coords import map_coordinates
Expand Down Expand Up @@ -2394,6 +2394,11 @@ def plot_coils(coils, grid=None, fig=None, return_data=False, **kwargs):
ValueError,
f"plot_coils got unexpected keyword argument: {kwargs.keys()}",
)
errorif(
not isinstance(coils, _Coil),
ValueError,
"Expected `coils` to be of type `_Coil`, instead got type" f" {type(coils)}",
)

if not isinstance(lw, (list, tuple)):
lw = [lw]
Expand Down
2 changes: 1 addition & 1 deletion desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load(
constraints = maybe_add_self_consistency(eq, constraints)
objective = ObjectiveFunction(constraints)
objective.build(verbose=0)
_, _, _, _, _, project, recover = factorize_linear_constraints(
_, _, _, _, _, _, project, recover = factorize_linear_constraints(
objective, objective
)
args = objective.unpack_state(recover(project(objective.x(eq))), False)[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ def test_bootstrap_optimization_comparison_qa():
objective=objective,
constraints=constraints,
optimizer="proximal-lsq-exact",
maxiter=4,
maxiter=5,
gtol=1e-16,
verbose=3,
)
Expand Down Expand Up @@ -1622,5 +1622,5 @@ def test_bootstrap_optimization_comparison_qa():
grid.compress(data2["<J*B>"]), grid.compress(data2["<J*B> Redl"]), rtol=1.8e-2
)
np.testing.assert_allclose(
grid.compress(data1["<J*B>"]), grid.compress(data2["<J*B>"]), rtol=1.8e-2
grid.compress(data1["<J*B>"]), grid.compress(data2["<J*B>"]), rtol=1.9e-2
)
11 changes: 8 additions & 3 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ def test_signed_PlasmaVesselDistance():
eq = Equilibrium(M=1, N=1)
surf = eq.surface.copy()
surf.change_resolution(M=1, N=1)
grid = LinearGrid(M=10, N=2, NFP=eq.NFP)
grid = LinearGrid(M=20, N=8, NFP=eq.NFP)

obj = PlasmaVesselDistance(
surface=surf,
Expand All @@ -1637,7 +1637,7 @@ def test_signed_PlasmaVesselDistance():
plasma_grid=grid,
use_signed_distance=True,
)
objective = ObjectiveFunction((obj,))
objective = ObjectiveFunction(obj)

optimizer = Optimizer("lsq-exact")
(eq, surf), _ = optimizer.optimize(
Expand All @@ -1650,4 +1650,9 @@ def test_signed_PlasmaVesselDistance():
xtol=1e-9,
)

np.testing.assert_allclose(obj.compute(*obj.xs(eq, surf)), target_dist, atol=1e-2)
np.testing.assert_allclose(
obj.compute(*obj.xs(eq, surf)),
target_dist,
atol=1e-2,
err_msg="allowing eq to change",
)
Loading

0 comments on commit e140509

Please sign in to comment.