diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index 65c7de50a0..27a1d9f8ba 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -2,7 +2,7 @@ from scipy.optimize import NonlinearConstraint, OptimizeResult -from desc.backend import jnp +from desc.backend import jnp, qr from desc.utils import errorif, setdefault from .bound_utils import ( @@ -25,6 +25,7 @@ inequality_to_bounds, print_header_nonlinear, print_iteration_nonlinear, + solve_triangular_regularized, ) @@ -368,6 +369,15 @@ def lagjac(z, y, mu, *args): U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False) elif tr_method == "cho": B_h = jnp.dot(J_a.T, J_a) + elif tr_method == "qr": + # try full newton step + tall = J_a.shape[0] >= J_a.shape[1] + if tall: + Q, R = qr(J_a, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ L_a) + else: + Q, R = qr(J_a.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, -L_a, lower=True) actual_reduction = -1 Lactual_reduction = -1 @@ -390,7 +400,7 @@ def lagjac(z, y, mu, *args): ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - L_a, J_a, trust_radius, alpha + p_newton, L_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 0ed35a506a..227cd93f70 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -2,7 +2,7 @@ from scipy.optimize import OptimizeResult -from desc.backend import jnp +from desc.backend import jnp, qr from desc.utils import errorif, setdefault from .bound_utils import ( @@ -24,6 +24,7 @@ compute_jac_scale, print_header_nonlinear, print_iteration_nonlinear, + solve_triangular_regularized, ) @@ -268,6 +269,15 @@ def lsqtr( # noqa: C901 - FIXME: simplify this U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False) elif tr_method == "cho": B_h = jnp.dot(J_a.T, J_a) + elif tr_method == "qr": + # try full newton step + tall = J_a.shape[0] >= J_a.shape[1] + if tall: + Q, R = qr(J_a, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f_a) + else: + Q, R = qr(J_a.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, -f_a, lower=True) actual_reduction = -1 @@ -289,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - f_a, J_a, trust_radius, alpha + p_newton, f_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/tr_subproblems.py b/desc/optimize/tr_subproblems.py index b107bca0a4..8c39e82295 100644 --- a/desc/optimize/tr_subproblems.py +++ b/desc/optimize/tr_subproblems.py @@ -378,7 +378,7 @@ def loop_body(state): @jit def trust_region_step_exact_qr( - f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 + p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 ): """Solve a trust-region problem using a semi-exact method. @@ -414,14 +414,6 @@ def trust_region_step_exact_qr( Sometimes called Levenberg-Marquardt parameter. """ - # try full newton step - tall = J.shape[0] >= J.shape[1] - if tall: - Q, R = qr(J, mode="economic") - p_newton = solve_triangular_regularized(R, -Q.T @ f) - else: - Q, R = qr(J.T, mode="economic") - p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) def truefun(*_): return p_newton, False, 0.0 @@ -453,6 +445,7 @@ def loop_body(state): Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) # Ji is always tall since its padded by alpha*I Q, R = qr(Ji, mode="economic") + p = solve_triangular_regularized(R, -Q.T @ fp) p_norm = jnp.linalg.norm(p) phi = p_norm - trust_radius @@ -474,6 +467,7 @@ def loop_body(state): alpha, *_ = while_loop( loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) ) + Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) Q, R = qr(Ji, mode="economic") p = solve_triangular(R, -Q.T @ fp)