Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use PyTreeLinearOperator #54

Open
ToshiyukiBandai opened this issue Oct 25, 2023 · 8 comments
Open

How to use PyTreeLinearOperator #54

ToshiyukiBandai opened this issue Oct 25, 2023 · 8 comments
Labels
question User queries

Comments

@ToshiyukiBandai
Copy link

Hi lineax community,

I came from Patrick's comment (jax-ml/jax#17203 (comment)). I believe PyTreeLinearOperator will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton system $Jx = -F$, where the Jacobian matrix $F$ is a PyTree. How can I use PyTreeLinearOperator correctly?

import jax
from jax import jit, vmap, lax, jacfwd, jacrev, grad, vjp, jvp, random
import jax.numpy as jnp
from jax.config import config
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten
import equinox as eqx
from jaxtyping import Float, Array, Bool
import lineax as lx

class Parameter(eqx.Module):
    alpha: Float[Array, ""]
    beta: Float[Array, ""]
    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta

class State(eqx.Module):
    x_0: Array
    x_1: Array
    def __init__(self, x_guess):
        self.x_0 = x_guess[0]
        self.x_1 = x_guess[1]

class Model(eqx.Module):
    parameters: eqx.Module
    def __init__(self, parameters):
        self.parameters = parameters
    def residual(self, state):
        F_0 = state.x_0**2 + state.x_1**2 - self.parameters.alpha
        F_1 = self.parameters.beta*state.x_0**3 - state.x_1
        return jnp.array([F_0, F_1])

alpha = 4.0
beta = 1.0
x_test = jnp.array([2.0, 3.0])

parameter = Parameter(alpha, beta)
state = State(x_test)
model = Model(parameter)

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
# lx.linear_solve(J, F) # failed
@packquickly
Copy link
Collaborator

packquickly commented Oct 25, 2023

So the issue in this example is F is a JAX array, but the output structure of J has the PyTree structure of State, as set by the out_structure in lx.PyTreeLinearOperator(fn, out_structure). So $Jx$ has the PyTree structure of State but $F$ has the PyTree structure of a standard JAX array. Throwing an error is the correct thing to do here, since $Jx = F$ doesn't make sense when $Jx$ and $F$ have different PyTree structures.

There's a number of ways to fix this. The cleanest is probably to use lx.JacobianLinearOperator, which exists specifically to simplify cases like this. In this case, the final code block becomes:

F = model.residual(state)
J = lx.JacobianLinearOperator(lambda x, a: model.residual(x), state)
lx.linear_solve(J, F)

where the lambda is there because in general we anticipate that the function in JacobianLinearOperator can take an extra args argument. There's no need to define Jacobian_JAX_class here, it's taken care of by lx.JacobianLinearOperator.

If you wanted to stick with PyTreeLinearOperator, then either replace the eval_shape with jax.eval_shape(lambda: F)

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: F))
soln = lx.linear_solve(J, F)
soln.value # A JAX array with the same shape as F

or make F have a PyTree structure of State

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = State(model.residual(state))
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
soln = lx.linear_solve(J, F)
soln.value # A PyTree with the same structure as State

However, I'm assuming this third option is not what you had in mind.

@packquickly packquickly added the question User queries label Oct 25, 2023
@ToshiyukiBandai
Copy link
Author

ToshiyukiBandai commented Oct 25, 2023

Hi Jason,

Thank you so much! It worked pretty well. As you might have guessed, it is used to solve some PDEs. Now, I am thinking of just dumping my solver and switching to optimistix in fact because it can do all the job under the hood (if they support Newton with backtracking line search but it seems not).

@patrick-kidger
Copy link
Owner

It's true, Newton with backtracking line search isn't something we've implemented in Optimistix yet. (patrick-kidger/optimistix#4)

It should be essentially straightforward to do, though: Newton and Gauss-Newton are basically the same algorithm, just applied to different problems. As such copy-pasting optx.AbstractGaussNewton would get us 95% of the way there. (If you feel up to we'd be happy to take a PR on that.)

@ToshiyukiBandai
Copy link
Author

ToshiyukiBandai commented Oct 26, 2023 via email

@packquickly
Copy link
Collaborator

Just to be clear, are you handling a general minimisation problem, or a nonlinear least-squares problem? In your example problem, $F$ is a residual and $J$ the Jacobian, so the solution $x = -J^{-1}F$ is actually the Gauss-Newton step, assuming your loss is $F_1^2 + F_2^2$.

If your actual PDE is also of this form, ie. you have some vector of residuals and you'd like to minimise the sum of their squares, then check out optx.AbstractGaussNewton and optx.BacktrackingArmijo with the mix-and-match API.

@ToshiyukiBandai
Copy link
Author

I am solving the system of non-linear equations (resulting from nonlinear PDEs) by the Newton method with backtracking line search. So, it's a root finding problem.

@packquickly
Copy link
Collaborator

Gauss-Newton and backtracking Armijo with the mix-and-match API should work then. Gauss-Newton is mathematically equivalent to Newton for nonlinear systems. As Patrick said, they're just applied a little differently.

The aim is to solve $F(x) = 0$. Recasting as a nonlinear least-squares problem $\min_x \frac{1}{2} \Vert F(x) \Vert_2^2$, the residual vector is $F(x)$. For Jacobian $J$ of $F(x)$ the Gauss-Newton update is

$$ \begin{aligned} x_{k + 1} &= x_k - (J^T J)^{-1} J^T F(x) \\ &= x_k - J^{-1} (J^T)^{-1} J^T F(x) \\ &= x_k - J^{-1} F(x). \end{aligned}$$

Which is the Newton update. You don't have to do this conversion manually, calling optx.root_find(fn, YourFancyNewSolver, ...) will automatically convert to the least-squares problem $\min_x \frac{1}{2} \Vert F(x) \Vert_2^2$ and solve it as described above.

@ToshiyukiBandai
Copy link
Author

Okay, that sounds good too. I will give it a shot. I will keep you posted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants