Skip to content

Commit

Permalink
Take p_newton out of inner while loop (#1165)
Browse files Browse the repository at this point in the history
Resolves #1078 

Some performance improvements for QR decomposition used in optimization
which was first introduced in #1050.

- Take the `p_newton` calculation out of inner while loop, since it is
basically calculating the same QR over and over again
- ~Use proper QR update procedure for the `falsefun` in
`trust_region_step_exact_qr`. That is we already now QR decomposition of
`J=QR`, if we stack a diagonal matrix `aI` to `J` then instead of taking
the whole QR decomposition again, there is a more clever way of updating
the QR.There are methods for updating a QR factorization when you add
rows. Suppose we have~

$$
QR = J
$$

what we want is 

$$
\tilde{Q} \tilde{R} = \begin{pmatrix} J \\ 
\alpha I \end{pmatrix}
$$

The QR update procedure can be implemented on a later PR with
Householder matrices, but for now, it seems a bit inefficient to
implement using JAX since QR is calculated by Fortran package LAPACK on
Scipy and Jax, our custom QR'ish thing will be slow.
  • Loading branch information
ddudt authored Aug 21, 2024
2 parents 13108f6 + 7f3858d commit 425fb02
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
14 changes: 12 additions & 2 deletions desc/optimize/aug_lagrangian_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from scipy.optimize import NonlinearConstraint, OptimizeResult

from desc.backend import jnp
from desc.backend import jnp, qr
from desc.utils import errorif, setdefault

from .bound_utils import (
Expand All @@ -25,6 +25,7 @@
inequality_to_bounds,
print_header_nonlinear,
print_iteration_nonlinear,
solve_triangular_regularized,
)


Expand Down Expand Up @@ -368,6 +369,15 @@ def lagjac(z, y, mu, *args):
U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False)
elif tr_method == "cho":
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
if tall:
Q, R = qr(J_a, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ L_a)
else:
Q, R = qr(J_a.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -L_a, lower=True)

actual_reduction = -1
Lactual_reduction = -1
Expand All @@ -390,7 +400,7 @@ def lagjac(z, y, mu, *args):
)
elif tr_method == "qr":
step_h, hits_boundary, alpha = trust_region_step_exact_qr(
L_a, J_a, trust_radius, alpha
p_newton, L_a, J_a, trust_radius, alpha
)

step = d * step_h # Trust-region solution in the original space.
Expand Down
14 changes: 12 additions & 2 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from scipy.optimize import OptimizeResult

from desc.backend import jnp
from desc.backend import jnp, qr
from desc.utils import errorif, setdefault

from .bound_utils import (
Expand All @@ -24,6 +24,7 @@
compute_jac_scale,
print_header_nonlinear,
print_iteration_nonlinear,
solve_triangular_regularized,
)


Expand Down Expand Up @@ -268,6 +269,15 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False)
elif tr_method == "cho":
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
if tall:
Q, R = qr(J_a, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ f_a)
else:
Q, R = qr(J_a.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f_a, lower=True)

actual_reduction = -1

Expand All @@ -289,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
)
elif tr_method == "qr":
step_h, hits_boundary, alpha = trust_region_step_exact_qr(
f_a, J_a, trust_radius, alpha
p_newton, f_a, J_a, trust_radius, alpha
)
step = d * step_h # Trust-region solution in the original space.

Expand Down
12 changes: 3 additions & 9 deletions desc/optimize/tr_subproblems.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def loop_body(state):

@jit
def trust_region_step_exact_qr(
f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
):
"""Solve a trust-region problem using a semi-exact method.
Expand Down Expand Up @@ -414,14 +414,6 @@ def trust_region_step_exact_qr(
Sometimes called Levenberg-Marquardt parameter.
"""
# try full newton step
tall = J.shape[0] >= J.shape[1]
if tall:
Q, R = qr(J, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ f)
else:
Q, R = qr(J.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True)

def truefun(*_):
return p_newton, False, 0.0
Expand Down Expand Up @@ -453,6 +445,7 @@ def loop_body(state):
Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
# Ji is always tall since its padded by alpha*I
Q, R = qr(Ji, mode="economic")

p = solve_triangular_regularized(R, -Q.T @ fp)
p_norm = jnp.linalg.norm(p)
phi = p_norm - trust_radius
Expand All @@ -474,6 +467,7 @@ def loop_body(state):
alpha, *_ = while_loop(
loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k)
)

Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
Q, R = qr(Ji, mode="economic")
p = solve_triangular(R, -Q.T @ fp)
Expand Down

0 comments on commit 425fb02

Please sign in to comment.