Skip to content

Commit

Permalink
remove qr_update stuff, only take p_newton out
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Aug 15, 2024
1 parent af2b502 commit 5e30fdc
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 118 deletions.
5 changes: 0 additions & 5 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
repeat = jnp.repeat
take = jnp.take
scan = jax.lax.scan
rsqrt = jax.lax.rsqrt
from jax import custom_jvp
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
Expand Down Expand Up @@ -812,7 +811,3 @@ def take(
else:
out = np.take(a, indices, axis, out, mode)
return out

def rsqrt(x):
"""Reciprocal square root."""
return 1 / np.sqrt(x)
17 changes: 6 additions & 11 deletions desc/optimize/aug_lagrangian_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,18 +371,13 @@ def lagjac(z, y, mu, *args):
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
tall = J.shape[0] >= J.shape[1]
if tall:
Q, R = qr(J_a)
p_newton = solve_triangular_regularized(
R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ L_a
)
Q, R = qr(J, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ f)
else:
Q, R = qr(J_a.T)
p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized(
R[: R.shape[1], : R.shape[1]].T, L_a, lower=True
)
Q, R = qr(J_a)
Q, R = qr(J.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True)

actual_reduction = -1
Lactual_reduction = -1
Expand All @@ -405,7 +400,7 @@ def lagjac(z, y, mu, *args):
)
elif tr_method == "qr":
step_h, hits_boundary, alpha = trust_region_step_exact_qr(
Q, R, p_newton, L_a, J_a, trust_radius, alpha
p_newton, L_a, J_a, trust_radius, alpha
)

step = d * step_h # Trust-region solution in the original space.
Expand Down
17 changes: 6 additions & 11 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,13 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
tall = J.shape[0] >= J.shape[1]
if tall:
Q, R = qr(J_a)
p_newton = solve_triangular_regularized(
R[: R.shape[1], : R.shape[1]], -Q[:, : R.shape[1]].T @ f_a
)
Q, R = qr(J, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ f)
else:
Q, R = qr(J_a.T)
p_newton = Q[:, : R.shape[1]] @ solve_triangular_regularized(
R[: R.shape[1], : R.shape[1]].T, f_a, lower=True
)
Q, R = qr(J_a)
Q, R = qr(J.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True)

actual_reduction = -1

Expand All @@ -304,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
)
elif tr_method == "qr":
step_h, hits_boundary, alpha = trust_region_step_exact_qr(
Q, R, p_newton, f_a, J_a, trust_radius, alpha
p_newton, f_a, J_a, trust_radius, alpha
)
step = d * step_h # Trust-region solution in the original space.

Expand Down
19 changes: 11 additions & 8 deletions desc/optimize/tr_subproblems.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from desc.utils import setdefault

from .utils import chol, solve_triangular_regularized, update_qr_jax_eco
from .utils import chol, solve_triangular_regularized


@jit
Expand Down Expand Up @@ -378,7 +378,7 @@ def loop_body(state):

@jit
def trust_region_step_exact_qr(
Q, R, p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
):
"""Solve a trust-region problem using a semi-exact method.
Expand Down Expand Up @@ -441,15 +441,18 @@ def loop_body(state):
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)
Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R)

p = solve_triangular_regularized(R2, -Q2.T @ fp)
Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
# Ji is always tall since its padded by alpha*I
Q, R = qr(Ji, mode="economic")

p = solve_triangular_regularized(R, -Q.T @ fp)
p_norm = jnp.linalg.norm(p)
phi = p_norm - trust_radius
alpha_upper = jnp.where(phi < 0, alpha, alpha_upper)
alpha_lower = jnp.where(phi > 0, alpha, alpha_lower)

q = solve_triangular_regularized(R2.T, p, lower=True)
q = solve_triangular_regularized(R.T, p, lower=True)
q_norm = jnp.linalg.norm(q)

alpha += (p_norm / q_norm) ** 2 * phi / trust_radius
Expand All @@ -465,9 +468,9 @@ def loop_body(state):
loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k)
)

Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R)

p = solve_triangular(R2, -Q2.T @ fp)
Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
Q, R = qr(Ji, mode="economic")
p = solve_triangular(R, -Q.T @ fp)

# Make the norm of p equal to trust_radius; p is changed only slightly.
# This is done to prevent p from lying outside the trust region
Expand Down
84 changes: 1 addition & 83 deletions desc/optimize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from desc.backend import cond, fori_loop, jit, jnp, put, rsqrt, solve_triangular
from desc.backend import cond, jit, jnp, put, solve_triangular
from desc.utils import Index


Expand Down Expand Up @@ -551,85 +551,3 @@ def solve_triangular_regularized(R, b, lower=False):
Rs = R * dri[:, None]
b = dri * b
return solve_triangular(Rs, b, unit_diagonal=True, lower=lower)


# TODO: add references to the docstrings
def _givens_jax(a, b):
"""Compute Givens rotation matrix.
Compute the Givens rotation matrix G2 that zeros out the second element
of a 2-vector.
G2*[a; b] = [r; 0]
where r = sqrt(a^2 + b^2)
G2 = [[c, -s], [s, c]]
"""
# Taken from jax._src.scipy.sparse.linalg._givens_rotation
b_zero = abs(b) == 0
a_lt_b = abs(a) < abs(b)
t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
r = rsqrt(1 + abs(t) ** 2).astype(t.dtype)
cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
G2 = jnp.array([[cs, -sn], [sn, cs]])
return G2.astype(float)


@jit
def update_qr_jax(A, w, q, r):
"""Update QR factorization with a diagonal matrix w at the bottom."""
m, n = A.shape
Q = jnp.eye(m + n)
Q = Q.at[:m, :m].set(q)

R = jnp.vstack([r, w])

def body_inner(i, jQR):
j, Q, R = jQR
i = m + j - i
a, b = R[i - 1, j], R[i, j]
G2 = _givens_jax(a, b)
R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)
return j, Q, R

def body(j, QR):
Q, R = QR
j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
return Q, R

Q, R = fori_loop(0, n, body, (Q, R))
R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

return Q, R


@jit
def update_qr_jax_eco(A, w, q, r):
"""Update QR factorization with a diagonal matrix w at the bottom."""
m, n = A.shape
Q = jnp.eye(m + n)
Q = Q.at[:m, :m].set(q)

R = jnp.vstack([r, w])

def body_inner(i, jQR):
j, Q, R = jQR
i = m + j - i
a, b = R[i - 1, j], R[i, j]
G2 = _givens_jax(a, b)
R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)
return j, Q, R

def body(j, QR):
Q, R = QR
j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
return Q, R

Q, R = fori_loop(0, n, body, (Q, R))
R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

Re = R.at[: R.shape[1], : R.shape[1]].get()
Qe = Q.at[:, : R.shape[1]].get()

return Qe, Re

0 comments on commit 5e30fdc

Please sign in to comment.