Skip to content

Commit

Permalink
Remove Duplicated and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Aug 7, 2024
1 parent b0754df commit 41ccc23
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 44 deletions.
3 changes: 3 additions & 0 deletions examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ We demonstrate several features that may come in handy for some users.
=#

using ComponentArrays
using Enzyme #src
using ForwardDiff
using ImplicitDifferentiation
using Krylov
Expand Down Expand Up @@ -67,6 +68,8 @@ J = ForwardDiff.jacobian(forward_components, x) #src
Zygote.jacobian(implicit_components, x)[1]
@test Zygote.jacobian(implicit_components, x)[1] J #src

@test Enzyme.jacobian(Enzyme.Forward, implicit_components, x) J #src

#- The full differentiable pipeline looks like this

function full_pipeline(a, b, m)
Expand Down
45 changes: 1 addition & 44 deletions ext/ImplicitDifferentiationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function EnzymeRules.forward(
B = build_B(implicit, x, y_or_yz, args...; suggested_backend)

dx_batch = reduce(hcat, dx)
dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx
dc_batch = mapreduce(hcat, dx_batch) do dₖx
B * X(dₖx)
end
dy_batch = implicit.linear_solver(A, -dc_batch)
Expand All @@ -52,47 +52,4 @@ function EnzymeRules.forward(
end
end

function EnzymeRules.forward(
func::Const{<:ImplicitFunction},
RT::Type{<:Union{Duplicated,DuplicatedNoNeed}},
func_x::Union{Duplicated{T},DuplicatedNoNeed{T}},
func_args::Vararg{Const,P},
) where {T,P}
implicit = func.val
x = func_x.val
dx = func_x.dval
args = map(a -> a.val, func_args)

y_or_yz = implicit(x, args...)
y = output(y_or_yz)
Y = typeof(y)

suggested_backend = AutoEnzyme(Enzyme.Forward)
A = build_A(implicit, x, y_or_yz, args...; suggested_backend)
B = build_B(implicit, x, y_or_yz, args...; suggested_backend)

dc = B * dx

dy = convert(Y, implicit.linear_solver(A, -dc))


if y_or_yz isa AbstractArray
if RT <: Duplicated
return Duplicated(y, dy)
elseif RT <: DuplicatedNoNeed
return dy
end
elseif y_or_yz isa Tuple
yz = y_or_yz
z = byproduct(yz)
Z = typeof(z)
dyz::NTuple{N,Tuple{Y,Z}} = ntuple(k -> (dy[k], make_zero(z)), Val(N))
if RT <: Duplicated
return Duplicated(yz, dyz)
elseif RT <: DuplicatedNoNeed
return dyz
end
end
end

end

0 comments on commit 41ccc23

Please sign in to comment.