Skip to content

Commit

Permalink
Adapt to new Enzyme and DifferentiationInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 26, 2024
1 parent 6195cd3 commit aada944
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 109 deletions.
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.6.1"
version = "0.6.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -21,14 +21,14 @@ ImplicitDifferentiationEnzymeExt = "Enzyme"
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"

[compat]
ADTypes = "1.7.1"
ChainRulesCore = "1.23.0"
DifferentiationInterface = "0.5.12"
Enzyme = "0.11.20,0.12"
ADTypes = "1.9.0"
ChainRulesCore = "1.25.0"
DifferentiationInterface = "0.6.1"
Enzyme = "0.13.3"
ForwardDiff = "0.10.36"
Krylov = "0.9.5"
Krylov = "0.9.6"
LinearAlgebra = "1.10"
LinearOperators = "2.7.0"
LinearOperators = "2.8.0"
julia = "1.10"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion ext/ImplicitDifferentiationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ module ImplicitDifferentiationEnzymeExt

using ADTypes
using Enzyme
using Enzyme.EnzymeCore
using Enzyme.EnzymeRules
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output

const FORWARD_BACKEND = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)

function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
func::Const{<:ImplicitFunction},
RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}},
func_x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}},
Expand Down
6 changes: 4 additions & 2 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ module ImplicitDifferentiation

using ADTypes: AbstractADType
using DifferentiationInterface:
Constant,
jacobian,
prepare_pushforward_same_point,
prepare_pullback_same_point,
pullback!,
pushforward!
pushforward!,
unwrap
using Krylov: block_gmres, gmres
using LinearOperators: LinearOperator
using LinearAlgebra: factorize, lu
using LinearAlgebra: axpby!, factorize, lu

