You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm implementing the interior-reflective Newton step method of Coleman and Li, for bound-constrained optimisation, in optimistix.
The core of this thing is to scale and translate the Hessian, such that
defines the operator for the Newton step, where the diagonal elements of $D_k$ are defined by a simple function on y (distance to bounds or $\pm1$ depending on some conditions on the gradient, all computed element-wise). For this to work for a general PyTree, I believe (?) the best thing to do would be to define a DiagonalPyTreeLinearOperator.
I noticed that the Hessian approximation we're computing in optx.BFGS is actually "square", by which I mean that they do not have the redundant entries one finds, for example, in jax.hessian, and as_matrix returns a square matrix (this is not true for jax.hessian). So maybe I'm making things more complicated than they need to be?
The text was updated successfully, but these errors were encountered:
So I think I can imagine two different kinds of 'diagonal pytree': one is where the pytree is the diagonal, and e.g. the diagonal array inside of DiagonaLinearOperator corresponds to its ravelled form. The second is where each element of the diagonal is a pytree, i.e. it's really a block diagonal operator.
If I understand your case correctly, I think you're in the first one? If so then I think that should be fairly straightforward to implement: pretty much copy-paste the existing DiagonalLinearOperator implementation, where the internal storage is now a pytree, mv is given by tree_map(operator.mul, ...) (I think), and as_matrix involves allocating a zero output matrix whose diagonal is filled in via .at[].set(). Does that make sense?
FWIW I'm not sure what you mean by 'redundant entries' in jax.hessian? Can you give an example?
FWIW I'm not sure what you mean by 'redundant entries' in jax.hessian? Can you give an example?
Making an MWE I think I got to the bottom of this - I was taking a hessian of a function that accidentally returned a list of two values, rather than a scalar. This caused my matrix sizes to go to (n, 2n), for n elements in a pytree.
If I understand your case correctly, I think you're in the first one? ... Does that make sense?
Hi Patrick,
I'm implementing the interior-reflective Newton step method of Coleman and Li, for bound-constrained optimisation, in optimistix.
The core of this thing is to scale and translate the Hessian, such that
defines the operator for the Newton step, where the diagonal elements of$D_k$ are defined by a simple function on $\pm1$ depending on some conditions on the gradient, all computed element-wise). For this to work for a general PyTree, I believe (?) the best thing to do would be to define a
y
(distance to bounds orDiagonalPyTreeLinearOperator
.I noticed that the Hessian approximation we're computing in
optx.BFGS
is actually "square", by which I mean that they do not have the redundant entries one finds, for example, injax.hessian
, andas_matrix
returns a square matrix (this is not true forjax.hessian
). So maybe I'm making things more complicated than they need to be?The text was updated successfully, but these errors were encountered: