You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here the implicit object returns a tuple (y, z) where y is the actually interesting part and we don't differentiate wrt z (this byproduct is ignored in the pullback, defined here).
When I pass dz=ZeroTangent() manually, the pullback succeeds
When ChainRulesTestUtils does it, the pullback fails due to a dimension mismatch
julia>using ChainRulesCore
julia>using ChainRulesTestUtils
julia>using ImplicitDifferentiation # use the main branch
julia>using Zygote
julia>forward(x) =sqrt.(abs.(x)), 2;
julia>conditions(x, y, z) =abs.(y) .^ z .-abs.(x);
julia> implicit =ImplicitFunction(forward, conditions)
ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing)
julia> x =rand(3)
3-element Vector{Float64}:0.151820388092156140.47245582771176310.6889531602780976
julia> y, z =implicit(x)
([0.3896413582926691, 0.6873542228805779, 0.8300320236461347], 2)
julia> dy =similar(y);
julia> dy .=1;
julia> rc = Zygote.ZygoteRuleConfig();
julia> _, back =rrule_via_ad(rc, implicit, x);
julia>back((dy, NoTangent()))
(NoTangent(), [1.2832313340424137, 0.7274269704848686, 0.6023863968568562])
julia>test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, 0)) # ok
Test Summary:| Pass Total Time
test_rrule:ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64} |770.0s
Test.DefaultTestSet("test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64}", Any[], 7, false, false, true, 1.691675539371241e9, 1.691675539408112e9, false, "/home/guillaume/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl")
julia>test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, NoTangent())) # not ok
test_rrule:ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64}: Error During Test at /home/guillaume/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:202
Got exception outside of a @test
DimensionMismatch: second dimension of A, 4, does not match length of x, 3
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:404
[2] generic_matvecmul!
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:71 [inlined]
[3] mul!
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
[4] mul!
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
[5] *(A::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:53
[6] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/BSR84/src/grad.jl:84
[7] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{@NamedTuple{}}, Tuple{ImplicitFunction{typeof(forward), typeof(conditions), IterativeLinearSolver, Nothing}, Vector{Float64}}, Tuple{Bool, Bool}}, ȳ::Tuple{Vector{Float64}, NoTangent}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/BSR84/src/grad.jl:77
[8] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/finite_difference_calls.jl:51
[9] macro expansion
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:233 [inlined]
[10] macro expansion
@ ChainRulesTestUtils ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[11] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, testset_name::Any, kwargs...)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:205
Here the
implicit
object returns a tuple(y, z)
wherey
is the actually interesting part and we don't differentiate wrtz
(this byproduct is ignored in the pullback, defined here).dz=ZeroTangent()
manually, the pullback succeedsThis is the broken test added in #111.
The text was updated successfully, but these errors were encountered: