Skip to content

Commit

Permalink
Merge branch 'master' into ku/fourier_bounce
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Dec 5, 2024
2 parents 877d7db + 005c4f8 commit 1d03f2f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ New Features
- Adds ``eq_fixed`` flag to ``ToroidalFlux`` to allow for the equilibrium/QFM surface to vary during optimization, useful for single-stage optimizations.
- Adds tutorial notebook showcasing QFM surface capability.
- Adds ``rotate_zeta`` function to ``desc.compat`` to rotate an ``Equilibrium`` around Z axis.
- Adds an option ``scaled_termination`` (defaults to True) to all of the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``.

Bug Fixes

Expand Down
50 changes: 30 additions & 20 deletions desc/optimize/aug_lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def fmin_auglag( # noqa: C901
problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic
for choosing the initial scale. The heuristic is described in [2]_, p.143.
By default uses ``"auto"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -312,18 +313,6 @@ def laghess(z, y, mu, *args):
y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
y, mu, c = jnp.broadcast_arrays(y, mu, c)

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", 1.0)
eta = options.pop("eta", 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

L = lagfun(f, c, y, mu)
g = laggrad(z, y, mu, *args)
ngev += 1
Expand All @@ -338,6 +327,7 @@ def laghess(z, y, mu, *args):
maxiter = setdefault(maxiter, z.size * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"]
if hess_scale:
Expand All @@ -353,7 +343,9 @@ def laghess(z, y, mu, *args):

g_h = g * d
H_h = d * H * d[:, None]
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -399,7 +391,7 @@ def laghess(z, y, mu, *args):
)
subproblem = methods[tr_method]

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand All @@ -412,6 +404,18 @@ def laghess(z, y, mu, *args):
if g_norm < gtol and constr_violation < ctol:
success, message = True, STATUS_MESSAGES["gtol"]

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0)
eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

if verbose > 1:
print_header_nonlinear(True, "Penalty param", "max(|mltplr|)")
print_iteration_nonlinear(
Expand Down Expand Up @@ -493,7 +497,7 @@ def laghess(z, y, mu, *args):
success, message = check_termination(
actual_reduction,
f,
step_norm,
(step_h_norm if scaled_termination else step_norm),
z_norm,
g_norm,
Lreduction_ratio,
Expand Down Expand Up @@ -536,7 +540,9 @@ def laghess(z, y, mu, *args):
scale, scale_inv = compute_hess_scale(H)
v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# updating augmented lagrangian params
if g_norm < gtolk:
Expand Down Expand Up @@ -565,9 +571,13 @@ def laghess(z, y, mu, *args):

v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(
((z * scale_inv) if scaled_termination else z), ord=2
)
d = v**0.5 * scale
diag_h = g * dv * scale
g_h = g * d
Expand All @@ -580,7 +590,7 @@ def laghess(z, y, mu, *args):
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
49 changes: 30 additions & 19 deletions desc/optimize/aug_lagrangian_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def lsq_auglag( # noqa: C901
value decomposition. ``"cho"`` is generally the fastest for large systems,
especially on GPU, but may be less accurate for badly scaled systems.
``"svd"`` is the most accurate but significantly slower. Default ``"qr"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -254,18 +256,6 @@ def lagjac(z, y, mu, *args):
y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
y, mu, c = jnp.broadcast_arrays(y, mu, c)

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", 1.0)
eta = options.pop("eta", 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

L = lagfun(f, c, y, mu)
J = lagjac(z, y, mu, *args)
Lcost = 1 / 2 * jnp.dot(L, L)
Expand All @@ -276,6 +266,7 @@ def lagjac(z, y, mu, *args):
maxiter = setdefault(maxiter, z.size * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"]
if jac_scale:
Expand All @@ -291,7 +282,9 @@ def lagjac(z, y, mu, *args):

g_h = g * d
J_h = J * d
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -332,7 +325,7 @@ def lagjac(z, y, mu, *args):

callback = setdefault(callback, lambda *args: False)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand All @@ -345,6 +338,18 @@ def lagjac(z, y, mu, *args):
if g_norm < gtol and constr_violation < ctol:
success, message = True, STATUS_MESSAGES["gtol"]

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0)
eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

if verbose > 1:
print_header_nonlinear(True, "Penalty param", "max(|mltplr|)")
print_iteration_nonlinear(
Expand Down Expand Up @@ -454,7 +459,7 @@ def lagjac(z, y, mu, *args):
success, message = check_termination(
actual_reduction,
cost,
step_norm,
(step_h_norm if scaled_termination else step_norm),
z_norm,
g_norm,
Lreduction_ratio,
Expand Down Expand Up @@ -492,7 +497,9 @@ def lagjac(z, y, mu, *args):
scale, scale_inv = compute_jac_scale(J, scale_inv)
v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# updating augmented lagrangian params
if g_norm < gtolk:
Expand All @@ -516,9 +523,13 @@ def lagjac(z, y, mu, *args):

v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(
((z * scale_inv) if scaled_termination else z), ord=2
)
d = v**0.5 * scale
diag_h = g * dv * scale
g_h = g * d
Expand All @@ -531,7 +542,7 @@ def lagjac(z, y, mu, *args):
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
21 changes: 15 additions & 6 deletions desc/optimize/fmin_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def fmintr( # noqa: C901
problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic
for choosing the initial scale. The heuristic is described in [2]_, p.143.
By default uses ``"auto"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -222,6 +224,7 @@ def fmintr( # noqa: C901
maxiter = N * 100
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"]
if hess_scale:
Expand All @@ -237,7 +240,9 @@ def fmintr( # noqa: C901

g_h = g * d
H_h = d * H * d[:, None]
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -283,7 +288,7 @@ def fmintr( # noqa: C901
)
subproblem = methods[tr_method]

x_norm = jnp.linalg.norm(x, ord=2)
x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand Down Expand Up @@ -366,7 +371,7 @@ def fmintr( # noqa: C901
success, message = check_termination(
actual_reduction,
f,
step_norm,
(step_h_norm if scaled_termination else step_norm),
x_norm,
g_norm,
reduction_ratio,
Expand Down Expand Up @@ -410,8 +415,12 @@ def fmintr( # noqa: C901
g_h = g * d
H_h = d * H * d[:, None]

x_norm = jnp.linalg.norm(x, ord=2)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
x_norm = jnp.linalg.norm(
((x * scale_inv) if scaled_termination else x), ord=2
)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

if g_norm < gtol:
success, message = True, STATUS_MESSAGES["gtol"]
Expand All @@ -421,7 +430,7 @@ def fmintr( # noqa: C901

allx.append(x)
else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
23 changes: 16 additions & 7 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def lsqtr( # noqa: C901
value decomposition. ``"cho"`` is generally the fastest for large systems,
especially on GPU, but may be less accurate for badly scaled systems.
``"svd"`` is the most accurate but significantly slower. Default ``"qr"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -183,6 +185,7 @@ def lsqtr( # noqa: C901
maxiter = setdefault(maxiter, n * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"]
if jac_scale:
Expand All @@ -198,7 +201,9 @@ def lsqtr( # noqa: C901

g_h = g * d
J_h = J * d
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -239,7 +244,7 @@ def lsqtr( # noqa: C901

callback = setdefault(callback, lambda *args: False)

x_norm = jnp.linalg.norm(x, ord=2)
x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand Down Expand Up @@ -344,11 +349,11 @@ def lsqtr( # noqa: C901
)
alltr.append(trust_radius)
alpha *= tr_old / trust_radius
# TODO (#1395): does this need to move to the outer loop?

success, message = check_termination(
actual_reduction,
cost,
step_norm,
(step_h_norm if scaled_termination else step_norm),
x_norm,
g_norm,
reduction_ratio,
Expand Down Expand Up @@ -386,8 +391,12 @@ def lsqtr( # noqa: C901

g_h = g * d
J_h = J * d
x_norm = jnp.linalg.norm(x, ord=2)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
x_norm = jnp.linalg.norm(
((x * scale_inv) if scaled_termination else x), ord=2
)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

if g_norm < gtol:
success, message = True, STATUS_MESSAGES["gtol"]
Expand All @@ -396,7 +405,7 @@ def lsqtr( # noqa: C901
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
Loading

0 comments on commit 1d03f2f

Please sign in to comment.