diff --git a/desc/optimize/least_squares.py b/desc/optimize/least_squares.py index 56e7d6e0b..9035b8157 100644 --- a/desc/optimize/least_squares.py +++ b/desc/optimize/least_squares.py @@ -3,7 +3,7 @@ from scipy.optimize import OptimizeResult from desc.backend import jnp, qr -from desc.utils import errorif, setdefault +from desc.utils import errorif, safediv, setdefault from .bound_utils import ( cl_scaling_vector, @@ -208,14 +208,12 @@ def lsqtr( # noqa: C901 # conngould : norm of the cauchy point, as recommended in ch17 of Conn & Gould # scipy : norm of the scaled x, as used in scipy # mix : geometric mean of conngould and scipy + scipy = jnp.linalg.norm(x * scale_inv / v**0.5) + conngould = safediv(jnp.sum(g_h**2), jnp.sum((J_h @ g_h) ** 2)) init_tr = { - "scipy": jnp.linalg.norm(x * scale_inv / v**0.5), - "conngould": jnp.sum(g_h**2) / jnp.sum((J_h @ g_h) ** 2), - "mix": jnp.sqrt( - jnp.sum(g_h**2) - / jnp.sum((J_h @ g_h) ** 2) - * jnp.linalg.norm(x * scale_inv / v**0.5) - ), + "scipy": scipy, + "conngould": conngould, + "mix": jnp.sqrt(conngould * scipy), } trust_radius = options.pop("initial_trust_radius", "scipy") tr_ratio = options.pop("initial_trust_ratio", 1.0)