Skip to content

Commit

Permalink
Merge branch 'rc/stopping_criteria' into rc/examples
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Dec 3, 2024
2 parents c66a34b + 24da16f commit bcfe01c
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ New Features
* use of both this and the ``QuadraticFlux`` objective allows for REGCOIL solutions to be obtained through the optimization framework, and combined with other objectives as well.
- Changes local area weighting of Bn in QuadraticFlux objective to be the square root of the local area element (Note that any existing optimizations using this objective may need different weights to achieve the same result now.)
- Adds a new tutorial showing how to use``REGCOIL`` features.
- Adds an option ``scaled_termination`` (defaults to True) to all of the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``.


Bug Fixes
Expand Down
50 changes: 30 additions & 20 deletions desc/optimize/aug_lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def fmin_auglag( # noqa: C901
problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic
for choosing the initial scale. The heuristic is described in [2]_, p.143.
By default uses ``"auto"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -312,18 +313,6 @@ def laghess(z, y, mu, *args):
y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
y, mu, c = jnp.broadcast_arrays(y, mu, c)

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", 1.0)
eta = options.pop("eta", 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

L = lagfun(f, c, y, mu)
g = laggrad(z, y, mu, *args)
ngev += 1
Expand All @@ -338,6 +327,7 @@ def laghess(z, y, mu, *args):
maxiter = setdefault(maxiter, z.size * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"]
if hess_scale:
Expand All @@ -353,7 +343,9 @@ def laghess(z, y, mu, *args):

g_h = g * d
H_h = d * H * d[:, None]
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -399,7 +391,7 @@ def laghess(z, y, mu, *args):
)
subproblem = methods[tr_method]

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand All @@ -412,6 +404,18 @@ def laghess(z, y, mu, *args):
if g_norm < gtol and constr_violation < ctol:
success, message = True, STATUS_MESSAGES["gtol"]

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0)
eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

if verbose > 1:
print_header_nonlinear(True, "Penalty param", "max(|mltplr|)")
print_iteration_nonlinear(
Expand Down Expand Up @@ -493,7 +497,7 @@ def laghess(z, y, mu, *args):
success, message = check_termination(
actual_reduction,
f,
step_norm,
(step_h_norm if scaled_termination else step_norm),
z_norm,
g_norm,
Lreduction_ratio,
Expand Down Expand Up @@ -536,7 +540,9 @@ def laghess(z, y, mu, *args):
scale, scale_inv = compute_hess_scale(H)
v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# updating augmented lagrangian params
if g_norm < gtolk:
Expand Down Expand Up @@ -565,9 +571,13 @@ def laghess(z, y, mu, *args):

v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(
((z * scale_inv) if scaled_termination else z), ord=2
)
d = v**0.5 * scale
diag_h = g * dv * scale
g_h = g * d
Expand All @@ -580,7 +590,7 @@ def laghess(z, y, mu, *args):
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
49 changes: 30 additions & 19 deletions desc/optimize/aug_lagrangian_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def lsq_auglag( # noqa: C901
value decomposition. ``"cho"`` is generally the fastest for large systems,
especially on GPU, but may be less accurate for badly scaled systems.
``"svd"`` is the most accurate but significantly slower. Default ``"qr"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -254,18 +256,6 @@ def lagjac(z, y, mu, *args):
y = jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
y, mu, c = jnp.broadcast_arrays(y, mu, c)

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", 1.0)
eta = options.pop("eta", 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

L = lagfun(f, c, y, mu)
J = lagjac(z, y, mu, *args)
Lcost = 1 / 2 * jnp.dot(L, L)
Expand All @@ -276,6 +266,7 @@ def lagjac(z, y, mu, *args):
maxiter = setdefault(maxiter, z.size * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"]
if jac_scale:
Expand All @@ -291,7 +282,9 @@ def lagjac(z, y, mu, *args):

g_h = g * d
J_h = J * d
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -332,7 +325,7 @@ def lagjac(z, y, mu, *args):

callback = setdefault(callback, lambda *args: False)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(((z * scale_inv) if scaled_termination else z), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand All @@ -345,6 +338,18 @@ def lagjac(z, y, mu, *args):
if g_norm < gtol and constr_violation < ctol:
success, message = True, STATUS_MESSAGES["gtol"]

# notation following Conn & Gould, algorithm 14.4.2, but with our mu = their mu^-1
omega = options.pop("omega", min(g_norm, 1e-2) if scaled_termination else 1.0)
eta = options.pop("eta", min(constr_violation, 1e-2) if scaled_termination else 1.0)
alpha_omega = options.pop("alpha_omega", 1.0)
beta_omega = options.pop("beta_omega", 1.0)
alpha_eta = options.pop("alpha_eta", 0.1)
beta_eta = options.pop("beta_eta", 0.9)
tau = options.pop("tau", 10)

gtolk = max(omega / jnp.mean(mu) ** alpha_omega, gtol)
ctolk = max(eta / jnp.mean(mu) ** alpha_eta, ctol)

if verbose > 1:
print_header_nonlinear(True, "Penalty param", "max(|mltplr|)")
print_iteration_nonlinear(
Expand Down Expand Up @@ -454,7 +459,7 @@ def lagjac(z, y, mu, *args):
success, message = check_termination(
actual_reduction,
cost,
step_norm,
(step_h_norm if scaled_termination else step_norm),
z_norm,
g_norm,
Lreduction_ratio,
Expand Down Expand Up @@ -492,7 +497,9 @@ def lagjac(z, y, mu, *args):
scale, scale_inv = compute_jac_scale(J, scale_inv)
v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# updating augmented lagrangian params
if g_norm < gtolk:
Expand All @@ -516,9 +523,13 @@ def lagjac(z, y, mu, *args):

v, dv = cl_scaling_vector(z, g, lb, ub)
v = jnp.where(dv != 0, v * scale_inv, v)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

z_norm = jnp.linalg.norm(z, ord=2)
z_norm = jnp.linalg.norm(
((z * scale_inv) if scaled_termination else z), ord=2
)
d = v**0.5 * scale
diag_h = g * dv * scale
g_h = g * d
Expand All @@ -531,7 +542,7 @@ def lagjac(z, y, mu, *args):
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
21 changes: 15 additions & 6 deletions desc/optimize/fmin_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def fmintr( # noqa: C901
problem dimension. Set it to ``"auto"`` in order to use an automatic heuristic
for choosing the initial scale. The heuristic is described in [2]_, p.143.
By default uses ``"auto"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -222,6 +224,7 @@ def fmintr( # noqa: C901
maxiter = N * 100
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

hess_scale = isinstance(x_scale, str) and x_scale in ["hess", "auto"]
if hess_scale:
Expand All @@ -237,7 +240,9 @@ def fmintr( # noqa: C901

g_h = g * d
H_h = d * H * d[:, None]
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -283,7 +288,7 @@ def fmintr( # noqa: C901
)
subproblem = methods[tr_method]

x_norm = jnp.linalg.norm(x, ord=2)
x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand Down Expand Up @@ -366,7 +371,7 @@ def fmintr( # noqa: C901
success, message = check_termination(
actual_reduction,
f,
step_norm,
(step_h_norm if scaled_termination else step_norm),
x_norm,
g_norm,
reduction_ratio,
Expand Down Expand Up @@ -410,8 +415,12 @@ def fmintr( # noqa: C901
g_h = g * d
H_h = d * H * d[:, None]

x_norm = jnp.linalg.norm(x, ord=2)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
x_norm = jnp.linalg.norm(
((x * scale_inv) if scaled_termination else x), ord=2
)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

if g_norm < gtol:
success, message = True, STATUS_MESSAGES["gtol"]
Expand All @@ -421,7 +430,7 @@ def fmintr( # noqa: C901

allx.append(x)
else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
23 changes: 16 additions & 7 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def lsqtr( # noqa: C901
value decomposition. ``"cho"`` is generally the fastest for large systems,
especially on GPU, but may be less accurate for badly scaled systems.
``"svd"`` is the most accurate but significantly slower. Default ``"qr"``.
- ``"scaled_termination"`` : Whether to evaluate termination criteria for
``xtol`` and ``gtol`` in scaled / normalized units (default) or base units.
Returns
-------
Expand Down Expand Up @@ -183,6 +185,7 @@ def lsqtr( # noqa: C901
maxiter = setdefault(maxiter, n * 100)
max_nfev = options.pop("max_nfev", 5 * maxiter + 1)
max_dx = options.pop("max_dx", jnp.inf)
scaled_termination = options.pop("scaled_termination", True)

jac_scale = isinstance(x_scale, str) and x_scale in ["jac", "auto"]
if jac_scale:
Expand All @@ -198,7 +201,9 @@ def lsqtr( # noqa: C901

g_h = g * d
J_h = J * d
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

# conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould
# scipy : norm of the scaled x, as used in scipy
Expand Down Expand Up @@ -239,7 +244,7 @@ def lsqtr( # noqa: C901

callback = setdefault(callback, lambda *args: False)

x_norm = jnp.linalg.norm(x, ord=2)
x_norm = jnp.linalg.norm(((x * scale_inv) if scaled_termination else x), ord=2)
success = None
message = None
step_norm = jnp.inf
Expand Down Expand Up @@ -344,11 +349,11 @@ def lsqtr( # noqa: C901
)
alltr.append(trust_radius)
alpha *= tr_old / trust_radius
# TODO (#1395): does this need to move to the outer loop?

success, message = check_termination(
actual_reduction,
cost,
step_norm,
(step_h_norm if scaled_termination else step_norm),
x_norm,
g_norm,
reduction_ratio,
Expand Down Expand Up @@ -386,8 +391,12 @@ def lsqtr( # noqa: C901

g_h = g * d
J_h = J * d
x_norm = jnp.linalg.norm(x, ord=2)
g_norm = jnp.linalg.norm(g * v, ord=jnp.inf)
x_norm = jnp.linalg.norm(
((x * scale_inv) if scaled_termination else x), ord=2
)
g_norm = jnp.linalg.norm(
(g * v * scale if scaled_termination else g * v), ord=jnp.inf
)

if g_norm < gtol:
success, message = True, STATUS_MESSAGES["gtol"]
Expand All @@ -396,7 +405,7 @@ def lsqtr( # noqa: C901
success, message = False, STATUS_MESSAGES["callback"]

else:
step_norm = actual_reduction = 0
step_norm = step_h_norm = actual_reduction = 0

iteration += 1
if verbose > 1:
Expand Down
Loading

0 comments on commit bcfe01c

Please sign in to comment.