Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with second derivative #26

Closed
rveltz opened this issue Feb 15, 2023 · 23 comments · Fixed by #31
Closed

Issue with second derivative #26

rveltz opened this issue Feb 15, 2023 · 23 comments · Fixed by #31
Labels
feature New feature or request

Comments

@rveltz
Copy link

rveltz commented Feb 15, 2023

Hi,
I have been trying to use ImplicitDifferentiation.jl for higher order jvp but I was not lucky.

I want to use it for BifurcationKit.jl, where I need 3rd order jvp

Basic example:

using ImplicitDifferentiation
using Optim
using Random
using Zygote


Random.seed!(63)

function dumb_identity(x)
    f(y) = sum(abs2, y-x)
    y0 = zero(x)
    res = optimize(f, y0, LBFGS(); autodiff=:forward)
    y = Optim.minimizer(res)
    return y
end;

zero_gradient(x, y) = 2(y - x);
implicit = ImplicitFunction(dumb_identity, zero_gradient);
x = rand(3, 2)
h = rand(3, 2)
D(x,h) = Zygote.jacobian(t->implicit(x .+ t .* h), 0)[1]
D(x,h) # works
D2(x,h1,h2) = Zygote.jacobian(t->D(x .+ t .* h2,h1), 0)[1]
D2(x,h,h) # does not work

I also have this code with ForwardDiff. Not sure the problem is the same

using ForwardDiffChainRules, ForwardDiff
@ForwardDiff_frule (f::typeof(implicit))(x::AbstractMatrix{<:ForwardDiff.Dual})
D(x,h) = ForwardDiff.derivative(t->implicit(x .+ t .* h), 0)
D(x,h) # works
D2(x,h1,h2) = ForwardDiff.derivative(t->D(x .+ t .* h2,h1), 0)
D2(x,h,h) # does not work
@mohamed82008
Copy link
Collaborator

what's the error?

@rveltz
Copy link
Author

rveltz commented Feb 15, 2023

for zygote:

julia> D2(x,h,h) # does not work
ERROR: Compiling Tuple{typeof(Optim.perform_linesearch!), Optim.LBFGSState{Matrix{Float64}, Vector{Matrix{Float64}}, Vector{Matrix{Float64}}, Float64, Matrix{Float64}}, LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Optim.var"#19#21"}, Optim.ManifoldObjective{OnceDifferentiable{Float64, Matrix{Float64}, Matrix{Float64}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

for FoD:


julia> D2(x,h,h) # does not work
ERROR: MethodError: no method matching gmres(::LinearOperators.LinearOperator{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}, Int64, ImplicitDifferentiation.var"#mul_A!#11"{ImplicitDifferentiation.var"#pushforward_A#9"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ForwardDiffChainRules.ForwardDiffRuleConfig, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, Nothing, Nothing, Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}}}, ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#29#30"{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Int64}, Float64, 1}})

it seems the second differential bypasses the machinery of ImplicitDifferentiation

@rveltz
Copy link
Author

rveltz commented Mar 1, 2023

Do you think it requires to implement the second differential in the implicit theorem or the algo should recursively call the first differential?

@gdalle
Copy link
Member

gdalle commented Mar 1, 2023

Do you think it requires to implement the second differential in the implicit theorem or the algo should recursively call the first differential?

I've been asking myself the same question for a project of mine. The problem is, ChainRulesCore only supports specification of first derivatives, so I'm not sure how to go about it in a way that's easy to use.

@gdalle
Copy link
Member

gdalle commented Mar 1, 2023

it seems the second differential bypasses the machinery of ImplicitDifferentiation

That's definitely worth investigating. Maybe we should define a rule on the pushforward / pullback?

@rveltz
Copy link
Author

rveltz commented Mar 1, 2023

I add the link to the discourse post

@mohamed82008
Copy link
Collaborator

The problem is, ChainRulesCore only supports specification of first derivatives, so I'm not sure how to go about it in a way that's easy to use.

An example is here https://github.com/JuliaNonconvex/NonconvexUtils.jl/blob/main/src/custom.jl

@mohamed82008
Copy link
Collaborator

Changing this line https://github.com/gdalle/ImplicitDifferentiation.jl/blob/1581d2e3b1b1ddf083f0d370c9f7b323aa98f610/src/implicit_function.jl#L64 to

y = implicit(x; kwargs...)

and using a linear solver that's compatible with ForwardDiff makes the FD case work:

linear_solver(A, x) = (Matrix(A) \ x, (solved = true,))
implicit = ImplicitFunction(dumb_identity, zero_gradient, linear_solver);

Zygote still fails but I suspect that's because Zygote.jacobian mutates and is tripping Zygote's second differentiation. I have run into this issue in the past and partially avoided it by defining a version of jacobian that's not mutating. Perhaps using AbstractDifferentiation's implementation would be better here. Anyways, I will leave the rest of it to you :)

