diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 572b13b..8570898 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -5,6 +5,7 @@ We demonstrate several features that may come in handy for some users. =# using ComponentArrays +using Enzyme #src using ForwardDiff using ImplicitDifferentiation using Krylov @@ -67,6 +68,8 @@ J = ForwardDiff.jacobian(forward_components, x) #src Zygote.jacobian(implicit_components, x)[1] @test Zygote.jacobian(implicit_components, x)[1] ≈ J #src +@test Enzyme.jacobian(Enzyme.Forward, implicit_components, x) ≈ J #src + #- The full differentiable pipeline looks like this function full_pipeline(a, b, m) diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl index 9d1d595..4f2f257 100644 --- a/ext/ImplicitDifferentiationEnzymeExt.jl +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -26,7 +26,7 @@ function EnzymeRules.forward( B = build_B(implicit, x, y_or_yz, args...; suggested_backend) dx_batch = reduce(hcat, dx) - dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx + dc_batch = mapreduce(hcat, dx_batch) do dₖx B * X(dₖx) end dy_batch = implicit.linear_solver(A, -dc_batch) @@ -52,47 +52,4 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( - func::Const{<:ImplicitFunction}, - RT::Type{<:Union{Duplicated,DuplicatedNoNeed}}, - func_x::Union{Duplicated{T},DuplicatedNoNeed{T}}, - func_args::Vararg{Const,P}, -) where {T,P} - implicit = func.val - x = func_x.val - dx = func_x.dval - args = map(a -> a.val, func_args) - - y_or_yz = implicit(x, args...) - y = output(y_or_yz) - Y = typeof(y) - - suggested_backend = AutoEnzyme(Enzyme.Forward) - A = build_A(implicit, x, y_or_yz, args...; suggested_backend) - B = build_B(implicit, x, y_or_yz, args...; suggested_backend) - - dc = B * dx - - dy = convert(Y, implicit.linear_solver(A, -dc)) - - - if y_or_yz isa AbstractArray - if RT <: Duplicated - return Duplicated(y, dy) - elseif RT <: DuplicatedNoNeed - return dy - end - elseif y_or_yz isa Tuple - yz = y_or_yz - z = byproduct(yz) - Z = typeof(z) - dyz::NTuple{N,Tuple{Y,Z}} = ntuple(k -> (dy[k], make_zero(z)), Val(N)) - if RT <: Duplicated - return Duplicated(yz, dyz) - elseif RT <: DuplicatedNoNeed - return dyz - end - end -end - end