include("implicit_function.jl")
include("operators.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ end

output(y::AbstractVector) = y
byproduct(::AbstractVector) = error("No byproduct")
rest(::AbstractVector) = ()

output(yz::Tuple{<:Any,<:Any}) = yz[1]
byproduct(yz::Tuple{<:Any,<:Any}) = yz[2]
rest(yz::Tuple) = (byproduct(yz),)

output((y, z)) = y
byproduct((y, z)) = z
163 changes: 65 additions & 98 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,112 +1,75 @@
## Partial conditions

struct ConditionsXNoByproduct{C,Y,A,K}
struct ConditionsX{C,K}
conditions::C
y::Y
args::A
kwargs::K
end

function (conditions_x_nobyproduct::ConditionsXNoByproduct)(x::AbstractVector)
(; conditions, y, args, kwargs) = conditions_x_nobyproduct
return conditions(x, y, args...; kwargs...)
end

struct ConditionsYNoByproduct{C,X,A,K}
struct ConditionsY{C,K}
conditions::C
x::X
args::A
kwargs::K
end

function (conditions_y_nobyproduct::ConditionsYNoByproduct)(y::AbstractVector)
(; conditions, x, args, kwargs) = conditions_y_nobyproduct
return conditions(x, y, args...; kwargs...)
function (cx::ConditionsX)(x, y, args...)
return cx.conditions(x, y, args...; cx.kwargs...)
end

struct ConditionsXByproduct{C,Y,Z,A,K}
conditions::C
y::Y
z::Z
args::A
kwargs::K
end

function (conditions_x_byproduct::ConditionsXByproduct)(x::AbstractVector)
(; conditions, y, z, args, kwargs) = conditions_x_byproduct
return conditions(x, y, z, args...; kwargs...)
function (cy::ConditionsY)(y, x, args...) # order switch
return cy.conditions(x, y, args...; cy.kwargs...)
end

struct ConditionsYByproduct{C,X,Z,A,K}
conditions::C
struct PushforwardOperator!{F,P,B,X,C,R}
f::F
prep::P
backend::B
x::X
z::Z
args::A
kwargs::K
end

function (conditions_y_byproduct::ConditionsYByproduct)(y::AbstractVector)
(; conditions, x, z, args, kwargs) = conditions_y_byproduct
return conditions(x, y, z, args...; kwargs...)
end

function ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
y = output(y_or_yz)
if y_or_yz isa Tuple
z = byproduct(y_or_yz)
return ConditionsXByproduct(conditions, y, z, args, kwargs)
else
return ConditionsXNoByproduct(conditions, y, args, kwargs)
end
end

function ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
if y_or_yz isa Tuple
z = byproduct(y_or_yz)
return ConditionsYByproduct(conditions, x, z, args, kwargs)
else
return ConditionsYNoByproduct(conditions, x, args, kwargs)
end
contexts::C
res_backup::R
end

## Lazy operators

struct PushforwardOperator!{F,B,X,E,R}
struct PullbackOperator!{F,P,B,X,C,R}
f::F
prep::P
backend::B
x::X
extras::E
contexts::C
res_backup::R
end

function PushforwardOperator!(f, prep, backend, x, contexts)
res_backup = similar(f(x, map(unwrap, contexts)...))
return PushforwardOperator!(f, prep, backend, x, contexts, res_backup)
end

function PullbackOperator!(f, prep, backend, x, contexts)
res_backup = similar(x)
return PullbackOperator!(f, prep, backend, x, contexts, res_backup)
end

function (po::PushforwardOperator!)(res, v, α, β)
(; f, backend, x, contexts, prep, res_backup) = po
if iszero(β)
pushforward!(po.f, res, po.backend, po.x, v, po.extras)
res .= α .* res
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
if !isone(α)
res .*= α

Check warning on line 52 in src/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/operators.jl#L52

Added line #L52 was not covered by tests
end
else
po.res_backup .= res
pushforward!(po.f, res, po.backend, po.x, v, po.extras)
res .= α .* res .+ β .* po.res_backup
copyto!(res_backup, res)
pushforward!(f, (res,), prep, backend, x, (v,), contexts...)
axpby!(β, res_backup, α, res)

Check warning on line 57 in src/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/operators.jl#L55-L57

Added lines #L55 - L57 were not covered by tests
end
return res
end

struct PullbackOperator!{F,B,X,E,R}
f::F
backend::B
x::X
extras::E
res_backup::R
end

function (po::PullbackOperator!)(res, v, α, β)
(; f, backend, x, contexts, prep, res_backup) = po
if iszero(β)
pullback!(po.f, res, po.backend, po.x, v, po.extras)
res .= α .* res
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
if !isone(α)
res .*= α

Check warning on line 67 in src/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/operators.jl#L67

Added line #L67 was not covered by tests
end
else
po.res_backup .= res
pullback!(po.f, res, po.backend, po.x, v, po.extras)
res .= α .* res .+ β .+ po.res_backup
copyto!(res_backup, res)
pullback!(f, (res,), prep, backend, x, (v,), contexts...)
axpby!(β, res_backup, α, res)

Check warning on line 72 in src/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/operators.jl#L70-L72

Added lines #L70 - L72 were not covered by tests
end
return res
end
Expand All @@ -119,24 +82,25 @@ function build_A(
suggested_backend,
kwargs...,
) where {lazy}
(; conditions, linear_solver, conditions_y_backend) = implicit
(; conditions, conditions_y_backend) = implicit
y = output(y_or_yz)
n, m = length(x), length(y)
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
cond_y = ConditionsY(conditions, kwargs)
contexts = (Constant(x), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
if lazy
extras = prepare_pushforward_same_point(cond_y, back_y, y, zero(y))
prep = prepare_pushforward_same_point(cond_y, back_y, y, (zero(y),), contexts...)
A = LinearOperator(
eltype(y),
m,
m,
false,
false,
PushforwardOperator!(cond_y, back_y, y, extras, similar(y)),
PushforwardOperator!(cond_y, prep, back_y, y, contexts),
typeof(y),
)
else
J = jacobian(cond_y, back_y, y)
J = jacobian(cond_y, back_y, y, contexts...)
A = factorize(J)
end
return A
Expand All @@ -150,24 +114,25 @@ function build_Aᵀ(
suggested_backend,
kwargs...,
) where {lazy}
(; conditions, linear_solver, conditions_y_backend) = implicit
(; conditions, conditions_y_backend) = implicit
y = output(y_or_yz)
n, m = length(x), length(y)
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
cond_y = ConditionsY(conditions, kwargs)
contexts = (Constant(x), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
if lazy
extras = prepare_pullback_same_point(cond_y, back_y, y, zero(y))
prep = prepare_pullback_same_point(cond_y, back_y, y, (zero(y),), contexts...)
Aᵀ = LinearOperator(
eltype(y),
m,
m,
false,
false,
PullbackOperator!(cond_y, back_y, y, extras, similar(y)),
PullbackOperator!(cond_y, prep, back_y, y, contexts),
typeof(y),
)
else
Jᵀ = transpose(jacobian(cond_y, back_y, y))
Jᵀ = transpose(jacobian(cond_y, back_y, y, contexts...))
Aᵀ = factorize(Jᵀ)
end
return Aᵀ
Expand All @@ -181,24 +146,25 @@ function build_B(
suggested_backend,
kwargs...,
) where {lazy}
(; conditions, linear_solver, conditions_x_backend) = implicit
(; conditions, conditions_x_backend) = implicit
y = output(y_or_yz)
n, m = length(x), length(y)
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
cond_x = ConditionsX(conditions, kwargs)
contexts = (Constant(y), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
if lazy
extras = prepare_pushforward_same_point(cond_x, back_x, x, zero(x))
prep = prepare_pushforward_same_point(cond_x, back_x, x, (zero(x),), contexts...)
B = LinearOperator(
eltype(y),
m,
n,
false,
false,
PushforwardOperator!(cond_x, back_x, x, extras, similar(y)),
PushforwardOperator!(cond_x, prep, back_x, x, contexts),
typeof(x),
)
else
B = transpose(jacobian(cond_x, back_x, x))
B = transpose(jacobian(cond_x, back_x, x, contexts...))
end
return B
end
Expand All @@ -211,24 +177,25 @@ function build_Bᵀ(
suggested_backend,
kwargs...,
) where {lazy}
(; conditions, linear_solver, conditions_x_backend) = implicit
(; conditions, conditions_x_backend) = implicit
y = output(y_or_yz)
n, m = length(x), length(y)
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
cond_x = ConditionsX(conditions, kwargs)
contexts = (Constant(y), map(Constant, rest(y_or_yz))..., map(Constant, args)...)
if lazy
extras = prepare_pullback_same_point(cond_x, back_x, x, zero(y))
prep = prepare_pullback_same_point(cond_x, back_x, x, (zero(y),), contexts...)
Bᵀ = LinearOperator(
eltype(y),
n,
m,
false,
false,
PullbackOperator!(cond_x, back_x, x, extras, similar(x)),
PullbackOperator!(cond_x, prep, back_x, x, contexts),
typeof(x),
)
else
Bᵀ = transpose(jacobian(cond_x, back_x, x))
Bᵀ = transpose(jacobian(cond_x, back_x, x, contexts...))
end
return Bᵀ
end
2 changes: 1 addition & 1 deletion test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ conditions_backend_candidates = (

x_candidates = (
Float32[3, 4], #
MVector{2}(Float32[3, 4]), #
# MVector{2}(Float32[3, 4]), #
);

## Test loop
Expand Down

0 comments on commit aada944

Please sign in to comment.