@gdalle
Copy link
Member

gdalle commented Mar 1, 2023

Changing this line

Seems easy enough to fix in the source code, should we?

@mohamed82008
Copy link
Collaborator

Seems easy enough to fix in the source code, should we?

Let's

gdalle added a commit that referenced this issue Mar 1, 2023
@gdalle
Copy link
Member

gdalle commented Mar 10, 2023

Random thought: the fact that we need a dual-compatible linear solver means we still autodiff through the iterations of the solver. Is it possible to avoid that altogether for second-order?

@mohamed82008
Copy link
Collaborator

In theory we could. That would require defining the linear solver as an implicit function and using ForwardDiffChainRules on that. We would need to think of a good way to make this work for direct and iterative linear solvers.

@mohamed82008
Copy link
Collaborator

Do you know if the Python package has higher order implicit derivates? If not, this could be an interesting conference paper.

@gdalle
Copy link
Member

gdalle commented Mar 10, 2023

Do you know if the Python package has higher order implicit derivatives?

No it only focuses on first-order differentiation. But it outsources the actual autodiff to JAX, which may be better at second-order stuff than Zygote for example.

@gdalle
Copy link
Member

gdalle commented Mar 10, 2023

That would require defining the linear solver as an implicit function and using ForwardDiffChainRules on that.

My thought was actually to differentiate the implicit function theorem a second time, like this.

@gdalle gdalle reopened this Mar 10, 2023
@gdalle gdalle added the feature New feature or request label Mar 10, 2023
@mohamed82008
Copy link
Collaborator

I believe what the derivations in this link are doing is going to be computationally identical to making our linear solver an implicit function and letting ImplicitDifferentiation automate the rest for us. They just defined the higher order rule manually but then it gets ugly with multiple inputs and outputs, Hessian rules for functions, etc. We have a nice abstraction which can be nested, let's take advantage of that.

@gdalle
Copy link
Member

gdalle commented Mar 10, 2023

what the derivations in this link are doing is going to be computationally identical to making our linear solver an implicit function

I believe you're right. Although I think the SciML solvers from https://github.com/SciML/LinearSolve.jl are already differentiable with a similar implicit machinery, it can't hurt to roll out our own simpler version

@mohamed82008
Copy link
Collaborator

I recall you tried LinearSolve before but didn't end up using it because it was still not mature enough. Maybe we can revisit it or roll our own, shouldn't be too hard for simple solvers.

@mohamed82008
Copy link
Collaborator

I don't think we need that many solvers so direct and GMRES should be enough.

@gdalle
Copy link
Member

gdalle commented Mar 10, 2023

I recall you tried LinearSolve before but didn't end up using it because it was still not mature enough

It should definitely be more mature now but my main beef is that it has a truckload of dependencies, which cannot be made conditional until Julia 1.9

https://github.com/SciML/LinearSolve.jl/blob/main/Project.toml

@gdalle gdalle changed the title issue with second derivative Issue with second derivative Jul 30, 2023
@mohamed82008
Copy link
Collaborator

Now (in main), you can pass DirectLinearSolver() as the linear_solver in the implicit function to do second order differentiation. For the more general case of making the linear solver an implicit function, I think we should open another issue.

@gdalle
Copy link
Member

gdalle commented Jul 30, 2023

See #77

@gdalle
Copy link
Member

gdalle commented Feb 21, 2024

Closing in favor of #77

@gdalle gdalle closed this as completed Feb 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants