Skip to content

Commit

Permalink
probably a stupid idea but anyway
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Nov 23, 2024
1 parent 0c63bf7 commit fa6a733
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 37 deletions.
10 changes: 3 additions & 7 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from scipy.optimize import OptimizeResult

from desc.backend import jax, jnp, qr
from desc.backend import jnp, qr
from desc.utils import errorif, setdefault

from .bound_utils import (
Expand Down Expand Up @@ -272,7 +272,7 @@ def lsqtr( # noqa: C901
U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False)
elif tr_method == "cho":
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
elif tr_method == "qr" or tr_method == "direct":

Check warning on line 275 in desc/optimize/least_squares.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/least_squares.py#L275

Added line #L275 was not covered by tests
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
if tall:
Expand All @@ -281,10 +281,6 @@ def lsqtr( # noqa: C901
else:
Q, R = qr(J_a.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f_a, lower=True)
elif tr_method == "direct":
JTJ = J_a.T @ J_a
fp = -J_a.T @ f_a
p_newton = jax.scipy.linalg.solve(JTJ, fp, assume_a="sym")

actual_reduction = -1

Expand All @@ -310,7 +306,7 @@ def lsqtr( # noqa: C901
)
elif tr_method == "direct":
step_h, hits_boundary, alpha = trust_region_step_exact_direct(

Check warning on line 308 in desc/optimize/least_squares.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/least_squares.py#L307-L308

Added lines #L307 - L308 were not covered by tests
p_newton, fp, JTJ, trust_radius, alpha
p_newton, f_a, Q, R, trust_radius, alpha
)
step = d * step_h # Trust-region solution in the original space.

Expand Down
55 changes: 25 additions & 30 deletions desc/optimize/tr_subproblems.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
cho_factor,
cho_solve,
cond,
jax,
jit,
jnp,
qr,
Expand Down Expand Up @@ -485,19 +484,21 @@ def loop_body(state):

@jit
def trust_region_step_exact_direct(
p_newton, fp, JTJ, trust_radius, initial_alpha=None, rtol=0.005, max_iter=10
p_newton, fa, Q, R, trust_radius, initial_alpha=None, rtol=0.005, max_iter=10
):
"""Solve a trust-region problem using a semi-exact method.
Solves problems of the form
min_p ||J*p + f||^2, ||p|| < trust_radius
min_p ||QR*p + f||^2, ||p|| < trust_radius
Parameters
----------
fp : ndarray
p_newton : ndarray
The step found by the Newton method.
fa : ndarray
Vector of residuals. fp=-J.T@f
JTJ : ndarray
Jacobian matrix. JTJ=J.T@J
Q, R : ndarray
QR decomposition of J.
trust_radius : float
Radius of a trust region.
initial_alpha : float, optional
Expand Down Expand Up @@ -526,32 +527,28 @@ def truefun(*_):
return p_newton, False, 0.0

Check warning on line 527 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L526-L527

Added lines #L526 - L527 were not covered by tests

def falsefun(*_):
alpha_upper = jnp.linalg.norm(fp) / trust_radius
QTf = Q.T @ fa
alpha_upper = jnp.linalg.norm(QTf) / trust_radius
alpha_lower = 0.0
alpha = setdefault(
initial_alpha,
0.001 * alpha_upper,
)
alpha_prev = 0.9 * alpha
p = jax.scipy.linalg.solve(
JTJ + alpha_prev * jnp.eye(JTJ.shape[0]), fp, assume_a="sym"
)
p_norm = jnp.linalg.norm(p)
phi_prev = p_norm - trust_radius
alpha = setdefault(initial_alpha, 0.001 * alpha_upper)
alpha_prev = 0.8 * alpha
p = solve_triangular(R + alpha_prev * jnp.eye(R.shape[0]), QTf)
phi_prev = jnp.linalg.norm(p) - trust_radius
k = 0

Check warning on line 537 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L529-L537

Added lines #L529 - L537 were not covered by tests

def loop_cond(state):
alpha, alpha_prev, alpha_lower, alpha_upper, phi, phi_prev, k = state
return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter)
alpha, alpha_prev, alpha_lower, alpha_upper, phi_prev, k = state
return (jnp.abs(phi_prev) > rtol * trust_radius) & (k < max_iter)

Check warning on line 541 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L539-L541

Added lines #L539 - L541 were not covered by tests

def loop_body(state):
alpha, alpha_prev, alpha_lower, alpha_upper, phi, phi_prev, k = state

# In future, maybe try to find an update to inverse instead of
# resolving from scratch
p = jax.scipy.linalg.solve(
JTJ + alpha * jnp.eye(JTJ.shape[0]), fp, assume_a="sym"
alpha, alpha_prev, alpha_lower, alpha_upper, phi_prev, k = state
alpha = jnp.where(

Check warning on line 545 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L543-L545

Added lines #L543 - L545 were not covered by tests
(alpha < alpha_lower) | (alpha > alpha_upper),
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)

p = solve_triangular(R + alpha * jnp.eye(R.shape[0]), QTf)
p_norm = jnp.linalg.norm(p)
phi = p_norm - trust_radius
alpha_upper = jnp.where(phi < 0, alpha, alpha_upper)
Expand All @@ -567,17 +564,15 @@ def loop_body(state):
)

k += 1
return alpha_new, alpha_prev, alpha_lower, alpha_upper, phi, phi_prev, k
return alpha_new, alpha_prev, alpha_lower, alpha_upper, phi, k

Check warning on line 567 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L566-L567

Added lines #L566 - L567 were not covered by tests

alpha, *_ = while_loop(

Check warning on line 569 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L569

Added line #L569 was not covered by tests
loop_cond,
loop_body,
(alpha, alpha_prev, alpha_lower, alpha_upper, jnp.inf, phi_prev, k),
(alpha, alpha_prev, alpha_lower, alpha_upper, phi_prev, k),
)

p = jax.scipy.linalg.solve(
JTJ + alpha * jnp.eye(JTJ.shape[0]), fp, assume_a="sym"
)
p = solve_triangular(R + alpha * jnp.eye(R.shape[0]), QTf)

Check warning on line 575 in desc/optimize/tr_subproblems.py

View check run for this annotation

Codecov / codecov/patch

desc/optimize/tr_subproblems.py#L575

Added line #L575 was not covered by tests

# Make the norm of p equal to trust_radius; p is changed only slightly.
# This is done to prevent p from lying outside the trust region
Expand Down

0 comments on commit fa6a733

Please sign in to comment.