-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use PyTreeLinearOperator #54
Comments
So the issue in this example is There's a number of ways to fix this. The cleanest is probably to use F = model.residual(state)
J = lx.JacobianLinearOperator(lambda x, a: model.residual(x), state)
lx.linear_solve(J, F) where the lambda is there because in general we anticipate that the function in If you wanted to stick with Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: F))
soln = lx.linear_solve(J, F)
soln.value # A JAX array with the same shape as F or make Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = State(model.residual(state))
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
soln = lx.linear_solve(J, F)
soln.value # A PyTree with the same structure as State However, I'm assuming this third option is not what you had in mind. |
Hi Jason, Thank you so much! It worked pretty well. As you might have guessed, it is used to solve some PDEs. Now, I am thinking of just dumping my solver and switching to optimistix in fact because it can do all the job under the hood (if they support Newton with backtracking line search but it seems not). |
It's true, Newton with backtracking line search isn't something we've implemented in Optimistix yet. (patrick-kidger/optimistix#4) It should be essentially straightforward to do, though: Newton and Gauss-Newton are basically the same algorithm, just applied to different problems. As such copy-pasting |
Sounds good! I am learning how you implemented Newton. I will update you
once I make something useful.
…On Wed, Oct 25, 2023 at 6:05 PM Patrick Kidger ***@***.***> wrote:
It's true, Newton with backtracking line search isn't something we've
implemented in Optimistix yet. (patrick-kidger/optimistix#4
<patrick-kidger/optimistix#4>)
It should be essentially straightforward to do, though: Newton and
Gauss-Newton are basically the same algorithm, just applied to different
problems. As such copy-pasting optx.AbstractGaussNewton would get us 95%
of the way there. (If you feel up to we'd be happy to take a PR on that.)
—
Reply to this email directly, view it on GitHub
<#54 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ALGSKBXGYC4T72X2JMCM6VDYBGZNTAVCNFSM6AAAAAA6OQPUPGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTOOBQGI3DAMRYG4>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
--
Best,
Toshi
|
Just to be clear, are you handling a general minimisation problem, or a nonlinear least-squares problem? In your example problem, If your actual PDE is also of this form, ie. you have some vector of residuals and you'd like to minimise the sum of their squares, then check out |
I am solving the system of non-linear equations (resulting from nonlinear PDEs) by the Newton method with backtracking line search. So, it's a root finding problem. |
Gauss-Newton and backtracking Armijo with the mix-and-match API should work then. Gauss-Newton is mathematically equivalent to Newton for nonlinear systems. As Patrick said, they're just applied a little differently. The aim is to solve Which is the Newton update. You don't have to do this conversion manually, calling |
Okay, that sounds good too. I will give it a shot. I will keep you posted. |
Hi lineax community,
I came from Patrick's comment (jax-ml/jax#17203 (comment)). I believe$Jx = -F$ , where the Jacobian matrix $F$ is a PyTree. How can I use
PyTreeLinearOperator
will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton systemPyTreeLinearOperator
correctly?The text was updated successfully, but these errors were encountered: