diff --git a/desc/backend.py b/desc/backend.py index 22c5c60c6f..c26213b045 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -76,7 +76,6 @@ repeat = jnp.repeat take = jnp.take scan = jax.lax.scan - rsqrt = jax.lax.rsqrt from jax import custom_jvp from jax.experimental.ode import odeint from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular @@ -812,7 +811,3 @@ def take( else: out = np.take(a, indices, axis, out, mode) return out - - def rsqrt(x): - """Reciprocal square root.""" - return 1 / np.sqrt(x) diff --git a/desc/optimize/aug_lagrangian_ls.py b/desc/optimize/aug_lagrangian_ls.py index fb222d776c..a5d75416b2 100644 --- a/desc/optimize/aug_lagrangian_ls.py +++ b/desc/optimize/aug_lagrangian_ls.py @@ -371,18 +371,13 @@ def lagjac(z, y, mu, *args): B_h = jnp.dot(J_a.T, J_a) elif tr_method == "qr": # try full newton step - tall = J_a.shape[0] >= J_a.shape[1] + tall = J.shape[0] >= J.shape[1] if tall: - Q, R = qr(J_a) - p_newton = solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ L_a - ) + Q, R = qr(J, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f) else: - Q, R = qr(J_a.T) - p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]].T, L_a, lower=True - ) - Q, R = qr(J_a) + Q, R = qr(J.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) actual_reduction = -1 Lactual_reduction = -1 @@ -405,7 +400,7 @@ def lagjac(z, y, mu, *args): ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - Q, R, p_newton, L_a, J_a, trust_radius, alpha + p_newton, L_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index a459e7f760..d373642995 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -271,18 +271,13 @@ def lsqtr( # noqa: C901 - FIXME: simplify this B_h = jnp.dot(J_a.T, J_a) elif tr_method == "qr": # try full newton step - tall = J_a.shape[0] >= J_a.shape[1] + tall = J.shape[0] >= J.shape[1] if tall: - Q, R = qr(J_a) - p_newton = solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ f_a - ) + Q, R = qr(J, mode="economic") + p_newton = solve_triangular_regularized(R, -Q.T @ f) else: - Q, R = qr(J_a.T) - p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized( - R[: R.shape[1], : R.shape[1]].T, f_a, lower=True - ) - Q, R = qr(J_a) + Q, R = qr(J.T, mode="economic") + p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True) actual_reduction = -1 @@ -304,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this ) elif tr_method == "qr": step_h, hits_boundary, alpha = trust_region_step_exact_qr( - Q, R, p_newton, f_a, J_a, trust_radius, alpha + p_newton, f_a, J_a, trust_radius, alpha ) step = d * step_h # Trust-region solution in the original space. diff --git a/desc/optimize/tr_subproblems.py b/desc/optimize/tr_subproblems.py index 20d874644a..8c39e82295 100644 --- a/desc/optimize/tr_subproblems.py +++ b/desc/optimize/tr_subproblems.py @@ -14,7 +14,7 @@ ) from desc.utils import setdefault -from .utils import chol, solve_triangular_regularized, update_qr_jax_eco +from .utils import chol, solve_triangular_regularized @jit @@ -378,7 +378,7 @@ def loop_body(state): @jit def trust_region_step_exact_qr( - Q, R, p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 + p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10 ): """Solve a trust-region problem using a semi-exact method. @@ -441,15 +441,18 @@ def loop_body(state): jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), alpha, ) - Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - p = solve_triangular_regularized(R2, -Q2.T @ fp) + Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) + # Ji is always tall since its padded by alpha*I + Q, R = qr(Ji, mode="economic") + + p = solve_triangular_regularized(R, -Q.T @ fp) p_norm = jnp.linalg.norm(p) phi = p_norm - trust_radius alpha_upper = jnp.where(phi < 0, alpha, alpha_upper) alpha_lower = jnp.where(phi > 0, alpha, alpha_lower) - q = solve_triangular_regularized(R2.T, p, lower=True) + q = solve_triangular_regularized(R.T, p, lower=True) q_norm = jnp.linalg.norm(q) alpha += (p_norm / q_norm) ** 2 * phi / trust_radius @@ -465,9 +468,9 @@ def loop_body(state): loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) ) - Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R) - - p = solve_triangular(R2, -Q2.T @ fp) + Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) + Q, R = qr(Ji, mode="economic") + p = solve_triangular(R, -Q.T @ fp) # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region diff --git a/desc/optimize/utils.py b/desc/optimize/utils.py index 8c31cdc1a2..bbee5a22f7 100644 --- a/desc/optimize/utils.py +++ b/desc/optimize/utils.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import cond, fori_loop, jit, jnp, put, rsqrt, solve_triangular +from desc.backend import cond, jit, jnp, put, solve_triangular from desc.utils import Index @@ -551,85 +551,3 @@ def solve_triangular_regularized(R, b, lower=False): Rs = R * dri[:, None] b = dri * b return solve_triangular(Rs, b, unit_diagonal=True, lower=lower) - - -# TODO: add references to the docstrings -def _givens_jax(a, b): - """Compute Givens rotation matrix. - - Compute the Givens rotation matrix G2 that zeros out the second element - of a 2-vector. - G2*[a; b] = [r; 0] - where r = sqrt(a^2 + b^2) - G2 = [[c, -s], [s, c]] - """ - # Taken from jax._src.scipy.sparse.linalg._givens_rotation - b_zero = abs(b) == 0 - a_lt_b = abs(a) < abs(b) - t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a) - r = rsqrt(1 + abs(t) ** 2).astype(t.dtype) - cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r)) - sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t)) - G2 = jnp.array([[cs, -sn], [sn, cs]]) - return G2.astype(float) - - -@jit -def update_qr_jax(A, w, q, r): - """Update QR factorization with a diagonal matrix w at the bottom.""" - m, n = A.shape - Q = jnp.eye(m + n) - Q = Q.at[:m, :m].set(q) - - R = jnp.vstack([r, w]) - - def body_inner(i, jQR): - j, Q, R = jQR - i = m + j - i - a, b = R[i - 1, j], R[i, j] - G2 = _givens_jax(a, b) - R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) - Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T) - return j, Q, R - - def body(j, QR): - Q, R = QR - j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) - return Q, R - - Q, R = fori_loop(0, n, body, (Q, R)) - R = jnp.where(jnp.abs(R) < 1e-10, 0, R) - - return Q, R - - -@jit -def update_qr_jax_eco(A, w, q, r): - """Update QR factorization with a diagonal matrix w at the bottom.""" - m, n = A.shape - Q = jnp.eye(m + n) - Q = Q.at[:m, :m].set(q) - - R = jnp.vstack([r, w]) - - def body_inner(i, jQR): - j, Q, R = jQR - i = m + j - i - a, b = R[i - 1, j], R[i, j] - G2 = _givens_jax(a, b) - R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])]) - Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T) - return j, Q, R - - def body(j, QR): - Q, R = QR - j, Q, R = fori_loop(0, m, body_inner, (j, Q, R)) - return Q, R - - Q, R = fori_loop(0, n, body, (Q, R)) - R = jnp.where(jnp.abs(R) < 1e-10, 0, R) - - Re = R.at[: R.shape[1], : R.shape[1]].get() - Qe = Q.at[:, : R.shape[1]].get() - - return Qe, Re