Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guidance for writing a DiagonalPyTreeLinearOperator #124

Open
johannahaffner opened this issue Jan 2, 2025 · 2 comments
Open

Guidance for writing a DiagonalPyTreeLinearOperator #124

johannahaffner opened this issue Jan 2, 2025 · 2 comments
Labels
question User queries

Comments

@johannahaffner
Copy link

Hi Patrick,

I'm implementing the interior-reflective Newton step method of Coleman and Li, for bound-constrained optimisation, in optimistix.
The core of this thing is to scale and translate the Hessian, such that

$B_k = D_k \cdot H_k \cdot D_k + diag(g_k) \cdot Jac(|v|) $

defines the operator for the Newton step, where the diagonal elements of $D_k$ are defined by a simple function on y (distance to bounds or $\pm1$ depending on some conditions on the gradient, all computed element-wise). For this to work for a general PyTree, I believe (?) the best thing to do would be to define a DiagonalPyTreeLinearOperator.

I noticed that the Hessian approximation we're computing in optx.BFGS is actually "square", by which I mean that they do not have the redundant entries one finds, for example, in jax.hessian, and as_matrix returns a square matrix (this is not true for jax.hessian). So maybe I'm making things more complicated than they need to be?

@patrick-kidger
Copy link
Owner

So I think I can imagine two different kinds of 'diagonal pytree': one is where the pytree is the diagonal, and e.g. the diagonal array inside of DiagonaLinearOperator corresponds to its ravelled form. The second is where each element of the diagonal is a pytree, i.e. it's really a block diagonal operator.

If I understand your case correctly, I think you're in the first one? If so then I think that should be fairly straightforward to implement: pretty much copy-paste the existing DiagonalLinearOperator implementation, where the internal storage is now a pytree, mv is given by tree_map(operator.mul, ...) (I think), and as_matrix involves allocating a zero output matrix whose diagonal is filled in via .at[].set(). Does that make sense?

FWIW I'm not sure what you mean by 'redundant entries' in jax.hessian? Can you give an example?

@patrick-kidger patrick-kidger added the question User queries label Jan 3, 2025
@johannahaffner
Copy link
Author

johannahaffner commented Jan 4, 2025

FWIW I'm not sure what you mean by 'redundant entries' in jax.hessian? Can you give an example?

Making an MWE I think I got to the bottom of this - I was taking a hessian of a function that accidentally returned a list of two values, rather than a scalar. This caused my matrix sizes to go to (n, 2n), for n elements in a pytree.

If I understand your case correctly, I think you're in the first one? ... Does that make sense?

Yes and it does!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants