Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Enzyme extension and add new broken test #151

Merged
merged 8 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.6.0"
version = "0.6.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -21,9 +21,9 @@ ImplicitDifferentiationEnzymeExt = "Enzyme"
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"

[compat]
ADTypes = "1.0"
ADTypes = "1.7.1"
ChainRulesCore = "1.23.0"
DifferentiationInterface = "0.5"
DifferentiationInterface = "0.5.12"
Enzyme = "0.11.20,0.12"
ForwardDiff = "0.10.36"
Krylov = "0.9.5"
Expand Down
3 changes: 3 additions & 0 deletions examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ We demonstrate several features that may come in handy for some users.
=#

using ComponentArrays
using Enzyme #src
using ForwardDiff
using ImplicitDifferentiation
using Krylov
Expand Down Expand Up @@ -67,6 +68,8 @@ J = ForwardDiff.jacobian(forward_components, x) #src
Zygote.jacobian(implicit_components, x)[1]
@test Zygote.jacobian(implicit_components, x)[1] ≈ J #src

@test_broken Enzyme.jacobian(Enzyme.Forward, implicit_components, x) ≈ J #src

#- The full differentiable pipeline looks like this

function full_pipeline(a, b, m)
Expand Down
7 changes: 4 additions & 3 deletions ext/ImplicitDifferentiationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using Enzyme
using Enzyme.EnzymeCore
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output

const FORWARD_BACKEND = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)

function EnzymeRules.forward(
func::Const{<:ImplicitFunction},
RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}},
Expand All @@ -20,12 +22,11 @@ function EnzymeRules.forward(
y = output(y_or_yz)
Y = typeof(y)

suggested_backend = AutoEnzyme(Enzyme.Forward)
suggested_backend = FORWARD_BACKEND
A = build_A(implicit, x, y_or_yz, args...; suggested_backend)
B = build_B(implicit, x, y_or_yz, args...; suggested_backend)

dx_batch = reduce(hcat, dx)
dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx
dc_batch = mapreduce(hcat, dx) do dₖx
B * dₖx
end
dy_batch = implicit.linear_solver(A, -dc_batch)
Expand Down
41 changes: 23 additions & 18 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ The value of `lazy` must be chosen together with the `linear_solver`, see below.
- `conditions_x_backend`: how the conditions will be differentiated w.r.t. the first argument `x`
- `conditions_y_backend`: how the conditions will be differentiated w.r.t. the second argument `y`

# Constructors

ImplicitFunction(
forward, conditions;
linear_solver=KrylovLinearSolver(),
conditions_x_backend=nothing,
conditions_x_backend=nothing,
)

Picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic: `lazy = linear_solver != \\`.

ImplicitFunction{lazy}(
forward, conditions;
linear_solver=lazy ? KrylovLinearSolver() : \\,
conditions_x_backend=nothing,
conditions_y_backend=nothing,
)

Picks the `linear_solver` automatically based on the `lazy` parameter.

# Function signatures

There are two possible signatures for `forward` and `conditions`, which must be consistent with one another:
Expand Down Expand Up @@ -87,8 +107,10 @@ Typically, direct solvers work best with dense Jacobians (`lazy = false`) while
# Condition backends

The provided `conditions_x_backend` and `conditions_y_backend` can be either:
- `nothing` (the default), in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used.
- an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl);
- `nothing`, in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used.

When differentiating with Enzyme as an outer backend, the default setting assumes that `conditions` does not contain writeable data involved in derivatives.
"""
struct ImplicitFunction{
lazy,F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType}
Expand All @@ -101,14 +123,7 @@ struct ImplicitFunction{
end

"""
ImplicitFunction{lazy}(
forward, conditions;
linear_solver=lazy ? KrylovLinearSolver() : \\,
conditions_x_backend=nothing,
conditions_y_backend=nothing,
)

Constructor for an [`ImplicitFunction`](@ref) which picks the `linear_solver` automatically based on the `lazy` parameter.
"""
function ImplicitFunction{lazy}(
forward::F,
Expand All @@ -126,16 +141,6 @@ function ImplicitFunction{lazy}(
)
end

"""
ImplicitFunction(
forward, conditions;
linear_solver=KrylovLinearSolver(),
conditions_x_backend=nothing,
conditions_x_backend=nothing,
)

Constructor for an [`ImplicitFunction`](@ref) which picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic: `lazy = linear_solver != \\`.
"""
function ImplicitFunction(
forward,
conditions;
Expand Down
4 changes: 2 additions & 2 deletions test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include("utils.jl")

backends = [
AutoForwardDiff(; chunksize=1), #
AutoEnzyme(Enzyme.Forward),
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
AutoZygote(),
]

Expand All @@ -24,7 +24,7 @@ linear_solver_candidates = (
conditions_backend_candidates = (
nothing, #
AutoForwardDiff(; chunksize=1),
AutoEnzyme(Enzyme.Forward),
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
);

x_candidates = (
Expand Down
Loading