From e89a038f7acbffb2c1221355c194c464f601bfc4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 19:04:13 +0100 Subject: [PATCH] Apply formatter --- bench/run_benchmarks.jl | 65 +-- docs/make.jl | 22 +- ext/MooncakeAllocCheckExt.jl | 2 +- ext/MooncakeCUDAExt.jl | 6 +- ext/MooncakeDynamicPPLExt.jl | 2 +- ext/MooncakeLuxLibExt.jl | 45 +- ext/MooncakeLuxLibSLEEFPiratesExtension.jl | 60 ++- ext/MooncakeNNlibExt.jl | 40 +- ext/MooncakeSpecialFunctionsExt.jl | 72 ++-- src/Mooncake.jl | 31 +- src/codual.jl | 22 +- src/config.jl | 4 +- src/debug_mode.jl | 4 +- src/developer_tools.jl | 8 +- src/fwds_rvs_data.jl | 104 +++-- src/interface.jl | 24 +- src/interpreter/abstract_interpretation.jl | 80 ++-- src/interpreter/bbcode.jl | 80 ++-- src/interpreter/ir_normalisation.jl | 89 ++-- src/interpreter/ir_utils.jl | 19 +- src/interpreter/s2s_reverse_mode_ad.jl | 180 ++++---- src/interpreter/zero_like_rdata.jl | 2 +- src/rrules/array_legacy.jl | 201 +++++---- .../avoiding_non_differentiable_code.jl | 39 +- src/rrules/blas.jl | 400 ++++++++++-------- src/rrules/builtins.jl | 258 +++++++---- src/rrules/fastmath.jl | 46 +- src/rrules/foreigncall.jl | 198 +++++---- src/rrules/function_wrappers.jl | 64 +-- src/rrules/iddict.jl | 60 ++- src/rrules/lapack.jl | 269 ++++++++---- src/rrules/linear_algebra.jl | 5 +- src/rrules/low_level_maths.jl | 66 +-- src/rrules/memory.jl | 251 ++++++----- src/rrules/misc.jl | 145 ++++--- src/rrules/new.jl | 99 +++-- src/rrules/tasks.jl | 25 +- src/rrules/twice_precision.jl | 183 +++++--- src/tangents.jl | 180 ++++---- src/test_resources.jl | 129 ++++-- src/test_utils.jl | 393 ++++++++++------- src/tools_for_rules.jl | 237 +++++------ src/utils.jl | 55 ++- test/codual.jl | 22 +- test/debug_mode.jl | 8 +- test/developer_tools.jl | 2 +- test/ext/cuda/cuda.jl | 12 +- .../differentiation_interface.jl | 5 +- test/ext/luxlib/luxlib.jl | 75 ++-- test/ext/nnlib/nnlib.jl | 59 ++- test/front_matter.jl | 5 +- test/fwds_rvs_data.jl | 64 +-- test/integration_testing/array/array.jl | 30 +- .../battery_tests/battery_tests.jl | 37 +- .../bijectors/bijectors.jl | 29 +- .../diff_tests/diff_tests.jl | 16 +- .../distributions/distributions.jl | 35 +- test/integration_testing/gp/gp.jl | 8 +- test/integration_testing/lux/lux.jl | 49 ++- .../misc_abstract_array.jl | 146 ++++--- .../temporalgps/temporalgps.jl | 1 - test/integration_testing/turing/turing.jl | 22 +- test/interpreter/abstract_interpretation.jl | 17 +- test/interpreter/bbcode.jl | 37 +- test/interpreter/contexts.jl | 4 +- test/interpreter/ir_normalisation.jl | 22 +- test/interpreter/ir_utils.jl | 16 +- test/interpreter/s2s_reverse_mode_ad.jl | 91 ++-- test/rrules/foreigncall.jl | 27 +- test/rrules/function_wrappers.jl | 4 +- test/rrules/low_level_maths.jl | 22 +- test/rrules/misc.jl | 1 - test/runtests.jl | 2 +- test/tangents.jl | 94 ++-- test/test_utils.jl | 58 ++- test/tools_for_rules.jl | 45 +- test/utils.jl | 8 +- 77 files changed, 3100 insertions(+), 2237 deletions(-) diff --git a/bench/run_benchmarks.jl b/bench/run_benchmarks.jl index a3ff13e78..2858e28db 100644 --- a/bench/run_benchmarks.jl +++ b/bench/run_benchmarks.jl @@ -1,8 +1,7 @@ using Pkg -Pkg.develop(path=joinpath(@__DIR__, "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..")) -using - AbstractGPs, +using AbstractGPs, Chairmarks, CSV, DataFrames, @@ -28,13 +27,13 @@ using Mooncake: using Mooncake.TestUtils: _deepcopy -function to_benchmark(__rrule!!::R, dx::Vararg{CoDual, N}) where {R, N} +function to_benchmark(__rrule!!::R, dx::Vararg{CoDual,N}) where {R,N} dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx) out, pb!! = __rrule!!(dx_f...) return pb!!(Mooncake.zero_rdata(primal(out))) end -function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N} +function zygote_to_benchmark(ctx, x::Vararg{Any,N}) where {N} out, pb = Zygote._pullback(ctx, x...) return pb(out) end @@ -107,7 +106,7 @@ end @model broadcast_demo(x) = begin μ ~ truncated(Normal(1, 2), 0.1, 10) σ ~ truncated(Normal(1, 2), 0.1, 10) - x .~ LogNormal(μ, σ) + x .~ LogNormal(μ, σ) end function build_turing_problem() @@ -122,17 +121,21 @@ function build_turing_problem() return test_function, randn(rng, d) end -run_turing_problem(f::F, x::X) where {F, X} = f(x) +run_turing_problem(f::F, x::X) where {F,X} = f(x) -should_run_benchmark( +function should_run_benchmark( ::Val{:zygote}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x... -) = false -should_run_benchmark( +) + return false +end +function should_run_benchmark( ::Val{:enzyme}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x... -) = false +) + return false +end should_run_benchmark(::Val{:enzyme}, x...) = false -@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N-1)) : x +@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N - 1)) : x large_single_block(x::AbstractVector{<:Real}) = g(x[1], x[2], Val(400)) @@ -168,14 +171,12 @@ function generate_inter_framework_tests() end function benchmark_rules!!(test_case_data, default_ratios, include_other_frameworks::Bool) - test_cases = reduce(vcat, map(first, test_case_data)) memory = map(x -> x[2], test_case_data) ranges = reduce(vcat, map(x -> x[3], test_case_data)) tags = reduce(vcat, map(x -> x[4], test_case_data)) GC.@preserve memory begin return map(enumerate(test_cases)) do (n, args) - @info "$n / $(length(test_cases))", _typeof(args) suite = Dict() @@ -186,7 +187,7 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo () -> primals, primals -> (primals[1], _deepcopy(primals[2:end])), (a -> a[1]((a[2]...))), - _ -> true, + _ -> true; evals=1, ) @@ -199,17 +200,19 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo () -> (rule, coduals), identity, a -> to_benchmark(a[1], a[2]...), - _ -> true, + _ -> true; evals=1, ) if include_other_frameworks - if should_run_benchmark(Val(:zygote), args...) @info "Zygote" suite["zygote"] = @be( - _, _, zygote_to_benchmark($(Zygote.Context()), $primals...), _, - evals=1, + _, + _, + zygote_to_benchmark($(Zygote.Context()), $primals...), + _, + evals = 1, ) end @@ -219,8 +222,11 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo compiled_tape = ReverseDiff.compile(tape) result = map(x -> randn(size(x)), primals[2:end]) suite["rd"] = @be( - _, _, rd_to_benchmark!($result, $compiled_tape, $primals[2:end]), _, - evals=1, + _, + _, + rd_to_benchmark!($result, $compiled_tape, $primals[2:end]), + _, + evals = 1, ) end @@ -228,12 +234,15 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo @info "Enzyme" dup_args = map(x -> Duplicated(x, randn(size(x))), primals[2:end]) suite["enzyme"] = @be( - _, _, autodiff(Reverse, $primals[1], Active, $dup_args...), _, - evals=1, + _, + _, + autodiff(Reverse, $primals[1], Active, $dup_args...), + _, + evals = 1, ) end end - + return combine_results((args, suite), tags[n], ranges[n], default_ratios) end end @@ -319,7 +328,7 @@ well-suited to the numbers typically found in this field. function plot_ratio_histogram!(df::DataFrame) bin = 10.0 .^ (-1.0:0.05:4.0) xlim = extrema(bin) - histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="") + return histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="") end function create_inter_ad_benchmarks() @@ -328,7 +337,7 @@ function create_inter_ad_benchmarks() df = DataFrame(results)[:, [:tag, tools...]] # Plot graph of results. - plt = plot(yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)") + plt = plot(; yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)") for label in string.(tools) plot!(plt, df.tag, df[:, label]; label, marker=:circle, xrotation=45) end @@ -337,7 +346,9 @@ function create_inter_ad_benchmarks() # Write table of results. formatted_cols = map(t -> t => string.(round.(df[:, t]; sigdigits=3)), tools) df_formatted = DataFrame(:Label => df.tag, formatted_cols...) - open(io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true) + return open( + io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true + ) end function main() diff --git a/docs/make.jl b/docs/make.jl index c24fbea32..c6a827a3d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,26 +9,18 @@ DocMeta.setdocmeta!( recursive=true, ) -makedocs( +makedocs(; sitename="Mooncake.jl", format=Documenter.HTML(; - mathengine = Documenter.KaTeX( - Dict( - :macros => Dict( - "\\RR" => "\\mathbb{R}", - ), - ) - ), + mathengine=Documenter.KaTeX(Dict(:macros => Dict("\\RR" => "\\mathbb{R}"))), size_threshold_ignore=[ - joinpath("developer_documentation", "internal_docstrings.md"), + joinpath("developer_documentation", "internal_docstrings.md") ], ), modules=[Mooncake], checkdocs=:none, - plugins=[ - CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"); style=:numeric), - ], - pages = [ + plugins=[CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"); style=:numeric)], + pages=[ "Mooncake.jl" => "index.md", "Understanding Mooncake.jl" => [ joinpath("understanding_mooncake", "introduction.md"), @@ -46,7 +38,7 @@ makedocs( joinpath("developer_documentation", "internal_docstrings.md"), ], "known_limitations.md", - ] + ], ) -deploydocs(repo="github.com/compintell/Mooncake.jl.git", push_preview=true) +deploydocs(; repo="github.com/compintell/Mooncake.jl.git", push_preview=true) diff --git a/ext/MooncakeAllocCheckExt.jl b/ext/MooncakeAllocCheckExt.jl index 9ede7467c..e225e5460 100644 --- a/ext/MooncakeAllocCheckExt.jl +++ b/ext/MooncakeAllocCheckExt.jl @@ -3,6 +3,6 @@ module MooncakeAllocCheckExt using AllocCheck, Mooncake import Mooncake.TestUtils: check_allocs, Shim -@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any, N}}) where {F, N} = f(x...) +@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any,N}}) where {F,N} = f(x...) end diff --git a/ext/MooncakeCUDAExt.jl b/ext/MooncakeCUDAExt.jl index b201c0574..c8ea194ed 100644 --- a/ext/MooncakeCUDAExt.jl +++ b/ext/MooncakeCUDAExt.jl @@ -38,7 +38,7 @@ set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0 _add_to_primal(x::P, y::P, ::Bool) where {P<:CuArray{<:IEEEFloat}} = x + y _diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y _dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y)) -_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y +_scale(x::Float64, y::P) where {T<:IEEEFloat,P<:CuArray{T}} = T(x) * y function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray) k = pointer_from_objref(p) v = pointer_from_objref(t) @@ -55,9 +55,7 @@ end # Basic rules for operating on CuArrays. -@is_primitive( - MinimalCtx, Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N}, -) +@is_primitive(MinimalCtx, Tuple{Type{<:CuArray},UndefInitializer,Vararg{Int,N}} where {N},) function rrule!!( p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}... ) where {P<:CuArray{<:Base.IEEEFloat}} diff --git a/ext/MooncakeDynamicPPLExt.jl b/ext/MooncakeDynamicPPLExt.jl index c8184728e..fed0970bd 100644 --- a/ext/MooncakeDynamicPPLExt.jl +++ b/ext/MooncakeDynamicPPLExt.jl @@ -4,6 +4,6 @@ using DynamicPPL: DynamicPPL, istrans using Mooncake: Mooncake # This is purely an optimisation. -Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg} +Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} end # module diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index b02277974..d3c53f134 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -5,20 +5,16 @@ using Base: IEEEFloat import LuxLib: Impl import LuxLib.Utils: static_training_mode_check -import Mooncake: - @from_rrule, - DefaultCtx, - @mooncake_overlay, - CoDual +import Mooncake: @from_rrule, DefaultCtx, @mooncake_overlay, CoDual -@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) +@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul),Array{P},Array{P}} where {P<:IEEEFloat}) @from_rrule( DefaultCtx, - Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, + Tuple{typeof(Impl.matmuladd),Array{P},Array{P},Vector{P}} where {P<:IEEEFloat}, ) @from_rrule( DefaultCtx, - Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, + Tuple{typeof(Impl.batched_matmul),Array{P,3},Array{P,3}} where {P<:IEEEFloat}, ) # Re-implement a bunch of methods to ensure that Mooncake can differentiate them. @@ -35,15 +31,15 @@ end @mooncake_overlay function LuxLib.Impl.fused_conv( ::LuxLib.Impl.AbstractInternalArrayOpMode, act::F, - weight::AbstractArray{wT, N}, - x::AbstractArray{xT, N}, + weight::AbstractArray{wT,N}, + x::AbstractArray{xT,N}, bias::LuxLib.Optional{<:AbstractVector}, cdims::LuxLib.Impl.ConvDims, -) where {F, wT, xT, N} +) where {F,wT,xT,N} return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) end -Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check), Vararg} +Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check),Vararg} # This is a really horrible hack that we need to do until Mooncake is able to support the # call-back-into-ad interface that ChainRules exposes. @@ -61,18 +57,18 @@ function CRC.rrule( ::typeof(batchnorm_affine_normalize_internal), opmode::AbstractInternalArrayOpMode, ::typeof(identity), - x::AbstractArray{T, N}, + x::AbstractArray{T,N}, μ::AbstractVector, σ²::AbstractVector, γ::LuxLib.Optional{<:AbstractVector}, β::LuxLib.Optional{<:AbstractVector}, ϵ::Real, -) where {T, N} +) where {T,N} y = similar( x, promote_type( safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) - ) + ), ) γ′ = similar( x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1) @@ -111,13 +107,13 @@ end @mooncake_overlay function batchnorm_affine_normalize_internal( opmode::LuxLib.AbstractInternalArrayOpMode, act::F, - x::AbstractArray{xT, 3}, + x::AbstractArray{xT,3}, μ::AbstractVector, σ²::AbstractVector, - γ::Union{Nothing, AbstractVector}, - β::Union{Nothing, AbstractVector}, + γ::Union{Nothing,AbstractVector}, + β::Union{Nothing,AbstractVector}, ϵ::Real, -) where {F, xT} +) where {F,xT} y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ) LuxLib.Impl.activation!(y, opmode, act, y) return y @@ -126,17 +122,18 @@ end @mooncake_overlay function batchnorm_affine_normalize_internal( opmode::LuxLib.AbstractInternalArrayOpMode, ::typeof(identity), - x::AbstractArray{xT, 3}, + x::AbstractArray{xT,3}, μ::AbstractVector, σ²::AbstractVector, - γ::Union{Nothing, AbstractVector}, - β::Union{Nothing, AbstractVector}, + γ::Union{Nothing,AbstractVector}, + β::Union{Nothing,AbstractVector}, ϵ::Real, ) where {xT} - y = similar(x, + y = similar( + x, promote_type( safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) - ) + ), ) batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) return y diff --git a/ext/MooncakeLuxLibSLEEFPiratesExtension.jl b/ext/MooncakeLuxLibSLEEFPiratesExtension.jl index 97f55835f..91fc37c70 100644 --- a/ext/MooncakeLuxLibSLEEFPiratesExtension.jl +++ b/ext/MooncakeLuxLibSLEEFPiratesExtension.jl @@ -6,11 +6,34 @@ using Mooncake: @from_rrule, DefaultCtx @static if VERSION >= v"1.11" -# Workaround for package load order problems. See -# https://github.com/JuliaLang/julia/issues/56204#issuecomment-2419553167 for more context. -function __init__() - Base.generating_output() && return nothing + # Workaround for package load order problems. See + # https://github.com/JuliaLang/julia/issues/56204#issuecomment-2419553167 for more context. + function __init__() + Base.generating_output() && return nothing + + for f in Any[ + LuxLib.NNlib.sigmoid_fast, + LuxLib.NNlib.softplus, + LuxLib.NNlib.logsigmoid, + LuxLib.NNlib.swish, + LuxLib.NNlib.lisht, + Base.tanh, + LuxLib.NNlib.tanh_fast, + ] + f_fast = LuxLib.Impl.sleefpirates_fast_act(f) + @eval @from_rrule DefaultCtx Tuple{typeof($f_fast),IEEEFloat} + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(Broadcast.broadcasted), + typeof($f_fast), + Union{IEEEFloat,Array{<:IEEEFloat}}, + }, + ) + end + end +else for f in Any[ LuxLib.NNlib.sigmoid_fast, LuxLib.NNlib.softplus, @@ -21,41 +44,16 @@ function __init__() LuxLib.NNlib.tanh_fast, ] f_fast = LuxLib.Impl.sleefpirates_fast_act(f) - @eval @from_rrule DefaultCtx Tuple{typeof($f_fast), IEEEFloat} + @eval @from_rrule DefaultCtx Tuple{typeof($f_fast),IEEEFloat} @eval @from_rrule( DefaultCtx, Tuple{ typeof(Broadcast.broadcasted), typeof($f_fast), - Union{IEEEFloat, Array{<:IEEEFloat}}, + Union{IEEEFloat,Array{<:IEEEFloat}}, }, ) end end -else - -for f in Any[ - LuxLib.NNlib.sigmoid_fast, - LuxLib.NNlib.softplus, - LuxLib.NNlib.logsigmoid, - LuxLib.NNlib.swish, - LuxLib.NNlib.lisht, - Base.tanh, - LuxLib.NNlib.tanh_fast, -] - f_fast = LuxLib.Impl.sleefpirates_fast_act(f) - @eval @from_rrule DefaultCtx Tuple{typeof($f_fast), IEEEFloat} - @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(Broadcast.broadcasted), - typeof($f_fast), - Union{IEEEFloat, Array{<:IEEEFloat}}, - }, - ) -end - -end - end diff --git a/ext/MooncakeNNlibExt.jl b/ext/MooncakeNNlibExt.jl index 5fedbe7b4..d535ce21e 100644 --- a/ext/MooncakeNNlibExt.jl +++ b/ext/MooncakeNNlibExt.jl @@ -8,59 +8,47 @@ using NNlib: conv, depthwiseconv import Mooncake: @from_rrule, DefaultCtx, MinimalCtx @from_rrule( - MinimalCtx, - Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, + MinimalCtx, Tuple{typeof(batched_mul),Array{P,3},Array{P,3}} where {P<:IEEEFloat}, ) @from_rrule( - MinimalCtx, - Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, - true, + MinimalCtx, Tuple{typeof(dropout),AbstractRNG,Array{P},P} where {P<:IEEEFloat}, true, ) -@from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) -@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) -@from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) +@from_rrule(MinimalCtx, Tuple{typeof(softmax),Array{<:IEEEFloat}}, true) +@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax),Array{<:IEEEFloat}}, true) +@from_rrule(MinimalCtx, Tuple{typeof(logsumexp),Array{<:IEEEFloat}}, true) @from_rrule( - MinimalCtx, - Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, + MinimalCtx, Tuple{typeof(upsample_nearest),Array{<:IEEEFloat},NTuple{N,Int} where {N}}, ) @from_rrule( MinimalCtx, - Tuple{ - typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, - }, -) -@from_rrule( - MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} + Tuple{typeof(NNlib.fold),Array{<:IEEEFloat},NTuple{N,Int} where {N},DenseConvDims}, ) +@from_rrule(MinimalCtx, Tuple{typeof(NNlib.unfold),Array{<:IEEEFloat},DenseConvDims}) @from_rrule( - MinimalCtx, - Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, - true, + MinimalCtx, Tuple{typeof(NNlib.scatter),Any,Array,Array{<:Union{Integer,Tuple}}}, true, ) for conv in [:conv, :depthwiseconv] local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) @eval @from_rrule( MinimalCtx, - Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + Tuple{typeof($conv),Array{P},Array{P},ConvDims} where {P<:IEEEFloat}, true, ) @eval @from_rrule( MinimalCtx, - Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + Tuple{typeof($∇conv_data),Array{P},Array{P},ConvDims} where {P<:IEEEFloat}, true, ) end @from_rrule( MinimalCtx, - Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + Tuple{typeof(∇conv_filter),Array{P},Array{P},ConvDims} where {P<:IEEEFloat}, true, ) for pool in [:maxpool, :meanpool] - @eval @from_rrule( - MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true - ) + @eval @from_rrule(MinimalCtx, Tuple{typeof($pool),Array{<:IEEEFloat},PoolDims}, true) end -@from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) +@from_rrule(MinimalCtx, Tuple{typeof(pad_constant),Array,Any,Any}, true) end diff --git a/ext/MooncakeSpecialFunctionsExt.jl b/ext/MooncakeSpecialFunctionsExt.jl index dc6fd1b0f..65806ae83 100644 --- a/ext/MooncakeSpecialFunctionsExt.jl +++ b/ext/MooncakeSpecialFunctionsExt.jl @@ -5,42 +5,42 @@ using Base: IEEEFloat import Mooncake: @from_rrule, DefaultCtx, @zero_adjoint -@from_rrule DefaultCtx Tuple{typeof(airyai), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airyaix), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airyaiprime), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airybi), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(airybiprime), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(besselj0), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(besselj1), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(bessely0), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(bessely1), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(dawson), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(digamma), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erf), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erf), IEEEFloat, IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfc), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logerfc), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfcinv), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfcx), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logerfcx), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfi), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(erfinv), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(gamma), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(invdigamma), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(trigamma), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(polygamma), Integer, IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(beta), IEEEFloat, IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logbeta), IEEEFloat, IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(logabsgamma), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(loggamma), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(expint), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(expintx), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(expinti), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(sinint), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(cosint), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(ellipk), IEEEFloat} -@from_rrule DefaultCtx Tuple{typeof(ellipe), IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airyai),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airyaix),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airyaiprime),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airybi),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(airybiprime),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(besselj0),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(besselj1),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(bessely0),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(bessely1),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(dawson),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(digamma),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erf),IEEEFloat,IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfc),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logerfc),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfcinv),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfcx),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logerfcx),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfi),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(erfinv),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(gamma),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(invdigamma),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(trigamma),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(polygamma),Integer,IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(beta),IEEEFloat,IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logbeta),IEEEFloat,IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(logabsgamma),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(loggamma),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(expint),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(expintx),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(expinti),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(sinint),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(cosint),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(ellipk),IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(ellipe),IEEEFloat} -@zero_adjoint DefaultCtx Tuple{typeof(logfactorial), Integer} +@zero_adjoint DefaultCtx Tuple{typeof(logfactorial),Integer} end diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 5ee8f9e50..cd268f166 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -2,8 +2,7 @@ module Mooncake const CC = Core.Compiler -using - ADTypes, +using ADTypes, ChainRules, DiffRules, ExprTools, @@ -18,13 +17,30 @@ using import ChainRulesCore as CRC using Base: - IEEEFloat, unsafe_convert, unsafe_pointer_to_objref, pointer_from_objref, arrayref, - arrayset, TwicePrecision, twiceprecision + IEEEFloat, + unsafe_convert, + unsafe_pointer_to_objref, + pointer_from_objref, + arrayref, + arrayset, + TwicePrecision, + twiceprecision using Base.Experimental: @opaque using Base.Iterators: product using Core: - Intrinsics, bitcast, SimpleVector, svec, ReturnNode, GotoNode, GotoIfNot, PhiNode, - PiNode, SSAValue, Argument, OpaqueClosure, compilerbarrier + Intrinsics, + bitcast, + SimpleVector, + svec, + ReturnNode, + GotoNode, + GotoIfNot, + PhiNode, + PiNode, + SSAValue, + Argument, + OpaqueClosure, + compilerbarrier using Core.Compiler: IRCode, NewInstruction using Core.Intrinsics: pointerref, pointerset using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm! @@ -102,8 +118,7 @@ include("interface.jl") include("config.jl") include("developer_tools.jl") -export - primal, +export primal, tangent, randn_tangent, increment!!, diff --git a/src/codual.jl b/src/codual.jl index 6a7efeb22..f440c239e 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -1,15 +1,15 @@ -struct CoDual{Tx, Tdx} +struct CoDual{Tx,Tdx} x::Tx dx::Tdx end # Always sharpen the first thing if it's a type so static dispatch remains possible. function CoDual(x::Type{P}, dx::NoFData) where {P} - return CoDual{@isdefined(P) ? Type{P} : typeof(x), NoFData}(P, dx) + return CoDual{@isdefined(P) ? Type{P} : typeof(x),NoFData}(P, dx) end function CoDual(x::Type{P}, dx::NoTangent) where {P} - return CoDual{@isdefined(P) ? Type{P} : typeof(x), NoTangent}(P, dx) + return CoDual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(P, dx) end primal(x::CoDual) = x.x @@ -38,13 +38,13 @@ The type of the `CoDual` which contains instances of `P` and associated tangents """ function codual_type(::Type{P}) where {P} P == DataType && return CoDual - P isa Union && return Union{codual_type(P.a), codual_type(P.b)} + P isa Union && return Union{codual_type(P.a),codual_type(P.b)} P <: UnionAll && return CoDual # P is abstract, so we don't know its tangent type. - return isconcretetype(P) ? CoDual{P, tangent_type(P)} : CoDual + return isconcretetype(P) ? CoDual{P,tangent_type(P)} : CoDual end function codual_type(p::Type{Type{P}}) where {P} - return @isdefined(P) ? CoDual{Type{P}, NoTangent} : CoDual{_typeof(p), NoTangent} + return @isdefined(P) ? CoDual{Type{P},NoTangent} : CoDual{_typeof(p),NoTangent} end struct NoPullback{R<:Tuple} @@ -66,7 +66,7 @@ for each of the arguments lazily, the `NoPullback` generated will be a singleton means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements. """ -function NoPullback(args::Vararg{CoDual, N}) where {N} +function NoPullback(args::Vararg{CoDual,N}) where {N} return NoPullback(tuple_map(lazy_zero_rdata ∘ primal, args)) end @@ -74,7 +74,7 @@ end to_fwds(x::CoDual) = CoDual(primal(x), fdata(tangent(x))) -to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFData()) +to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P},NoFData}(primal(x), NoFData()) zero_fcodual(p) = to_fwds(zero_codual(p)) @@ -93,11 +93,11 @@ The type of the `CoDual` which contains instances of `P` and its fdata. """ function fcodual_type(::Type{P}) where {P} P == DataType && return CoDual - P isa Union && return Union{fcodual_type(P.a), fcodual_type(P.b)} + P isa Union && return Union{fcodual_type(P.a),fcodual_type(P.b)} P <: UnionAll && return CoDual - return isconcretetype(P) ? CoDual{P, fdata_type(tangent_type(P))} : CoDual + return isconcretetype(P) ? CoDual{P,fdata_type(tangent_type(P))} : CoDual end function fcodual_type(p::Type{Type{P}}) where {P} - return @isdefined(P) ? CoDual{Type{P}, NoFData} : CoDual{_typeof(p), NoFData} + return @isdefined(P) ? CoDual{Type{P},NoFData} : CoDual{_typeof(p),NoFData} end diff --git a/src/config.jl b/src/config.jl index d4371e1e0..7db7faa42 100644 --- a/src/config.jl +++ b/src/config.jl @@ -4,6 +4,6 @@ Configuration struct for use with ADTypes.AutoMooncake. """ @kwdef struct Config - debug_mode::Bool=false - silence_debug_messages::Bool=false + debug_mode::Bool = false + silence_debug_messages::Bool = false end diff --git a/src/debug_mode.jl b/src/debug_mode.jl index 25342afac..d51e1687a 100644 --- a/src/debug_mode.jl +++ b/src/debug_mode.jl @@ -9,7 +9,7 @@ post-conditions to `pb`. Let `dx = pb.pb(dy)`, for some rdata `dy`, then this fu Reverse pass counterpart to [`DebugRRule`](@ref) """ -struct DebugPullback{Tpb, Ty, Tx} +struct DebugPullback{Tpb,Ty,Tx} pb::Tpb y::Ty x::Tx @@ -84,7 +84,7 @@ _copy(x::P) where {P<:DebugRRule} = P(_copy(x.rule)) Apply type checking to enforce pre- and post-conditions on `rule.rule`. See the docstring for `DebugRRule` for details. """ -@noinline function (rule::DebugRRule)(x::Vararg{CoDual, N}) where {N} +@noinline function (rule::DebugRRule)(x::Vararg{CoDual,N}) where {N} verify_fwds_inputs(rule.rule, x) y, pb = rule.rule(x...) verify_fwds_output(x, y) diff --git a/src/developer_tools.jl b/src/developer_tools.jl index 9c977e7c9..14a83c68e 100644 --- a/src/developer_tools.jl +++ b/src/developer_tools.jl @@ -54,7 +54,9 @@ true """ function fwd_ir( sig::Type{<:Tuple}; - interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true + interp=get_interpreter(), + debug_mode::Bool=false, + do_inline::Bool=true, )::IRCode return generate_ir(interp, sig; debug_mode, do_inline).fwd_ir end @@ -93,7 +95,9 @@ true """ function rvs_ir( sig::Type{<:Tuple}; - interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true + interp=get_interpreter(), + debug_mode::Bool=false, + do_inline::Bool=true, )::IRCode return generate_ir(interp, sig; debug_mode, do_inline).rvs_ir end diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index ade6a4725..d2c5adadf 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -166,12 +166,11 @@ end T == NoTangent && return NoFData # This method can only handle struct types. Tell user to implement their own method. - isprimitivetype(T) && throw(error( - "$T is a primitive type. Implement a method of `fdata_type` for it." - )) + isprimitivetype(T) && + throw(error("$T is a primitive type. Implement a method of `fdata_type` for it.")) # If the type is a Union, then take the union type of its arguments. - T isa Union && return Union{fdata_type(T.a), fdata_type(T.b)} + T isa Union && return Union{fdata_type(T.a),fdata_type(T.b)} # If `P` is a mutable type, then its forwards data is its tangent. ismutabletype(T) && return T @@ -188,7 +187,7 @@ end return fdata_type(fieldtype(Tfields, n)) end all(==(NoFData), fwds_data_field_types) && return NoFData - return FData{NamedTuple{fieldnames(Tfields), Tuple{fwds_data_field_types...}}} + return FData{NamedTuple{fieldnames(Tfields),Tuple{fwds_data_field_types...}}} end return :(error("Unhandled type $T")) @@ -197,20 +196,20 @@ end fdata_type(::Type{T}) where {T<:Ptr} = T @generated function fdata_type(::Type{P}) where {P<:Tuple} - isa(P, Union) && return Union{fdata_type(P.a), fdata_type(P.b)} + isa(P, Union) && return Union{fdata_type(P.a),fdata_type(P.b)} isempty(P.parameters) && return NoFData isa(last(P.parameters), Core.TypeofVararg) && return Any - nofdata_tt = Tuple{Vararg{NoFData, length(P.parameters)}} + nofdata_tt = Tuple{Vararg{NoFData,length(P.parameters)}} fdata_tt = Tuple{map(fdata_type, fieldtypes(P))...} fdata_tt <: nofdata_tt && return NoFData - return nofdata_tt <: fdata_tt ? Union{NoFData, fdata_tt} : fdata_tt + return nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt end -@generated function fdata_type(::Type{NamedTuple{names, T}}) where {names, T<:Tuple} +@generated function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple} if fdata_type(T) == NoFData return NoFData elseif isconcretetype(fdata_type(T)) - return NamedTuple{names, fdata_type(T)} + return NamedTuple{names,fdata_type(T)} else return Any end @@ -254,7 +253,7 @@ function fdata(t::T) where {T<:PossiblyUninitTangent} return is_init(t) ? F(fdata(val(t))) : F() end -@generated function fdata(t::Union{Tuple, NamedTuple}) +@generated function fdata(t::Union{Tuple,NamedTuple}) fdata_type(t) == NoFData && return NoFData() return :(tuple_map(fdata, t)) end @@ -299,7 +298,7 @@ invalid. """ function verify_fdata_value(p, f)::Nothing verify_fdata_type(_typeof(p), typeof(f)) - _verify_fdata_value(p, f) + return _verify_fdata_value(p, f) end _verify_fdata_value(::IEEEFloat, ::NoFData) = nothing @@ -385,21 +384,21 @@ fields_type(::Type{RData{T}}) where {T<:NamedTuple} = T @inline increment!!(x::RData{T}, y::RData{T}) where {T} = RData(increment!!(x.data, y.data)) -@inline function increment_field!!(x::RData{T}, y, ::Val{f}) where {T, f} +@inline function increment_field!!(x::RData{T}, y, ::Val{f}) where {T,f} y isa NoRData && return x new_val = fieldtype(T, f) <: PossiblyUninitTangent ? fieldtype(T, f)(y) : y return RData(increment_field!!(x.data, new_val, Val(f))) end -@doc""" - rdata_type(T) +@doc """ + rdata_type(T) -Returns the type of the reverse data of a tangent of type T. + Returns the type of the reverse data of a tangent of type T. -# Extended help + # Extended help -See extended help in [`fdata_type`](@ref) docstring. -""" + See extended help in [`fdata_type`](@ref) docstring. + """ rdata_type(T) rdata_type(x) = throw(error("$x is not a type. Perhaps you meant typeof(x)?")) @@ -416,12 +415,11 @@ end T == NoTangent && return NoRData # This method can only handle struct types. Tell user to implement their own method. - isprimitivetype(T) && throw(error( - "$T is a primitive type. Implement a method of `rdata_type` for it." - )) + isprimitivetype(T) && + throw(error("$T is a primitive type. Implement a method of `rdata_type` for it.")) # If the type is a Union, then take the union type of its arguments. - T isa Union && return Union{rdata_type(T.a), rdata_type(T.b)} + T isa Union && return Union{rdata_type(T.a),rdata_type(T.b)} # If `P` is a mutable type, then all tangent info is propagated on the forwards-pass. ismutabletype(T) && return NoRData @@ -436,27 +434,27 @@ end Tfs = fields_type(T) rvs_types = map(n -> rdata_type(fieldtype(Tfs, n)), 1:fieldcount(Tfs)) all(==(NoRData), rvs_types) && return NoRData - return RData{NamedTuple{fieldnames(Tfs), Tuple{rvs_types...}}} + return RData{NamedTuple{fieldnames(Tfs),Tuple{rvs_types...}}} end end rdata_type(::Type{<:Ptr}) = NoRData @generated function rdata_type(::Type{P}) where {P<:Tuple} - isa(P, Union) && return Union{rdata_type(P.a), rdata_type(P.b)} + isa(P, Union) && return Union{rdata_type(P.a),rdata_type(P.b)} isempty(P.parameters) && return NoRData isa(last(P.parameters), Core.TypeofVararg) && return Any - nordata_tt = Tuple{Vararg{NoRData, length(P.parameters)}} + nordata_tt = Tuple{Vararg{NoRData,length(P.parameters)}} rdata_tt = Tuple{map(rdata_type, fieldtypes(P))...} rdata_tt <: nordata_tt && return NoRData - return nordata_tt <: rdata_tt ? Union{NoRData, rdata_tt} : rdata_tt + return nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt end -function rdata_type(::Type{NamedTuple{names, T}}) where {names, T<:Tuple} +function rdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple} if rdata_type(T) == NoRData return NoRData elseif isconcretetype(rdata_type(T)) - return NamedTuple{names, rdata_type(T)} + return NamedTuple{names,rdata_type(T)} else return Any end @@ -503,7 +501,7 @@ function rdata(t::T) where {T<:PossiblyUninitTangent} return is_init(t) ? R(rdata(val(t))) : R() end -@generated function rdata(t::Union{Tuple, NamedTuple}) +@generated function rdata(t::Union{Tuple,NamedTuple}) rdata_type(t) == NoRData && return NoRData() return :(tuple_map(rdata, t)) end @@ -511,7 +509,7 @@ end function rdata_backing_type(::Type{P}) where {P} rdata_field_types = map(n -> rdata_field_type(P, n), 1:fieldcount(P)) all(==(NoRData), rdata_field_types) && return NoRData - return NamedTuple{fieldnames(P), Tuple{rdata_field_types...}} + return NamedTuple{fieldnames(P),Tuple{rdata_field_types...}} end """ @@ -547,7 +545,7 @@ zero_rdata(p::IEEEFloat) = zero(p) return Expr(:call, R, backing_expr) end -@generated function zero_rdata(p::Union{Tuple, NamedTuple}) +@generated function zero_rdata(p::Union{Tuple,NamedTuple}) rdata_type(tangent_type(p)) == NoRData && return NoRData() return :(tuple_map(zero_rdata, p)) end @@ -694,7 +692,7 @@ invalid. function verify_rdata_value(p, r)::Nothing r isa ZeroRData && return nothing verify_rdata_type(_typeof(p), typeof(r)) - _verify_rdata_value(p, r) + return _verify_rdata_value(p, r) end _verify_rdata_value(::P, ::P) where {P<:IEEEFloat} = nothing @@ -729,7 +727,6 @@ function _verify_rdata_value(p, r) return nothing end - """ LazyZeroRData{P, Tdata}() @@ -743,7 +740,7 @@ be stored. For example, `Float64`s do not need any data, so `LazyZeroRData(0.0)` an instance of a singleton type, meaning that various important optimisations can be performed in AD. """ -struct LazyZeroRData{P, Tdata} +struct LazyZeroRData{P,Tdata} data::Tdata end @@ -752,14 +749,14 @@ _copy(x::P) where {P<:LazyZeroRData} = P(_copy(x.data)) # Returns the type which must be output by LazyZeroRData whenever it is passed a `P`. @inline function lazy_zero_rdata_type(::Type{P}) where {P} Tdata = can_produce_zero_rdata_from_type(P) ? Nothing : rdata_type(tangent_type(P)) - return LazyZeroRData{P, Tdata} + return LazyZeroRData{P,Tdata} end # Be lazy if we can compute the zero element given only the type, otherwise just store the # zero element and use it later. L is the precise type of `LazyZeroRData` that you wish to # construct -- very occassionally you need complete control over this, but don't want to # figure out for yourself whether or not construction can be performed lazily. -@inline function lazy_zero_rdata(::Type{L}, p::P) where {S, L<:LazyZeroRData{S}, P} +@inline function lazy_zero_rdata(::Type{L}, p::P) where {S,L<:LazyZeroRData{S},P} return L(can_produce_zero_rdata_from_type(S) ? nothing : zero_rdata(p)) end @@ -769,10 +766,10 @@ end # Ensure proper specialisation on types. @inline function lazy_zero_rdata(p::Type{P}) where {P} Rtype = @isdefined(P) ? Type{P} : _typeof(p) - return LazyZeroRData{Rtype, Nothing}(nothing) + return LazyZeroRData{Rtype,Nothing}(nothing) end -@inline instantiate(::LazyZeroRData{P, Nothing}) where {P} = zero_rdata_from_type(P) +@inline instantiate(::LazyZeroRData{P,Nothing}) where {P} = zero_rdata_from_type(P) @inline instantiate(r::LazyZeroRData) = r.data @inline instantiate(::NoRData) = NoRData() @@ -787,7 +784,7 @@ tangent_type(::Type{NoFData}, ::Type{R}) where {R<:IEEEFloat} = R tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F # Tuples -function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple, R<:Tuple} +function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple} return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...} end function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple} @@ -800,22 +797,22 @@ function tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Tuple} end # NamedTuples -function tangent_type(::Type{F}, ::Type{R}) where {ns, F<:NamedTuple{ns}, R<:NamedTuple{ns}} - return NamedTuple{ns, tangent_type(tuple_type(F), tuple_type(R))} +function tangent_type(::Type{F}, ::Type{R}) where {ns,F<:NamedTuple{ns},R<:NamedTuple{ns}} + return NamedTuple{ns,tangent_type(tuple_type(F), tuple_type(R))} end -function tangent_type(::Type{NoFData}, ::Type{R}) where {ns, R<:NamedTuple{ns}} - return NamedTuple{ns, tangent_type(NoFData, tuple_type(R))} +function tangent_type(::Type{NoFData}, ::Type{R}) where {ns,R<:NamedTuple{ns}} + return NamedTuple{ns,tangent_type(NoFData, tuple_type(R))} end -function tangent_type(::Type{F}, ::Type{NoRData}) where {ns, F<:NamedTuple{ns}} - return NamedTuple{ns, tangent_type(tuple_type(F), NoRData)} +function tangent_type(::Type{F}, ::Type{NoRData}) where {ns,F<:NamedTuple{ns}} + return NamedTuple{ns,tangent_type(tuple_type(F), NoRData)} end -tuple_type(::Type{<:NamedTuple{<:Any, T}}) where {T<:Tuple} = T +tuple_type(::Type{<:NamedTuple{<:Any,T}}) where {T<:Tuple} = T # mutable structs tangent_type(::Type{F}, ::Type{NoRData}) where {F<:MutableTangent} = F # structs -function tangent_type(::Type{F}, ::Type{R}) where {F<:FData, R<:RData} +function tangent_type(::Type{F}, ::Type{R}) where {F<:FData,R<:RData} return Tangent{tangent_type(fields_type(F), fields_type(R))} end function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:RData} @@ -827,14 +824,13 @@ end function tangent_type( ::Type{PossiblyUninitTangent{F}}, ::Type{PossiblyUninitTangent{R}} -) where {F, R} +) where {F,R} return PossiblyUninitTangent{tangent_type(F, R)} end # Abstract types. tangent_type(::Type{Any}, ::Type{Any}) = Any - """ tangent(f, r) @@ -864,7 +860,7 @@ end tangent(f::MutableTangent, r::NoRData) = f # structs -function tangent(f::F, r::R) where {F<:FData, R<:RData} +function tangent(f::F, r::R) where {F<:FData,R<:RData} return tangent_type(F, R)(tangent(f.data, r.data)) end function tangent(::NoFData, r::R) where {R<:RData} @@ -874,7 +870,7 @@ function tangent(f::F, ::NoRData) where {F<:FData} return tangent_type(F, NoRData)(tangent(f.data, NoRData())) end -function tangent(f::PossiblyUninitTangent{F}, r::PossiblyUninitTangent{R}) where {F, R} +function tangent(f::PossiblyUninitTangent{F}, r::PossiblyUninitTangent{R}) where {F,R} T = PossiblyUninitTangent{tangent_type(F, R)} is_init(f) && is_init(r) && return T(tangent(val(f), val(r))) !is_init(f) && !is_init(r) && return T() @@ -907,11 +903,11 @@ Equivalent to `tangent(fdata, rdata(zero_tangent(primal)))`. """ zero_tangent(p, ::NoFData) = zero_tangent(p) -function zero_tangent(p::P, f::F) where {P, F} +function zero_tangent(p::P, f::F) where {P,F} T = tangent_type(P) T == F && return f r = rdata(zero_tangent(p)) return tangent(f, r) end -zero_tangent(p::Tuple, f::Union{Tuple, NamedTuple}) = tuple_map(zero_tangent, p, f) +zero_tangent(p::Tuple, f::Union{Tuple,NamedTuple}) = tuple_map(zero_tangent, p, f) diff --git a/src/interface.jl b/src/interface.jl index 56c02e362..98ecca98f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -8,7 +8,7 @@ In-place version of `value_and_pullback!!` in which the arguments have been wrap if calling this function multiple times with different values of `x`, should be careful to ensure that you zero-out the tangent fields of `x` each time. """ -function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, N, T} +function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual,N}) where {R,N,T} fx_fwds = tuple_map(to_fwds, fx) __verify_sig(rule, fx_fwds) out, pb!! = rule(fx_fwds...) @@ -19,8 +19,8 @@ function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R end function __verify_sig( - rule::DerivedRule{<:Any, <:MistyClosure{<:OpaqueClosure{sig}}}, fx::Tfx -) where {sig, Tfx} + rule::DerivedRule{<:Any,<:MistyClosure{<:OpaqueClosure{sig}}}, fx::Tfx +) where {sig,Tfx} Pfx = typeof(__unflatten_codual_varargs(rule.isva, fx, rule.nargs)) if sig != Pfx msg = "signature of arguments, $Pfx, not equal to signature required by rule, $sig." @@ -67,16 +67,18 @@ Mooncake.__value_and_gradient!!( (4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0])) ``` """ -function __value_and_gradient!!(rule::R, fx::Vararg{CoDual, N}) where {R, N} +function __value_and_gradient!!(rule::R, fx::Vararg{CoDual,N}) where {R,N} fx_fwds = tuple_map(to_fwds, fx) __verify_sig(rule, fx_fwds) out, pb!! = rule(fx_fwds...) y = primal(out) if !(y isa IEEEFloat) - throw(ValueAndGradientReturnTypeError( - "When calling __value_and_gradient!!, return value of primal must be a " * - "subtype of IEEEFloat. Instead, found value of type $(typeof(y))." - )) + throw( + ValueAndGradientReturnTypeError( + "When calling __value_and_gradient!!, return value of primal must be a " * + "subtype of IEEEFloat. Instead, found value of type $(typeof(y)).", + ), + ) end @assert y isa IEEEFloat @assert tangent(out) isa NoFData @@ -113,7 +115,7 @@ use-case, consider pre-allocating the `CoDual`s and calling the other method of function. The `CoDual`s should be primal-tangent pairs (as opposed to primal-fdata pairs). There are lots of ways to get this wrong though, so we generally advise against doing this. """ -function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any, N}) where {R, N} +function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any,N}) where {R,N} return __value_and_pullback!!(rule, ȳ, __create_coduals(fx)...) end @@ -140,7 +142,7 @@ value_and_gradient!!(rule, f, x, y) (4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0])) ``` """ -function value_and_gradient!!(rule::R, fx::Vararg{Any, N}) where {R, N} +function value_and_gradient!!(rule::R, fx::Vararg{Any,N}) where {R,N} return __value_and_gradient!!(rule, __create_coduals(fx)...) end @@ -154,7 +156,7 @@ function __create_coduals(args) "means that Mooncake.jl has encountered a self-referential type. Mooncake.jl " * "is not presently able to handle self-referential types, so if you are " * "indeed using a self-referential type somewhere, you will need to " * - "refactor to avoid it if you wish to use Mooncake.jl." + "refactor to avoid it if you wish to use Mooncake.jl.", ) else rethrow(e) diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 6442be1d7..bd6c019c0 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -16,10 +16,10 @@ struct ClosureCacheKey end struct MooncakeCache - dict::IdDict{Core.MethodInstance, Core.CodeInstance} + dict::IdDict{Core.MethodInstance,Core.CodeInstance} end -MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance, Core.CodeInstance}()) +MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance,Core.CodeInstance}()) # The method table used by `Mooncake.@mooncake_overlay`. Base.Experimental.@MethodTable mooncake_method_table @@ -31,16 +31,16 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter opt_params::CC.OptimizationParams inf_cache::Vector{CC.InferenceResult} code_cache::MooncakeCache - oc_cache::Dict{ClosureCacheKey, Any} + oc_cache::Dict{ClosureCacheKey,Any} function MooncakeInterpreter( ::Type{C}; meta=nothing, world::UInt=Base.get_world_counter(), inf_params::CC.InferenceParams=CC.InferenceParams(), opt_params::CC.OptimizationParams=CC.OptimizationParams(), - inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], + inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), - oc_cache::Dict{ClosureCacheKey, Any}=Dict{ClosureCacheKey, Any}(), + oc_cache::Dict{ClosureCacheKey,Any}=Dict{ClosureCacheKey,Any}(), ) where {C} return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) end @@ -54,7 +54,7 @@ end Base.show(io::IO, mc::MooncakeInterpreter) = _show_interp(io, MIME"text/plain"(), mc) function _show_interp(io::IO, ::MIME"text/plain", ::MooncakeInterpreter) - print(io, "MooncakeInterpreter()") + return print(io, "MooncakeInterpreter()") end MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx) @@ -119,7 +119,7 @@ end _type(x::Type) = x _type(x::CC.Const) = _typeof(x.val) _type(x::CC.PartialStruct) = x.typ -_type(x::CC.Conditional) = Union{_type(x.thentype), _type(x.elsetype)} +_type(x::CC.Conditional) = Union{_type(x.thentype),_type(x.elsetype)} _type(::CC.PartialTypeVar) = TypeVar struct NoInlineCallInfo <: CC.CallInfo @@ -161,48 +161,42 @@ function Core.Compiler.abstract_call_gf_by_type( end @static if VERSION < v"1.11-" - -function CC.inlining_policy( - interp::MooncakeInterpreter{C}, - @nospecialize(src), - @nospecialize(info::CC.CallInfo), - stmt_flag::UInt8, - mi::Core.MethodInstance, - argtypes::Vector{Any}, -) where {C} - - # Do not inline away primitives. - info isa NoInlineCallInfo && return nothing - - # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. - return @invoke CC.inlining_policy( - interp::CC.AbstractInterpreter, - src::Any, - info::CC.CallInfo, + function CC.inlining_policy( + interp::MooncakeInterpreter{C}, + @nospecialize(src), + @nospecialize(info::CC.CallInfo), stmt_flag::UInt8, mi::Core.MethodInstance, argtypes::Vector{Any}, - ) -end - -else # 1.11 and up. + ) where {C} -function CC.inlining_policy( - interp::MooncakeInterpreter, - @nospecialize(src), - @nospecialize(info::CC.CallInfo), - stmt_flag::UInt32, -) - # Do not inline away primitives. - info isa NoInlineCallInfo && return nothing + # Do not inline away primitives. + info isa NoInlineCallInfo && return nothing + + # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. + return @invoke CC.inlining_policy( + interp::CC.AbstractInterpreter, + src::Any, + info::CC.CallInfo, + stmt_flag::UInt8, + mi::Core.MethodInstance, + argtypes::Vector{Any}, + ) + end - # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. - return @invoke CC.inlining_policy( - interp::CC.AbstractInterpreter, - src::Any, - info::CC.CallInfo, +else # 1.11 and up. + function CC.inlining_policy( + interp::MooncakeInterpreter, + @nospecialize(src), + @nospecialize(info::CC.CallInfo), stmt_flag::UInt32, ) -end + # Do not inline away primitives. + info isa NoInlineCallInfo && return nothing + # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. + return @invoke CC.inlining_policy( + interp::CC.AbstractInterpreter, src::Any, info::CC.CallInfo, stmt_flag::UInt32 + ) + end end diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index 4fcd53cd3..46f9920aa 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -1,6 +1,6 @@ # See the docstring for `BBCode` for some context on this file. -const _id_count::Dict{Int, Int32} = Dict{Int, Int32}() +const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() """ ID() @@ -32,7 +32,7 @@ ensure determinism between two runs of the same function which makes use of `ID` This is akin to setting the random seed associated to a random number generator globally. """ function seed_id!() - global _id_count[Threads.threadid()] = 0 + return global _id_count[Threads.threadid()] = 0 end """ @@ -107,7 +107,7 @@ end A Union of the possible types of a terminator node. """ -const Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} +const Terminator = Union{Switch,IDGotoIfNot,IDGotoNode,ReturnNode} """ BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector) @@ -139,7 +139,7 @@ end """ const IDInstPair = Tuple{ID, NewInstruction} """ -const IDInstPair = Tuple{ID, NewInstruction} +const IDInstPair = Tuple{ID,NewInstruction} """ BBlock(id::ID, inst_pairs::Vector{IDInstPair}) @@ -254,7 +254,7 @@ end Make a new `BBCode` whose `blocks` is given by `new_blocks`, and fresh copies are made of all other fields from `ir`. """ -function BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{BBlock}) +function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) return BBCode( new_blocks, CC.copy(ir.argtypes), @@ -272,7 +272,7 @@ Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) Compute a map from the `ID of each `BBlock` in `ir` to its possible successors. """ -function compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} +function compute_all_successors(ir::BBCode)::Dict{ID,Vector{ID}} return _compute_all_successors(ir.blocks) end @@ -283,15 +283,15 @@ Internal method implementing [`compute_all_successors`](@ref). This method is ea construct test cases for because it only requires the collection of `BBlocks`, not all of the other stuff that goes into a `BBCode`. """ -function _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} +function _compute_all_successors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} succs = map(enumerate(blks)) do (n, blk) return successors(terminator(blk), n, blks, n == length(blks)) end - return Dict{ID, Vector{ID}}(zip(map(b -> b.id, blks), succs)) + return Dict{ID,Vector{ID}}(zip(map(b -> b.id, blks), succs)) end function successors(::Nothing, n::Int, blks::Vector{BBlock}, is_final_block::Bool) - return is_final_block ? ID[] : ID[blks[n+1].id] + return is_final_block ? ID[] : ID[blks[n + 1].id] end successors(t::IDGotoNode, ::Int, ::Vector{BBlock}, ::Bool) = [t.label] function successors(t::IDGotoIfNot, n::Int, blks::Vector{BBlock}, is_final_block::Bool) @@ -305,7 +305,7 @@ successors(t::Switch, ::Int, ::Vector{BBlock}, ::Bool) = vcat(t.dests, t.fallthr Compute a map from the `ID of each `BBlock` in `ir` to its possible predecessors. """ -function compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} +function compute_all_predecessors(ir::BBCode)::Dict{ID,Vector{ID}} return _compute_all_predecessors(ir.blocks) end @@ -316,13 +316,12 @@ Internal method implementing [`compute_all_predecessors`](@ref). This method is construct test cases for because it only requires the collection of `BBlocks`, not all of the other stuff that goes into a `BBCode`. """ -function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} - +function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} successor_map = _compute_all_successors(blks) # Initialise predecessor map to be empty. ks = collect(keys(successor_map)) - predecessor_map = Dict{ID, Vector{ID}}(zip(ks, map(_ -> ID[], ks))) + predecessor_map = Dict{ID,Vector{ID}}(zip(ks, map(_ -> ID[], ks))) # Find all predecessors by iterating through the successor map. for (k, succs) in successor_map @@ -381,7 +380,7 @@ function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG # Construct map from block ID to block number. block_ids = map(b -> b.id, blks) - id_to_num = Dict{ID, Int}(zip(block_ids, collect(eachindex(block_ids)))) + id_to_num = Dict{ID,Int}(zip(block_ids, collect(eachindex(block_ids)))) # Convert predecessor and successor IDs to numbers. preds = map(id -> sort(map(p -> id_to_num[p], preds_ids[id])), block_ids) @@ -389,10 +388,10 @@ function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG index = vcat(0, cumsum(map(length, blks))) .+ 1 basic_blocks = map(eachindex(blks)) do n - stmt_range = Core.Compiler.StmtRange(index[n], index[n+1] - 1) + stmt_range = Core.Compiler.StmtRange(index[n], index[n + 1] - 1) return Core.Compiler.BasicBlock(stmt_range, preds[n], succs[n]) end - return Core.Compiler.CFG(basic_blocks, index[2:end-1]) + return Core.Compiler.CFG(basic_blocks, index[2:(end - 1)]) end # @@ -426,19 +425,18 @@ function BBCode(ir::IRCode) return BBCode(ir, blocks) end - """ new_inst_vec(x::CC.InstructionStream) Convert an `Compiler.InstructionStream` into a list of `Compiler.NewInstruction`s. """ function new_inst_vec(x::CC.InstructionStream) - return map((v..., ) -> NewInstruction(v...), stmt(x), x.type, x.info, x.line, x.flag) + return map((v...,) -> NewInstruction(v...), stmt(x), x.type, x.info, x.line, x.flag) end # Maps from positional names (SSAValues for nodes, Integers for basic blocks) to IDs. -const SSAToIdDict = Dict{SSAValue, ID} -const BlockNumToIdDict = Dict{Integer, ID} +const SSAToIdDict = Dict{SSAValue,ID} +const BlockNumToIdDict = Dict{Integer,ID} """ _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} @@ -447,7 +445,7 @@ Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValu line with the corresponding `ID`. For example, a call statement of the form `Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. """ -function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} +function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID},InstVector} ids = map(_ -> ID(), insts) val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) @@ -489,7 +487,7 @@ _ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x. Assign to each basic block in `cfg` an `ID`. Replace all integers referencing block numbers in `insts` with the corresponding `ID`. Return the `ID`s and the updated instructions. """ -function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} +function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID},InstVector} ids = map(_ -> ID(), cfg.blocks) block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) return ids, map(Base.Fix1(_block_num_to_ids, block_num_id_map), insts) @@ -525,7 +523,7 @@ function CC.IRCode(bb_code::BBCode) bb_code = _lower_switch_statements(bb_code) bb_code = _remove_double_edges(bb_code) insts = _ids_to_line_numbers(bb_code) - cfg = control_flow_graph(bb_code) + cfg = control_flow_graph(bb_code) insts = _lines_to_blocks(insts, cfg) return IRCode( CC.InstructionStream( @@ -556,7 +554,7 @@ function _lower_switch_statements(bb_code::BBCode) if t isa Switch # Create new block without the `Switch`. - bb = BBlock(block.id, block.inst_ids[1:end-1], block.insts[1:end-1]) + bb = BBlock(block.id, block.inst_ids[1:(end - 1)], block.insts[1:(end - 1)]) push!(new_blocks, bb) # Create new blocks for each `GotoIfNot` from the `Switch`. @@ -586,7 +584,7 @@ function _ids_to_line_numbers(bb_code::BBCode)::InstVector # Construct map from `ID`s to `SSAValue`s. block_ids = [b.id for b in bb_code.blocks] block_lengths = map(length, bb_code.blocks) - block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:end-1] .+ 1)) + block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:(end - 1)] .+ 1)) line_ids = concatenate_ids(bb_code) line_ssas = SSAValue.(eachindex(line_ids)) id_to_ssa_map = Dict(zip(vcat(block_ids, line_ids), vcat(block_start_ssas, line_ssas))) @@ -630,8 +628,8 @@ in `ir`. function _remove_double_edges(ir::BBCode) new_blks = map(enumerate(ir.blocks)) do (n, blk) t = terminator(blk) - if t isa IDGotoIfNot && t.dest == ir.blocks[n+1].id - new_insts = vcat(blk.insts[1:end-1], NewInstruction(t; stmt=IDGotoNode(t.dest))) + if t isa IDGotoIfNot && t.dest == ir.blocks[n + 1].id + new_insts = vcat(blk.insts[1:(end - 1)], NewInstruction(t; stmt=IDGotoNode(t.dest))) return BBlock(blk.id, blk.inst_ids, new_insts) else return blk @@ -652,7 +650,7 @@ Returns a 2-tuple, whose first element is `g`, and whose second element is a map the `ID` associated to each basic block in `ir`, to the `Int` corresponding to its node index in `g`. """ -function _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}} +function _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph,Dict{ID,Int}} node_ints = collect(eachindex(blks)) id_to_int = Dict(zip(map(blk -> blk.id, blks), node_ints)) successors = _compute_all_successors(blks) @@ -725,7 +723,7 @@ block stack at all. """ function characterise_unique_predecessor_blocks( blks::Vector{BBlock} -)::Tuple{Dict{ID, Bool}, Dict{ID, Bool}} +)::Tuple{Dict{ID,Bool},Dict{ID,Bool}} # Obtain the block IDs in order -- this ensures that we get the entry block first. blk_ids = ID[b.id for b in blks] @@ -733,7 +731,7 @@ function characterise_unique_predecessor_blocks( succs = _compute_all_successors(blks) # The bulk of blocks can be hanled by this general loop. - is_unique_pred = Dict{ID, Bool}() + is_unique_pred = Dict{ID,Bool}() for id in blk_ids ss = succs[id] is_unique_pred[id] = !isempty(ss) && all(s -> length(preds[s]) == 1, ss) @@ -754,7 +752,7 @@ function characterise_unique_predecessor_blocks( end # pred_is_unique_pred is true if the unique predecessor to a block is a unique pred. - pred_is_unique_pred = Dict{ID, Bool}() + pred_is_unique_pred = Dict{ID,Bool}() for id in blk_ids pred_is_unique_pred[id] = length(preds[id]) == 1 && is_unique_pred[only(preds[id])] end @@ -774,12 +772,12 @@ For each line in `stmts`, determine whether it is referenced anywhere else in th Returns a dictionary containing the results. An element is `false` if the corresponding `ID` is unused, and `true` if is used. """ -function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool} +function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} ids = first.(stmts) insts = last.(stmts) # Initialise to false. - is_used = Dict{ID, Bool}(zip(ids, fill(false, length(ids)))) + is_used = Dict{ID,Bool}(zip(ids, fill(false, length(ids)))) # Hunt through the instructions, flipping a value in is_used to true whenever an ID # is encountered which corresponds to an SSA. @@ -797,29 +795,29 @@ the corresponding value of `d` to `true`. For example, if `x = ReturnNode(ID(5))`, then this function sets `d[ID(5)] = true`. """ -function _find_id_uses!(d::Dict{ID, Bool}, x::Expr) +function _find_id_uses!(d::Dict{ID,Bool}, x::Expr) for arg in x.args in(arg, keys(d)) && setindex!(d, true, arg) end end -function _find_id_uses!(d::Dict{ID, Bool}, x::IDGotoIfNot) +function _find_id_uses!(d::Dict{ID,Bool}, x::IDGotoIfNot) return in(x.cond, keys(d)) && setindex!(d, true, x.cond) end -_find_id_uses!(::Dict{ID, Bool}, ::IDGotoNode) = nothing -function _find_id_uses!(d::Dict{ID, Bool}, x::PiNode) +_find_id_uses!(::Dict{ID,Bool}, ::IDGotoNode) = nothing +function _find_id_uses!(d::Dict{ID,Bool}, x::PiNode) return in(x.val, keys(d)) && setindex!(d, true, x.val) end -function _find_id_uses!(d::Dict{ID, Bool}, x::IDPhiNode) +function _find_id_uses!(d::Dict{ID,Bool}, x::IDPhiNode) v = x.values for n in eachindex(v) isassigned(v, n) && in(v[n], keys(d)) && setindex!(d, true, v[n]) end end -function _find_id_uses!(d::Dict{ID, Bool}, x::ReturnNode) +function _find_id_uses!(d::Dict{ID,Bool}, x::ReturnNode) return isdefined(x, :val) && in(x.val, keys(d)) && setindex!(d, true, x.val) end -_find_id_uses!(d::Dict{ID, Bool}, x::QuoteNode) = nothing -_find_id_uses!(d::Dict{ID, Bool}, x) = nothing +_find_id_uses!(d::Dict{ID,Bool}, x::QuoteNode) = nothing +_find_id_uses!(d::Dict{ID,Bool}, x) = nothing """ _is_reachable(blks::Vector{BBlock})::Vector{Bool} diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 21873c966..570607457 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -21,7 +21,7 @@ from which the `IRCode` is derived must be consulted. `Mooncake.is_vararg_and_sp provides a convenient way to do this. """ function normalise!(ir::IRCode, spnames::Vector{Symbol}) - sp_map = Dict{Symbol, CC.VarState}(zip(spnames, ir.sptypes)) + sp_map = Dict{Symbol,CC.VarState}(zip(spnames, ir.sptypes)) ir = interpolate_boundschecks!(ir) ir = CC.compact!(ir) for (n, inst) in enumerate(stmt(ir.stmts)) @@ -76,13 +76,13 @@ to be called in the context of an `IRCode`, in which case the values of `sp_map` by the `sptypes` field of said `IRCode`. The keys should generally be obtained from the `Method` from which the `IRCode` is derived. See `Mooncake.normalise!` for more details. """ -function foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState}) +function foreigncall_to_call(inst, sp_map::Dict{Symbol,CC.VarState}) if Meta.isexpr(inst, :foreigncall) # See Julia's AST devdocs for info on `:foreigncall` expressions. args = inst.args name = __extract_foreigncall_name(args[1]) RT = Val(interpolate_sparams(args[2], sp_map)) - AT = (map(x -> Val(interpolate_sparams(x, sp_map)), args[3])..., ) + AT = (map(x -> Val(interpolate_sparams(x, sp_map)), args[3])...,) nreq = Val(args[4]) calling_convention = Val(args[5] isa QuoteNode ? args[5].value : args[5]) x = args[6:end] @@ -113,7 +113,7 @@ end # Copied from Umlaut.jl. Originally, adapted from # https://github.com/JuliaDebug/JuliaInterpreter.jl/blob/aefaa300746b95b75f99d944a61a07a8cb145ef3/src/optimize.jl#L239 -function interpolate_sparams(@nospecialize(t::Type), sparams::Dict{Symbol, CC.VarState}) +function interpolate_sparams(@nospecialize(t::Type), sparams::Dict{Symbol,CC.VarState}) t isa Core.TypeofBottom && return t while t isa UnionAll t = t.body @@ -192,15 +192,20 @@ Does the same for... function lift_getfield_and_others(inst) Meta.isexpr(inst, :call) || return inst f = __get_arg(inst.args[1]) - if f === getfield && length(inst.args) == 3 && inst.args[3] isa Union{QuoteNode, Int} + if f === getfield && length(inst.args) == 3 && inst.args[3] isa Union{QuoteNode,Int} field = inst.args[3] new_field = field isa Int ? Val(field) : Val(field.value) return Expr(:call, lgetfield, inst.args[2], new_field) - elseif f === getfield && length(inst.args) == 4 && inst.args[3] isa Union{QuoteNode, Int} && inst.args[4] isa Bool + elseif f === getfield && + length(inst.args) == 4 && + inst.args[3] isa Union{QuoteNode,Int} && + inst.args[4] isa Bool field = inst.args[3] new_field = field isa Int ? Val(field) : Val(field.value) return Expr(:call, lgetfield, inst.args[2], new_field, Val(inst.args[4])) - elseif f === setfield! && length(inst.args) == 4 && inst.args[3] isa Union{QuoteNode, Int} + elseif f === setfield! && + length(inst.args) == 4 && + inst.args[3] isa Union{QuoteNode,Int} name = inst.args[3] new_name = name isa Int ? Val(name) : Val(name.value) return Expr(:call, lsetfield!, inst.args[2], new_name, inst.args[4]) @@ -215,46 +220,48 @@ __get_arg(x) = x # memoryrefget and memoryrefset! were introduced in 1.11. @static if VERSION >= v"1.11-" - -""" - lift_memoryrefget_and_memoryrefset_builtins(inst) - -Replaces memoryrefget -> lmemoryrefget and memoryrefset! -> lmemoryrefset! if their final -two arguments (`ordering` and `boundscheck`) are constants. See [`lmemoryrefget`] and -[`lmemoryrefset!`](@ref) for more context. -""" -function lift_memoryrefget_and_memoryrefset_builtins(inst) - Meta.isexpr(inst, :call) || return inst - f = __get_arg(inst.args[1]) - if f == Core.memoryrefget && length(inst.args) == 4 - ordering = inst.args[3] - boundscheck = inst.args[4] - if ordering isa QuoteNode && boundscheck isa Bool - new_ordering = Val(ordering.value) - return Expr(:call, lmemoryrefget, inst.args[2], new_ordering, Val(boundscheck)) - else - return inst - end - elseif f == Core.memoryrefset! && length(inst.args) == 5 - ordering = inst.args[4] - boundscheck = inst.args[5] - if ordering isa QuoteNode && boundscheck isa Bool - new_ordering = Val(ordering.value) - bc = Val(boundscheck) - return Expr(:call, lmemoryrefset!, inst.args[2], inst.args[3], new_ordering, bc) + """ + lift_memoryrefget_and_memoryrefset_builtins(inst) + + Replaces memoryrefget -> lmemoryrefget and memoryrefset! -> lmemoryrefset! if their final + two arguments (`ordering` and `boundscheck`) are constants. See [`lmemoryrefget`] and + [`lmemoryrefset!`](@ref) for more context. + """ + function lift_memoryrefget_and_memoryrefset_builtins(inst) + Meta.isexpr(inst, :call) || return inst + f = __get_arg(inst.args[1]) + if f == Core.memoryrefget && length(inst.args) == 4 + ordering = inst.args[3] + boundscheck = inst.args[4] + if ordering isa QuoteNode && boundscheck isa Bool + new_ordering = Val(ordering.value) + return Expr( + :call, lmemoryrefget, inst.args[2], new_ordering, Val(boundscheck) + ) + else + return inst + end + elseif f == Core.memoryrefset! && length(inst.args) == 5 + ordering = inst.args[4] + boundscheck = inst.args[5] + if ordering isa QuoteNode && boundscheck isa Bool + new_ordering = Val(ordering.value) + bc = Val(boundscheck) + return Expr( + :call, lmemoryrefset!, inst.args[2], inst.args[3], new_ordering, bc + ) + else + return inst + end else return inst end - else - return inst end -end else -# memoryrefget and memoryrefset! do not exist before v1.11. -lift_memoryrefget_and_memoryrefset_builtins(inst) = inst - + # memoryrefget and memoryrefset! do not exist before v1.11. + lift_memoryrefget_and_memoryrefset_builtins(inst) = inst end """ @@ -265,7 +272,7 @@ until the pullback that it returns is run. """ @inline gc_preserve(xs...) = nothing -@is_primitive MinimalCtx Tuple{typeof(gc_preserve), Vararg{Any, N}} where {N} +@is_primitive MinimalCtx Tuple{typeof(gc_preserve),Vararg{Any,N}} where {N} function rrule!!(f::CoDual{typeof(gc_preserve)}, xs::CoDual...) pb = NoPullback(f, xs...) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 56883b76d..d4a518bb5 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -127,7 +127,7 @@ end # already have one with the right argument types. Credit to @oxinabox: # https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 function __get_toplevel_mi_from_ir(ir, _module::Module) - mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()); + mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()) mi.specTypes = Tuple{map(_type, ir.argtypes)...} mi.def = _module return mi @@ -136,7 +136,7 @@ end # Run type inference and constant propagation on the ir. Credit to @oxinabox: # https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance) - method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing) + method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=# min_world = world = get_inference_world(interp) max_world = Base.get_world_counter() irsv = CC.IRInterpretationState( @@ -184,7 +184,7 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) inline_state = CC.InliningState(local_interp) CC.verify_ir(ir) if do_inline - ir = CC.ssa_inlining_pass!(ir, inline_state, #=propagate_inbounds=#true) + ir = CC.ssa_inlining_pass!(ir, inline_state, true) #=propagate_inbounds=# ir = CC.compact!(ir) end ir = __strip_coverage!(ir) @@ -222,7 +222,9 @@ there is no code found, or if more than one `IRCode` instance returned. Returns a tuple containing the `IRCode` and its return type. """ -function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_until=nothing) +function lookup_ir( + interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_until=nothing +) matches = CC.findall(tt, CC.method_table(interp)) asts = [] for match in get_matches(matches.matches) @@ -242,7 +244,8 @@ function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_u end end if isempty(asts) - msg = "No methods found for signature: $tt.\n" * + msg = + "No methods found for signature: $tt.\n" * "\n" * "This is often caused by accidentally trying to get Mooncake.jl to " * "differentiate a call (directly or indirectly) which does not exist. For " * @@ -260,7 +263,9 @@ function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_u return only(asts) end -function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance; optimize_until=nothing) +function lookup_ir( + interp::CC.AbstractInterpreter, mi::Core.MethodInstance; optimize_until=nothing +) return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, optimize_until) end @@ -320,7 +325,7 @@ Replace all uses of `def` with `val` in the single statement `stmt`. Note: this function is highly incomplete, really only working correctly for a specific function in `ir_normalisation.jl`. You probably do not want to use it. """ -function replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val) +function replace_uses_with!(stmt, def::Union{Argument,SSAValue}, val) if stmt isa Expr stmt.args = Any[arg == def ? val : arg for arg in stmt.args] return stmt diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 1896981ce..7fcba7d1f 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -11,8 +11,8 @@ is passed to an `OpaqueClosure`, and extracting this data into registers associa corresponding `ID`s. """ struct SharedDataPairs - pairs::Vector{Tuple{ID, Any}} - SharedDataPairs() = new(Tuple{ID, Any}[]) + pairs::Vector{Tuple{ID,Any}} + SharedDataPairs() = new(Tuple{ID,Any}[]) end """ @@ -126,12 +126,12 @@ struct ADInfo block_stack::BlockStack entry_id::ID shared_data_pairs::SharedDataPairs - arg_types::Dict{Argument, Any} - ssa_insts::Dict{ID, NewInstruction} - arg_rdata_ref_ids::Dict{Argument, ID} - ssa_rdata_ref_ids::Dict{ID, ID} + arg_types::Dict{Argument,Any} + ssa_insts::Dict{ID,NewInstruction} + arg_rdata_ref_ids::Dict{Argument,ID} + ssa_rdata_ref_ids::Dict{ID,ID} debug_mode::Bool - is_used_dict::Dict{ID, Bool} + is_used_dict::Dict{ID,Bool} lazy_zero_rdata_ref_id::ID end @@ -139,9 +139,9 @@ end # See the definition of the ADInfo struct for info on the arguments. function ADInfo( interp::MooncakeInterpreter, - arg_types::Dict{Argument, Any}, - ssa_insts::Dict{ID, NewInstruction}, - is_used_dict::Dict{ID, Bool}, + arg_types::Dict{Argument,Any}, + ssa_insts::Dict{ID,NewInstruction}, + is_used_dict::Dict{ID,Bool}, debug_mode::Bool, zero_lazy_rdata_ref::Ref{<:Tuple}, ) @@ -166,14 +166,16 @@ end # The constructor you should use for ADInfo if you _do_ have a BBCode lying around. See the # ADInfo struct for information regarding `interp` and `debug_mode`. function ADInfo(interp::MooncakeInterpreter, ir::BBCode, debug_mode::Bool) - arg_types = Dict{Argument, Any}( + arg_types = Dict{Argument,Any}( map(((n, t),) -> (Argument(n) => _type(t)), enumerate(ir.argtypes)) ) stmts = collect_stmts(ir) - ssa_insts = Dict{ID, NewInstruction}(stmts) + ssa_insts = Dict{ID,NewInstruction}(stmts) is_used_dict = characterise_used_ids(stmts) zero_lazy_rdata_ref = Ref{Tuple{map(lazy_zero_rdata_type ∘ _type, ir.argtypes)...}}() - return ADInfo(interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref) + return ADInfo( + interp, arg_types, ssa_insts, is_used_dict, debug_mode, zero_lazy_rdata_ref + ) end """ @@ -190,7 +192,7 @@ Returns `x` if it is a singleton, or the `ID` of the ssa which will contain it o forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR. """ -function add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x) +function add_data_if_not_singleton!(p::Union{ADInfo,SharedDataPairs}, x) return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x) end @@ -215,7 +217,7 @@ function get_primal_type(::ADInfo, x::GlobalRef) end function get_primal_type(::ADInfo, x::Expr) x.head === :boundscheck && return Bool - error("Unrecognised expression $x found in argument slot.") + return error("Unrecognised expression $x found in argument slot.") end """ @@ -283,14 +285,14 @@ end _copy(x::P) where {P<:RRuleZeroWrapper} = P(_copy(x.rule)) -struct RRuleWrapperPb{Tpb!!, Tl} +struct RRuleWrapperPb{Tpb!!,Tl} pb!!::Tpb!! l::Tl end (rule::RRuleWrapperPb)(dy) = rule.pb!!(increment!!(dy, instantiate(rule.l))) -@inline function (rule::RRuleZeroWrapper{R})(f::F, args::Vararg{CoDual, N}) where {R, F, N} +@inline function (rule::RRuleZeroWrapper{R})(f::F, args::Vararg{CoDual,N}) where {R,F,N} y, pb!! = rule.rule(f, args...) l = lazy_zero_rdata(primal(y)) return y::CoDual, (pb!! isa NoPullback ? pb!! : RRuleWrapperPb(pb!!, l)) @@ -311,7 +313,7 @@ Data structure which contains the result of `make_ad_stmts!`. Fields are """ struct ADStmtInfo line::ID - comms_id::Union{ID, Nothing} + comms_id::Union{ID,Nothing} fwds::Vector{IDInstPair} rvs::Vector{IDInstPair} end @@ -322,7 +324,7 @@ end Convenient constructor for `ADStmtInfo`. If either `fwds` or `rvs` is not a vector, `__vec` promotes it to a single-element `Vector`. """ -function ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs) +function ad_stmt_info(line::ID, comms_id::Union{ID,Nothing}, fwds, rvs) if !(comms_id === nothing || in(comms_id, map(first, __vec(line, fwds)))) throw(ArgumentError("comms_id not found in IDs of `fwds` instructions.")) end @@ -331,7 +333,7 @@ end __vec(line::ID, x::Any) = __vec(line, new_inst(x)) __vec(line::ID, x::NewInstruction) = IDInstPair[(line, x)] -__vec(line::ID, x::Vector{Tuple{ID, Any}}) = throw(error("boooo")) +__vec(line::ID, x::Vector{Tuple{ID,Any}}) = throw(error("boooo")) __vec(line::ID, x::Vector{IDInstPair}) = x """ @@ -557,7 +559,11 @@ end inc_or_const(stmt, info::ADInfo) = is_active(stmt) ? __inc(stmt) : const_codual(stmt, info) function inc_or_const_stmt(stmt, info::ADInfo) - return is_active(stmt) ? Expr(:call, identity, __inc(stmt)) : const_codual_stmt(stmt, info) + return if is_active(stmt) + Expr(:call, identity, __inc(stmt)) + else + const_codual_stmt(stmt, info) + end end """ @@ -574,19 +580,19 @@ get_const_primal_value(x) = x # Mooncake does not yet handle `PhiCNode`s. Throw an error if one is encountered. function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo) - unhandled_feature("Encountered PhiCNode: $stmt") + return unhandled_feature("Encountered PhiCNode: $stmt") end # Mooncake does not yet handle `UpsilonNode`s. Throw an error if one is encountered. function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo) - unhandled_feature( + return unhandled_feature( "Encountered UpsilonNode: $stmt. These are generated as part of some try / catch " * "/ finally blocks. At the present time, Mooncake.jl cannot differentiate through " * "these, so they must be avoided. Strategies for resolving this error include " * "re-writing code such that it avoids generating any UpsilonNodes, or writing a " * "rule to differentiate the code by hand. If you are in any doubt as to what to " * "do, please request assistance by opening an issue at " * - "github.com/compintell/Mooncake.jl." + "github.com/compintell/Mooncake.jl.", ) end @@ -597,7 +603,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) if Meta.isexpr(stmt, :call) || is_invoke # Find the types of all arguments to this call / invoke. - args = ((is_invoke ? stmt.args[2:end] : stmt.args)..., ) + args = ((is_invoke ? stmt.args[2:end] : stmt.args)...,) arg_types = map(arg -> get_primal_type(info, arg), args) # Special case: if the result of a call to getfield is un-used, then leave the @@ -748,7 +754,8 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) return ad_stmt_info(line, nothing, stmt, nothing) elseif stmt.head == :(=) && stmt.args[1] isa GlobalRef - msg = "Encountered assignment to global variable: $(stmt.args[1]). " * + msg = + "Encountered assignment to global variable: $(stmt.args[1]). " * "Cannot differentiate through assignments to globals. " * "Please refactor your code to avoid assigning to a global, for example by " * "passing the variable in to the function as an argument." @@ -759,7 +766,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) end end -is_active(::Union{Argument, ID}) = true +is_active(::Union{Argument,ID}) = true is_active(::Any) = false """ @@ -768,16 +775,16 @@ is_active(::Any) = false Get a bound on the pullback type, given a rule and associated primal types. """ function pullback_type(Trule, arg_types) - T = Core.Compiler.return_type(Tuple{Trule, map(fcodual_type, arg_types)...}) + T = Core.Compiler.return_type(Tuple{Trule,map(fcodual_type, arg_types)...}) return T <: Tuple ? _pullback_type(T) : Any end _pullback_type(::Core.TypeofBottom) = Any _pullback_type(T::DataType) = T.parameters[2] -_pullback_type(T::Union) = Union{_pullback_type(T.a), _pullback_type(T.b)} +_pullback_type(T::Union) = Union{_pullback_type(T.a),_pullback_type(T.b)} # Used by the getfield special-case in call / invoke statments. -@inline function __fwds_pass_no_ad!(f::F, raw_args::Vararg{Any, N}) where {F, N} +@inline function __fwds_pass_no_ad!(f::F, raw_args::Vararg{Any,N}) where {F,N} return tuple_splat(__get_primal(f), tuple_map(__get_primal, raw_args)) end @@ -806,8 +813,8 @@ end @inline increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t)) @inline increment_ref!(::Base.RefValue{NoRData}, t) = nothing -@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P, R} - r[] = zero_like_rdata_from_type(P) +@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P,R} + return r[] = zero_like_rdata_from_type(P) end @inline set_ret_ref_to_zero!!(::Type{P}, r::Base.RefValue{NoRData}) where {P} = nothing @@ -816,7 +823,7 @@ end # between differing varargs conventions. # -struct Pullback{Tprimal, Tpb_oc, Tisva<:Val, Tnvargs<:Val} +struct Pullback{Tprimal,Tpb_oc,Tisva<:Val,Tnvargs<:Val} pb_oc::Tpb_oc isva::Tisva nvargs::Tnvargs @@ -824,21 +831,21 @@ end function Pullback( Tprimal, pb_oc::Tpb_oc, isva::Tisva, nvargs::Tnvargs -) where {Tpb_oc, Tisva, Tnvargs} - return Pullback{Tprimal, Tpb_oc, Tisva, Tnvargs}(pb_oc, isva, nvargs) +) where {Tpb_oc,Tisva,Tnvargs} + return Pullback{Tprimal,Tpb_oc,Tisva,Tnvargs}(pb_oc, isva, nvargs) end @inline (pb::Pullback)(dy) = __flatten_varargs(pb.isva, pb.pb_oc[].oc(dy), pb.nvargs) -struct DerivedRule{Tprimal, Tfwds_oc, Tpb, Tisva<:Val, Tnargs<:Val} +struct DerivedRule{Tprimal,Tfwds_oc,Tpb,Tisva<:Val,Tnargs<:Val} fwds_oc::Tfwds_oc pb::Tpb isva::Tisva nargs::Tnargs end -function DerivedRule(Tprimal, fwds_oc::T, pb::U, isva::V, nargs::W) where {T, U, V, W} - return DerivedRule{Tprimal, T, U, V, W}(fwds_oc, pb, isva, nargs) +function DerivedRule(Tprimal, fwds_oc::T, pb::U, isva::V, nargs::W) where {T,U,V,W} + return DerivedRule{Tprimal,T,U,V,W}(fwds_oc, pb, isva, nargs) end # Extends functionality defined for debug_mode. @@ -870,7 +877,7 @@ _copy(x::Type) = x _copy(x) = copy(x) -@inline function (fwds::DerivedRule{P, Q, S})(args::Vararg{CoDual, N}) where {P, Q, S, N} +@inline function (fwds::DerivedRule{P,Q,S})(args::Vararg{CoDual,N}) where {P,Q,S,N} uf_args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs) return fwds.fwds_oc.oc(uf_args...)::CoDual, fwds.pb end @@ -880,10 +887,10 @@ end If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0). """ -function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs} +function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva,nvargs} isva || return args last_el = isa(args[end], NoRData) ? ntuple(n -> NoRData(), nvargs) : args[end] - return (args[1:end-1]..., last_el...) + return (args[1:(end - 1)]..., last_el...) end """ @@ -892,7 +899,7 @@ end If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))` are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`. """ -function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} +function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva,nargs} isva || return args group_primal = map(primal, args[nargs:end]) if fdata_type(tangent_type(_typeof(group_primal))) == NoFData @@ -900,7 +907,7 @@ function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva else grouped_args = CoDual(group_primal, map(tangent, args[nargs:end])) end - return (args[1:nargs-1]..., grouped_args) + return (args[1:(nargs - 1)]..., grouped_args) end # @@ -910,7 +917,7 @@ end _is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes) _is_primitive(C::Type, sig::Type) = is_primitive(C, sig) -const RuleMC{A, R} = MistyClosure{OpaqueClosure{A, R}} +const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}} """ rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} @@ -920,7 +927,6 @@ important for performance in dynamic dispatch, and to ensure that recursion work properly. """ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C} - if _is_primitive(C, sig_or_mi) return debug_mode ? DebugRRule{typeof(rrule!!)} : typeof(rrule!!) end @@ -934,23 +940,23 @@ function rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where arg_fwds_types = Tuple{map(fcodual_type, arg_types)...} arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...} rvs_return_type = rdata_type(tangent_type(Treturn)) - pb_oc_type = MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}} - pb_type = Pullback{sig, Base.RefValue{pb_oc_type}, Val{isva}, nvargs(isva, sig)} + pb_oc_type = MistyClosure{OpaqueClosure{Tuple{rvs_return_type},arg_rvs_types}} + pb_type = Pullback{sig,Base.RefValue{pb_oc_type},Val{isva},nvargs(isva, sig)} nargs = Val{length(ir.argtypes)} if isconcretetype(Treturn) Tderived_rule = DerivedRule{ - sig, RuleMC{arg_fwds_types, fcodual_type(Treturn)}, pb_type, Val{isva}, nargs, + sig,RuleMC{arg_fwds_types,fcodual_type(Treturn)},pb_type,Val{isva},nargs } return debug_mode ? DebugRRule{Tderived_rule} : Tderived_rule else if debug_mode - return DebugRRule{DerivedRule{ - sig, RuleMC{arg_fwds_types, P}, pb_type, Val{isva}, nargs, - }} where {P<:fcodual_type(Treturn)} + return DebugRRule{ + DerivedRule{sig,RuleMC{arg_fwds_types,P},pb_type,Val{isva},nargs} + } where {P<:fcodual_type(Treturn)} else return DerivedRule{ - sig, RuleMC{arg_fwds_types, P}, pb_type, Val{isva}, nargs, + sig,RuleMC{arg_fwds_types,P},pb_type,Val{isva},nargs } where {P<:fcodual_type(Treturn)} end end @@ -965,18 +971,22 @@ struct MooncakeRuleCompilationError <: Exception end function Base.showerror(io::IO, err::MooncakeRuleCompilationError) - msg = "MooncakeRuleCompilationError: an error occured while Mooncake was compiling a " * + msg = + "MooncakeRuleCompilationError: an error occured while Mooncake was compiling a " * "rule to differentiate something. If the `caused by` error " * "message below does not make it clear to you how the problem can be fixed, " * "please open an issue at github.com/compintell/Mooncake.jl describing your " * "problem.\n" * "To replicate this error run the following:\n" println(io, msg) - println(io, "Mooncake.build_rrule(Mooncake.$(err.interp), $(err.sig); debug_mode=$(err.debug_mode))") println( + io, + "Mooncake.build_rrule(Mooncake.$(err.interp), $(err.sig); debug_mode=$(err.debug_mode))", + ) + return println( io, "\nNote that you may need to `using` some additional packages if not all of the " * - "names printed in the above signature are available currently in your environment." + "names printed in the above signature are available currently in your environment.", ) end @@ -1023,10 +1033,12 @@ function build_rrule( # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. if Base.get_world_counter() > interp.world - throw(ArgumentError( - "World age associated to interp is behind current world age. Please " * - "a new interpreter for the current world age." - )) + throw( + ArgumentError( + "World age associated to interp is behind current world age. Please " * + "a new interpreter for the current world age.", + ), + ) end # If we're compiling in debug mode, let the user know by default. @@ -1055,7 +1067,9 @@ function build_rrule( sig = sig_or_mi isa Core.MethodInstance ? sig_or_mi.specTypes : sig_or_mi nargs = num_args(dri.info) if dri.isva - sig = Tuple{sig.parameters[1:nargs-1]..., Tuple{sig.parameters[nargs:end]...}} + sig = Tuple{ + sig.parameters[1:(nargs - 1)]...,Tuple{sig.parameters[nargs:end]...} + } end pb = Pullback(sig, Ref(rvs_oc), Val(dri.isva), nvargs(dri.isva, sig)()) @@ -1152,7 +1166,7 @@ function replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} return Tmc(replace_captures(mc.oc, new_captures), mc.ir) end -const ADStmts = Vector{Tuple{ID, Vector{ADStmtInfo}}} +const ADStmts = Vector{Tuple{ID,Vector{ADStmtInfo}}} """ create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo) @@ -1223,7 +1237,6 @@ Produce the IR associated to the `OpaqueClosure` which runs most of the forwards function forwards_pass_ir( ir::BBCode, ad_stmts_blocks::ADStmts, fwds_comms_insts, info::ADInfo, Tshared_data ) - is_unique_pred, pred_is_unique_pred = characterise_unique_predecessor_blocks(ir.blocks) # Insert a block at the start which extracts all items from the captures field of the @@ -1295,8 +1308,8 @@ straightforward to figure out much time is spent pushing to the block stack when @inline __push_blk_stack!(block_stack::BlockStack, id::Int32) = push!(block_stack, id) @inline function __assemble_lazy_zero_rdata( - r::Ref{T}, args::Vararg{CoDual, N} -) where {T<:Tuple, N} + r::Ref{T}, args::Vararg{CoDual,N} +) where {T<:Tuple,N} r[] = __make_tuples(T, args) return nothing end @@ -1542,7 +1555,9 @@ on. The same ideas apply if `pred_id` were `#3`. The block would end with `#3`, and there would be two `increment_ref!` calls because both `%5` and `_2` are not constants. """ -function rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo) +function rvs_phi_block( + pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo +) @assert length(rdata_ids) == length(values) inc_stmts = map(rdata_ids, values) do id, val stmt = Expr(:call, increment_if_ref!, get_rev_data_id(info, val), id) @@ -1595,12 +1610,12 @@ function make_switch_stmts( end # Compare predecessor from primal with all possible predecessors. - conds = map(pred_ids[1:end-1]) do id + conds = map(pred_ids[1:(end - 1)]) do id return (ID(), new_inst(Expr(:call, __switch_case, id.id, prev_blk_id))) end # Switch statement to change to the predecessor. - switch_stmt = Switch(Any[c[1] for c in conds], target_ids[1:end-1], target_ids[end]) + switch_stmt = Switch(Any[c[1] for c in conds], target_ids[1:(end - 1)], target_ids[end]) switch = (ID(), new_inst(switch_stmt)) return vcat((prev_blk_id, prev_blk), conds, switch) @@ -1640,11 +1655,11 @@ struct DynamicDerivedRule{V} debug_mode::Bool end -DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any, Any}(), debug_mode) +DynamicDerivedRule(debug_mode::Bool) = DynamicDerivedRule(Dict{Any,Any}(), debug_mode) -_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any, Any}(), x.debug_mode) +_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any,Any}(), x.debug_mode) -function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} +function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any,N}) where {N} sig = Tuple{map(_typeof ∘ primal, args)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing @@ -1672,24 +1687,24 @@ reason to keep this around is for debugging -- it is very helpful to have this t in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit. """ -mutable struct LazyDerivedRule{primal_sig, Trule} +mutable struct LazyDerivedRule{primal_sig,Trule} debug_mode::Bool mi::Core.MethodInstance rule::Trule function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool) interp = get_interpreter() - return new{mi.specTypes, rule_type(interp, mi; debug_mode)}(debug_mode, mi) + return new{mi.specTypes,rule_type(interp, mi; debug_mode)}(debug_mode, mi) end - function LazyDerivedRule{Tprimal_sig, Trule}( + function LazyDerivedRule{Tprimal_sig,Trule}( mi::Core.MethodInstance, debug_mode::Bool - ) where {Tprimal_sig, Trule} - return new{Tprimal_sig, Trule}(debug_mode, mi) + ) where {Tprimal_sig,Trule} + return new{Tprimal_sig,Trule}(debug_mode, mi) end end _copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode) -@inline function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N} +@inline function (rule::LazyDerivedRule)(args::Vararg{Any,N}) where {N} return isdefined(rule, :rule) ? rule.rule(args...) : _build_rule!(rule, args) end @@ -1712,20 +1727,21 @@ function Base.showerror(io::IO, err::BadRuleTypeException) println(io, "This error occured for $(err.mi) with signature:") println(io, err.sig) println(io) - msg = "Usually this error is indicative of something having gone wrong in the " * + msg = + "Usually this error is indicative of something having gone wrong in the " * "compilation of the rule in question. Look at the error message for the error " * "which caused this error (below) for more details. If the error below does not " * "immediately give you enough information to debug what is going on, consider " * "building the rule for the signature above, and inspecting the IR." - println(io, msg) + return println(io, msg) end -_rtype(::Type{<:DebugRRule}) = Tuple{CoDual, DebugPullback} +_rtype(::Type{<:DebugRRule}) = Tuple{CoDual,DebugPullback} _rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc)) -_rtype(::Type{<:OpaqueClosure{<:Any, <:R}}) where {R} = (@isdefined R) ? R : CoDual -_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)), fieldtype(T, :pb)} +_rtype(::Type{<:OpaqueClosure{<:Any,<:R}}) where {R} = (@isdefined R) ? R : CoDual +_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)),fieldtype(T, :pb)} -@noinline function _build_rule!(rule::LazyDerivedRule{sig, Trule}, args) where {sig, Trule} +@noinline function _build_rule!(rule::LazyDerivedRule{sig,Trule}, args) where {sig,Trule} derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode) if derived_rule isa Trule rule.rule = derived_rule diff --git a/src/interpreter/zero_like_rdata.jl b/src/interpreter/zero_like_rdata.jl index 38a990cf2..882f5d9f1 100644 --- a/src/interpreter/zero_like_rdata.jl +++ b/src/interpreter/zero_like_rdata.jl @@ -21,7 +21,7 @@ of `R` and `ZeroRData` if an instance of `P` is needed. """ function zero_like_rdata_type(::Type{P}) where {P} R = rdata_type(tangent_type(P)) - return can_produce_zero_rdata_from_type(P) ? R : Union{R, ZeroRData} + return can_produce_zero_rdata_from_type(P) ? R : Union{R,ZeroRData} end """ diff --git a/src/rrules/array_legacy.jl b/src/rrules/array_legacy.jl index 30b700b4a..010733792 100644 --- a/src/rrules/array_legacy.jl +++ b/src/rrules/array_legacy.jl @@ -1,27 +1,31 @@ -@inline function zero_tangent_internal(x::Array{P, N}, stackdict::IdDict) where {P, N} +@inline function zero_tangent_internal(x::Array{P,N}, stackdict::IdDict) where {P,N} haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x)) - zt = Array{tangent_type(P), N}(undef, size(x)...) + zt = Array{tangent_type(P),N}(undef, size(x)...) stackdict[x] = zt - return _map_if_assigned!(Base.Fix2(zero_tangent_internal, stackdict), zt, x)::Array{tangent_type(P), N} + return _map_if_assigned!( + Base.Fix2(zero_tangent_internal, stackdict), zt, x + )::Array{tangent_type(P),N} end -function randn_tangent_internal(rng::AbstractRNG, x::Array{T, N}, stackdict::IdDict) where {T, N} +function randn_tangent_internal( + rng::AbstractRNG, x::Array{T,N}, stackdict::IdDict +) where {T,N} haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x)) - dx = Array{tangent_type(T), N}(undef, size(x)...) + dx = Array{tangent_type(T),N}(undef, size(x)...) stackdict[x] = dx return _map_if_assigned!(x -> randn_tangent_internal(rng, x, stackdict), dx, x) end -function increment!!(x::T, y::T) where {P, N, T<:Array{P, N}} +function increment!!(x::T, y::T) where {P,N,T<:Array{P,N}} return x === y ? x : _map_if_assigned!(increment!!, x, x, y) end set_to_zero!!(x::Array) = _map_if_assigned!(set_to_zero!!, x, x) -function _scale(a::Float64, t::Array{T, N}) where {T, N} - t′ = Array{T, N}(undef, size(t)...) +function _scale(a::Float64, t::Array{T,N}) where {T,N} + t′ = Array{T,N}(undef, size(t)...) return _map_if_assigned!(Base.Fix1(_scale, a), t′, t) end @@ -35,23 +39,23 @@ function _dot(t::T, s::T) where {T<:Array} ) end -function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}, unsafe::Bool) where {P, N} - x′ = Array{P, N}(undef, size(x)...) +function _add_to_primal(x::Array{P,N}, t::Array{<:Any,N}, unsafe::Bool) where {P,N} + x′ = Array{P,N}(undef, size(x)...) return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t) end -function _diff(p::P, q::P) where {V, N, P<:Array{V, N}} - t = Array{tangent_type(V), N}(undef, size(p)) +function _diff(p::P, q::P) where {V,N,P<:Array{V,N}} + t = Array{tangent_type(V),N}(undef, size(p)) return _map_if_assigned!(_diff, t, p, q) end -@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Vararg} where {T, N} -@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Tuple{}} where {T, N} -@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), NTuple{N}} where {T, N} +@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Vararg} where {T,N} +@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),Tuple{}} where {T,N} +@zero_adjoint MinimalCtx Tuple{Type{<:Array{T,N}},typeof(undef),NTuple{N}} where {T,N} -@is_primitive MinimalCtx Tuple{typeof(Base._deletebeg!), Vector, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base._deletebeg!),Vector,Integer} function rrule!!( - ::CoDual{typeof(Base._deletebeg!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer}, + ::CoDual{typeof(Base._deletebeg!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer} ) delta = primal(_delta) a = primal(_a) @@ -71,7 +75,7 @@ function rrule!!( return zero_fcodual(nothing), _deletebeg!_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base._deleteend!), Vector, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base._deleteend!),Vector,Integer} function rrule!!( ::CoDual{typeof(Base._deleteend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer} ) @@ -81,27 +85,26 @@ function rrule!!( delta = primal(_delta) # Store the section to be cut for later. - primal_tail = a[end-delta+1:end] - tangent_tail = da[end-delta+1:end] + primal_tail = a[(end - delta + 1):end] + tangent_tail = da[(end - delta + 1):end] # Cut the end off the primal and tangent. Base._deleteend!(a, delta) Base._deleteend!(da, delta) function _deleteend!_pb!!(::NoRData) - Base._growend!(a, delta) - a[end-delta+1:end] .= primal_tail + a[(end - delta + 1):end] .= primal_tail Base._growend!(da, delta) - da[end-delta+1:end] .= tangent_tail + da[(end - delta + 1):end] .= tangent_tail return NoRData(), NoRData(), NoRData() end return zero_fcodual(nothing), _deleteend!_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base._deleteat!), Vector, Integer, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base._deleteat!),Vector,Integer,Integer} function rrule!!( ::CoDual{typeof(Base._deleteat!)}, _a::CoDual{<:Vector}, @@ -113,25 +116,25 @@ function rrule!!( da = tangent(_a) # Store the cut section for later. - primal_mem = a[i:i+delta-1] - tangent_mem = da[i:i+delta-1] + primal_mem = a[i:(i + delta - 1)] + tangent_mem = da[i:(i + delta - 1)] # Run the primal. Base._deleteat!(a, i, delta) Base._deleteat!(da, i, delta) function _deleteat!_pb!!(::NoRData) - splice!(a, i:i-1, primal_mem) - splice!(da, i:i-1, tangent_mem) + splice!(a, i:(i - 1), primal_mem) + splice!(da, i:(i - 1), tangent_mem) return NoRData(), NoRData(), NoRData(), NoRData() end return zero_fcodual(nothing), _deleteat!_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base._growbeg!), Vector, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base._growbeg!),Vector,Integer} function rrule!!( - ::CoDual{typeof(Base._growbeg!)}, _a::CoDual{<:Vector{T}}, _delta::CoDual{<:Integer}, + ::CoDual{typeof(Base._growbeg!)}, _a::CoDual{<:Vector{T}}, _delta::CoDual{<:Integer} ) where {T} d = primal(_delta) a = primal(_a) @@ -146,9 +149,9 @@ function rrule!!( return zero_fcodual(nothing), _growbeg!_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base._growend!), Vector, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base._growend!),Vector,Integer} function rrule!!( - ::CoDual{typeof(Base._growend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer}, + ::CoDual{typeof(Base._growend!)}, _a::CoDual{<:Vector}, _delta::CoDual{<:Integer} ) d = primal(_delta) a = primal(_a) @@ -163,7 +166,7 @@ function rrule!!( return zero_fcodual(nothing), _growend!_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(Base._growat!), Vector, Integer, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base._growat!),Vector,Integer,Integer} function rrule!!( ::CoDual{typeof(Base._growat!)}, _a::CoDual{<:Vector}, @@ -179,14 +182,14 @@ function rrule!!( Base._growat!(da, i, delta) function _growat!_pb!!(::NoRData) - deleteat!(a, i:i+delta-1) - deleteat!(da, i:i+delta-1) + deleteat!(a, i:(i + delta - 1)) + deleteat!(da, i:(i + delta - 1)) return NoRData(), NoRData(), NoRData(), NoRData() end return zero_fcodual(nothing), _growat!_pb!! end -@is_primitive MinimalCtx Tuple{typeof(sizehint!), Vector, Integer} +@is_primitive MinimalCtx Tuple{typeof(sizehint!),Vector,Integer} function rrule!!(f::CoDual{typeof(sizehint!)}, x::CoDual{<:Vector}, sz::CoDual{<:Integer}) sizehint!(primal(x), primal(sz)) sizehint!(tangent(x), primal(sz)) @@ -200,16 +203,18 @@ function rrule!!( ::CoDual{Tuple{Val{Any}}}, ::CoDual, # nreq ::CoDual, # calling convention - a::CoDual{<:Array{T}, <:Array{V}}, -) where {T, V} + a::CoDual{<:Array{T},<:Array{V}}, +) where {T,V} y = CoDual( - ccall(:jl_array_ptr, Ptr{T}, (Any, ), primal(a)), - ccall(:jl_array_ptr, Ptr{V}, (Any, ), tangent(a)), + ccall(:jl_array_ptr, Ptr{T}, (Any,), primal(a)), + ccall(:jl_array_ptr, Ptr{V}, (Any,), tangent(a)), ) return y, NoPullback(ntuple(_ -> NoRData(), 7)) end -@is_primitive MinimalCtx Tuple{typeof(unsafe_copyto!), Array{T}, Any, Array{T}, Any, Any} where {T} +@is_primitive MinimalCtx Tuple{ + typeof(unsafe_copyto!),Array{T},Any,Array{T},Any,Any +} where {T} function rrule!!( ::CoDual{typeof(unsafe_copyto!)}, dest::CoDual{<:Array{T}}, @@ -222,7 +227,7 @@ function rrule!!( # Record values that will be overwritten. _doffs = primal(doffs) - dest_idx = _doffs:_doffs + _n - 1 + dest_idx = _doffs:(_doffs + _n - 1) _soffs = primal(soffs) pdest = primal(dest) ddest = tangent(dest) @@ -237,7 +242,7 @@ function rrule!!( function unsafe_copyto_pb!!(::NoRData) # Increment dsrc. - src_idx = _soffs:_soffs + _n - 1 + src_idx = _soffs:(_soffs + _n - 1) dsrc[src_idx] .= increment!!.(view(dsrc, src_idx), view(ddest, dest_idx)) # Restore initial state. @@ -257,7 +262,7 @@ Base.@propagate_inbounds function rrule!!( ::CoDual{typeof(Core.arrayref)}, checkbounds::CoDual{Bool}, x::CoDual{<:Array}, - inds::Vararg{CoDual{Int}, N}, + inds::Vararg{CoDual{Int},N}, ) where {N} # Convert to linear indices to reduce amount of data required on the reverse-pass, to @@ -279,10 +284,10 @@ end function rrule!!( ::CoDual{typeof(Core.arrayset)}, inbounds::CoDual{Bool}, - A::CoDual{<:Array{P}, TdA}, + A::CoDual{<:Array{P},TdA}, v::CoDual, inds::CoDual{Int}..., -) where {P, V, TdA <: Array{V}} +) where {P,V,TdA<:Array{V}} _inbounds = primal(inbounds) _inds = map(primal, inds) @@ -291,7 +296,7 @@ function rrule!!( end to_save = isassigned(primal(A), _inds...) - old_A = Ref{Tuple{P, V}}() + old_A = Ref{Tuple{P,V}}() if to_save old_A[] = ( arrayref(_inbounds, primal(A), _inds...), @@ -314,8 +319,8 @@ function rrule!!( end function isbits_arrayset_rrule( - boundscheck, _inds, A::CoDual{<:Array{P}, TdA}, v::CoDual{P} -) where {P, V, TdA <: Array{V}} + boundscheck, _inds, A::CoDual{<:Array{P},TdA}, v::CoDual{P} +) where {P,V,TdA<:Array{V}} # Convert to linear indices lin_inds = LinearIndices(size(primal(A)))[_inds...] @@ -339,8 +344,8 @@ end function rrule!!(f::CoDual{typeof(Core.arraysize)}, X, dim) return zero_fcodual(Core.arraysize(primal(X), primal(dim))), NoPullback(f, X, dim) end - -@is_primitive MinimalCtx Tuple{typeof(copy), Array} + +@is_primitive MinimalCtx Tuple{typeof(copy),Array} function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) dx = tangent(a) dy = copy(dx) @@ -352,10 +357,10 @@ function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) return y, copy_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(fill!), Array{<:Union{UInt8, Int8}}, Integer} +@is_primitive MinimalCtx Tuple{typeof(fill!),Array{<:Union{UInt8,Int8}},Integer} function rrule!!( ::CoDual{typeof(fill!)}, a::CoDual{T}, x::CoDual{<:Integer} -) where {V<:Union{UInt8, Int8}, T<:Array{V}} +) where {V<:Union{UInt8,Int8},T<:Array{V}} pa = primal(a) old_value = copy(pa) fill!(pa, primal(x)) @@ -376,15 +381,15 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:array_legacy} test_cases = Any[ # Old foreigncall wrappers. - (true, :stability, nothing, Array{Float64, 0}, undef), - (true, :stability, nothing, Array{Float64, 1}, undef, 5), - (true, :stability, nothing, Array{Float64, 2}, undef, 5, 4), - (true, :stability, nothing, Array{Float64, 3}, undef, 5, 4, 3), - (true, :stability, nothing, Array{Float64, 4}, undef, 5, 4, 3, 2), - (true, :stability, nothing, Array{Float64, 5}, undef, 5, 4, 3, 2, 1), - (true, :stability, nothing, Array{Float64, 0}, undef, ()), - (true, :stability, nothing, Array{Float64, 4}, undef, (2, 3, 4, 5)), - (true, :stability, nothing, Array{Float64, 5}, undef, (2, 3, 4, 5, 6)), + (true, :stability, nothing, Array{Float64,0}, undef), + (true, :stability, nothing, Array{Float64,1}, undef, 5), + (true, :stability, nothing, Array{Float64,2}, undef, 5, 4), + (true, :stability, nothing, Array{Float64,3}, undef, 5, 4, 3), + (true, :stability, nothing, Array{Float64,4}, undef, 5, 4, 3, 2), + (true, :stability, nothing, Array{Float64,5}, undef, 5, 4, 3, 2, 1), + (true, :stability, nothing, Array{Float64,0}, undef, ()), + (true, :stability, nothing, Array{Float64,4}, undef, (2, 3, 4, 5)), + (true, :stability, nothing, Array{Float64,5}, undef, (2, 3, 4, 5, 6)), (false, :stability, nothing, copy, randn(5, 4)), (false, :stability, nothing, Base._deletebeg!, randn(5), 0), (false, :stability, nothing, Base._deletebeg!, randn(5), 2), @@ -403,19 +408,35 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:array_legacy} (false, :stability, nothing, sizehint!, randn(5), 10), (false, :stability, nothing, unsafe_copyto!, randn(4), 2, randn(3), 1, 2), ( - false, :stability, nothing, - unsafe_copyto!, [rand(3) for _ in 1:5], 2, [rand(4) for _ in 1:4], 1, 3, + false, + :stability, + nothing, + unsafe_copyto!, + [rand(3) for _ in 1:5], + 2, + [rand(4) for _ in 1:4], + 1, + 3, ), ( - false, :none, nothing, - unsafe_copyto!, Vector{Any}(undef, 5), 2, Any[rand() for _ in 1:4], 1, 3, + false, + :none, + nothing, + unsafe_copyto!, + Vector{Any}(undef, 5), + 2, + Any[rand() for _ in 1:4], + 1, + 3, ), ( - true, :none, nothing, + true, + :none, + nothing, _foreigncall_, Val(:jl_array_ptr), Val(Ptr{Float64}), - (Val(Any), ), + (Val(Any),), Val(0), # nreq Val(:ccall), # calling convention randn(5), @@ -436,9 +457,27 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:array_legacy} (false, :stability, nothing, Base.arrayset, false, randn(5, 4), 3.0, 1, 3), (false, :stability, nothing, Base.arrayset, true, randn(5), 4.0, 3), (false, :stability, nothing, Base.arrayset, true, randn(5, 4), 3.0, 1, 3), - (false, :stability, nothing, Base.arrayset, false, [randn(3) for _ in 1:5], randn(4), 1), + ( + false, + :stability, + nothing, + Base.arrayset, + false, + [randn(3) for _ in 1:5], + randn(4), + 1, + ), (false, :stability, nothing, Base.arrayset, false, _a, randn(4), 1), - (false, :stability, nothing, Base.arrayset, true, [(5.0, rand(1))], (4.0, rand(1)), 1), + ( + false, + :stability, + nothing, + Base.arrayset, + true, + [(5.0, rand(1))], + (4.0, rand(1)), + 1, + ), ( false, :stability, @@ -467,16 +506,16 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:array_legacy} end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:array_legacy}) - test_cases = Any[ - ( - false, :none, nothing, - Base._unsafe_copyto!, - fill!(Matrix{Real}(undef, 5, 4), 1.0), - 3, - randn(10), - 2, - 4, - ), - ] + test_cases = Any[( + false, + :none, + nothing, + Base._unsafe_copyto!, + fill!(Matrix{Real}(undef, 5, 4), 1.0), + 3, + randn(10), + 2, + 4, + ),] return test_cases, Any[] end diff --git a/src/rrules/avoiding_non_differentiable_code.jl b/src/rrules/avoiding_non_differentiable_code.jl index 519979390..e37455605 100644 --- a/src/rrules/avoiding_non_differentiable_code.jl +++ b/src/rrules/avoiding_non_differentiable_code.jl @@ -1,18 +1,18 @@ # Avoid troublesome bitcast magic -- we can't handle converting from pointer to UInt, # because we drop the gradient, because the tangent type of integers is NoTangent. # https://github.com/JuliaLang/julia/blob/9f9e989f241fad1ae03c3920c20a93d8017a5b8f/base/pointer.jl#L282 -@is_primitive MinimalCtx Tuple{typeof(Base.:(+)), Ptr, Integer} +@is_primitive MinimalCtx Tuple{typeof(Base.:(+)),Ptr,Integer} function rrule!!(f::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Integer}) return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback(f, x, y) end -@zero_adjoint MinimalCtx Tuple{typeof(randn), AbstractRNG, Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(string), Vararg} -@zero_adjoint MinimalCtx Tuple{Type{Symbol}, Vararg} -@zero_adjoint MinimalCtx Tuple{Type{Float64}, Any, RoundingMode} -@zero_adjoint MinimalCtx Tuple{Type{Float32}, Any, RoundingMode} -@zero_adjoint MinimalCtx Tuple{Type{Float16}, Any, RoundingMode} -@zero_adjoint MinimalCtx Tuple{typeof(==), Type, Type} +@zero_adjoint MinimalCtx Tuple{typeof(randn),AbstractRNG,Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(string),Vararg} +@zero_adjoint MinimalCtx Tuple{Type{Symbol},Vararg} +@zero_adjoint MinimalCtx Tuple{Type{Float64},Any,RoundingMode} +@zero_adjoint MinimalCtx Tuple{Type{Float32},Any,RoundingMode} +@zero_adjoint MinimalCtx Tuple{Type{Float16},Any,RoundingMode} +@zero_adjoint MinimalCtx Tuple{typeof(==),Type,Type} function generate_hand_written_rrule!!_test_cases( rng_ctor, ::Val{:avoiding_non_differentiable_code} @@ -21,17 +21,18 @@ function generate_hand_written_rrule!!_test_cases( _dx = Ref(4.0) test_cases = vcat( Any[ - # Rules to avoid pointer type conversions. - ( - true, :stability_and_allocs, nothing, - +, - CoDual( - bitcast(Ptr{Float64}, pointer_from_objref(_x)), - bitcast(Ptr{Float64}, pointer_from_objref(_dx)), - ), - 2, + # Rules to avoid pointer type conversions. + ( + true, + :stability_and_allocs, + nothing, + +, + CoDual( + bitcast(Ptr{Float64}, pointer_from_objref(_x)), + bitcast(Ptr{Float64}, pointer_from_objref(_dx)), ), - ], + 2, + ),], # Rules in order to avoid introducing determinism. reduce( @@ -63,7 +64,7 @@ function generate_hand_written_rrule!!_test_cases( end function generate_derived_rrule!!_test_cases( - rng_ctor, ::Val{:avoiding_non_differentiable_code}, + rng_ctor, ::Val{:avoiding_non_differentiable_code} ) return Any[], Any[] end diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index abcd8131f..cd2d96466 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -3,7 +3,7 @@ function blas_name(name::Symbol) end function wrap_ptr_as_view(ptr::Ptr{T}, N::Int, inc::Int) where {T} - return view(unsafe_wrap(Vector{T}, ptr, N * inc), 1:inc:N*inc) + return view(unsafe_wrap(Vector{T}, ptr, N * inc), 1:inc:(N * inc)) end function wrap_ptr_as_view(ptr::Ptr{T}, buffer_nrows::Int, nrows::Int, ncols::Int) where {T} @@ -21,7 +21,7 @@ function tri!(A, u::Char, d::Char) return u == 'L' ? tril!(A, d == 'U' ? -1 : 0) : triu!(A, d == 'U' ? 1 : 0) end -const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}} +const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2,Matrix{T}}} # # Utility @@ -29,8 +29,8 @@ const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}} @zero_adjoint MinimalCtx Tuple{typeof(BLAS.get_num_threads)} @zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} -@zero_adjoint MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} -@zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} +@zero_adjoint MinimalCtx Tuple{typeof(BLAS.set_num_threads),Union{Integer,Nothing}} +@zero_adjoint MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads),Any} # # LEVEL 1 @@ -49,13 +49,13 @@ for (fname, elty) in ((:cblas_ddot, :Float64), (:cblas_sdot, :Float32)) _incx::CoDual{BLAS.BlasInt}, _DY::CoDual{Ptr{$elty}}, _incy::CoDual{BLAS.BlasInt}, - args::Vararg{Any, N}, + args::Vararg{Any,N}, ) where {N} GC.@preserve args begin # Load in values from pointers. n, incx, incy = map(primal, (_n, _incx, _incy)) - xinds = 1:incx:incx * n - yinds = 1:incy:incy * n + xinds = 1:incx:(incx * n) + yinds = 1:incy:(incy * n) DX = view(unsafe_wrap(Vector{$elty}, primal(_DX), n * incx), xinds) DY = view(unsafe_wrap(Vector{$elty}, primal(_DY), n * incy), yinds) @@ -90,9 +90,8 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) DA::CoDual{Ptr{$elty}}, DX::CoDual{Ptr{$elty}}, incx::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, N}, + args::Vararg{Any,N}, ) where {N} - GC.@preserve args begin # Load in values from pointers, and turn pointers to memory buffers into Vectors. @@ -111,7 +110,6 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) end function dscal_pullback!!(::NoRData) - GC.@preserve args begin # Set primal to previous state. @@ -130,8 +128,6 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32)) end end - - # # LEVEL 2 # @@ -155,9 +151,8 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemv_, :Float32)) _beta::CoDual{Ptr{$elty}}, _y::CoDual{Ptr{$elty}}, _incy::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs} + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Load in data. @@ -170,8 +165,8 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemv_, :Float32)) A = wrap_ptr_as_view(primal(_A), lda, M, N) Nx = tA == 'N' ? N : M Ny = tA == 'N' ? M : N - x = view(unsafe_wrap(Vector{$elty}, primal(_x), incx * Nx), 1:incx:incx * Nx) - y = view(unsafe_wrap(Vector{$elty}, primal(_y), incy * Ny), 1:incy:incy * Ny) + x = view(unsafe_wrap(Vector{$elty}, primal(_x), incx * Nx), 1:incx:(incx * Nx)) + y = view(unsafe_wrap(Vector{$elty}, primal(_y), incy * Ny), 1:incy:(incy * Ny)) y_copy = copy(y) BLAS.gemv!(tA, alpha, A, x, beta, y) @@ -184,13 +179,12 @@ for (gemv, elty) in ((:dgemv_, :Float64), (:sgemv_, :Float32)) end function gemv_pb!!(::NoRData) - GC.@preserve args begin # Load up the tangents. dA = wrap_ptr_as_view(_dA, lda, M, N) - dx = view(unsafe_wrap(Vector{$elty}, _dx, incx * Nx), 1:incx:incx * Nx) - dy = view(unsafe_wrap(Vector{$elty}, _dy, incy * Ny), 1:incy:incy * Ny) + dx = view(unsafe_wrap(Vector{$elty}, _dx, incx * Nx), 1:incx:(incx * Nx)) + dy = view(unsafe_wrap(Vector{$elty}, _dy, incy * Ny), 1:incy:(incy * Ny)) # Increment the tangents. unsafe_store!(dalpha, unsafe_load(dalpha) + dot(dy, _trans(tA, A), x)) @@ -212,14 +206,8 @@ end @is_primitive( MinimalCtx, Tuple{ - typeof(BLAS.symv!), - Char, - T, - MatrixOrView{T}, - Vector{T}, - T, - Vector{T}, - } where {T<:Union{Float32, Float64}}, + typeof(BLAS.symv!),Char,T,MatrixOrView{T},Vector{T},T,Vector{T} + } where {T<:Union{Float32,Float64}}, ) function rrule!!( @@ -230,7 +218,7 @@ function rrule!!( x_dx::CoDual{Vector{T}}, beta::CoDual{T}, y_dy::CoDual{Vector{T}}, -) where {T<:Union{Float32, Float64}} +) where {T<:Union{Float32,Float64}} # Extract primals. ul = primal(uplo) @@ -254,7 +242,6 @@ function rrule!!( end function symv!_adjoint(::NoRData) - if (α == 1 && β == 0) dα = dot(dy, y) BLAS.copyto!(y, y_copy) @@ -310,9 +297,8 @@ for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) _lda::CoDual{Ptr{BLAS.BlasInt}}, _x::CoDual{Ptr{$elty}}, _incx::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Load in data. uplo, trans, diag = map(Char ∘ unsafe_load ∘ primal, (_uplo, _trans, _diag)) @@ -329,7 +315,6 @@ for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) end function trmv_pb!!(::NoRData) - GC.@preserve args begin # Load up the tangents. @@ -350,8 +335,6 @@ for (trmv, elty) in ((:dtrmv_, :Float64), (:strmv_, :Float32)) end end - - # # LEVEL 3 # @@ -359,15 +342,8 @@ end @is_primitive( MinimalCtx, Tuple{ - typeof(BLAS.gemm!), - Char, - Char, - T, - MatrixOrView{T}, - MatrixOrView{T}, - T, - Matrix{T}, - } where {T<:Union{Float32, Float64}}, + typeof(BLAS.gemm!),Char,Char,T,MatrixOrView{T},MatrixOrView{T},T,Matrix{T} + } where {T<:Union{Float32,Float64}}, ) function rrule!!( @@ -379,7 +355,7 @@ function rrule!!( B::CoDual{<:MatrixOrView{T}}, beta::CoDual{T}, C::CoDual{Matrix{T}}, -) where {T<:Union{Float32, Float64}} +) where {T<:Union{Float32,Float64}} tA = primal(transA) tB = primal(transB) a = primal(alpha) @@ -458,9 +434,8 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) beta::CoDual{Ptr{$elty}}, C::CoDual{Ptr{$elty}}, LDC::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin _tA = Char(unsafe_load(primal(tA))) _tB = Char(unsafe_load(primal(tB))) @@ -476,8 +451,12 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) _C = primal(C) _LDC = unsafe_load(primal(LDC)) - A_mat = wrap_ptr_as_view(primal(A), _LDA, (_tA == 'N' ? (_m, _ka) : (_ka, _m))...) - B_mat = wrap_ptr_as_view(primal(B), _LDB, (_tB == 'N' ? (_ka, _n) : (_n, _ka))...) + A_mat = wrap_ptr_as_view( + primal(A), _LDA, (_tA == 'N' ? (_m, _ka) : (_ka, _m))... + ) + B_mat = wrap_ptr_as_view( + primal(B), _LDB, (_tB == 'N' ? (_ka, _n) : (_n, _ka))... + ) C_mat = wrap_ptr_as_view(primal(C), _LDC, _m, _n) C_copy = collect(C_mat) @@ -491,7 +470,6 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) end function gemm!_pullback!!(::NoRData) - GC.@preserve args begin # Restore previous state. C_mat .= C_copy @@ -505,8 +483,10 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) unsafe_store!(dbeta, unsafe_load(dbeta) + tr(dC_mat' * C_mat)) dalpha_inc = tr(dC_mat' * _trans(_tA, A_mat) * _trans(_tB, B_mat)) unsafe_store!(dalpha, unsafe_load(dalpha) + dalpha_inc) - dA_mat .+= _alpha * transpose(_trans(_tA, _trans(_tB, B_mat) * transpose(dC_mat))) - dB_mat .+= _alpha * transpose(_trans(_tB, transpose(dC_mat) * _trans(_tA, A_mat))) + dA_mat .+= + _alpha * transpose(_trans(_tA, _trans(_tB, B_mat) * transpose(dC_mat))) + dB_mat .+= + _alpha * transpose(_trans(_tB, transpose(dC_mat) * _trans(_tA, A_mat))) dC_mat .*= _beta end @@ -519,15 +499,8 @@ end @is_primitive( MinimalCtx, Tuple{ - typeof(BLAS.symm!), - Char, - Char, - T, - MatrixOrView{T}, - MatrixOrView{T}, - T, - Matrix{T}, - } where {T<:Union{Float32, Float64}}, + typeof(BLAS.symm!),Char,Char,T,MatrixOrView{T},MatrixOrView{T},T,Matrix{T} + } where {T<:Union{Float32,Float64}}, ) function rrule!!( @@ -539,7 +512,7 @@ function rrule!!( B_dB::CoDual{<:MatrixOrView{T}}, beta::CoDual{T}, C_dC::CoDual{Matrix{T}}, -) where {T<:Union{Float32, Float64}} +) where {T<:Union{Float32,Float64}} # Extract primals. s = primal(side) @@ -564,7 +537,6 @@ function rrule!!( end function symm!_adjoint(::NoRData) - if (α == 1 && β == 0) dα = dot(dC, C) BLAS.copyto!(C, C_copy) @@ -622,7 +594,7 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) beta::CoDual{Ptr{$elty}}, C::CoDual{Ptr{$elty}}, LDC::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} GC.@preserve args begin _uplo = Char(unsafe_load(primal(uplo))) @@ -649,7 +621,6 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) end function syrk!_pullback!!(::NoRData) - GC.@preserve args begin # Restore previous state. C_mat .= C_copy @@ -664,7 +635,8 @@ for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) dalpha_inc = tr(B' * _trans(_t, A_mat) * _trans(_t, A_mat)') unsafe_store!(dalpha, unsafe_load(dalpha) + dalpha_inc) dA_mat .+= _alpha * (_t == 'N' ? (B + B') * A_mat : A_mat * (B + B')) - dC_mat .= (_uplo == 'U' ? tril!(dC_mat, -1) : triu!(dC_mat, 1)) .+ _beta .* B + dC_mat .= + (_uplo == 'U' ? tril!(dC_mat, -1) : triu!(dC_mat, 1)) .+ _beta .* B end return tuple_fill(NoRData(), Val(16 + Nargs)) @@ -692,13 +664,14 @@ for (trmm, elty) in ((:dtrmm_, :Float64), (:strmm_, :Float32)) _lda::CoDual{Ptr{BLAS.BlasInt}}, _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Load in data and store B for the reverse-pass. - side, ul, tA, diag = map(Char ∘ unsafe_load ∘ primal, (_side, _uplo, _trans, _diag)) + side, ul, tA, diag = map( + Char ∘ unsafe_load ∘ primal, (_side, _uplo, _trans, _diag) + ) M, N, lda, ldb = map(unsafe_load ∘ primal, (_M, _N, _lda, _ldb)) alpha = unsafe_load(primal(_alpha)) R = side == 'L' ? M : N @@ -715,7 +688,6 @@ for (trmm, elty) in ((:dtrmm_, :Float64), (:strmm_, :Float32)) end function trmm!_pullback!!(::NoRData) - GC.@preserve args begin # Convert pointers to views. dA = wrap_ptr_as_view(_dA, lda, R, R) @@ -764,9 +736,8 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) _lda::CoDual{Ptr{BLAS.BlasInt}}, _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin side = Char(unsafe_load(primal(_side))) uplo = Char(unsafe_load(primal(_uplo))) @@ -790,7 +761,6 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) end function trsm_pb!!(::NoRData) - GC.@preserve args begin # Convert pointers to views. dA = wrap_ptr_as_view(_dA, lda, R, R) @@ -839,59 +809,97 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) test_cases = vcat( # symv! - vec(reduce( - vcat, - vec(map(product(['L', 'U'], alphas, betas)) do (uplo, α, β) - A = randn(5, 5) - vA = view(randn(15, 15), 1:5, 1:5) - x = randn(5) - y = randn(5) - return Any[ - (false, :stability, nothing, BLAS.symv!, uplo, α, A, x, β, y), - (false, :stability, nothing, BLAS.symv!, uplo, α, vA, x, β, y), - ] - end) - )), + vec( + reduce( + vcat, + vec( + map(product(['L', 'U'], alphas, betas)) do (uplo, α, β) + A = randn(5, 5) + vA = view(randn(15, 15), 1:5, 1:5) + x = randn(5) + y = randn(5) + return Any[ + (false, :stability, nothing, BLAS.symv!, uplo, α, A, x, β, y), + (false, :stability, nothing, BLAS.symv!, uplo, α, vA, x, β, y), + ] + end, + ), + ), + ), # gemm! - vec(reduce( - vcat, - vec(map(product(t_flags, t_flags, alphas, betas)) do (tA, tB, a, b) - A = tA == 'N' ? randn(3, 4) : randn(4, 3) - B = tB == 'N' ? randn(4, 5) : randn(5, 4) - As = if tA == 'N' - [randn(3, 4), view(randn(15, 15), 2:4, 3:6)] - else - [randn(4, 3), view(randn(15, 15), 2:5, 3:5)] - end - Bs = if tB == 'N' - [randn(4, 5), view(randn(15, 15), 1:4, 2:6)] - else - [randn(5, 4), view(randn(15, 15), 1:5, 3:6)] - end - C = randn(3, 5) - return map(product(As, Bs)) do (A, B) - (false, :stability, nothing, BLAS.gemm!, tA, tB, a, A, B, b, C) - end - end), - )), + vec( + reduce( + vcat, + vec( + map(product(t_flags, t_flags, alphas, betas)) do (tA, tB, a, b) + A = tA == 'N' ? randn(3, 4) : randn(4, 3) + B = tB == 'N' ? randn(4, 5) : randn(5, 4) + As = if tA == 'N' + [randn(3, 4), view(randn(15, 15), 2:4, 3:6)] + else + [randn(4, 3), view(randn(15, 15), 2:5, 3:5)] + end + Bs = if tB == 'N' + [randn(4, 5), view(randn(15, 15), 1:4, 2:6)] + else + [randn(5, 4), view(randn(15, 15), 1:5, 3:6)] + end + C = randn(3, 5) + return map(product(As, Bs)) do (A, B) + (false, :stability, nothing, BLAS.gemm!, tA, tB, a, A, B, b, C) + end + end, + ), + ), + ), # symm! - vec(reduce( - vcat, - vec(map(product(['L', 'R'], ['L', 'U'], alphas, betas)) do (side, uplo, α, β) - nA = side == 'L' ? 5 : 7 - A = randn(nA, nA) - vA = view(randn(15, 15), 1:nA, 1:nA) - B = randn(5, 7) - vB = view(randn(15, 15), 1:5, 1:7) - C = randn(5, 7) - return Any[ - (false, :stability, nothing, BLAS.symm!, side, uplo, α, A, B, β, C), - (false, :stability, nothing, BLAS.symm!, side, uplo, α, vA, vB, β, C), - ] - end) - )), + vec( + reduce( + vcat, + vec( + map( + product(['L', 'R'], ['L', 'U'], alphas, betas) + ) do (side, uplo, α, β) + nA = side == 'L' ? 5 : 7 + A = randn(nA, nA) + vA = view(randn(15, 15), 1:nA, 1:nA) + B = randn(5, 7) + vB = view(randn(15, 15), 1:5, 1:7) + C = randn(5, 7) + return Any[ + ( + false, + :stability, + nothing, + BLAS.symm!, + side, + uplo, + α, + A, + B, + β, + C, + ), + ( + false, + :stability, + nothing, + BLAS.symm!, + side, + uplo, + α, + vA, + vB, + β, + C, + ), + ] + end, + ), + ), + ), ) memory = Any[] @@ -927,85 +935,105 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # # gemv! - vec(reduce( - vcat, - map(product(t_flags, [1, 3], [1, 2])) do (tA, M, N) - t = tA == 'N' - As = [ - t ? randn(M, N) : randn(N, M), - view(randn(15, 15), t ? (3:M+2) : (2:N+1), t ? (2:N+1) : (3:M+2)), - ] - xs = [randn(N), view(randn(15), 3:N+2), view(randn(30), 1:2:2N)] - ys = [randn(M), view(randn(15), 2:M+1), view(randn(30), 2:2:2M)] - return map(Iterators.product(As, xs, ys)) do (A, x, y) - (false, :none, nothing, BLAS.gemv!, tA, randn(), A, x, randn(), y) - end - end, - )), + vec( + reduce( + vcat, + map(product(t_flags, [1, 3], [1, 2])) do (tA, M, N) + t = tA == 'N' + As = [ + t ? randn(M, N) : randn(N, M), + view( + randn(15, 15), + t ? (3:(M + 2)) : (2:(N + 1)), + t ? (2:(N + 1)) : (3:(M + 2)), + ), + ] + xs = [randn(N), view(randn(15), 3:(N + 2)), view(randn(30), 1:2:(2N))] + ys = [randn(M), view(randn(15), 2:(M + 1)), view(randn(30), 2:2:(2M))] + return map(Iterators.product(As, xs, ys)) do (A, x, y) + (false, :none, nothing, BLAS.gemv!, tA, randn(), A, x, randn(), y) + end + end, + ), + ), # trmv! - vec(reduce( - vcat, - map(product(['L', 'U'], t_flags, ['N', 'U'], [1, 3])) do (ul, tA, dA, N) - As = [randn(N, N), view(randn(15, 15), 3:N+2, 4:N+3)] - bs = [randn(N), view(randn(14), 4:N+3)] - return map(product(As, bs)) do (A, b) - (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) - end - end, - )), + vec( + reduce( + vcat, + map(product(['L', 'U'], t_flags, ['N', 'U'], [1, 3])) do (ul, tA, dA, N) + As = [randn(N, N), view(randn(15, 15), 3:(N + 2), 4:(N + 3))] + bs = [randn(N), view(randn(14), 4:(N + 3))] + return map(product(As, bs)) do (A, b) + (false, :none, nothing, BLAS.trmv!, ul, tA, dA, A, b) + end + end, + ), + ), # # BLAS LEVEL 3 # # aliased gemm! - vec(map(product(t_flags, t_flags)) do (tA, tB) - A = randn(5, 5) - B = randn(5, 5) - (false, :none, nothing, aliased_gemm!, tA, tB, randn(), randn(), A, B) - end), + vec( + map(product(t_flags, t_flags)) do (tA, tB) + A = randn(5, 5) + B = randn(5, 5) + (false, :none, nothing, aliased_gemm!, tA, tB, randn(), randn(), A, B) + end, + ), # syrk! - vec(map(product(['U', 'L'], t_flags)) do (uplo, t) - A = t == 'N' ? randn(3, 4) : randn(4, 3) - C = randn(3, 3) - Any[false, :none, nothing, BLAS.syrk!, uplo, t, randn(), A, randn(), C] - end), + vec( + map(product(['U', 'L'], t_flags)) do (uplo, t) + A = t == 'N' ? randn(3, 4) : randn(4, 3) + C = randn(3, 3) + Any[false, :none, nothing, BLAS.syrk!, uplo, t, randn(), A, randn(), C] + end, + ), # trmm! - vec(reduce( - vcat, - map( - product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]), - ) do (side, ul, tA, dA, M, N) - t = tA == 'N' - R = side == 'L' ? M : N - As = [randn(R, R), view(randn(15, 15), 3:R+2, 4:R+3)] - Bs = [randn(M, N), view(randn(15, 15), 2:M+1, 5:N+4)] - return map(product(As, Bs)) do (A, B) - alpha = randn() - Any[false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, alpha, A, B] - end - end, - )), + vec( + reduce( + vcat, + map( + product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]) + ) do (side, ul, tA, dA, M, N) + t = tA == 'N' + R = side == 'L' ? M : N + As = [randn(R, R), view(randn(15, 15), 3:(R + 2), 4:(R + 3))] + Bs = [randn(M, N), view(randn(15, 15), 2:(M + 1), 5:(N + 4))] + return map(product(As, Bs)) do (A, B) + alpha = randn() + Any[ + false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, alpha, A, B + ] + end + end, + ), + ), # trsm! - vec(reduce( - vcat, - map( - product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]), - ) do (side, ul, tA, dA, M, N) - t = tA == 'N' - R = side == 'L' ? M : N - As = [randn(R, R) + 5I, view(randn(15, 15), 3:R+2, 4:R+3) + 5I] - Bs = [randn(M, N), view(randn(15, 15), 2:M+1, 5:N+4)] - return map(product(As, Bs)) do (A, B) - alpha = randn() - Any[false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, alpha, A, B] - end - end, - )), + vec( + reduce( + vcat, + map( + product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]) + ) do (side, ul, tA, dA, M, N) + t = tA == 'N' + R = side == 'L' ? M : N + As = [randn(R, R) + 5I, view(randn(15, 15), 3:(R + 2), 4:(R + 3)) + 5I] + Bs = [randn(M, N), view(randn(15, 15), 2:(M + 1), 5:(N + 4))] + return map(product(As, Bs)) do (A, B) + alpha = randn() + Any[ + false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, alpha, A, B + ] + end + end, + ), + ), ) memory = Any[] return test_cases, memory diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index c862aaefe..66e055d57 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -9,8 +9,7 @@ # As of version 1.9.2 of Julia, there are exactly 139 examples of `Core.Builtin`s. # - -@is_primitive MinimalCtx Tuple{Core.Builtin, Vararg} +@is_primitive MinimalCtx Tuple{Core.Builtin,Vararg} struct MissingRuleForBuiltinException <: Exception msg::String @@ -18,26 +17,28 @@ end function rrule!!(f::CoDual{<:Core.Builtin}, args...) T_args = map(typeof ∘ primal, args) - throw(MissingRuleForBuiltinException( - "All built-in functions are primitives by default, as they do not have any Julia " * - "code to recurse into. This means that they must all have methods of `rrule!!` " * - "written for them by hand. " * - "The built-in $(primal(f)) has been called with arguments with types $T_args, " * - "but there is no specialised method of `rrule!!` for this built-in and these " * - "types. In order to fix this problem, you will either need to modify your code " * - "to avoid hitting this built-in function, or implement a method of `rrule!!` " * - "which is specialised to this case. " * - "Either way, please consider commenting on " * - "https://github.com/compintell/Mooncake.jl/issues/208/ so that the issue can be " * - "fixed more widely.\n" * - "For reproducibility, note that the full signature is:\n" * - "$(typeof((f, args...)))" - )) + throw( + MissingRuleForBuiltinException( + "All built-in functions are primitives by default, as they do not have any Julia " * + "code to recurse into. This means that they must all have methods of `rrule!!` " * + "written for them by hand. " * + "The built-in $(primal(f)) has been called with arguments with types $T_args, " * + "but there is no specialised method of `rrule!!` for this built-in and these " * + "types. In order to fix this problem, you will either need to modify your code " * + "to avoid hitting this built-in function, or implement a method of `rrule!!` " * + "which is specialised to this case. " * + "Either way, please consider commenting on " * + "https://github.com/compintell/Mooncake.jl/issues/208/ so that the issue can be " * + "fixed more widely.\n" * + "For reproducibility, note that the full signature is:\n" * + "$(typeof((f, args...)))", + ), + ) end function Base.showerror(io::IO, err::MissingRuleForBuiltinException) print(io, "MissingRuleForBuiltinException: ") - println(io, err.msg) + return println(io, err.msg) end module IntrinsicsWrappers @@ -46,9 +47,26 @@ using Base: IEEEFloat using Core: Intrinsics using Mooncake import ..Mooncake: - rrule!!, CoDual, primal, tangent, zero_tangent, NoPullback, - tangent_type, increment!!, @is_primitive, MinimalCtx, is_primitive, NoFData, - zero_rdata, NoRData, tuple_map, fdata, NoRData, rdata, increment_rdata!!, zero_fcodual + rrule!!, + CoDual, + primal, + tangent, + zero_tangent, + NoPullback, + tangent_type, + increment!!, + @is_primitive, + MinimalCtx, + is_primitive, + NoFData, + zero_rdata, + NoRData, + tuple_map, + fdata, + NoRData, + rdata, + increment_rdata!!, + zero_fcodual using Core.Intrinsics: atomic_pointerref @@ -57,7 +75,8 @@ struct MissingIntrinsicWrapperException <: Exception end function translate(f) - msg = "Unable to translate the intrinsic $f into a regular Julia function. " * + msg = + "Unable to translate the intrinsic $f into a regular Julia function. " * "Please see github.com/compintell/Mooncake.jl/issues/208 for more discussion." throw(MissingIntrinsicWrapperException(msg)) end @@ -70,7 +89,7 @@ end macro intrinsic(name) expr = quote $name(x...) = Intrinsics.$name(x...) - (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name), Vararg}}) = true + (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name),Vararg}}) = true translate(::Val{Intrinsics.$name}) = $name end return esc(expr) @@ -79,9 +98,9 @@ end macro inactive_intrinsic(name) expr = quote $name(x...) = Intrinsics.$name(x...) - (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name), Vararg}}) = true + (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name),Vararg}}) = true translate(::Val{Intrinsics.$name}) = $name - function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any, N}) where {N} + function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any,N}) where {N} return Mooncake.zero_adjoint(f, args...) end end @@ -147,11 +166,12 @@ end @intrinsic bitcast function rrule!!(f::CoDual{typeof(bitcast)}, t::CoDual{Type{T}}, x) where {T} if T <: IEEEFloat - msg = "It is not permissible to bitcast to a differentiable type during AD, as " * - "this risks dropping tangents, and therefore risks silently giving the wrong " * - "answer. If this call to bitcast appears as part of the implementation of a " * - "differentiable function, you should write a rule for this function, or modify " * - "its implementation to avoid the bitcast." + msg = + "It is not permissible to bitcast to a differentiable type during AD, as " * + "this risks dropping tangents, and therefore risks silently giving the wrong " * + "answer. If this call to bitcast appears as part of the implementation of a " * + "differentiable function, you should write a rule for this function, or modify " * + "its implementation to avoid the bitcast." throw(ArgumentError(msg)) end _x = primal(x) @@ -184,10 +204,10 @@ named `Mooncake.IntrinsicsWrappers.__cglobal`, rather than If you examine the code associated with `Mooncake.intrinsic_to_function`, you will see that special handling of `cglobal` is used. """ -__cglobal(::Val{s}, x::Vararg{Any, N}) where {s, N} = cglobal(s, x...) +__cglobal(::Val{s}, x::Vararg{Any,N}) where {s,N} = cglobal(s, x...) translate(::Val{Intrinsics.cglobal}) = __cglobal -Mooncake.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(__cglobal), Vararg}}) = true +Mooncake.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(__cglobal),Vararg}}) = true function rrule!!(f::CoDual{typeof(__cglobal)}, args...) return Mooncake.uninit_fcodual(__cglobal(map(primal, args)...)), NoPullback(f, args...) end @@ -253,7 +273,7 @@ end @intrinsic fpext function rrule!!( ::CoDual{typeof(fpext)}, ::CoDual{Type{Pext}}, x::CoDual{P} -) where {Pext<:IEEEFloat, P<:IEEEFloat} +) where {Pext<:IEEEFloat,P<:IEEEFloat} fpext_adjoint!!(dy::Pext) = NoRData(), NoRData(), fptrunc(P, dy) return zero_fcodual(fpext(Pext, primal(x))), fpext_adjoint!! end @@ -265,7 +285,7 @@ end @intrinsic fptrunc function rrule!!( ::CoDual{typeof(fptrunc)}, ::CoDual{Type{Ptrunc}}, x::CoDual{P} -) where {Ptrunc<:IEEEFloat, P<:IEEEFloat} +) where {Ptrunc<:IEEEFloat,P<:IEEEFloat} fptrunc_adjoint!!(dy::Ptrunc) = NoRData(), NoRData(), convert(P, dy) return zero_fcodual(fptrunc(Ptrunc, primal(x))), fptrunc_adjoint!! end @@ -430,8 +450,8 @@ end end # IntrinsicsWrappers -@zero_adjoint MinimalCtx Tuple{typeof(<:), Any, Any} -@zero_adjoint MinimalCtx Tuple{typeof(===), Any, Any} +@zero_adjoint MinimalCtx Tuple{typeof(<:),Any,Any} +@zero_adjoint MinimalCtx Tuple{typeof(===),Any,Any} # Core._abstracttype @@ -443,7 +463,7 @@ end # IntrinsicsWrappers # a pre-processing step. # A function with the same semantics as `Core._apply_iterate`, but which is differentiable. -function _apply_iterate_equivalent(itr, f::F, args::Vararg{Any, N}) where {F, N} +function _apply_iterate_equivalent(itr, f::F, args::Vararg{Any,N}) where {F,N} vec_args = reduce(vcat, map(collect, args)) tuple_args = __vec_to_tuple(vec_args) return tuple_splat(f, tuple_args) @@ -452,12 +472,12 @@ end # A primitive used to avoid exposing `_apply_iterate_equivalent` to `Core._apply_iterate`. __vec_to_tuple(v::Vector) = Tuple(v) -@is_primitive MinimalCtx Tuple{typeof(__vec_to_tuple), Vector} +@is_primitive MinimalCtx Tuple{typeof(__vec_to_tuple),Vector} function rrule!!(::CoDual{typeof(__vec_to_tuple)}, v::CoDual{<:Vector}) dv = tangent(v) y = CoDual(Tuple(primal(v)), fdata(Tuple(dv))) - function vec_to_tuple_pb!!(dy::Union{Tuple, NoRData}) + function vec_to_tuple_pb!!(dy::Union{Tuple,NoRData}) if dy isa Tuple for n in eachindex(dy) dv[n] = increment_rdata!!(dv[n], dy[n]) @@ -474,7 +494,7 @@ end # Core._call_latest # Doesn't do anything differentiable. -@zero_adjoint MinimalCtx Tuple{typeof(Core._compute_sparams), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(Core._compute_sparams),Vararg} # Core._equiv_typedef # Core._expr @@ -501,7 +521,7 @@ end function rrule!!(f::CoDual{typeof(Core.apply_type)}, args...) T = Core.apply_type(tuple_map(primal, args)...) - return CoDual{_typeof(T), NoFData}(T, NoFData()), NoPullback(f, args...) + return CoDual{_typeof(T),NoFData}(T, NoFData()), NoPullback(f, args...) end function rrule!!(::CoDual{typeof(compilerbarrier)}, setting::CoDual{Symbol}, val::CoDual) @@ -513,21 +533,22 @@ end # Core.finalizer # Core.get_binding_type -function rrule!!(f::CoDual{typeof(Core.ifelse)}, cond, a::A, b::B) where {A, B} +function rrule!!(f::CoDual{typeof(Core.ifelse)}, cond, a::A, b::B) where {A,B} _cond = primal(cond) p_a = primal(a) p_b = primal(b) - pb!! = if rdata_type(tangent_type(A)) == NoRData && rdata_type(tangent_type(B)) == NoRData - NoPullback(f, cond, a, b) - else - lazy_da = lazy_zero_rdata(p_a) - lazy_db = lazy_zero_rdata(p_b) - function ifelse_pullback!!(dc) - da = ifelse(_cond, dc, instantiate(lazy_da)) - db = ifelse(_cond, instantiate(lazy_db), dc) - return NoRData(), NoRData(), da, db + pb!! = + if rdata_type(tangent_type(A)) == NoRData && rdata_type(tangent_type(B)) == NoRData + NoPullback(f, cond, a, b) + else + lazy_da = lazy_zero_rdata(p_a) + lazy_db = lazy_zero_rdata(p_b) + function ifelse_pullback!!(dc) + da = ifelse(_cond, dc, instantiate(lazy_da)) + db = ifelse(_cond, instantiate(lazy_db), dc) + return NoRData(), NoRData(), da, db + end end - end # It's a good idea to split up applying ifelse to the primal and tangent. This is # because if you push a `CoDual` through ifelse, it _forces_ the construction of the @@ -537,17 +558,17 @@ function rrule!!(f::CoDual{typeof(Core.ifelse)}, cond, a::A, b::B) where {A, B} return CoDual(ifelse(_cond, p_a, p_b), ifelse(_cond, tangent(a), tangent(b))), pb!! end -@zero_adjoint MinimalCtx Tuple{typeof(Core.sizeof), Any} +@zero_adjoint MinimalCtx Tuple{typeof(Core.sizeof),Any} # Core.svec -@zero_adjoint MinimalCtx Tuple{typeof(applicable), Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(fieldtype), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(applicable),Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(fieldtype),Vararg} -const StandardFDataType = Union{Tuple, NamedTuple, FData, MutableTangent, NoFData} +const StandardFDataType = Union{Tuple,NamedTuple,FData,MutableTangent,NoFData} function rrule!!( - f::CoDual{typeof(getfield)}, x::CoDual{P, <:StandardFDataType}, name::CoDual + f::CoDual{typeof(getfield)}, x::CoDual{P,<:StandardFDataType}, name::CoDual ) where {P} if tangent_type(P) == NoTangent y = uninit_fcodual(getfield(primal(x), primal(name))) @@ -567,8 +588,8 @@ function rrule!!( end function rrule!!( - f::CoDual{typeof(getfield)}, x::CoDual{P, F}, name::CoDual, order::CoDual -) where {P, F<:StandardFDataType} + f::CoDual{typeof(getfield)}, x::CoDual{P,F}, name::CoDual, order::CoDual +) where {P,F<:StandardFDataType} if tangent_type(P) == NoTangent y = uninit_fcodual(getfield(primal(x), primal(name))) return y, NoPullback(f, x, name, order) @@ -605,16 +626,16 @@ is_homogeneous_and_immutable(::Any) = false # return y, pb!! # end -@zero_adjoint MinimalCtx Tuple{typeof(getglobal), Any, Any} +@zero_adjoint MinimalCtx Tuple{typeof(getglobal),Any,Any} # invoke -@zero_adjoint MinimalCtx Tuple{typeof(isa), Any, Any} -@zero_adjoint MinimalCtx Tuple{typeof(isdefined), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(isa),Any,Any} +@zero_adjoint MinimalCtx Tuple{typeof(isdefined),Vararg} # modifyfield! -@zero_adjoint MinimalCtx Tuple{typeof(nfields), Any} +@zero_adjoint MinimalCtx Tuple{typeof(nfields),Any} # replacefield! @@ -638,7 +659,7 @@ end @inline tuple_pullback(dy::NoRData) = NoRData() -function rrule!!(f::CoDual{typeof(tuple)}, args::Vararg{Any, N}) where {N} +function rrule!!(f::CoDual{typeof(tuple)}, args::Vararg{Any,N}) where {N} primal_output = tuple(map(primal, args)...) if tangent_type(_typeof(primal_output)) == NoTangent return zero_fcodual(primal_output), NoPullback(f, args...) @@ -656,10 +677,9 @@ function rrule!!(::CoDual{typeof(typeassert)}, x::CoDual, type::CoDual) return CoDual(typeassert(primal(x), primal(type)), tangent(x)), typeassert_pullback end -@zero_adjoint MinimalCtx Tuple{typeof(typeof), Any} +@zero_adjoint MinimalCtx Tuple{typeof(typeof),Any} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) - _x = Ref(5.0) # data used in tests which aren't protected by GC. _dx = Ref(4.0) _a = Vector{Vector{Float64}}(undef, 3) @@ -687,14 +707,26 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.add_float_fast, 4.0, 5.0), (false, :stability, nothing, IntrinsicsWrappers.add_int, 1, 2), (false, :stability, nothing, IntrinsicsWrappers.and_int, 2, 3), - (false, :stability, nothing, IntrinsicsWrappers.ashr_int, 123456, 0x0000000000000020), + ( + false, + :stability, + nothing, + IntrinsicsWrappers.ashr_int, + 123456, + 0x0000000000000020, + ), # atomic_fence -- NEEDS IMPLEMENTING AND TESTING # atomic_pointermodify -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerref -- NEEDS IMPLEMENTING AND TESTING # atomic_pointerreplace -- NEEDS IMPLEMENTING AND TESTING ( - true, :none, nothing, - IntrinsicsWrappers.atomic_pointerset, CoDual(p, dp), 1.0, :monotonic, + true, + :none, + nothing, + IntrinsicsWrappers.atomic_pointerset, + CoDual(p, dp), + 1.0, + :monotonic, ), # atomic_pointerswap -- NEEDS IMPLEMENTING AND TESTING (false, :stability, nothing, IntrinsicsWrappers.bitcast, Int64, 5.0), @@ -734,7 +766,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.flipsign_int, 4, -3), (false, :stability, nothing, IntrinsicsWrappers.floor_llvm, 4.1), (false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0, 4.0, 3.0), - (true, :stability_and_allocs, nothing, IntrinsicsWrappers.fpext, Float64, 5f0), + (true, :stability_and_allocs, nothing, IntrinsicsWrappers.fpext, Float64, 5.0f0), (false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.1, 4.0), (false, :stability, nothing, IntrinsicsWrappers.fptosi, UInt32, 4.1), (false, :stability, nothing, IntrinsicsWrappers.fptoui, Int32, 4.1), @@ -743,7 +775,14 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.le_float, 4.1, 4.0), (false, :stability, nothing, IntrinsicsWrappers.le_float_fast, 4.1, 4.0), # llvm_call -- NEEDS IMPLEMENTING AND TESTING - (false, :stability, nothing, IntrinsicsWrappers.lshr_int, 1308622848, 0x0000000000000018), + ( + false, + :stability, + nothing, + IntrinsicsWrappers.lshr_int, + 1308622848, + 0x0000000000000018, + ), (false, :stability, nothing, IntrinsicsWrappers.lt_float, 4.1, 4.0), (false, :stability, nothing, IntrinsicsWrappers.lt_float_fast, 4.1, 4.0), (false, :stability, nothing, IntrinsicsWrappers.mul_float, 5.0, 4.0), @@ -761,14 +800,30 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.or_int, 5, 5), (true, :stability, nothing, IntrinsicsWrappers.pointerref, CoDual(p, dp), 2, 1), (true, :stability, nothing, IntrinsicsWrappers.pointerref, CoDual(q, dq), 2, 1), - (true, :stability, nothing, IntrinsicsWrappers.pointerset, CoDual(p, dp), 5.0, 2, 1), + ( + true, + :stability, + nothing, + IntrinsicsWrappers.pointerset, + CoDual(p, dp), + 5.0, + 2, + 1, + ), (true, :stability, nothing, IntrinsicsWrappers.pointerset, CoDual(q, dq), 1, 2, 1), # rem_float -- untested and unimplemented because seemingly unused on master # rem_float_fast -- untested and unimplemented because seemingly unused on master (false, :stability, nothing, IntrinsicsWrappers.rint_llvm, 5), (false, :stability, nothing, IntrinsicsWrappers.sdiv_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.sext_int, Int64, Int32(1308622848)), - (false, :stability, nothing, IntrinsicsWrappers.shl_int, 1308622848, 0xffffffffffffffe8), + ( + false, + :stability, + nothing, + IntrinsicsWrappers.shl_int, + 1308622848, + 0xffffffffffffffe8, + ), (false, :stability, nothing, IntrinsicsWrappers.sitofp, Float64, 0), (false, :stability, nothing, IntrinsicsWrappers.sle_int, 5, 4), (false, :stability, nothing, IntrinsicsWrappers.slt_int, 4, 5), @@ -850,11 +905,11 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :none, _range, getfield, MutableFoo(5.0, randn(5)), :b), (false, :stability_and_allocs, nothing, getfield, UnitRange{Int}(5:9), :start), (false, :stability_and_allocs, nothing, getfield, UnitRange{Int}(5:9), :stop), - (false, :stability_and_allocs, nothing, getfield, (5.0, ), 1), + (false, :stability_and_allocs, nothing, getfield, (5.0,), 1), (false, :stability_and_allocs, nothing, getfield, (5.0, 4.0), 1), - (false, :stability_and_allocs, nothing, getfield, (5.0, ), 1, false), + (false, :stability_and_allocs, nothing, getfield, (5.0,), 1, false), (false, :stability_and_allocs, nothing, getfield, (5.0, 4.0), 1, false), - (false, :stability_and_allocs, nothing, getfield, (1, ), 1, false), + (false, :stability_and_allocs, nothing, getfield, (1,), 1, false), (false, :stability_and_allocs, nothing, getfield, (1, 2), 1), (false, :stability_and_allocs, nothing, getfield, (a=5, b=4), 1), (false, :stability_and_allocs, nothing, getfield, (a=5, b=4), 2), @@ -892,7 +947,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability_and_allocs, nothing, tuple), (false, :stability_and_allocs, nothing, tuple, 1), (false, :stability_and_allocs, nothing, tuple, 1, 5), - (false, :stability_and_allocs, nothing, tuple, 1.0, (5, )), + (false, :stability_and_allocs, nothing, tuple, 1.0, (5,)), (false, :stability, nothing, typeassert, 5.0, Float64), (false, :stability, nothing, typeassert, randn(5), Vector{Float64}), (false, :stability, nothing, typeof, 5.0), @@ -906,37 +961,52 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, 5.0, 4.0), (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, (5.0, 4.0)), (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, [5.0, 4.0]), - (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, [5.0], (4.0, )), - (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, 3, (4.0, )), + (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, [5.0], (4.0,)), + (false, :none, nothing, _apply_iterate_equivalent, Base.iterate, *, 3, (4.0,)), ( # 33 arguments is the critical length at which splatting gives up on inferring, # and backs off to `Core._apply_iterate`. It's important to check this in order # to verify that we don't wind up in an infinite recursion. - false, :none, nothing, - _apply_iterate_equivalent, Base.iterate, +, randn(33), + false, + :none, + nothing, + _apply_iterate_equivalent, + Base.iterate, + +, + randn(33), ), ( # Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. - false, :none, nothing, - x -> +(x...), randn(33), + false, + :none, + nothing, + x -> +(x...), + randn(33), ), ( - false, :none, nothing, - ( - function (x) - rx = Ref(x) - pointerref(bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1) - end - ), + false, + :none, + nothing, + (function (x) + rx = Ref(x) + return pointerref(bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1) + end), 5.0, ), ( - false, :none, nothing, - (v, x) -> (pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5), + false, + :none, + nothing, + (v, x) -> (pointerset(pointer(x), v, 2, 1); x), + 3.0, + randn(5), ), ( - false, :none, nothing, - x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5), + false, + :none, + nothing, + x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), + rand(UInt8, 5), ), (false, :none, nothing, getindex, randn(5), [1, 1]), (false, :none, nothing, getindex, randn(5), [1, 2, 2]), diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 80d014e0e..935c0947f 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -1,33 +1,39 @@ -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), IEEEFloat} -function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{P}) where {P<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast),IEEEFloat} +function rrule!!( + ::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{P} +) where {P<:IEEEFloat} yp = Base.FastMath.exp_fast(primal(x)) exp_fast_pb!!(dy::P) = NoRData(), dy * yp return CoDual(yp, NoFData()), exp_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), IEEEFloat} -function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P}) where {P<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast),IEEEFloat} +function rrule!!( + ::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P} +) where {P<:IEEEFloat} yp = Base.FastMath.exp2_fast(primal(x)) exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2) return CoDual(yp, NoFData()), exp2_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), IEEEFloat} -function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P}) where {P<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast),IEEEFloat} +function rrule!!( + ::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P} +) where {P<:IEEEFloat} yp = Base.FastMath.exp10_fast(primal(x)) exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10) return CoDual(yp, NoFData()), exp2_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.sincos), IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.sincos),IEEEFloat} function rrule!!(::CoDual{typeof(Base.FastMath.sincos)}, x::CoDual{P}) where {P<:IEEEFloat} y = Base.FastMath.sincos(primal(x)) - sincos_fast_adj!!(dy::Tuple{P, P}) = NoRData(), dy[1] * y[2] - dy[2] * y[1] + sincos_fast_adj!!(dy::Tuple{P,P}) = NoRData(), dy[1] * y[2] - dy[2] * y[1] return CoDual(y, NoFData()), sincos_fast_adj!! end -@is_primitive MinimalCtx Tuple{typeof(Base.log), Union{IEEEFloat, Int}} -@zero_adjoint MinimalCtx Tuple{typeof(log), Int} +@is_primitive MinimalCtx Tuple{typeof(Base.log),Union{IEEEFloat,Int}} +@zero_adjoint MinimalCtx Tuple{typeof(log),Int} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) test_cases = Any[ @@ -82,15 +88,25 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) (false, :allocs, nothing, Base.FastMath.lt_fast, 5.0, 0.4), (false, :allocs, nothing, Base.FastMath.max_fast, 5.0, 4.0), ( - false, :none, nothing, - Base.FastMath.maximum!_fast, sin, [0.0, 0.0], [5.0 4.0; 3.0 2.0], + false, + :none, + nothing, + Base.FastMath.maximum!_fast, + sin, + [0.0, 0.0], + [5.0 4.0; 3.0 2.0], ), (false, :allocs, nothing, Base.FastMath.maximum_fast, [5.0, 4.0, 3.0]), (false, :allocs, nothing, Base.FastMath.min_fast, 5.0, 4.0), (false, :allocs, nothing, Base.FastMath.min_fast, 4.0, 5.0), ( - false, :none, nothing, - Base.FastMath.minimum!_fast, sin, [0.0, 0.0], [5.0 4.0; 3.0 2.0], + false, + :none, + nothing, + Base.FastMath.minimum!_fast, + sin, + [0.0, 0.0], + [5.0 4.0; 3.0 2.0], ), (false, :allocs, nothing, Base.FastMath.minimum_fast, [5.0, 3.0, 4.0]), (false, :allocs, nothing, Base.FastMath.minmax_fast, 5.0, 4.0), @@ -113,5 +129,3 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) memory = Any[] return test_cases, memory end - - diff --git a/src/rrules/foreigncall.jl b/src/rrules/foreigncall.jl index 5f042a321..d9cf49416 100644 --- a/src/rrules/foreigncall.jl +++ b/src/rrules/foreigncall.jl @@ -9,18 +9,20 @@ Base.showerror(io::IO, err::MissingForeigncallRuleError) = print(io, err.msg) # creating an informative error message, so that users have some chance of knowing why # they're not able to differentiate a piece of code. function rrule!!(::CoDual{typeof(_foreigncall_)}, args...) - throw(MissingForeigncallRuleError( - "No rrule!! available for foreigncall with primal argument types " * - "$(typeof(map(primal, args))). " * - "This problem has most likely arisen because there is a ccall somewhere in the " * - "function you are trying to differentiate, for which an rrule!! has not been " * - "explicitly written." * - "You have three options: write an rrule!! for this foreigncall, write an rrule!! " * - "for a Julia function that calls this foreigncall, or re-write your code to " * - "avoid this foreigncall entirely. " * - "If you believe that this error has arisen for some other reason than the above, " * - "or the above does not help you to workaround this problem, please open an issue." - )) + throw( + MissingForeigncallRuleError( + "No rrule!! available for foreigncall with primal argument types " * + "$(typeof(map(primal, args))). " * + "This problem has most likely arisen because there is a ccall somewhere in the " * + "function you are trying to differentiate, for which an rrule!! has not been " * + "explicitly written." * + "You have three options: write an rrule!! for this foreigncall, write an rrule!! " * + "for a Julia function that calls this foreigncall, or re-write your code to " * + "avoid this foreigncall entirely. " * + "If you believe that this error has arisen for some other reason than the above, " * + "or the above does not help you to workaround this problem, please open an issue.", + ), + ) end _get_arg_type(::Type{Val{T}}) where {T} = T @@ -46,8 +48,13 @@ Credit: Umlaut.jl has the original implementation of this function. This is larg over from there. """ @generated function _foreigncall_( - ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x::Vararg{Any, N} -) where {name, RT, nreq, calling_convention, N} + ::Val{name}, + ::Val{RT}, + AT::Tuple, + ::Val{nreq}, + ::Val{calling_convention}, + x::Vararg{Any,N}, +) where {name,RT,nreq,calling_convention,N} return Expr( :foreigncall, QuoteNode(name), @@ -59,17 +66,17 @@ over from there. ) end -@is_primitive MinimalCtx Tuple{typeof(_foreigncall_), Vararg} +@is_primitive MinimalCtx Tuple{typeof(_foreigncall_),Vararg} # # Rules to handle / avoid foreigncall nodes # -@zero_adjoint MinimalCtx Tuple{typeof(Base.allocatedinline), Type} +@zero_adjoint MinimalCtx Tuple{typeof(Base.allocatedinline),Type} -@zero_adjoint MinimalCtx Tuple{typeof(objectid), Any} +@zero_adjoint MinimalCtx Tuple{typeof(objectid),Any} -@is_primitive MinimalCtx Tuple{typeof(pointer_from_objref), Any} +@is_primitive MinimalCtx Tuple{typeof(pointer_from_objref),Any} function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) y = CoDual( pointer_from_objref(primal(x)), @@ -78,16 +85,16 @@ function rrule!!(f::CoDual{typeof(pointer_from_objref)}, x) return y, NoPullback(f, x) end -@zero_adjoint MinimalCtx Tuple{typeof(CC.return_type), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(CC.return_type),Vararg} -@is_primitive MinimalCtx Tuple{typeof(Base.unsafe_pointer_to_objref), Ptr} +@is_primitive MinimalCtx Tuple{typeof(Base.unsafe_pointer_to_objref),Ptr} function rrule!!(f::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:Ptr}) y = CoDual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(tangent(x))) return y, NoPullback(f, x) end @zero_adjoint MinimalCtx Tuple{typeof(Threads.threadid)} -@zero_adjoint MinimalCtx Tuple{typeof(typeintersect), Any, Any} +@zero_adjoint MinimalCtx Tuple{typeof(typeintersect),Any,Any} function _increment_pointer!(x::Ptr{T}, y::Ptr{T}, N::Integer) where {T} increment!!(unsafe_wrap(Vector{T}, x, N), unsafe_wrap(Vector{T}, y, N)) @@ -97,7 +104,7 @@ end # unsafe_copyto! is the only function in Julia that appears to rely on a ccall to `memmove`. # Since we can't differentiate `memmove` (due to a lack of type information), it is # necessary to work with `unsafe_copyto!` instead. -@is_primitive MinimalCtx Tuple{typeof(unsafe_copyto!), Ptr{T}, Ptr{T}, Any} where {T} +@is_primitive MinimalCtx Tuple{typeof(unsafe_copyto!),Ptr{T},Ptr{T},Any} where {T} function rrule!!( ::CoDual{typeof(unsafe_copyto!)}, dest::CoDual{Ptr{T}}, src::CoDual{Ptr{T}}, n::CoDual ) where {T} @@ -133,18 +140,18 @@ end function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_reshape_array}}, - ::CoDual{Val{Array{P, M}}}, - ::CoDual{Tuple{Val{Any}, Val{Any}, Val{Any}}}, + ::CoDual{Val{Array{P,M}}}, + ::CoDual{Tuple{Val{Any},Val{Any},Val{Any}}}, ::CoDual, # nreq ::CoDual, # calling convention - x::CoDual{Type{Array{P, M}}}, - a::CoDual{Array{P, N}, Array{T, N}}, + x::CoDual{Type{Array{P,M}}}, + a::CoDual{Array{P,N},Array{T,N}}, dims::CoDual, -) where {P, T, M, N} +) where {P,T,M,N} d = primal(dims) y = CoDual( - ccall(:jl_reshape_array, Array{P, M}, (Any, Any, Any), Array{P, M}, primal(a), d), - ccall(:jl_reshape_array, Array{T, M}, (Any, Any, Any), Array{T, M}, tangent(a), d), + ccall(:jl_reshape_array, Array{P,M}, (Any, Any, Any), Array{P,M}, primal(a), d), + ccall(:jl_reshape_array, Array{T,M}, (Any, Any, Any), Array{T,M}, tangent(a), d), ) return y, NoPullback(ntuple(_ -> NoRData(), 9)) end @@ -159,7 +166,7 @@ function rrule!!( a::CoDual{<:Array}, ii::CoDual{UInt}, args..., -) where {RT, AT, nreq, calling_convention} +) where {RT,AT,nreq,calling_convention} GC.@preserve args begin y = ccall(:jl_array_isassigned, Cint, (Any, UInt), primal(a), primal(ii)) end @@ -170,7 +177,7 @@ function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_type_unionall}}, ::CoDual{Val{Any}}, # return type - ::CoDual{Tuple{Val{Any}, Val{Any}}}, # arg types + ::CoDual{Tuple{Val{Any},Val{Any}}}, # arg types ::CoDual{Val{0}}, # number of required args ::CoDual{Val{:ccall}}, a::CoDual, @@ -180,7 +187,7 @@ function rrule!!( return zero_fcodual(y), NoPullback(ntuple(_ -> NoRData(), 8)) end -@is_primitive MinimalCtx Tuple{typeof(deepcopy), Any} +@is_primitive MinimalCtx Tuple{typeof(deepcopy),Any} function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) fdx = tangent(x) dx = zero_rdata(primal(x)) @@ -193,13 +200,13 @@ function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) return y, deepcopy_pb!! end -@zero_adjoint MinimalCtx Tuple{typeof(fieldoffset), DataType, Integer} -@zero_adjoint MinimalCtx Tuple{Type{UnionAll}, TypeVar, Any} -@zero_adjoint MinimalCtx Tuple{Type{UnionAll}, TypeVar, Type} -@zero_adjoint MinimalCtx Tuple{typeof(hash), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(fieldoffset),DataType,Integer} +@zero_adjoint MinimalCtx Tuple{Type{UnionAll},TypeVar,Any} +@zero_adjoint MinimalCtx Tuple{Type{UnionAll},TypeVar,Type} +@zero_adjoint MinimalCtx Tuple{typeof(hash),Vararg} function rrule!!( - f::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual, N} + f::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual,N} ) where {N} x = tuple_map(primal, args) pb!! = NoPullback((NoRData(), NoRData(), tuple_map(_ -> NoRData(), args)...)) @@ -207,36 +214,54 @@ function rrule!!( end function unexepcted_foreigncall_error(name) - throw(error( - "AD has hit a :($name) ccall. This should not happen. " * - "Please open an issue with a minimal working example in order to reproduce. ", - "This is true unless you have intentionally written a ccall to :$(name), ", - "in which case you must write a :foreigncall rule. It may not be possible ", - "to implement a :foreigncall rule if too much type information has been lost ", - "in which case your only recourse is to write a rule for whichever Julia ", - "function calls this one (and retains enough type information).", - )) + throw( + error( + "AD has hit a :($name) ccall. This should not happen. " * + "Please open an issue with a minimal working example in order to reproduce. ", + "This is true unless you have intentionally written a ccall to :$(name), ", + "in which case you must write a :foreigncall rule. It may not be possible ", + "to implement a :foreigncall rule if too much type information has been lost ", + "in which case your only recourse is to write a rule for whichever Julia ", + "function calls this one (and retains enough type information).", + ), + ) end for name in [ - :(:jl_alloc_array_1d), :(:jl_alloc_array_2d), :(:jl_alloc_array_3d), :(:jl_new_array), - :(:jl_array_grow_end), :(:jl_array_del_end), :(:jl_array_copy), :(:jl_object_id), - :(:jl_type_intersection), :(:memset), :(:jl_get_tls_world_age), :(:memmove), - :(:jl_array_sizehint), :(:jl_array_del_at), :(:jl_array_grow_at), :(:jl_array_del_beg), - :(:jl_array_grow_beg), :(:jl_value_ptr), :(:jl_type_unionall), :(:jl_threadid), - :(:memhash_seed), :(:memhash32_seed), :(:jl_get_field_offset), + :(:jl_alloc_array_1d), + :(:jl_alloc_array_2d), + :(:jl_alloc_array_3d), + :(:jl_new_array), + :(:jl_array_grow_end), + :(:jl_array_del_end), + :(:jl_array_copy), + :(:jl_object_id), + :(:jl_type_intersection), + :(:memset), + :(:jl_get_tls_world_age), + :(:memmove), + :(:jl_array_sizehint), + :(:jl_array_del_at), + :(:jl_array_grow_at), + :(:jl_array_del_beg), + :(:jl_array_grow_beg), + :(:jl_value_ptr), + :(:jl_type_unionall), + :(:jl_threadid), + :(:memhash_seed), + :(:memhash32_seed), + :(:jl_get_field_offset), ] @eval function _foreigncall_( - ::Val{$name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x..., - ) where {RT, nreq, calling_convention} - unexepcted_foreigncall_error($name) + ::Val{$name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x... + ) where {RT,nreq,calling_convention} + return unexepcted_foreigncall_error($name) end @eval function rrule!!(::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$name}}, args...) - unexepcted_foreigncall_error($name) + return unexepcted_foreigncall_error($name) end end - function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) _x = Ref(5.0) _dx = randn_tangent(Xoshiro(123456), _x) @@ -264,14 +289,22 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) ), (false, :none, nothing, Core.Compiler.return_type, sin, Tuple{Float64}), ( - false, :none, (lb=1e-3, ub=100.0), - Core.Compiler.return_type, Tuple{typeof(sin), Float64}, + false, + :none, + (lb=1e-3, ub=100.0), + Core.Compiler.return_type, + Tuple{typeof(sin),Float64}, ), (false, :stability, nothing, Threads.threadid), (false, :stability, nothing, typeintersect, Float64, Int), ( - true, :stability, nothing, - unsafe_copyto!, CoDual(ptr_a, ptr_da), CoDual(ptr_b, ptr_db), 4, + true, + :stability, + nothing, + unsafe_copyto!, + CoDual(ptr_a, ptr_da), + CoDual(ptr_b, ptr_db), + 4, ), (false, :stability, nothing, deepcopy, 5.0), (false, :stability, nothing, deepcopy, randn(5)), @@ -291,7 +324,6 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) - _x = Ref(5.0) function unsafe_copyto_tester(x::Vector{T}, y::Vector{T}, n::Int) where {T} @@ -314,26 +346,48 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) (false, :none, nothing, unsafe_copyto_tester, randn(5), randn(3), 2), (false, :none, nothing, unsafe_copyto_tester, randn(5), randn(6), 4), ( - false, :none, nothing, - unsafe_copyto_tester, [randn(3) for _ in 1:5], [randn(4) for _ in 1:6], 4, + false, + :none, + nothing, + unsafe_copyto_tester, + [randn(3) for _ in 1:5], + [randn(4) for _ in 1:6], + 4, ), ( - false, :none, (lb=0.1, ub=150), - x -> unsafe_pointer_to_objref(pointer_from_objref(x)), _x, + false, + :none, + (lb=0.1, ub=150), + x -> unsafe_pointer_to_objref(pointer_from_objref(x)), + _x, ), (false, :none, nothing, isassigned, randn(5), 4), (false, :none, nothing, x -> (Base._growbeg!(x, 2); x[1:2] .= 2.0), randn(5)), ( - false, :none, nothing, - (t, v) -> ccall(:jl_type_unionall, Any, (Any, Any), t, v), TypeVar(:a), Real, + false, + :none, + nothing, + (t, v) -> ccall(:jl_type_unionall, Any, (Any, Any), t, v), + TypeVar(:a), + Real, ), ( - true, :none, nothing, - unsafe_copyto!, CoDual(ptr_a, ptr_da), CoDual(ptr_b, ptr_db), 4, + true, + :none, + nothing, + unsafe_copyto!, + CoDual(ptr_a, ptr_da), + CoDual(ptr_b, ptr_db), + 4, ), ( - true, :none, nothing, - unsafe_copyto!, CoDual(ptr_a, ptr_da), CoDual(ptr_b, ptr_db), 4, + true, + :none, + nothing, + unsafe_copyto!, + CoDual(ptr_a, ptr_da), + CoDual(ptr_b, ptr_db), + 4, ), ] return test_cases, memory diff --git a/src/rrules/function_wrappers.jl b/src/rrules/function_wrappers.jl index 0665f8a02..e605d2629 100644 --- a/src/rrules/function_wrappers.jl +++ b/src/rrules/function_wrappers.jl @@ -8,38 +8,36 @@ end function _construct_types(R, A) # Convert signature into a tuple of types. - primal_arg_types = (A.parameters..., ) + primal_arg_types = (A.parameters...,) # Signature and OpaqueClosure type for reverse pass. rvs_sig = Tuple{rdata_type(tangent_type(R))} primal_rdata_sig = Tuple{map(rdata_type ∘ tangent_type, primal_arg_types)...} - pb_ret_type = Tuple{NoRData, primal_rdata_sig.parameters...} - rvs_oc_type = Core.OpaqueClosure{rvs_sig, pb_ret_type} + pb_ret_type = Tuple{NoRData,primal_rdata_sig.parameters...} + rvs_oc_type = Core.OpaqueClosure{rvs_sig,pb_ret_type} # Signature and OpaqueClosure type for forwards pass. fwd_sig = Tuple{map(fcodual_type, primal_arg_types)...} - fwd_oc_type = Core.OpaqueClosure{fwd_sig, Tuple{fcodual_type(R), rvs_oc_type}} + fwd_oc_type = Core.OpaqueClosure{fwd_sig,Tuple{fcodual_type(R),rvs_oc_type}} return fwd_oc_type, rvs_oc_type, fwd_sig, rvs_sig end -function tangent_type(::Type{FunctionWrapper{R, A}}) where {R, A<:Tuple} +function tangent_type(::Type{FunctionWrapper{R,A}}) where {R,A<:Tuple} return FunctionWrapperTangent{_construct_types(R, A)[1]} end import .TestUtils: has_equal_data_internal function has_equal_data_internal( - p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} + p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} ) where {P<:FunctionWrapper} return has_equal_data_internal(p.obj, q.obj, equal_undefs, d) end function has_equal_data_internal( - t::T, s::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} + t::T, s::T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} ) where {T<:FunctionWrapperTangent} return has_equal_data_internal(t.dobj_ref[], s.dobj_ref[], equal_undefs, d) end - - function _function_wrapper_tangent(R, obj::Tobj, A, obj_tangent) where {Tobj} # Analyse types. @@ -49,7 +47,7 @@ function _function_wrapper_tangent(R, obj::Tobj, A, obj_tangent) where {Tobj} obj_tangent_ref = Ref{tangent_type(Tobj)}(obj_tangent) # Contruct a rule for `obj`, applied to its declared argument types. - rule = build_rrule(Tuple{Tobj, A.parameters...}) + rule = build_rrule(Tuple{Tobj,A.parameters...}) # Construct stack which can hold pullbacks generated by `rule`. The forwards-pass will # run `rule` and push the pullback to `pb_stack`. The reverse-pass will pop and run it. @@ -74,8 +72,8 @@ function _function_wrapper_tangent(R, obj::Tobj, A, obj_tangent) where {Tobj} end function zero_tangent_internal( - p::FunctionWrapper{R, A}, stackdict::Union{Nothing, IdDict} -) where {R, A} + p::FunctionWrapper{R,A}, stackdict::Union{Nothing,IdDict} +) where {R,A} # If we've seen this primal before, then we must return that tangent. haskey(stackdict, p) && return stackdict[p]::tangent_type(typeof(p)) @@ -88,8 +86,8 @@ function zero_tangent_internal( end function randn_tangent_internal( - rng::AbstractRNG, p::FunctionWrapper{R, A}, stackdict::Union{Nothing, IdDict} -) where {R, A} + rng::AbstractRNG, p::FunctionWrapper{R,A}, stackdict::Union{Nothing,IdDict} +) where {R,A} # If we've seen this primal before, then we must return that tangent. haskey(stackdict, p) && return stackdict[p]::tangent_type(typeof(p)) @@ -115,7 +113,7 @@ function _add_to_primal(p::FunctionWrapper, t::FunctionWrapperTangent, unsafe::B return typeof(p)(_add_to_primal(p.obj[], t.dobj_ref[], unsafe)) end -function _diff(p::P, q::P) where {R, A, P<:FunctionWrapper{R, A}} +function _diff(p::P, q::P) where {R,A,P<:FunctionWrapper{R,A}} return first(_function_wrapper_tangent(R, p.obj[], A, _diff(p.obj[], q.obj[]))) end @@ -147,14 +145,14 @@ _verify_fdata_value(p::FunctionWrapper, t::FunctionWrapperTangent) = nothing # meaningful way inside of ChainRules, but it seems unlikely that this will ever happen. to_cr_tangent(t::FunctionWrapperTangent) = t -@is_primitive MinimalCtx Tuple{Type{<:FunctionWrapper}, Any} -function rrule!!(::CoDual{Type{FunctionWrapper{R, A}}}, obj::CoDual{P}) where {R, A, P} +@is_primitive MinimalCtx Tuple{Type{<:FunctionWrapper},Any} +function rrule!!(::CoDual{Type{FunctionWrapper{R,A}}}, obj::CoDual{P}) where {R,A,P} t, obj_tangent_ref = _function_wrapper_tangent(R, obj.x, A, zero_tangent(obj.x, obj.dx)) function_wrapper_pb(::NoRData) = NoRData(), rdata(obj_tangent_ref[]) - return CoDual(FunctionWrapper{R, A}(obj.x), t), function_wrapper_pb + return CoDual(FunctionWrapper{R,A}(obj.x), t), function_wrapper_pb end -@is_primitive MinimalCtx Tuple{<:FunctionWrapper, Vararg} +@is_primitive MinimalCtx Tuple{<:FunctionWrapper,Vararg} function rrule!!(f::CoDual{<:FunctionWrapper}, x::Vararg{CoDual}) y, pb = f.dx.fwds_wrapper(x...) function_wrapper_eval_pb(dy) = pb(dy) @@ -163,8 +161,8 @@ end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:function_wrappers}) test_cases = Any[ - (false, :none, nothing, FunctionWrapper{Float64, Tuple{Float64}}, sin), - (false, :none, nothing, FunctionWrapper{Float64, Tuple{Float64}}(sin), 5.0), + (false, :none, nothing, FunctionWrapper{Float64,Tuple{Float64}}, sin), + (false, :none, nothing, FunctionWrapper{Float64,Tuple{Float64}}(sin), 5.0), ] memory = Any[] return test_cases, memory @@ -173,29 +171,35 @@ end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:function_wrappers}) test_cases = Any[ ( - false, :none, nothing, - function(x, y) - p = FunctionWrapper{Float64, Tuple{Float64}}(x -> x * y) + false, + :none, + nothing, + function (x, y) + p = FunctionWrapper{Float64,Tuple{Float64}}(x -> x * y) out = 0.0 for _ in 1:1_000 out += p(x) end return out end, - 5.0, 4.0, + 5.0, + 4.0, ), ( - false, :none, nothing, - function(x::Vector{Float64}, y::Float64) - p = FunctionWrapper{Float64, Tuple{Float64}}(x -> x * y) + false, + :none, + nothing, + function (x::Vector{Float64}, y::Float64) + p = FunctionWrapper{Float64,Tuple{Float64}}(x -> x * y) out = 0.0 for _x in x out += p(_x) end return out end, - randn(100), randn(), + randn(100), + randn(), ), - ] + ] return test_cases, Any[] end diff --git a/src/rrules/iddict.jl b/src/rrules/iddict.jl index 74a2ea67b..e61291d84 100644 --- a/src/rrules/iddict.jl +++ b/src/rrules/iddict.jl @@ -1,11 +1,11 @@ # We're going to use `IdDict`s to represent tangents for `IdDict`s. -tangent_type(::Type{<:IdDict{K, V}}) where {K, V} = IdDict{K, tangent_type(V)} -function randn_tangent(rng::AbstractRNG, d::IdDict{K, V}) where {K, V} - return IdDict{K, tangent_type(V)}([k => randn_tangent(rng, v) for (k, v) in d]) +tangent_type(::Type{<:IdDict{K,V}}) where {K,V} = IdDict{K,tangent_type(V)} +function randn_tangent(rng::AbstractRNG, d::IdDict{K,V}) where {K,V} + return IdDict{K,tangent_type(V)}([k => randn_tangent(rng, v) for (k, v) in d]) end -function zero_tangent(d::IdDict{K, V}) where {K, V} - return IdDict{K, tangent_type(V)}([k => zero_tangent(v) for (k, v) in d]) +function zero_tangent(d::IdDict{K,V}) where {K,V} + return IdDict{K,tangent_type(V)}([k => zero_tangent(v) for (k, v) in d]) end function increment!!(p::T, q::T) where {T<:IdDict} @@ -20,17 +20,17 @@ function set_to_zero!!(t::IdDict) end return t end -function _scale(a::Float64, t::IdDict{K, V}) where {K, V} - return IdDict{K, V}([k => _scale(a, v) for (k, v) in t]) +function _scale(a::Float64, t::IdDict{K,V}) where {K,V} + return IdDict{K,V}([k => _scale(a, v) for (k, v) in t]) end _dot(p::T, q::T) where {T<:IdDict} = sum([_dot(p[k], q[k]) for k in keys(p)]; init=0.0) -function _add_to_primal(p::IdDict{K, V}, t::IdDict{K}, unsafe::Bool) where {K, V} +function _add_to_primal(p::IdDict{K,V}, t::IdDict{K}, unsafe::Bool) where {K,V} ks = intersect(keys(p), keys(t)) - return IdDict{K, V}([k => _add_to_primal(p[k], t[k], unsafe) for k in ks]) + return IdDict{K,V}([k => _add_to_primal(p[k], t[k], unsafe) for k in ks]) end -function _diff(p::P, q::P) where {K, V, P<:IdDict{K, V}} +function _diff(p::P, q::P) where {K,V,P<:IdDict{K,V}} @assert union(keys(p), keys(q)) == keys(p) - return IdDict{K, tangent_type(V)}([k => _diff(p[k], q[k]) for k in keys(p)]) + return IdDict{K,tangent_type(V)}([k => _diff(p[k], q[k]) for k in keys(p)]) end function TestUtils.populate_address_map!(m::TestUtils.AddressMap, p::IdDict, t::IdDict) k = pointer_from_objref(p) @@ -41,7 +41,7 @@ function TestUtils.populate_address_map!(m::TestUtils.AddressMap, p::IdDict, t:: return m end function TestUtils.has_equal_data_internal( - p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} + p::P, q::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} ) where {P<:IdDict} ks = union(keys(p), keys(q)) ks != keys(p) && return false @@ -61,16 +61,15 @@ tangent(f::IdDict, ::NoRData) = f # All of the rules in here are provided in order to avoid nasty `:ccall`s, and to support # standard built-in functionality on `IdDict`s. -@is_primitive MinimalCtx Tuple{typeof(Base.rehash!), IdDict, Any} +@is_primitive MinimalCtx Tuple{typeof(Base.rehash!),IdDict,Any} function rrule!!(::CoDual{typeof(Base.rehash!)}, d::CoDual{<:IdDict}, newsz::CoDual) Base.rehash!(primal(d), primal(newsz)) Base.rehash!(tangent(d), primal(newsz)) return d, NoPullback((NoRData(), NoRData(), NoRData())) end -@is_primitive MinimalCtx Tuple{typeof(setindex!), IdDict, Any, Any} -function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) where {K, V} - +@is_primitive MinimalCtx Tuple{typeof(setindex!),IdDict,Any,Any} +function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) where {K,V} k = primal(key) restore_state = in(k, keys(primal(d))) if restore_state @@ -102,10 +101,10 @@ function rrule!!(::CoDual{typeof(setindex!)}, d::CoDual{IdDict{K,V}}, val, key) return d, setindex_pb!! end -@is_primitive MinimalCtx Tuple{typeof(get), IdDict, Any, Any} +@is_primitive MinimalCtx Tuple{typeof(get),IdDict,Any,Any} function rrule!!( - ::CoDual{typeof(get)}, d::CoDual{IdDict{K, V}}, key::CoDual, default::CoDual -) where {K, V} + ::CoDual{typeof(get)}, d::CoDual{IdDict{K,V}}, key::CoDual, default::CoDual +) where {K,V} k = primal(key) has_key = in(k, keys(primal(d))) y = has_key ? CoDual(primal(d)[k], fdata(tangent(d)[k])) : default @@ -125,32 +124,31 @@ function rrule!!( return y, get_pb!! end -@is_primitive MinimalCtx Tuple{typeof(getindex), IdDict, Any} +@is_primitive MinimalCtx Tuple{typeof(getindex),IdDict,Any} function rrule!!( - ::CoDual{typeof(getindex)}, d::CoDual{IdDict{K, V}}, key::CoDual -) where {K, V} + ::CoDual{typeof(getindex)}, d::CoDual{IdDict{K,V}}, key::CoDual +) where {K,V} k = primal(key) y = CoDual(getindex(primal(d), k), fdata(getindex(tangent(d), k))) dkey = lazy_zero_rdata(primal(key)) dd = tangent(d) function getindex_pb!!(dy) - dd[k] = increment_rdata!!(dd[k], dy) + dd[k] = increment_rdata!!(dd[k], dy) return NoRData(), NoRData(), instantiate(dkey) end return y, getindex_pb!! end -for name in [ - :(:jl_idtable_rehash), :(:jl_eqtable_put), :(:jl_eqtable_get), :(:jl_eqtable_nextind), -] +for name in + [:(:jl_idtable_rehash), :(:jl_eqtable_put), :(:jl_eqtable_get), :(:jl_eqtable_nextind)] @eval function rrule!!(::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{$name}}, args...) - unexepcted_foreigncall_error($name) + return unexepcted_foreigncall_error($name) end end -@is_primitive MinimalCtx Tuple{Type{IdDict{K, V}} where {K, V}} -function rrule!!(f::CoDual{Type{IdDict{K, V}}}) where {K, V} - return CoDual(IdDict{K, V}(), IdDict{K, tangent_type(V)}()), NoPullback(f) +@is_primitive MinimalCtx Tuple{Type{IdDict{K,V}} where {K,V}} +function rrule!!(f::CoDual{Type{IdDict{K,V}}}) where {K,V} + return CoDual(IdDict{K,V}(), IdDict{K,tangent_type(V)}()), NoPullback(f) end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:iddict}) @@ -161,7 +159,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:iddict}) (false, :none, nothing, get, IdDict(true => 5.0, false => 4.0), false, 2.0), (false, :none, nothing, get, IdDict(true => 5.0), false, 2.0), (false, :none, nothing, getindex, IdDict(true => 5.0, false => 4.0), true), - (false, :none, nothing, IdDict{Any, Any}), + (false, :none, nothing, IdDict{Any,Any}), ] memory = Any[] return test_cases, memory diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index 9578a07ef..48571a8e3 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -13,7 +13,7 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) _LDA::CoDual{$TInt}, # leading dimension of A _IPIV::CoDual{$TInt}, # pivot indices _INFO::CoDual{$TInt}, # some info of some kind - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} GC.@preserve args begin # Extract names. @@ -34,8 +34,15 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) # Run the primal. ccall( - $(blas_name(fname)), Cvoid, ($TInt, $TInt, Ptr{$elty}, $TInt, $TInt, $TInt), - M, N, A, LDA, IPIV, INFO, + $(blas_name(fname)), + Cvoid, + ($TInt, $TInt, Ptr{$elty}, $TInt, $TInt, $TInt), + M, + N, + A, + LDA, + IPIV, + INFO, ) ipiv_vec = copy(unsafe_wrap(Array, IPIV, N_val)) @@ -46,7 +53,6 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) dA = tangent(_A) function getrf_pb!!(::NoRData) - GC.@preserve args begin # Run reverse-pass. L, U = UnitLowerTriangular(A_mat), UpperTriangular(A_mat) @@ -71,7 +77,6 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) end for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) - TInt = :(Ptr{BLAS.BlasInt}) @eval @inline function rrule!!( ::CoDual{typeof(_foreigncall_)}, @@ -90,15 +95,14 @@ for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BLAS.BlasInt}}, _info::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Load in data. ul_p, tA_p, diag_p = map(primal, (_ul, _tA, _diag)) N_p, Nrhs_p, lda_p, ldb_p, info_p = map(primal, (_N, _Nrhs, _lda, _ldb, _info)) ul, tA, diag, N, Nrhs, lda, ldb, info = map( - unsafe_load, (ul_p, tA_p, diag_p, N_p, Nrhs_p, lda_p, ldb_p, info_p), + unsafe_load, (ul_p, tA_p, diag_p, N_p, Nrhs_p, lda_p, ldb_p, info_p) ) A = wrap_ptr_as_view(primal(_A), lda, N, N) @@ -110,19 +114,39 @@ for (fname, elty) in ((:dtrtrs_, :Float64), (:strtrs_, :Float32)) $(blas_name(fname)), Cvoid, ( - Ptr{UInt8}, Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, - Clong, Clong, Clong, + Ptr{UInt8}, + Ptr{UInt8}, + Ptr{UInt8}, + Ptr{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{BlasInt}, + Ptr{BlasInt}, + Clong, + Clong, + Clong, ), - ul_p, tA_p, diag_p, N_p, Nrhs_p, primal(_A), lda_p, primal(_B),ldb_p, info_p, - 1, 1, 1, + ul_p, + tA_p, + diag_p, + N_p, + Nrhs_p, + primal(_A), + lda_p, + primal(_B), + ldb_p, + info_p, + 1, + 1, + 1, ) end _dA = tangent(_A) _dB = tangent(_B) function trtrs_pb!!(::NoRData) - GC.@preserve args begin # Compute cotangent of B. dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) @@ -163,13 +187,14 @@ for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Load in values. tA = Char(unsafe_load(primal(_tA))) - N, Nrhs, lda, ldb, info = map(unsafe_load ∘ primal, (_N, _Nrhs, _lda, _ldb, _info)) + N, Nrhs, lda, ldb, info = map( + unsafe_load ∘ primal, (_N, _Nrhs, _lda, _ldb, _info) + ) ipiv = unsafe_wrap(Vector{BlasInt}, primal(_ipiv), N) A = wrap_ptr_as_view(primal(_A), lda, N, N) B = wrap_ptr_as_view(primal(_B), ldb, N, Nrhs) @@ -209,7 +234,6 @@ for (fname, elty) in ((:dgetrs_, :Float64), (:sgetrs_, :Float32)) _dA = tangent(_A) _dB = tangent(_B) function getrs_pb!!(::NoRData) - GC.@preserve args begin dA = wrap_ptr_as_view(_dA, lda, N, N) dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) @@ -266,9 +290,8 @@ for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) _work::CoDual{Ptr{$elty}}, _lwork::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Pull out data. N_p, lda_p, lwork_p, info_p = map(primal, (_N, _lda, _lwork, _info)) @@ -279,12 +302,24 @@ for (fname, elty) in ((:dgetri_, :Float64), (:sgetri_, :Float32)) # Run forwards-pass. ccall( - $(blas_name(fname)), Cvoid, + $(blas_name(fname)), + Cvoid, ( - Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, - Ptr{BlasInt}, Ptr{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{BlasInt}, + Ptr{BlasInt}, ), - N_p, A_p, lda_p, primal(_ipiv), primal(_work), lwork_p, info_p, + N_p, + A_p, + lda_p, + primal(_ipiv), + primal(_work), + lwork_p, + info_p, ) p = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, primal(_ipiv), N), N) @@ -329,9 +364,8 @@ for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) _A::CoDual{Ptr{$elty}}, _lda::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Pull out the data. uplo_p, N_p, A_p, lda_p, info_p = map(primal, (_uplo, _N, _A, _lda, _info)) @@ -343,15 +377,19 @@ for (fname, elty) in ((:dpotrf_, :Float64), (:spotrf_, :Float32)) # Run forwards-pass. ccall( - $(blas_name(fname)), Cvoid, + $(blas_name(fname)), + Cvoid, (Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}), - uplo_p, N_p, A_p, lda_p, info_p, + uplo_p, + N_p, + A_p, + lda_p, + info_p, ) end _dA = tangent(_A) function potrf_pb!!(::NoRData) - GC.@preserve args begin dA = wrap_ptr_as_view(_dA, lda, N, N) dA2 = dA @@ -395,9 +433,8 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) _B::CoDual{Ptr{$elty}}, _ldb::CoDual{Ptr{BlasInt}}, _info::CoDual{Ptr{BlasInt}}, - args::Vararg{Any, Nargs}, + args::Vararg{Any,Nargs}, ) where {Nargs} - GC.@preserve args begin # Pull out the data. uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p = map( @@ -412,19 +449,32 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32)) # Run forwards-pass. ccall( - $(blas_name(fname)), Cvoid, + $(blas_name(fname)), + Cvoid, ( - Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, - Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{BlasInt}, + Ptr{UInt8}, + Ptr{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{BlasInt}, + Ptr{BlasInt}, ), - uplo_p, N_p, Nrhs_p, A_p, lda_p, B_p, ldb_p, info_p, + uplo_p, + N_p, + Nrhs_p, + A_p, + lda_p, + B_p, + ldb_p, + info_p, ) end _dA = tangent(_A) _dB = tangent(_B) function potrs_pb!!(::NoRData) - GC.@preserve args begin dA = wrap_ptr_as_view(_dA, lda, N, N) dB = wrap_ptr_as_view(_dB, ldb, N, Nrhs) @@ -467,72 +517,103 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) ], # trtrs - vec(reduce( - vcat, - map(product( - ['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2]) - ) do (ul, tA, diag, N, Nrhs) - As = [randn(N, N) + 10I, view(randn(15, 15) + 10I, 2:N+1, 2:N+1)] - Bs = [randn(N, Nrhs), view(randn(15, 15), 4:N+3, 3:N+2)] - return map(product(As, Bs)) do (A, B) - (false, :none, nothing, trtrs!, ul, tA, diag, A, B) - end - end, - )), + vec( + reduce( + vcat, + map( + product(['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2]) + ) do (ul, tA, diag, N, Nrhs) + As = [ + randn(N, N) + 10I, view(randn(15, 15) + 10I, 2:(N + 1), 2:(N + 1)) + ] + Bs = [randn(N, Nrhs), view(randn(15, 15), 4:(N + 3), 3:(N + 2))] + return map(product(As, Bs)) do (A, B) + (false, :none, nothing, trtrs!, ul, tA, diag, A, B) + end + end, + ), + ), # getrs - vec(reduce( - vcat, - map(product(['N', 'T'], [1, 9], [1, 2])) do (trans, N, Nrhs) - As = getrf!.([ - randn(N, N) + 5I, - view(randn(15, 15) + 5I, 2:N+1, 2:N+1), - ]) - Bs = [randn(N, Nrhs), view(randn(15, 15), 4:N+3, 3:Nrhs+2)] - return map(product(As, Bs)) do ((A, ipiv), B) - (false, :none, nothing, getrs!, trans, A, ipiv, B) - end - end, - )), + vec( + reduce( + vcat, + map(product(['N', 'T'], [1, 9], [1, 2])) do (trans, N, Nrhs) + As = + getrf!.([ + randn(N, N) + 5I, view(randn(15, 15) + 5I, 2:(N + 1), 2:(N + 1)) + ]) + Bs = [randn(N, Nrhs), view(randn(15, 15), 4:(N + 3), 3:(Nrhs + 2))] + return map(product(As, Bs)) do ((A, ipiv), B) + (false, :none, nothing, getrs!, trans, A, ipiv, B) + end + end, + ), + ), # getri - vec(reduce( - vcat, - map([1, 9]) do N - As = getrf!.([randn(N, N) + 5I, view(randn(15, 15) + I, 2:N+1, 2:N+1)]) - As = getrf!.([randn(N, N) + 5I]) - return map(As) do (A, ipiv) - (false, :none, nothing, getri!, A, ipiv) - end - end, - )), + vec( + reduce( + vcat, + map([1, 9]) do N + As = + getrf!.([ + randn(N, N) + 5I, view(randn(15, 15) + I, 2:(N + 1), 2:(N + 1)) + ]) + As = getrf!.([randn(N, N) + 5I]) + return map(As) do (A, ipiv) + (false, :none, nothing, getri!, A, ipiv) + end + end, + ), + ), # potrf - vec(reduce( - vcat, - map([1, 3, 9]) do N - X = randn(N, N) - A = X * X' + I - return Any[ - (false, :none, nothing, potrf!, 'L', A), - (false, :none, nothing, potrf!, 'U', A), - ] - end, - )), + vec( + reduce( + vcat, + map([1, 3, 9]) do N + X = randn(N, N) + A = X * X' + I + return Any[ + (false, :none, nothing, potrf!, 'L', A), + (false, :none, nothing, potrf!, 'U', A), + ] + end, + ), + ), # potrs - vec(reduce( - vcat, - map(product([1, 3, 9], [1, 2])) do (N, Nrhs) - X = randn(N, N) - A = X * X' + I - B = randn(N, Nrhs) - return Any[ - (false, :none, nothing, potrs!, 'L', potrf!('L', copy(A))[1], copy(B)), - (false, :none, nothing, potrs!, 'U', potrf!('U', copy(A))[1], copy(B)), - ] - end, - )), + vec( + reduce( + vcat, + map(product([1, 3, 9], [1, 2])) do (N, Nrhs) + X = randn(N, N) + A = X * X' + I + B = randn(N, Nrhs) + return Any[ + ( + false, + :none, + nothing, + potrs!, + 'L', + potrf!('L', copy(A))[1], + copy(B), + ), + ( + false, + :none, + nothing, + potrs!, + 'U', + potrf!('U', copy(A))[1], + copy(B), + ), + ] + end, + ), + ), ) memory = Any[] return test_cases, memory diff --git a/src/rrules/linear_algebra.jl b/src/rrules/linear_algebra.jl index dc628d2bc..ce5e65ae1 100644 --- a/src/rrules/linear_algebra.jl +++ b/src/rrules/linear_algebra.jl @@ -1,4 +1,4 @@ -@is_primitive MinimalCtx Tuple{typeof(exp), Matrix{<:IEEEFloat}} +@is_primitive MinimalCtx Tuple{typeof(exp),Matrix{<:IEEEFloat}} struct ExpPullback{P} pb @@ -20,8 +20,7 @@ end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:linear_algebra}) test_cases = Any[ - (false, :none, nothing, exp, randn(3, 3)), - (false, :none, nothing, exp, randn(7, 7)), + (false, :none, nothing, exp, randn(3, 3)), (false, :none, nothing, exp, randn(7, 7)) ] memory = Any[] return test_cases, memory diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index 9569ccfe8..2c297ff83 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -10,7 +10,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) dx = DiffRules.diffrule(M, f, :x) pb_name = Symbol("$(M).$(f)_pb!!") @eval begin - @is_primitive MinimalCtx Tuple{typeof($M.$f), P} where {P<:IEEEFloat} + @is_primitive MinimalCtx Tuple{typeof($M.$f),P} where {P<:IEEEFloat} function rrule!!(::CoDual{typeof($M.$f)}, _x::CoDual{P}) where {P<:IEEEFloat} x = primal(_x) # needed for dx expression $pb_name(ȳ::P) = NoRData(), ȳ * $dx @@ -21,7 +21,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) da, db = DiffRules.diffrule(M, f, :a, :b) pb_name = Symbol("$(M).$(f)_pb!!") @eval begin - @is_primitive MinimalCtx Tuple{typeof($M.$f), P, P} where {P<:IEEEFloat} + @is_primitive MinimalCtx Tuple{typeof($M.$f),P,P} where {P<:IEEEFloat} function rrule!!( ::CoDual{typeof($M.$f)}, _a::CoDual{P}, _b::CoDual{P} ) where {P<:IEEEFloat} @@ -34,21 +34,21 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end end -@is_primitive MinimalCtx Tuple{typeof(sin), <:IEEEFloat} -function rrule!!(::CoDual{typeof(sin), NoFData}, x::CoDual{P, NoFData}) where {P<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(sin),<:IEEEFloat} +function rrule!!(::CoDual{typeof(sin),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) sin_pullback!!(dy::P) = NoRData(), dy * c return CoDual(s, NoFData()), sin_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(cos), <:IEEEFloat} -function rrule!!(::CoDual{typeof(cos), NoFData}, x::CoDual{P, NoFData}) where {P<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(cos),<:IEEEFloat} +function rrule!!(::CoDual{typeof(cos),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) cos_pullback!!(dy::P) = NoRData(), -dy * s return CoDual(c, NoFData()), cos_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(exp), <:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(exp),<:IEEEFloat} function rrule!!(::CoDual{typeof(exp)}, x::CoDual{P}) where {P<:IEEEFloat} y = exp(primal(x)) exp_pb!!(dy::P) = NoRData(), dy * y @@ -56,23 +56,23 @@ function rrule!!(::CoDual{typeof(exp)}, x::CoDual{P}) where {P<:IEEEFloat} end rand_inputs(rng, P::Type{<:IEEEFloat}, f, arity) = randn(rng, P, arity) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acosh), _) = (rand(rng) + 1 + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asech), _) = (rand(rng) * 0.9, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log), _) = (rand(rng) + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asin), _) = (rand(rng) * 0.9, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asecd), _) = (rand(rng) + 1, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log2), _) = (rand(rng) + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log10), _) = (rand(rng) + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acscd), _) = (rand(rng) + 1 + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log1p), _) = (rand(rng) + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acsc), _) = (rand(rng) + 1 + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(atanh), _) = (2 * 0.9 * rand(rng) - 0.9, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acoth), _) = (rand(rng) + 1 + 1e-3, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asind), _) = (0.9 * rand(rng), ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asec), _) = (rand(rng) + 1.001, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acosd), _) = (2 * 0.9 * rand(rng) - 0.9, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acos), _) = (2 * 0.9 * rand(rng) - 0.9, ) -rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(sqrt), _) = (rand(rng) + 1e-3, ) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acosh), _) = (rand(rng) + 1 + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asech), _) = (rand(rng) * 0.9,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log), _) = (rand(rng) + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asin), _) = (rand(rng) * 0.9,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asecd), _) = (rand(rng) + 1,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log2), _) = (rand(rng) + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log10), _) = (rand(rng) + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acscd), _) = (rand(rng) + 1 + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log1p), _) = (rand(rng) + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acsc), _) = (rand(rng) + 1 + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(atanh), _) = (2 * 0.9 * rand(rng) - 0.9,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acoth), _) = (rand(rng) + 1 + 1e-3,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asind), _) = (0.9 * rand(rng),) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asec), _) = (rand(rng) + 1.001,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acosd), _) = (2 * 0.9 * rand(rng) - 0.9,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acos), _) = (2 * 0.9 * rand(rng) - 0.9,) +rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(sqrt), _) = (rand(rng) + 1e-3,) function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_maths}) rng = Xoshiro(123) @@ -80,14 +80,20 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_mat foreach(DiffRules.diffrules(; filter_modules=nothing)) do (M, f, arity) if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) || M == :SpecialFunctions - return # Skip rules for methods not defined in the current scope + return nothing # Skip rules for methods not defined in the current scope end - arity > 2 && return - (f == :rem2pi || f == :ldexp || f == :(^)) && return - (f in [:+, :*, :sin, :cos, :exp, :-, :abs2, :inv, :abs, :/, :\]) && return # use other functionality to implement these + arity > 2 && return nothing + (f == :rem2pi || f == :ldexp || f == :(^)) && return nothing + (f in [:+, :*, :sin, :cos, :exp, :-, :abs2, :inv, :abs, :/, :\]) && return nothing # use other functionality to implement these f = @eval $M.$f - push!(test_cases, (false, :stability, nothing, f, rand_inputs(rng, Float64, f, arity)...)) - push!(test_cases, (true, :stability, nothing, f, rand_inputs(rng, Float32, f, arity)...)) + push!( + test_cases, + (false, :stability, nothing, f, rand_inputs(rng, Float64, f, arity)...), + ) + push!( + test_cases, + (true, :stability, nothing, f, rand_inputs(rng, Float32, f, arity)...), + ) end # test cases for additional rules written in this file. diff --git a/src/rrules/memory.jl b/src/rrules/memory.jl index 86cce8818..c60850f1d 100644 --- a/src/rrules/memory.jl +++ b/src/rrules/memory.jl @@ -12,12 +12,11 @@ # Tangent Interface Implementation -const Maybe{T} = Union{Nothing, T} +const Maybe{T} = Union{Nothing,T} tangent_type(::Type{<:Memory{P}}) where {P} = Memory{tangent_type(P)} function zero_tangent_internal(x::Memory{P}, stackdict::Maybe{IdDict}) where {P} - T = tangent_type(typeof(x)) # If no stackdict is provided, then the caller promises that there is no need for it. @@ -47,7 +46,7 @@ function randn_tangent_internal(rng::AbstractRNG, x::Memory, stackdict::Maybe{Id end function TestUtils.has_equal_data_internal( - x::Memory{P}, y::Memory{P}, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} + x::Memory{P}, y::Memory{P}, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} ) where {P} length(x) == length(y) || return false id_pair = (objectid(x), objectid(y)) @@ -79,7 +78,7 @@ function _add_to_primal(p::Memory{P}, t::Memory, unsafe::Bool) where {P} end function _diff(p::Memory{P}, q::Memory{P}) where {P} - return _map_if_assigned!(_diff, Memory{tangent_type(P)}(undef, length(p)), p ,q) + return _map_if_assigned!(_diff, Memory{tangent_type(P)}(undef, length(p)), p, q) end function _dot(t::Memory{T}, s::Memory{T}) where {T} @@ -112,9 +111,10 @@ tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Memory} = F tangent(f::Memory, ::NoRData) = f -function _verify_fdata_value(p::Memory{P}, f::Memory{F}) where {P, F} +function _verify_fdata_value(p::Memory{P}, f::Memory{F}) where {P,F} if length(p) != length(f) - msg = "length(p) == $(length(p)) but length(f) == $(length(f)). " * + msg = + "length(p) == $(length(p)) but length(f) == $(length(f)). " * "p isa Memory{$P} and f isa Memory{$F}" throw(error(msg)) end @@ -174,8 +174,8 @@ function _dot(t::T, s::T) where {T<:Array} ) end -function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}, unsafe::Bool) where {P, N} - x′ = Array{P, N}(undef, size(x)...) +function _add_to_primal(x::Array{P,N}, t::Array{<:Any,N}, unsafe::Bool) where {P,N} + x′ = Array{P,N}(undef, size(x)...) return _map_if_assigned!((x, t) -> _add_to_primal(x, t, unsafe), x′, x, t) end @@ -186,7 +186,7 @@ end # Rules @is_primitive( - MinimalCtx, Tuple{typeof(unsafe_copyto!), MemoryRef{P}, MemoryRef{P}, Int} where {P} + MinimalCtx, Tuple{typeof(unsafe_copyto!),MemoryRef{P},MemoryRef{P},Int} where {P} ) function rrule!!( ::CoDual{typeof(unsafe_copyto!)}, @@ -261,7 +261,7 @@ function randn_tangent_internal(rng::AbstractRNG, x::MemoryRef, stackdict::Maybe end function TestUtils.has_equal_data_internal( - x::MemoryRef{P}, y::MemoryRef{P}, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool} + x::MemoryRef{P}, y::MemoryRef{P}, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} ) where {P} equal_refs = Core.memoryrefoffset(x) == Core.memoryrefoffset(y) equal_data = TestUtils.has_equal_data_internal(x.mem, y.mem, equal_undefs, d) @@ -305,8 +305,8 @@ tangent_type(::Type{<:MemoryRef{T}}, ::Type{NoRData}) where {T} = MemoryRef{T} tangent(f::MemoryRef, ::NoRData) = f -function _verify_fdata_value(p::MemoryRef{P}, f::MemoryRef{T}) where {P, T} - _verify_fdata_value(p.mem, f.mem) +function _verify_fdata_value(p::MemoryRef{P}, f::MemoryRef{T}) where {P,T} + return _verify_fdata_value(p.mem, f.mem) end # @@ -317,17 +317,15 @@ _val(::Val{c}) where {c} = c using Core: memoryref_isassigned, memoryrefget, memoryrefset!, memoryrefnew, memoryrefoffset -@zero_adjoint( - MinimalCtx, Tuple{typeof(memoryref_isassigned), GenericMemoryRef, Symbol, Bool} -) +@zero_adjoint(MinimalCtx, Tuple{typeof(memoryref_isassigned),GenericMemoryRef,Symbol,Bool}) @inline function lmemoryrefget( x::MemoryRef, ::Val{ordering}, ::Val{boundscheck} -) where {ordering, boundscheck} +) where {ordering,boundscheck} return memoryrefget(x, ordering, boundscheck) end -@is_primitive MinimalCtx Tuple{typeof(lmemoryrefget), MemoryRef, Val, Val} +@is_primitive MinimalCtx Tuple{typeof(lmemoryrefget),MemoryRef,Val,Val} @inline function rrule!!( ::CoDual{typeof(lmemoryrefget)}, x::CoDual{<:MemoryRef}, @@ -386,32 +384,32 @@ end return CoDual(y, dy), NoPullback(f, x, ii, boundscheck) end -@zero_adjoint MinimalCtx Tuple{typeof(memoryrefoffset), GenericMemoryRef} +@zero_adjoint MinimalCtx Tuple{typeof(memoryrefoffset),GenericMemoryRef} # Core.memoryrefreplace! @inline function lmemoryrefset!( x::MemoryRef, value, ::Val{ordering}, ::Val{boundscheck} -) where {ordering, boundscheck} +) where {ordering,boundscheck} return memoryrefset!(x, value, ordering, boundscheck) end -@is_primitive MinimalCtx Tuple{typeof(lmemoryrefset!), MemoryRef, Any, Val, Val} +@is_primitive MinimalCtx Tuple{typeof(lmemoryrefset!),MemoryRef,Any,Val,Val} @inline function rrule!!( ::CoDual{typeof(lmemoryrefset!)}, - x::CoDual{<:MemoryRef{P}, <:MemoryRef{V}}, + x::CoDual{<:MemoryRef{P},<:MemoryRef{V}}, value::CoDual, _ordering::CoDual{<:Val}, _boundscheck::CoDual{<:Val}, -) where {P, V} +) where {P,V} ordering = primal(_ordering) bc = primal(_boundscheck) isbitstype(P) && return isbits_lmemoryrefset!_rule(x, value, ordering, bc) to_save = isassigned(x.x) - old_x = Ref{Tuple{P, V}}() + old_x = Ref{Tuple{P,V}}() if to_save old_x[] = ( memoryrefget(x.x, _val(ordering), _val(bc)), @@ -452,11 +450,11 @@ end @inline function rrule!!( ::CoDual{typeof(memoryrefset!)}, - x::CoDual{<:MemoryRef{P}, <:MemoryRef{V}}, + x::CoDual{<:MemoryRef{P},<:MemoryRef{V}}, value::CoDual, ordering::CoDual{Symbol}, boundscheck::CoDual{Bool}, -) where {P, V} +) where {P,V} y, adj = rrule!!( zero_fcodual(lmemoryrefset!), x, @@ -474,11 +472,9 @@ end # _new_ and _new_-adjacent rules for Memory, MemoryRef, and Array. -@is_primitive MinimalCtx Tuple{Type{<:Memory}, UndefInitializer, Int} +@is_primitive MinimalCtx Tuple{Type{<:Memory},UndefInitializer,Int} function rrule!!( - ::CoDual{Type{Memory{P}}}, - ::CoDual{UndefInitializer}, - n::CoDual{Int}, + ::CoDual{Type{Memory{P}}}, ::CoDual{UndefInitializer}, n::CoDual{Int} ) where {P} x = Memory{P}(undef, primal(n)) dx = zero_tangent_internal(x, nothing) @@ -498,12 +494,12 @@ end function rrule!!( ::CoDual{typeof(_new_)}, - ::CoDual{Type{Array{P, N}}}, + ::CoDual{Type{Array{P,N}}}, ref::CoDual{MemoryRef{P}}, - size::CoDual{<:NTuple{N, Int}}, -) where {P, N} - y = _new_(Array{P, N}, ref.x, size.x) - dy = _new_(Array{tangent_type(P), N}, ref.dx, size.x) + size::CoDual{<:NTuple{N,Int}}, +) where {P,N} + y = _new_(Array{P,N}, ref.x, size.x) + dy = _new_(Array{tangent_type(P),N}, ref.dx, size.x) return CoDual(y, dy), NoPullback(ntuple(_ -> NoRData(), 4)) end @@ -511,10 +507,10 @@ end function rrule!!( ::CoDual{typeof(lgetfield)}, - x::CoDual{<:Memory, <:Memory}, + x::CoDual{<:Memory,<:Memory}, ::CoDual{Val{name}}, ::CoDual{Val{order}}, -) where {name, order} +) where {name,order} y = getfield(primal(x), name, order) wants_length = name === 1 || name === :length dy = wants_length ? NoFData() : bitcast(Ptr{NoTangent}, x.dx.ptr) @@ -523,10 +519,10 @@ end function rrule!!( ::CoDual{typeof(lgetfield)}, - x::CoDual{<:MemoryRef, <:MemoryRef}, + x::CoDual{<:MemoryRef,<:MemoryRef}, ::CoDual{Val{name}}, ::CoDual{Val{order}}, -) where {name, order} +) where {name,order} y = getfield(primal(x), name, order) wants_offset = name === 1 || name === :ptr_or_offset dy = wants_offset ? bitcast(Ptr{NoTangent}, x.dx.ptr_or_offset) : x.dx.mem @@ -535,20 +531,20 @@ end function rrule!!( ::CoDual{typeof(lgetfield)}, - x::CoDual{<:Array, <:Array}, + x::CoDual{<:Array,<:Array}, ::CoDual{Val{name}}, ::CoDual{Val{order}}, -) where {name, order} +) where {name,order} y = getfield(primal(x), name, order) wants_size = name === 2 || name === :size dy = wants_size ? NoFData() : x.dx.ref return CoDual(y, dy), NoPullback(ntuple(_ -> NoRData(), 4)) end -const _MemTypes = Union{Memory, MemoryRef, Array} +const _MemTypes = Union{Memory,MemoryRef,Array} function rrule!!( - f::CoDual{typeof(lgetfield)}, x::CoDual{<:_MemTypes, <:_MemTypes}, name::CoDual{<:Val} + f::CoDual{typeof(lgetfield)}, x::CoDual{<:_MemTypes,<:_MemTypes}, name::CoDual{<:Val} ) y, adj = rrule!!(f, x, name, zero_fcodual(Val(:not_atomic))) ternary_lgetfield_adjoint(dy) = adj(dy)[1:3] @@ -556,8 +552,9 @@ function rrule!!( end function rrule!!( - ::CoDual{typeof(getfield)}, x::CoDual{<:_MemTypes, <:_MemTypes}, - name::CoDual{<:Union{Int, Symbol}}, + ::CoDual{typeof(getfield)}, + x::CoDual{<:_MemTypes,<:_MemTypes}, + name::CoDual{<:Union{Int,Symbol}}, order::CoDual{Symbol}, ) y, adj = rrule!!( @@ -572,8 +569,8 @@ end function rrule!!( f::CoDual{typeof(getfield)}, - x::CoDual{<:_MemTypes, <:_MemTypes}, - name::CoDual{<:Union{Int, Symbol}}, + x::CoDual{<:_MemTypes,<:_MemTypes}, + name::CoDual{<:Union{Int,Symbol}}, ) y, adj = rrule!!(f, x, name, zero_fcodual(:not_atomic)) ternary_getfield_adjoint(dy) = adj(dy)[1:3] @@ -582,7 +579,7 @@ end @inline function rrule!!( ::CoDual{typeof(lsetfield!)}, - value::CoDual{<:Array, <:Array}, + value::CoDual{<:Array,<:Array}, ::CoDual{Val{name}}, x::CoDual, ) where {name} @@ -600,7 +597,7 @@ end # Misc. other rules which are required for correctness. -@is_primitive MinimalCtx Tuple{typeof(copy), Array} +@is_primitive MinimalCtx Tuple{typeof(copy),Array} function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) dx = tangent(a) dy = copy(dx) @@ -612,11 +609,11 @@ function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array}) return y, copy_pullback!! end -@is_primitive MinimalCtx Tuple{typeof(fill!), Array{<:Union{UInt8, Int8}}, Integer} -@is_primitive MinimalCtx Tuple{typeof(fill!), Memory{<:Union{UInt8, Int8}}, Integer} +@is_primitive MinimalCtx Tuple{typeof(fill!),Array{<:Union{UInt8,Int8}},Integer} +@is_primitive MinimalCtx Tuple{typeof(fill!),Memory{<:Union{UInt8,Int8}},Integer} function rrule!!( ::CoDual{typeof(fill!)}, a::CoDual{T}, x::CoDual{<:Integer} -) where {V<:Union{UInt8, Int8}, T<:Union{Array{V}, Memory{V}}} +) where {V<:Union{UInt8,Int8},T<:Union{Array{V},Memory{V}}} pa = primal(a) old_value = copy(pa) fill!(pa, primal(x)) @@ -689,41 +686,64 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:memory}) [(false, :none, nothing, getfield, m, 1) for m in mems], # Rules for `MemoryRef` - [(false, :none, nothing, memoryref_isassigned, mem_ref, :not_atomic, bc) for + [ + (false, :none, nothing, memoryref_isassigned, mem_ref, :not_atomic, bc) for mem_ref in mem_refs for bc in [false, true] ], - [(false, :none, nothing, memoryrefget, mem_ref, :not_atomic, bc) for + [ + (false, :none, nothing, memoryrefget, mem_ref, :not_atomic, bc) for mem_ref in filter(isassigned, mem_refs) for bc in [false, true] ], [(false, :none, nothing, memoryrefnew, mem) for mem in mems], - [(false, :none, nothing, memoryrefnew, mem, 1) for + [ + (false, :none, nothing, memoryrefnew, mem, 1) for mem in filter(x -> length(x.mem) > Core.memoryrefoffset(x), mem_refs) ], - [(false, :none, nothing, memoryrefnew, mem, 1, bc) for + [ + (false, :none, nothing, memoryrefnew, mem, 1, bc) for mem in filter(x -> length(x.mem) > Core.memoryrefoffset(x), mem_refs) for bc in [false, true] ], [(false, :none, nothing, memoryrefoffset, mem_ref) for mem_ref in mem_refs], [ - (false, :none, nothing, lmemoryrefset!, mem_ref, sample_value, Val(:not_atomic), bc) for - (mem_ref, sample_value) in assignable_refs for + ( + false, + :none, + nothing, + lmemoryrefset!, + mem_ref, + sample_value, + Val(:not_atomic), + bc, + ) for (mem_ref, sample_value) in assignable_refs for bc in [Val(false), Val(true)] ], [ - (false, :none, nothing, memoryrefset!, mem_ref, sample_value, :not_atomic, bc) for - (mem_ref, sample_value) in assignable_refs for - bc in [false, true] + (false, :none, nothing, memoryrefset!, mem_ref, sample_value, :not_atomic, bc) + for (mem_ref, sample_value) in assignable_refs for bc in [false, true] ], - (false, :stability, nothing, unsafe_copyto!, randn(rng, 10).ref, randn(rng, 8).ref, 5), ( - false, :stability, nothing, + false, + :stability, + nothing, + unsafe_copyto!, + randn(rng, 10).ref, + randn(rng, 8).ref, + 5, + ), + ( + false, + :stability, + nothing, unsafe_copyto!, memoryref(randn(rng, 10).ref, 2), memoryref(randn(rng, 8).ref, 3), 4, ), ( - false, :stability, nothing, + false, + :stability, + nothing, unsafe_copyto!, [randn(rng, 10), randn(rng, 5)].ref, [randn(rng, 10), randn(rng, 3)].ref, @@ -731,39 +751,64 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:memory}) ), # Rules for `Array` - (false, :stability, nothing, _new_, Vector{Float64}, randn(rng, 10).ref, (10, )), + (false, :stability, nothing, _new_, Vector{Float64}, randn(rng, 10).ref, (10,)), ( - false, :stability, nothing, + false, + :stability, + nothing, _new_, Vector{Vector{Float64}}, [randn(rng, 10), randn(rng, 5)].ref, - (2, ), + (2,), ), + (false, :none, nothing, _new_, Vector{Any}, [1, randn(rng, 5)].ref, (2,)), + (false, :stability, nothing, _new_, Matrix{Float64}, randn(rng, 12).ref, (4, 3)), ( - false, :none, nothing, + false, + :stability, + nothing, _new_, - Vector{Any}, - [1, randn(rng, 5)].ref, - (2, ), + Array{Float64,3}, + randn(rng, 12).ref, + (4, 1, 3), ), - (false, :stability, nothing, _new_, Matrix{Float64}, randn(rng, 12).ref, (4, 3)), - (false, :stability, nothing, _new_, Array{Float64, 3}, randn(rng, 12).ref, (4, 1, 3)), [ (false, :stability, nothing, lgetfield, randn(rng, 10), f) for - f in [Val(:ref), Val(:size), Val(1), Val(2)] + f in [Val(:ref), Val(:size), Val(1), Val(2)] ], - [ - (false, :none, nothing, getfield, randn(rng, 10), f) for - f in [:ref, :size, 1, 2] - ], - (false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(:ref), randn(rng, 10).ref), - (false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(1), randn(rng, 10).ref), - (false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(:size), (10, )), - (false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(2), (10, )), + [(false, :none, nothing, getfield, randn(rng, 10), f) for f in [:ref, :size, 1, 2]], + ( + false, + :stability_and_allocs, + nothing, + lsetfield!, + randn(rng, 10), + Val(:ref), + randn(rng, 10).ref, + ), + ( + false, + :stability_and_allocs, + nothing, + lsetfield!, + randn(rng, 10), + Val(1), + randn(rng, 10).ref, + ), + ( + false, + :stability_and_allocs, + nothing, + lsetfield!, + randn(rng, 10), + Val(:size), + (10,), + ), + (false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(2), (10,)), (false, :none, nothing, setfield!, randn(rng, 10), :ref, randn(rng, 10).ref), (false, :none, nothing, setfield!, randn(rng, 10), 1, randn(rng, 10).ref), - (false, :none, nothing, setfield!, randn(rng, 10), :size, (10, )), - (false, :none, nothing, setfield!, randn(rng, 10), 2, (10, )), + (false, :none, nothing, setfield!, randn(rng, 10), :size, (10,)), + (false, :none, nothing, setfield!, randn(rng, 10), 2, (10,)), ) memory = Any[] return test_cases, memory @@ -773,15 +818,15 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:memory}) rng = rng_ctor(123) x = Memory{Float64}(randn(rng, 10)) test_cases = Any[ - (true, :none, nothing, Array{Float64, 0}, undef), - (true, :none, nothing, Array{Float64, 1}, undef, 5), - (true, :none, nothing, Array{Float64, 2}, undef, 5, 4), - (true, :none, nothing, Array{Float64, 3}, undef, 5, 4, 3), - (true, :none, nothing, Array{Float64, 4}, undef, 5, 4, 3, 2), - (true, :none, nothing, Array{Float64, 5}, undef, 5, 4, 3, 2, 1), - (true, :none, nothing, Array{Float64, 0}, undef, ()), - (true, :none, nothing, Array{Float64, 4}, undef, (2, 3, 4, 5)), - (true, :none, nothing, Array{Float64, 5}, undef, (2, 3, 4, 5, 6)), + (true, :none, nothing, Array{Float64,0}, undef), + (true, :none, nothing, Array{Float64,1}, undef, 5), + (true, :none, nothing, Array{Float64,2}, undef, 5, 4), + (true, :none, nothing, Array{Float64,3}, undef, 5, 4, 3), + (true, :none, nothing, Array{Float64,4}, undef, 5, 4, 3, 2), + (true, :none, nothing, Array{Float64,5}, undef, 5, 4, 3, 2, 1), + (true, :none, nothing, Array{Float64,0}, undef, ()), + (true, :none, nothing, Array{Float64,4}, undef, (2, 3, 4, 5)), + (true, :none, nothing, Array{Float64,5}, undef, (2, 3, 4, 5, 6)), (false, :none, nothing, copy, randn(5, 4)), (false, :none, nothing, Base._deletebeg!, randn(5), 0), (false, :none, nothing, Base._deletebeg!, randn(5), 2), @@ -802,12 +847,26 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:memory}) (false, :none, nothing, sizehint!, randn(5), 10), (false, :none, nothing, unsafe_copyto!, randn(4), 2, randn(3), 1, 2), ( - false, :none, nothing, - unsafe_copyto!, [rand(3) for _ in 1:5], 2, [rand(4) for _ in 1:4], 1, 3, + false, + :none, + nothing, + unsafe_copyto!, + [rand(3) for _ in 1:5], + 2, + [rand(4) for _ in 1:4], + 1, + 3, ), ( - false, :none, nothing, - unsafe_copyto!, Vector{Any}(undef, 5), 2, Any[rand() for _ in 1:4], 1, 3, + false, + :none, + nothing, + unsafe_copyto!, + Vector{Any}(undef, 5), + 2, + Any[rand() for _ in 1:4], + 1, + 3, ), (false, :none, nothing, x -> unsafe_copyto!(memoryref(x, 1), memoryref(x), 3), x), (false, :none, nothing, x -> unsafe_copyto!(memoryref(x), memoryref(x), 3), x), diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index 7fe309c31..f6f4de643 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -6,37 +6,37 @@ # deduce that these bits of code are inactive though. # -@zero_adjoint DefaultCtx Tuple{typeof(in), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(iszero), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(isempty), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(isbitstype), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(sizeof), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(promote_type), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.elsize), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Core.Compiler.sizeof_nothrow), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_haspadding), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_nfields), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_pointerfree), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_alignment), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_fielddesc_type), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(LinearAlgebra.chkstride1), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Threads.nthreads), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.depwarn), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.reduced_indices), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.check_reducedims), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.throw_boundserror), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.Broadcast.eltypes), Vararg} -@zero_adjoint DefaultCtx Tuple{typeof(Base.eltype), Vararg} -@zero_adjoint MinimalCtx Tuple{typeof(Base.padding), DataType} -@zero_adjoint MinimalCtx Tuple{typeof(Base.padding), DataType, Int} -@zero_adjoint MinimalCtx Tuple{Type, TypeVar, Type} +@zero_adjoint DefaultCtx Tuple{typeof(in),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(iszero),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(isempty),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(isbitstype),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(sizeof),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(promote_type),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.elsize),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Core.Compiler.sizeof_nothrow),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_haspadding),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_nfields),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_pointerfree),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_alignment),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.datatype_fielddesc_type),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(LinearAlgebra.chkstride1),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Threads.nthreads),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.depwarn),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.reduced_indices),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.check_reducedims),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.throw_boundserror),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.Broadcast.eltypes),Vararg} +@zero_adjoint DefaultCtx Tuple{typeof(Base.eltype),Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(Base.padding),DataType} +@zero_adjoint MinimalCtx Tuple{typeof(Base.padding),DataType,Int} +@zero_adjoint MinimalCtx Tuple{Type,TypeVar,Type} # Required to avoid an ambiguity. -@zero_adjoint MinimalCtx Tuple{Type{Symbol}, TypeVar, Type} +@zero_adjoint MinimalCtx Tuple{Type{Symbol},TypeVar,Type} @static if VERSION >= v"1.11-" - @zero_adjoint MinimalCtx Tuple{typeof(Random.hash_seed), Vararg} - @zero_adjoint MinimalCtx Tuple{typeof(Base.dataids), Memory} + @zero_adjoint MinimalCtx Tuple{typeof(Random.hash_seed),Vararg} + @zero_adjoint MinimalCtx Tuple{typeof(Base.dataids),Memory} end """ @@ -57,10 +57,10 @@ This approach is identical to the one taken by `Zygote.jl` to circumvent the sam """ lgetfield(x, ::Val{f}) where {f} = getfield(x, f) -@is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Val} +@is_primitive MinimalCtx Tuple{typeof(lgetfield),Any,Val} @inline function rrule!!( - ::CoDual{typeof(lgetfield)}, x::CoDual{P, F}, ::CoDual{Val{f}} -) where {P, F<:StandardFDataType, f} + ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}} +) where {P,F<:StandardFDataType,f} pb!! = if ismutabletype(P) dx = tangent(x) function mutable_lgetfield_pb!!(dy) @@ -78,14 +78,14 @@ lgetfield(x, ::Val{f}) where {f} = getfield(x, f) return y, pb!! end -@inline _get_fdata_field(_, t::Union{Tuple, NamedTuple}, f) = getfield(t, f) +@inline _get_fdata_field(_, t::Union{Tuple,NamedTuple}, f) = getfield(t, f) @inline _get_fdata_field(_, data::FData, f) = val(getfield(data.data, f)) @inline _get_fdata_field(primal, ::NoFData, f) = uninit_fdata(getfield(primal, f)) @inline _get_fdata_field(_, t::MutableTangent, f) = fdata(val(getfield(t.fields, f))) increment_field_rdata!(dx::MutableTangent, ::NoRData, ::Val) = dx increment_field_rdata!(dx::NoFData, ::NoRData, ::Val) = dx -function increment_field_rdata!(dx::T, dy_rdata, ::Val{f}) where {T<:MutableTangent, f} +function increment_field_rdata!(dx::T, dy_rdata, ::Val{f}) where {T<:MutableTangent,f} set_tangent_field!(dx, f, increment_rdata!!(get_tangent_field(dx, f), dy_rdata)) return dx end @@ -97,10 +97,10 @@ end # This is largely copy + pasted from the above. Attempts were made to refactor to avoid # code duplication, but it wound up not being any cleaner than this copy + pasted version. -@is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Val, Val} +@is_primitive MinimalCtx Tuple{typeof(lgetfield),Any,Val,Val} @inline function rrule!!( - ::CoDual{typeof(lgetfield)}, x::CoDual{P, F}, ::CoDual{Val{f}}, ::CoDual{Val{order}} -) where {P, F<:StandardFDataType, f, order} + ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}}, ::CoDual{Val{order}} +) where {P,F<:StandardFDataType,f,order} pb!! = if ismutabletype(P) dx = tangent(x) function mutable_lgetfield_pb!!(dy) @@ -118,16 +118,16 @@ end return y, pb!! end -@is_primitive MinimalCtx Tuple{typeof(lsetfield!), Any, Any, Any} +@is_primitive MinimalCtx Tuple{typeof(lsetfield!),Any,Any,Any} @inline function rrule!!( - ::CoDual{typeof(lsetfield!)}, value::CoDual{P, F}, name::CoDual, x::CoDual -) where {P, F<:StandardFDataType} + ::CoDual{typeof(lsetfield!)}, value::CoDual{P,F}, name::CoDual, x::CoDual +) where {P,F<:StandardFDataType} return lsetfield_rrule(value, name, x) end function lsetfield_rrule( - value::CoDual{P, F}, ::CoDual{Val{name}}, x::CoDual -) where {P, F, name} + value::CoDual{P,F}, ::CoDual{Val{name}}, x::CoDual +) where {P,F,name} save = isdefined(primal(value), name) old_x = save ? getfield(primal(value), name) : nothing old_dx = if F == NoFData @@ -149,7 +149,11 @@ function lsetfield_rrule( return NoRData(), NoRData(), NoRData(), new_dx end end - yf = F == NoFData ? NoFData() : fdata(set_tangent_field!(dvalue, name, zero_tangent(primal(x), tangent(x)))) + yf = if F == NoFData + NoFData() + else + fdata(set_tangent_field!(dvalue, name, zero_tangent(primal(x), tangent(x)))) + end y = CoDual(lsetfield!(primal(value), Val(name), primal(x)), yf) return y, pb!! end @@ -164,7 +168,9 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) specific_test_cases = Any[ # Rules to avoid pointer type conversions. ( - true, :stability, nothing, + true, + :stability, + nothing, +, CoDual( bitcast(Ptr{Float64}, pointer_from_objref(_x)), @@ -188,8 +194,12 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) (false, :stability_and_allocs, nothing, promote_type, Float64, Float64), (false, :stability_and_allocs, nothing, LinearAlgebra.chkstride1, randn(3, 3)), ( - false, :stability_and_allocs, nothing, - LinearAlgebra.chkstride1, randn(3, 3), randn(2, 2), + false, + :stability_and_allocs, + nothing, + LinearAlgebra.chkstride1, + randn(3, 3), + randn(2, 2), ), (false, :allocs, nothing, Threads.nthreads), (false, :none, nothing, Base.eltype, randn(1)), @@ -198,21 +208,41 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) # Literal replacement for setfield!. ( - false, :stability_and_allocs, nothing, - lsetfield!, MutableFoo(5.0, [1.0, 2.0]), Val(:a), 4.0, + false, + :stability_and_allocs, + nothing, + lsetfield!, + MutableFoo(5.0, [1.0, 2.0]), + Val(:a), + 4.0, ), ( - false, :stability_and_allocs, nothing, - lsetfield!, FullyInitMutableStruct(5.0, [1.0, 2.0]), Val(:y), [1.0, 3.0, 4.0], + false, + :stability_and_allocs, + nothing, + lsetfield!, + FullyInitMutableStruct(5.0, [1.0, 2.0]), + Val(:y), + [1.0, 3.0, 4.0], ), ( - false, :stability_and_allocs, nothing, - lsetfield!, NonDifferentiableFoo(5, false), Val(:x), 4, + false, + :stability_and_allocs, + nothing, + lsetfield!, + NonDifferentiableFoo(5, false), + Val(:x), + 4, ), ( - false, :stability_and_allocs, nothing, - lsetfield!, NonDifferentiableFoo(5, false), Val(:y), true, - ) + false, + :stability_and_allocs, + nothing, + lsetfield!, + NonDifferentiableFoo(5, false), + Val(:y), + true, + ), ] # Some specific test cases for lgetfield to test the basics. @@ -264,8 +294,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) _, primal = TestTypes.instantiate((interface_only, P, args)) names = fieldnames(P)[1:length(args)] # only query fields which get initialised return Any[ - (interface_only, :none, nothing, lgetfield, primal, Val(name)) for - name in names + (interface_only, :none, nothing, lgetfield, primal, Val(name)) for name in names ] end @@ -273,7 +302,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) all_lgetfield_test_cases = Any[ (case..., order...) for case in vcat(specific_lgetfield_test_cases, general_lgetfield_test_cases...) for - order in Any[(), (Val(false), )] + order in Any[(), (Val(false),)] ] # Create `lsetfield` testsfor each type in TestTypes in order to increase coverage. @@ -288,9 +317,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) end test_cases = vcat( - specific_test_cases, - all_lgetfield_test_cases..., - general_lsetfield_test_cases..., + specific_test_cases, all_lgetfield_test_cases..., general_lsetfield_test_cases... ) return test_cases, memory end diff --git a/src/rrules/new.jl b/src/rrules/new.jl index 5941f3f86..8e95cfc72 100644 --- a/src/rrules/new.jl +++ b/src/rrules/new.jl @@ -1,18 +1,22 @@ -@is_primitive MinimalCtx Tuple{typeof(_new_), Vararg} +@is_primitive MinimalCtx Tuple{typeof(_new_),Vararg} function rrule!!( - f::CoDual{typeof(_new_)}, p::CoDual{Type{P}}, x::Vararg{CoDual, N} -) where {P, N} + f::CoDual{typeof(_new_)}, p::CoDual{Type{P}}, x::Vararg{CoDual,N} +) where {P,N} y = _new_(P, tuple_map(primal, x)...) F = fdata_type(tangent_type(P)) R = rdata_type(tangent_type(P)) - dy = F == NoFData ? NoFData() : build_fdata(P, tuple_map(primal, x), tuple_map(tangent, x)) + dy = if F == NoFData + NoFData() + else + build_fdata(P, tuple_map(primal, x), tuple_map(tangent, x)) + end pb!! = if ismutabletype(P) if F == NoFData NoPullback(f, p, x...) else function _mutable_new_pullback!!(::NoRData) - rdatas = tuple_map(rdata ∘ val, Tuple(dy.fields)[1:N]) + rdatas = tuple_map(rdata ∘ val, Tuple(dy.fields)[1:N]) return NoRData(), NoRData(), rdatas... end end @@ -69,29 +73,58 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:new}) (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Float64}, 5.0), (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Int, x::Int}, 5, 4), ( - false, :stability_and_allocs, nothing, - _new_, @NamedTuple{y::Float64, x::Int}, 5.0, 4, + false, + :stability_and_allocs, + nothing, + _new_, + @NamedTuple{y::Float64, x::Int}, + 5.0, + 4, ), ( - false, :stability_and_allocs, nothing, - _new_, @NamedTuple{y::Vector{Float64}, x::Int}, randn(2), 4, + false, + :stability_and_allocs, + nothing, + _new_, + @NamedTuple{y::Vector{Float64}, x::Int}, + randn(2), + 4, ), ( - false, :stability_and_allocs, nothing, - _new_, @NamedTuple{y::Vector{Float64}}, randn(2), + false, + :stability_and_allocs, + nothing, + _new_, + @NamedTuple{y::Vector{Float64}}, + randn(2), ), ( - false, :stability_and_allocs, nothing, - _new_, TestResources.TypeStableStruct{Float64}, 5, 4.0, + false, + :stability_and_allocs, + nothing, + _new_, + TestResources.TypeStableStruct{Float64}, + 5, + 4.0, ), (false, :stability_and_allocs, nothing, _new_, UnitRange{Int64}, 5, 4), ( - false, :stability_and_allocs, nothing, - _new_, TestResources.TypeStableMutableStruct{Float64}, 5.0, 4.0, + false, + :stability_and_allocs, + nothing, + _new_, + TestResources.TypeStableMutableStruct{Float64}, + 5.0, + 4.0, ), ( - false, :none, nothing, - _new_, TestResources.TypeStableMutableStruct{Any}, 5.0, 4.0, + false, + :none, + nothing, + _new_, + TestResources.TypeStableMutableStruct{Any}, + 5.0, + 4.0, ), (false, :none, nothing, _new_, TestResources.StructFoo, 6.0, [1.0, 2.0]), (false, :none, nothing, _new_, TestResources.StructFoo, 6.0), @@ -100,20 +133,36 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:new}) (false, :stability_and_allocs, nothing, _new_, TestResources.StructNoFwds, 5.0), (false, :stability_and_allocs, nothing, _new_, TestResources.StructNoRvs, [5.0]), ( - false, :stability_and_allocs, nothing, - _new_, LowerTriangular{Float64, Matrix{Float64}}, randn(2, 2), + false, + :stability_and_allocs, + nothing, + _new_, + LowerTriangular{Float64,Matrix{Float64}}, + randn(2, 2), ), ( - false, :stability_and_allocs, nothing, - _new_, UpperTriangular{Float64, Matrix{Float64}}, randn(2, 2), + false, + :stability_and_allocs, + nothing, + _new_, + UpperTriangular{Float64,Matrix{Float64}}, + randn(2, 2), ), ( - false, :stability_and_allocs, nothing, - _new_, UnitLowerTriangular{Float64, Matrix{Float64}}, randn(2, 2), + false, + :stability_and_allocs, + nothing, + _new_, + UnitLowerTriangular{Float64,Matrix{Float64}}, + randn(2, 2), ), ( - false, :stability_and_allocs, nothing, - _new_, UnitUpperTriangular{Float64, Matrix{Float64}}, randn(2, 2), + false, + :stability_and_allocs, + nothing, + _new_, + UnitUpperTriangular{Float64,Matrix{Float64}}, + randn(2, 2), ), ] general_test_cases = map(TestTypes.PRIMALS) do (interface_only, P, args) diff --git a/src/rrules/tasks.jl b/src/rrules/tasks.jl index 3b21aad1a..36504b0e6 100644 --- a/src/rrules/tasks.jl +++ b/src/rrules/tasks.jl @@ -53,7 +53,7 @@ function get_tangent_field(t::TaskTangent, f) throw(error("Unhandled field $f")) end -const TaskCoDual = CoDual{Task, TaskTangent} +const TaskCoDual = CoDual{Task,TaskTangent} function rrule!!(::CoDual{typeof(lgetfield)}, x::TaskCoDual, ::CoDual{Val{f}}) where {f} dx = x.dx @@ -83,7 +83,15 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:tasks}) test_cases = Any[ (false, :none, nothing, lgetfield, Task(() -> nothing), Val(:rngState1)), (false, :none, nothing, getfield, Task(() -> nothing), :rngState1), - (false, :none, nothing, lsetfield!, Task(() -> nothing), Val(:rngState1), UInt64(5)), + ( + false, + :none, + nothing, + lsetfield!, + Task(() -> nothing), + Val(:rngState1), + UInt64(5), + ), (false, :stability_and_allocs, nothing, current_task), ] memory = Any[] @@ -91,12 +99,13 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:tasks}) end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:tasks}) - test_cases = Any[ - ( - false, :none, nothing, - (rng) -> (Random.seed!(rng, 0); rand(rng)), Random.default_rng(), - ), - ] + test_cases = Any[( + false, + :none, + nothing, + (rng) -> (Random.seed!(rng, 0); rand(rng)), + Random.default_rng(), + ),] memory = Any[] return test_cases, memory end diff --git a/src/rrules/twice_precision.jl b/src/rrules/twice_precision.jl index 3cbcbccb0..3fda79ef2 100644 --- a/src/rrules/twice_precision.jl +++ b/src/rrules/twice_precision.jl @@ -22,7 +22,7 @@ end import .TestUtils: has_equal_data_internal function has_equal_data_internal( - p::P, q::P, ::Bool, ::Dict{Tuple{UInt, UInt}, Bool} + p::P, q::P, ::Bool, ::Dict{Tuple{UInt,UInt},Bool} ) where {P<:TWP} return Float64(p) ≈ Float64(q) end @@ -61,7 +61,7 @@ zero_rdata_from_type(P::Type{<:TWP{F}}) where {F} = P(zero(F), zero(F)) # Rules. These are required for a lot of functionality in this case. # -@is_primitive MinimalCtx Tuple{typeof(_new_), <:TWP, IEEEFloat, IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(_new_),<:TWP,IEEEFloat,IEEEFloat} function rrule!!( ::CoDual{typeof(_new_)}, ::CoDual{Type{TWP{P}}}, hi::CoDual{P}, lo::CoDual{P} ) where {P<:IEEEFloat} @@ -69,7 +69,7 @@ function rrule!!( return zero_fcodual(_new_(TWP{P}, hi.x, lo.x)), _new_twice_precision_pb end -@is_primitive MinimalCtx Tuple{typeof(twiceprecision), IEEEFloat, Integer} +@is_primitive MinimalCtx Tuple{typeof(twiceprecision),IEEEFloat,Integer} function rrule!!( ::CoDual{typeof(twiceprecision)}, val::CoDual{P}, nb::CoDual{<:Integer} ) where {P<:IEEEFloat} @@ -77,7 +77,7 @@ function rrule!!( return zero_fcodual(twiceprecision(val.x, nb.x)), twiceprecision_float_pb end -@is_primitive MinimalCtx Tuple{typeof(twiceprecision), TWP, Integer} +@is_primitive MinimalCtx Tuple{typeof(twiceprecision),TWP,Integer} function rrule!!( ::CoDual{typeof(twiceprecision)}, val::CoDual{P}, nb::CoDual{<:Integer} ) where {P<:TWP} @@ -85,52 +85,58 @@ function rrule!!( return zero_fcodual(twiceprecision(val.x, nb.x)), twiceprecision_pb end -@is_primitive MinimalCtx Tuple{Type{<:IEEEFloat}, TWP} -function rrule!!(::CoDual{Type{P}}, x::CoDual{S}) where {P<:IEEEFloat, S<:TWP} +@is_primitive MinimalCtx Tuple{Type{<:IEEEFloat},TWP} +function rrule!!(::CoDual{Type{P}}, x::CoDual{S}) where {P<:IEEEFloat,S<:TWP} float_from_twice_precision_pb(dy::P) = NoRData(), S(dy) return zero_fcodual(P(x.x)), float_from_twice_precision_pb end -@is_primitive MinimalCtx Tuple{typeof(-), TWP} +@is_primitive MinimalCtx Tuple{typeof(-),TWP} function rrule!!(::CoDual{typeof(-)}, x::CoDual{P}) where {P<:TWP} negate_twice_precision_pb(dy::P) = NoRData(), -dy return zero_fcodual(-(x.x)), negate_twice_precision_pb end -@is_primitive MinimalCtx Tuple{typeof(+), TWP, IEEEFloat} -function rrule!!(::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{S}) where {P<:TWP, S<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(+),TWP,IEEEFloat} +function rrule!!( + ::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{S} +) where {P<:TWP,S<:IEEEFloat} plus_pullback(dz::P) = NoRData(), dz, S(dz) return zero_fcodual(x.x + y.x), plus_pullback end -@is_primitive(MinimalCtx, Tuple{typeof(+), P, P} where {P<:TWP}) +@is_primitive(MinimalCtx, Tuple{typeof(+),P,P} where {P<:TWP}) function rrule!!(::CoDual{typeof(+)}, x::CoDual{P}, y::CoDual{P}) where {P<:TWP} plus_pullback(dz::P) = NoRData(), dz, dz return zero_fcodual(x.x + y.x), plus_pullback end -@is_primitive MinimalCtx Tuple{typeof(*), TWP, IEEEFloat} -function rrule!!(::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{S}) where {P<:TWP, S<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(*),TWP,IEEEFloat} +function rrule!!( + ::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{S} +) where {P<:TWP,S<:IEEEFloat} _x, _y = x.x, y.x mul_twice_precision_and_float_pb(dz::P) = NoRData(), dz * _y, S(dz * _x) return zero_fcodual(_x * _y), mul_twice_precision_and_float_pb end -@is_primitive MinimalCtx Tuple{typeof(*), TWP, Integer} +@is_primitive MinimalCtx Tuple{typeof(*),TWP,Integer} function rrule!!(::CoDual{typeof(*)}, x::CoDual{P}, y::CoDual{<:Integer}) where {P<:TWP} _y = y.x mul_twice_precision_and_int_pb(dz::P) = NoRData(), dz * _y, NoRData() return zero_fcodual(x.x * _y), mul_twice_precision_and_int_pb end -@is_primitive MinimalCtx Tuple{typeof(/), TWP, IEEEFloat} -function rrule!!(::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{S}) where {P<:TWP, S<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(/),TWP,IEEEFloat} +function rrule!!( + ::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{S} +) where {P<:TWP,S<:IEEEFloat} _x, _y = x.x, y.x div_twice_precision_and_float_pb(dz::P) = NoRData(), dz / _y, S(-dz * _x / _y^2) return zero_fcodual(_x / _y), div_twice_precision_and_float_pb end -@is_primitive MinimalCtx Tuple{typeof(/), TWP, Integer} +@is_primitive MinimalCtx Tuple{typeof(/),TWP,Integer} function rrule!!(::CoDual{typeof(/)}, x::CoDual{P}, y::CoDual{<:Integer}) where {P<:TWP} _y = y.x div_twice_precision_and_int_pb(dz::P) = NoRData(), dz / _y, NoRData() @@ -139,20 +145,20 @@ end # Primitives -@zero_adjoint MinimalCtx Tuple{Type{<:TwicePrecision}, Tuple{Integer, Integer}, Integer} -@zero_adjoint MinimalCtx Tuple{typeof(Base.splitprec), Type, Integer} +@zero_adjoint MinimalCtx Tuple{Type{<:TwicePrecision},Tuple{Integer,Integer},Integer} +@zero_adjoint MinimalCtx Tuple{typeof(Base.splitprec),Type,Integer} @zero_adjoint( MinimalCtx, - Tuple{typeof(Base.floatrange), Type{<:IEEEFloat}, Integer, Integer, Integer, Integer}, + Tuple{typeof(Base.floatrange),Type{<:IEEEFloat},Integer,Integer,Integer,Integer}, ) @zero_adjoint( MinimalCtx, - Tuple{typeof(Base._linspace), Type{<:IEEEFloat}, Integer, Integer, Integer, Integer}, + Tuple{typeof(Base._linspace),Type{<:IEEEFloat},Integer,Integer,Integer,Integer}, ) using Base: range_start_step_length @is_primitive( - MinimalCtx, Tuple{typeof(range_start_step_length), T, T, Integer} where {T<:IEEEFloat} + MinimalCtx, Tuple{typeof(range_start_step_length),T,T,Integer} where {T<:IEEEFloat} ) function rrule!!( ::CoDual{typeof(range_start_step_length)}, @@ -165,8 +171,8 @@ function rrule!!( end using Base: unsafe_getindex -const TWPStepRangeLen = StepRangeLen{<:Any, <:TWP, <:TWP} -@is_primitive(MinimalCtx, Tuple{typeof(unsafe_getindex), TWPStepRangeLen, Integer}) +const TWPStepRangeLen = StepRangeLen{<:Any,<:TWP,<:TWP} +@is_primitive(MinimalCtx, Tuple{typeof(unsafe_getindex),TWPStepRangeLen,Integer}) function rrule!!( ::CoDual{typeof(unsafe_getindex)}, r::CoDual{P}, i::CoDual{<:Integer} ) where {P<:TWPStepRangeLen} @@ -183,7 +189,7 @@ function rrule!!( end using Base: _getindex_hiprec -@is_primitive(MinimalCtx, Tuple{typeof(_getindex_hiprec), TWPStepRangeLen, Integer}) +@is_primitive(MinimalCtx, Tuple{typeof(_getindex_hiprec),TWPStepRangeLen,Integer}) function rrule!!( ::CoDual{typeof(_getindex_hiprec)}, r::CoDual{P}, i::CoDual{<:Integer} ) where {P<:TWPStepRangeLen} @@ -198,7 +204,7 @@ function rrule!!( return zero_fcodual(_getindex_hiprec(r.x, i.x)), unsafe_getindex_pb end -@is_primitive MinimalCtx Tuple{typeof(:), P, P, P} where {P<:IEEEFloat} +@is_primitive MinimalCtx Tuple{typeof(:),P,P,P} where {P<:IEEEFloat} function rrule!!( ::CoDual{typeof(:)}, start::CoDual{P}, step::CoDual{P}, stop::CoDual{P} ) where {P<:IEEEFloat} @@ -206,7 +212,7 @@ function rrule!!( return zero_fcodual((:)(start.x, step.x, stop.x)), colon_pb end -@is_primitive MinimalCtx Tuple{typeof(sum), TWPStepRangeLen} +@is_primitive MinimalCtx Tuple{typeof(sum),TWPStepRangeLen} function rrule!!(::CoDual{typeof(sum)}, x::CoDual{P}) where {P<:TWPStepRangeLen} l = x.x.len offset = x.x.offset @@ -222,7 +228,7 @@ end @is_primitive( MinimalCtx, - Tuple{typeof(Base.range_start_stop_length), P, P, Integer} where {P<:IEEEFloat}, + Tuple{typeof(Base.range_start_stop_length),P,P,Integer} where {P<:IEEEFloat}, ) function rrule!!( ::CoDual{typeof(Base.range_start_stop_length)}, @@ -241,30 +247,35 @@ function rrule!!( end @static if VERSION >= v"1.11" + @is_primitive MinimalCtx Tuple{ + typeof(Base._exp_allowing_twice64),TwicePrecision{Float64} + } + function rrule!!( + ::CoDual{typeof(Base._exp_allowing_twice64)}, x::CoDual{TwicePrecision{Float64}} + ) + y = Base._exp_allowing_twice64(x.x) + _exp_allowing_twice64_pb(dy::Float64) = NoRData(), TwicePrecision(dy * y) + return zero_fcodual(y), _exp_allowing_twice64_pb + end -@is_primitive MinimalCtx Tuple{typeof(Base._exp_allowing_twice64), TwicePrecision{Float64}} -function rrule!!( - ::CoDual{typeof(Base._exp_allowing_twice64)}, x::CoDual{TwicePrecision{Float64}} -) - y = Base._exp_allowing_twice64(x.x) - _exp_allowing_twice64_pb(dy::Float64) = NoRData(), TwicePrecision(dy * y) - return zero_fcodual(y), _exp_allowing_twice64_pb -end - -@is_primitive(MinimalCtx, Tuple{typeof(Base._log_twice64_unchecked), Float64}) -function rrule!!(::CoDual{typeof(Base._log_twice64_unchecked)}, x::CoDual{Float64}) - _x = x.x - _log_twice64_pb(dy::TwicePrecision{Float64}) = NoRData(), Float64(dy) / _x - return zero_fcodual(Base._log_twice64_unchecked(_x)), _log_twice64_pb -end - + @is_primitive(MinimalCtx, Tuple{typeof(Base._log_twice64_unchecked),Float64}) + function rrule!!(::CoDual{typeof(Base._log_twice64_unchecked)}, x::CoDual{Float64}) + _x = x.x + _log_twice64_pb(dy::TwicePrecision{Float64}) = NoRData(), Float64(dy) / _x + return zero_fcodual(Base._log_twice64_unchecked(_x)), _log_twice64_pb + end end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:twice_precision}) test_cases = Any[ ( - false, :stability_and_allocs, nothing, - _new_, TwicePrecisionFloat{Float64}, 5.0, 4.0 + false, + :stability_and_allocs, + nothing, + _new_, + TwicePrecisionFloat{Float64}, + 5.0, + 4.0, ), (false, :stability_and_allocs, nothing, twiceprecision, 5.0, 4), (false, :stability_and_allocs, nothing, twiceprecision, TwicePrecision(5.0), 4), @@ -272,33 +283,44 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:twice_precisi (false, :stability_and_allocs, nothing, -, TwicePrecision(5.0, 3.0)), (false, :stability_and_allocs, nothing, +, TwicePrecision(5.0, 3.0), 4.0), ( - false, :stability_and_allocs, nothing, - +, TwicePrecision(5.0, 3.0), TwicePrecision(4.0, 5.0), + false, + :stability_and_allocs, + nothing, + +, + TwicePrecision(5.0, 3.0), + TwicePrecision(4.0, 5.0), ), (false, :stability_and_allocs, nothing, *, TwicePrecision(5.0, 1e-12), 3.0), (false, :stability_and_allocs, nothing, *, TwicePrecision(5.0, 1e-12), 3), (false, :stability_and_allocs, nothing, /, TwicePrecision(5.0, 1e-12), 3.0), (false, :stability_and_allocs, nothing, /, TwicePrecision(5.0, 1e-12), 3), - (false, :stability_and_allocs, nothing, Base.splitprec, Float64, 5), (false, :stability_and_allocs, nothing, Base.splitprec, Float32, 5), (false, :stability_and_allocs, nothing, Base.splitprec, Float16, 5), - (false, :stability_and_allocs, nothing, Base.floatrange, Float64, 5, 6, 7, 8), (false, :stability_and_allocs, nothing, Base._linspace, Float64, 5, 6, 7, 8), (false, :stability_and_allocs, nothing, Base.range_start_step_length, 5.0, 6.0, 10), ( - false, :stability_and_allocs, nothing, - Base.range_start_step_length, 5.0, Float64(π), 10, + false, + :stability_and_allocs, + nothing, + Base.range_start_step_length, + 5.0, + Float64(π), + 10, ), ( - false, :stability_and_allocs, nothing, + false, + :stability_and_allocs, + nothing, unsafe_getindex, StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), 5, ), ( - false, :stability_and_allocs, nothing, + false, + :stability_and_allocs, + nothing, _getindex_hiprec, StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), 5, @@ -306,19 +328,32 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:twice_precisi (false, :stability_and_allocs, nothing, (:), -0.1, 0.99, 5.1), (false, :stability_and_allocs, nothing, sum, range(-0.1, 9.9; length=51)), ( - false, :stability_and_allocs, nothing, - Base.range_start_stop_length, -0.5, 11.7, 7, + false, + :stability_and_allocs, + nothing, + Base.range_start_stop_length, + -0.5, + 11.7, + 7, ), ( - false, :stability_and_allocs, nothing, - Base.range_start_stop_length, -0.5, -11.7, 11, + false, + :stability_and_allocs, + nothing, + Base.range_start_stop_length, + -0.5, + -11.7, + 11, ), ] @static if VERSION >= v"1.11" extra_test_cases = Any[ ( - false, :stability_and_allocs, nothing, - Base._exp_allowing_twice64, TwicePrecision(2.0), + false, + :stability_and_allocs, + nothing, + Base._exp_allowing_twice64, + TwicePrecision(2.0), ), (false, :stability_and_allocs, nothing, Base._log_twice64_unchecked, 3.0), ] @@ -333,7 +368,14 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:twice_precision}) # Functionality in base/twiceprecision.jl (false, :allocs, nothing, TwicePrecision{Float64}, 5.0, 0.3), - (false, :allocs, nothing, (x, y) -> Float64(TwicePrecision{Float64}(x, y)), 5.0, 0.3), + ( + false, + :allocs, + nothing, + (x, y) -> Float64(TwicePrecision{Float64}(x, y)), + 5.0, + 0.3, + ), (false, :allocs, nothing, TwicePrecision, 5.0, 0.3), (false, :allocs, nothing, (x, y) -> Float64(TwicePrecision(x, y)), 5.0, 0.3), (false, :allocs, nothing, TwicePrecision{Float64}, 5.0), @@ -345,7 +387,14 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:twice_precision}) (false, :none, nothing, TwicePrecision{Float64}, (5, 4)), (false, :none, nothing, x -> Float64(TwicePrecision{Float64}(x)), (5, 4)), (false, :none, nothing, TwicePrecision{Float64}, (5, 4), 3), - (false, :none, nothing, (x, y) -> Float64(TwicePrecision{Float64}(x, y)), (5, 4), 3), + ( + false, + :none, + nothing, + (x, y) -> Float64(TwicePrecision{Float64}(x, y)), + (5, 4), + 3, + ), (false, :allocs, nothing, +, TwicePrecision(5.0), TwicePrecision(4.0)), (false, :allocs, nothing, +, 5.0, TwicePrecision(4.0)), (false, :allocs, nothing, +, TwicePrecision(5.0), 4.0), @@ -355,14 +404,20 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:twice_precision}) (false, :allocs, nothing, *, 3.0, TwicePrecision(5.0, 1e-12)), (false, :allocs, nothing, *, 3, TwicePrecision(5.0, 1e-12)), ( - false, :allocs, nothing, + false, + :allocs, + nothing, getindex, StepRangeLen(TwicePrecision(-0.45), TwicePrecision(0.98), 10, 3), 2:2:6, ), ( - false, :allocs, nothing, - +, range(0.0, 5.0; length=44), range(-33.0, 4.5; length=44), + false, + :allocs, + nothing, + +, + range(0.0, 5.0; length=44), + range(-33.0, 4.5; length=44), ), # Functionality in base/range.jl diff --git a/src/tangents.jl b/src/tangents.jl index f28564f22..43155529d 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -46,7 +46,7 @@ end _wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T} -_wrap_field(::Type{Q}, x::T) where {Q, T} = PossiblyUninitTangent{Q}(x) +_wrap_field(::Type{Q}, x::T) where {Q,T} = PossiblyUninitTangent{Q}(x) _wrap_field(x::T) where {T} = _wrap_field(T, x) struct Tangent{Tfields<:NamedTuple} @@ -59,16 +59,20 @@ mutable struct MutableTangent{Tfields<:NamedTuple} fields::Tfields MutableTangent{Tfields}() where {Tfields} = new{Tfields}() MutableTangent(fields::Tfields) where {Tfields} = MutableTangent{Tfields}(fields) - MutableTangent{Tfields}(fields::NamedTuple{names}) where {names, Tfields<:NamedTuple{names}} = new{Tfields}(fields) + function MutableTangent{Tfields}( + fields::NamedTuple{names} + ) where {names,Tfields<:NamedTuple{names}} + return new{Tfields}(fields) + end end Base.:(==)(x::MutableTangent, y::MutableTangent) = x.fields == y.fields fields_type(::Type{MutableTangent{Tfields}}) where {Tfields<:NamedTuple} = Tfields fields_type(::Type{Tangent{Tfields}}) where {Tfields<:NamedTuple} = Tfields -fields_type(::Type{<:Union{MutableTangent, Tangent}}) = NamedTuple +fields_type(::Type{<:Union{MutableTangent,Tangent}}) = NamedTuple -const PossiblyMutableTangent{T} = Union{MutableTangent{T}, Tangent{T}} +const PossiblyMutableTangent{T} = Union{MutableTangent{T},Tangent{T}} """ get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int) @@ -106,15 +110,17 @@ were actually fields of `t`. This is the moral equivalent of `setfield!` for return x end -@inline function set_tangent_field!(t::MutableTangent{Tfields}, s::Symbol, x) where {Tfields} +@inline function set_tangent_field!( + t::MutableTangent{Tfields}, s::Symbol, x +) where {Tfields} return set_tangent_field!(t, _sym_to_int(Tfields, Val(s)), x) end -@generated function _sym_to_int(::Type{Tfields}, ::Val{s}) where {Tfields, s} +@generated function _sym_to_int(::Type{Tfields}, ::Val{s}) where {Tfields,s} return findfirst(==(s), fieldnames(Tfields)) end -@generated function build_tangent(::Type{P}, fields::Vararg{Any, N}) where {P, N} +@generated function build_tangent(::Type{P}, fields::Vararg{Any,N}) where {P,N} tangent_values_exprs = map(enumerate(fieldtypes(P))) do (n, field_type) if tangent_field_type(P, n) <: PossiblyUninitTangent tt = PossiblyUninitTangent{tangent_type(field_type)} @@ -134,7 +140,9 @@ end ) end -function build_tangent(::Type{P}, fields::Vararg{Any, N}) where {P<:Union{Tuple, NamedTuple}, N} +function build_tangent( + ::Type{P}, fields::Vararg{Any,N} +) where {P<:Union{Tuple,NamedTuple},N} T = tangent_type(P) if T == NoTangent return NoTangent() @@ -146,7 +154,7 @@ function build_tangent(::Type{P}, fields::Vararg{Any, N}) where {P<:Union{Tuple, end __tangent_from_non_concrete(::Type{P}, fields) where {P<:Tuple} = Tuple(fields) -function __tangent_from_non_concrete(::Type{P}, fields) where {names, P<:NamedTuple{names}} +function __tangent_from_non_concrete(::Type{P}, fields) where {names,P<:NamedTuple{names}} return NamedTuple{names}(fields) end @@ -306,9 +314,9 @@ tangent_type(::Type{Core.TypeofVararg}) = NoTangent tangent_type(::Type{SimpleVector}) = Vector{Any} -tangent_type(::Type{P}) where {P<:Union{UInt8, UInt16, UInt32, UInt64, UInt128}} = NoTangent +tangent_type(::Type{P}) where {P<:Union{UInt8,UInt16,UInt32,UInt64,UInt128}} = NoTangent -tangent_type(::Type{P}) where {P<:Union{Int8, Int16, Int32, Int64, Int128}} = NoTangent +tangent_type(::Type{P}) where {P<:Union{Int8,Int16,Int32,Int64,Int128}} = NoTangent tangent_type(::Type{<:Core.Builtin}) = NoTangent @@ -318,9 +326,9 @@ tangent_type(::Type{<:Core.LLVMPtr}) = NoTangent tangent_type(::Type{String}) = NoTangent -tangent_type(::Type{<:Array{P, N}}) where {P, N} = Array{tangent_type(P), N} +tangent_type(::Type{<:Array{P,N}}) where {P,N} = Array{tangent_type(P),N} -tangent_type(::Type{<:Array{P, N} where {P}}) where {N} = Array +tangent_type(::Type{<:Array{P,N} where {P}}) where {N} = Array tangent_type(::Type{<:MersenneTwister}) = NoTangent @@ -334,9 +342,9 @@ tangent_type(::Type{Method}) = NoTangent tangent_type(::Type{<:Enum}) = NoTangent -function tangent_type(::Type{P}) where {N, P<:Tuple{Vararg{Any, N}}} +function tangent_type(::Type{P}) where {N,P<:Tuple{Vararg{Any,N}}} # As with other types, tangent type of Union is Union of tangent types. - P isa Union && return Union{tangent_type(P.a), tangent_type(P.b)} + P isa Union && return Union{tangent_type(P.a),tangent_type(P.b)} # Determine whether P isa a Tuple with a Vararg, e.g, Tuple, or Tuple{Float64, Vararg}. # Need to exclude `UnionAll`s from this, by checking `isa(P, DataType)`, in order to @@ -354,7 +362,7 @@ function tangent_type(::Type{P}) where {N, P<:Tuple{Vararg{Any, N}}} # and return `NoTangent`. tangent_types = tuple_map(tangent_type, fieldtypes(P)) T = Tuple{tangent_types...} - T_all_notangent = Tuple{Vararg{NoTangent, length(tangent_types)}} + T_all_notangent = Tuple{Vararg{NoTangent,length(tangent_types)}} T <: T_all_notangent && return NoTangent # If it's _possible_ for a subtype of `P` to have tangent type `NoTangent`, then we must @@ -363,33 +371,32 @@ function tangent_type(::Type{P}) where {N, P<:Tuple{Vararg{Any, N}}} # tangent type `NoTangent`, it must be true that `NoTangent <: tangent_type(P)`. # If, on the other hand, it's not possible for `NoTangent` to be the tangent type, e.g. # for `Tuple{Float64, Any}`, then there's no need to take the union. - return T_all_notangent <: T ? Union{T, NoTangent} : T + return T_all_notangent <: T ? Union{T,NoTangent} : T end -function tangent_type(::Type{P}) where {N, P<:NamedTuple{N}} - P isa Union && return Union{tangent_type(P.a), tangent_type(P.b)} - !isconcretetype(P) && return Union{NoTangent, NamedTuple{N}} +function tangent_type(::Type{P}) where {N,P<:NamedTuple{N}} + P isa Union && return Union{tangent_type(P.a),tangent_type(P.b)} + !isconcretetype(P) && return Union{NoTangent,NamedTuple{N}} TT = tangent_type(Tuple{fieldtypes(P)...}) TT == NoTangent && return NoTangent - return isconcretetype(TT) ? NamedTuple{N, TT} : Any + return isconcretetype(TT) ? NamedTuple{N,TT} : Any end @generated function tangent_type(::Type{P}) where {P} # This method can only handle struct types. Tell user to implement tangent type # directly for primitive types. - isprimitivetype(P) && throw(error( - "$P is a primitive type. Implement a method of `tangent_type` for it." - )) + isprimitivetype(P) && + throw(error("$P is a primitive type. Implement a method of `tangent_type` for it.")) # If the type is a Union, then take the union type of its arguments. - P isa Union && return Union{tangent_type(P.a), tangent_type(P.b)} + P isa Union && return Union{tangent_type(P.a),tangent_type(P.b)} # If the type is itself abstract, it's tangent could be anything. # The same goes for if the type has any undetermined type parameters. (isabstracttype(P) || !isconcretetype(P)) && return Any # If all fields are definitely NoTangents, then the overall tangent type is NoTangent. - T_all_notangent = Tuple{Vararg{NoTangent, fieldcount(P)}} + T_all_notangent = Tuple{Vararg{NoTangent,fieldcount(P)}} Tuple{tangent_field_types(P)...} <: T_all_notangent && return NoTangent # Derive tangent type. @@ -398,12 +405,12 @@ end end @inline function tangent_field_types(P) - return tuple_map(Base.Fix1(tangent_field_type, P), (1:fieldcount(P)..., )) + return tuple_map(Base.Fix1(tangent_field_type, P), (1:fieldcount(P)...,)) end backing_type(P::Type{<:Tuple}) = Tuple{tangent_field_types(P)...} -backing_type(P::Type) = NamedTuple{fieldnames(P), Tuple{tangent_field_types(P)...}} +backing_type(P::Type) = NamedTuple{fieldnames(P),Tuple{tangent_field_types(P)...}} """ tangent_field_type(::Type{P}, n::Int) where {P} @@ -436,20 +443,30 @@ function zero_tangent(x::P) where {P} return zero_tangent_internal(x, isbitstype(P) ? nothing : IdDict()) end -const StackDict = Union{Nothing, IdDict} +const StackDict = Union{Nothing,IdDict} # the `stackdict` naming following convention of Julia's `deepcopy` and `deepcopy_internal` # https://github.com/JuliaLang/julia/blob/48d4fd48430af58502699fdf3504b90589df3852/base/deepcopy.jl#L35 -@inline zero_tangent_internal(::Union{Int8, Int16, Int32, Int64, Int128}, ::Any) = NoTangent() +@inline zero_tangent_internal(::Union{Int8,Int16,Int32,Int64,Int128}, ::Any) = NoTangent() @inline zero_tangent_internal(x::IEEEFloat, ::Any) = zero(x) -@inline function zero_tangent_internal(x::P, stackdict::Any) where {P<:Union{Tuple, NamedTuple}} - return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(Base.Fix2(zero_tangent_internal, stackdict), x) +@inline function zero_tangent_internal( + x::P, stackdict::Any +) where {P<:Union{Tuple,NamedTuple}} + return if tangent_type(P) == NoTangent + NoTangent() + else + tuple_map(Base.Fix2(zero_tangent_internal, stackdict), x) + end end function zero_tangent_internal(x::Ptr, ::Any) return throw(ArgumentError("zero_tangent not available for pointers.")) end @inline function zero_tangent_internal(x::SimpleVector, stackdict::IdDict) - return map!(n -> zero_tangent_internal(x[n], stackdict), Vector{Any}(undef, length(x)), eachindex(x)) + return map!( + n -> zero_tangent_internal(x[n], stackdict), + Vector{Any}(undef, length(x)), + eachindex(x), + ) end function zero_tangent_internal(x::P, stackdict) where {P} tangent_type(P) == NoTangent && return NoTangent() @@ -458,8 +475,8 @@ function zero_tangent_internal(x::P, stackdict) where {P} if !(stackdict isa IdDict) throw( ArgumentError( - "Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue." - ) + "Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue.", + ), ) end if haskey(stackdict, x) @@ -480,7 +497,13 @@ end tangent_field_zeros_exprs = ntuple(fieldcount(P)) do n if tangent_field_type(P, n) <: PossiblyUninitTangent V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - return :(isdefined(x, $n) ? $V(zero_tangent_internal(getfield(x, $n), stackdict)) : $V()) + return :( + if isdefined(x, $n) + $V(zero_tangent_internal(getfield(x, $n), stackdict)) + else + $V() + end + ) else return :(zero_tangent_internal(getfield(x, $n), stackdict)) end @@ -511,8 +534,14 @@ end randn_tangent_internal(::AbstractRNG, ::NoTangent, ::Any) = NoTangent() randn_tangent_internal(rng::AbstractRNG, ::T, ::Any) where {T<:IEEEFloat} = randn(rng, T) -function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict::Any) where {P<:Union{Tuple, NamedTuple}} - return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(x -> randn_tangent_internal(rng, x, stackdict), x) +function randn_tangent_internal( + rng::AbstractRNG, x::P, stackdict::Any +) where {P<:Union{Tuple,NamedTuple}} + return if tangent_type(P) == NoTangent + NoTangent() + else + tuple_map(x -> randn_tangent_internal(rng, x, stackdict), x) + end end function randn_tangent_internal(rng::AbstractRNG, x::SimpleVector, stackdict::IdDict) return map!(Vector{Any}(undef, length(x)), eachindex(x)) do n @@ -528,8 +557,8 @@ function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict) where {P} if !(stackdict isa IdDict) throw( ArgumentError( - "Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue." - ) + "Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue.", + ), ) end if haskey(stackdict, x) @@ -544,11 +573,16 @@ function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict) where {P} end @generated function randn_tangent_struct_field(rng::AbstractRNG, x::P, stackdict) where {P} - tangent_field_exprs = map(1:fieldcount(P)) do n if tangent_field_type(P, n) <: PossiblyUninitTangent V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - return :(isdefined(x, $n) ? $V(randn_tangent_internal(rng, getfield(x, $n), stackdict)) : $V()) + return :( + if isdefined(x, $n) + $V(randn_tangent_internal(rng, getfield(x, $n), stackdict)) + else + $V() + end + ) else return :(randn_tangent_internal(rng, getfield(x, $n), stackdict)) end @@ -589,7 +623,7 @@ Set `x` to its zero element (`x` should be a tangent, so the zero must exist). """ set_to_zero!!(::NoTangent) = NoTangent() set_to_zero!!(x::Base.IEEEFloat) = zero(x) -set_to_zero!!(x::Union{Tuple, NamedTuple}) = map(set_to_zero!!, x) +set_to_zero!!(x::Union{Tuple,NamedTuple}) = map(set_to_zero!!, x) function set_to_zero!!(x::T) where {T<:PossiblyUninitTangent} return is_init(x) ? T(set_to_zero!!(val(x))) : x end @@ -610,11 +644,11 @@ correspond to a vector field. Not using `*` in order to avoid piracy. """ _scale(::Float64, ::NoTangent) = NoTangent() _scale(a::Float64, t::T) where {T<:IEEEFloat} = T(a * t) -_scale(a::Float64, t::Union{Tuple, NamedTuple}) = map(Base.Fix1(_scale, a), t) +_scale(a::Float64, t::Union{Tuple,NamedTuple}) = map(Base.Fix1(_scale, a), t) function _scale(a::Float64, t::T) where {T<:PossiblyUninitTangent} return is_init(t) ? T(_scale(a, val(t))) : T() end -_scale(a::Float64, t::T) where {T<:Union{Tangent, MutableTangent}} = T(_scale(a, t.fields)) +_scale(a::Float64, t::T) where {T<:Union{Tangent,MutableTangent}} = T(_scale(a, t.fields)) struct FieldUndefined end @@ -629,12 +663,12 @@ Always available because all tangent types correspond to finite-dimensional vect """ _dot(::NoTangent, ::NoTangent) = 0.0 _dot(t::T, s::T) where {T<:IEEEFloat} = Float64(t * s) -_dot(t::T, s::T) where {T<:Union{Tuple, NamedTuple}} = sum(map(_dot, t, s); init=0.0) +_dot(t::T, s::T) where {T<:Union{Tuple,NamedTuple}} = sum(map(_dot, t, s); init=0.0) function _dot(t::T, s::T) where {T<:PossiblyUninitTangent} is_init(t) && is_init(s) && return _dot(val(t), val(s)) return 0.0 end -function _dot(t::T, s::T) where {T<:Union{Tangent, MutableTangent}} +function _dot(t::T, s::T) where {T<:Union{Tangent,MutableTangent}} return sum(_map(_dot, t.fields, s.fields); init=0.0) end @@ -681,7 +715,8 @@ struct AddToPrimalException <: Exception end function Base.showerror(io::IO, err::AddToPrimalException) - msg = "Attempted to construct an instance of $(err.primal_type) using the default " * + msg = + "Attempted to construct an instance of $(err.primal_type) using the default " * "constuctor. In most cases, this error is caused by the lack of existence of the " * "default constructor for this type. There are two approaches to dealing with " * "this problem. The first is to avoid having to call `_add_to_primal` on this " * @@ -692,17 +727,19 @@ function Base.showerror(io::IO, err::AddToPrimalException) "by setting the `unsafe_perturb` setting to `true` -- check the docstring " * "for `Mooncake._add_to_primal` to ensure that your use case is unlikely to " * "cause problems." - println(io, msg) + return println(io, msg) end -function _add_to_primal(p::P, t::T, unsafe::Bool) where {P, T<:Union{Tangent, MutableTangent}} +function _add_to_primal(p::P, t::T, unsafe::Bool) where {P,T<:Union{Tangent,MutableTangent}} Tt = tangent_type(P) if Tt != typeof(t) throw(ArgumentError("p of type $P has tangent_type $Tt, but t is of type $T")) end tmp = map(fieldnames(P)) do f tf = getfield(t.fields, f) - isdefined(p, f) && is_init(tf) && return _add_to_primal(getfield(p, f), val(tf), unsafe) + isdefined(p, f) && + is_init(tf) && + return _add_to_primal(getfield(p, f), val(tf), unsafe) !isdefined(p, f) && !is_init(tf) && return FieldUndefined() throw(error("unable to handle undefined-ness")) end @@ -711,13 +748,13 @@ function _add_to_primal(p::P, t::T, unsafe::Bool) where {P, T<:Union{Tangent, Mu # If unsafe mode is enabled, then call `_new_` directly, and avoid the possibility that # the default inner constructor for `P` does not exist. if unsafe - return i === nothing ? _new_(P, tmp...) : _new_(P, tmp[1:i-1]...) + return i === nothing ? _new_(P, tmp...) : _new_(P, tmp[1:(i - 1)]...) end # If unsafe mode is disabled, try to use the default constructor for `P`. If this does # not work, then throw an informative error message. - try - return i === nothing ? P(tmp...) : P(tmp[1:i-1]...) + try + return i === nothing ? P(tmp...) : P(tmp[1:(i - 1)]...) catch e if e isa MethodError throw(AddToPrimalException(P)) @@ -737,7 +774,7 @@ Returns a tangent of type `tangent_type(P)`. """ function _diff(p::P, q::P) where {P} tangent_type(P) === NoTangent && return NoTangent() - T = Tangent{NamedTuple{(), Tuple{}}} + T = Tangent{NamedTuple{(),Tuple{}}} tangent_type(P) === T && return T((;)) return _containerlike_diff(p, q) end @@ -745,7 +782,7 @@ _diff(p::P, q::P) where {P<:IEEEFloat} = p - q function _diff(p::P, q::P) where {P<:SimpleVector} return Any[_diff(a, b) for (a, b) in zip(p, q)] end -function _diff(p::P, q::P) where {P<:Union{Tuple, NamedTuple}} +function _diff(p::P, q::P) where {P<:Union{Tuple,NamedTuple}} return tangent_type(P) == NoTangent ? NoTangent() : _map(_diff, p, q) end @@ -756,7 +793,7 @@ function _containerlike_diff(p::P, q::P) where {P} throw(error("Unhandleable undefinedness")) end i = findfirst(==(FieldUndefined()), diffed_fields) - diffed_fields = i === nothing ? diffed_fields : diffed_fields[1:i-1] + diffed_fields = i === nothing ? diffed_fields : diffed_fields[1:(i - 1)] return build_tangent(P, diffed_fields...) end @@ -775,7 +812,7 @@ function increment_field!!(x::Tuple, y, i::Int) return ntuple(n -> n == i ? increment!!(x[n], y) : x[n], length(x)) end -@inline @generated function increment_field!!(x::T, y, ::Val{f}) where {T<:NamedTuple, f} +@inline @generated function increment_field!!(x::T, y, ::Val{f}) where {T<:NamedTuple,f} i = f isa Symbol ? findfirst(==(f), fieldnames(T)) : f new_fields = Expr(:call, increment_field!!, :(Tuple(x)), :y, :(Val($i))) return Expr(:call, T, new_fields) @@ -789,12 +826,12 @@ function increment_field!!(x::T, y, s::Symbol) where {T<:NamedTuple} return T(tuple_map(n -> n == s ? increment!!(x[n], y) : x[n], fieldnames(T))) end -function increment_field!!(x::Tangent{T}, y, f::Val{F}) where {T, F} +function increment_field!!(x::Tangent{T}, y, f::Val{F}) where {T,F} y isa NoTangent && return x new_val = fieldtype(T, F) <: PossiblyUninitTangent ? fieldtype(T, F)(y) : y return Tangent(increment_field!!(x.fields, new_val, f)) end -function increment_field!!(x::MutableTangent{T}, y, f::V) where {T, F, V<:Val{F}} +function increment_field!!(x::MutableTangent{T}, y, f::V) where {T,F,V<:Val{F}} y isa NoTangent && return x new_val = fieldtype(T, F) <: PossiblyUninitTangent ? fieldtype(T, F)(y) : y setfield!(x, :fields, increment_field!!(x.fields, new_val, f)) @@ -829,7 +866,6 @@ Test cases in the first format make use of `zero_tangent` / `randn_tangent` etc tangents, but they're unable to check that `increment!!` is correct in an absolute sense. """ function tangent_test_cases() - N_large = 33 _names = Tuple(map(n -> Symbol("x$n"), 1:N_large)) @@ -837,7 +873,7 @@ function tangent_test_cases() [ (sin, NoTangent(), NoTangent(), NoTangent()), (map(Float16, (5.0, 4.0, 3.1, 7.1))...), - (5f0, 4f0, 3f0, 7f0), + (5.0f0, 4.0f0, 3.0f0, 7.0f0), (5.1, 4.0, 3.0, 7.0), (svec(5.0), Any[4.0], Any[3.0], Any[7.0]), ([3.0, 2.0], [1.0, 2.0], [2.0, 3.0], [3.0, 5.0]), @@ -866,12 +902,7 @@ function tangent_test_cases() setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 2), setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 2), ), - ( - (6.0, [1.0, 2.0]), - (5.0, [3.0, 4.0]), - (4.0, [4.0, 3.0]), - (9.0, [7.0, 7.0]), - ), + ((6.0, [1.0, 2.0]), (5.0, [3.0, 4.0]), (4.0, [4.0, 3.0]), (9.0, [7.0, 7.0])), ((), NoTangent(), NoTangent(), NoTangent()), ((1,), NoTangent(), NoTangent(), NoTangent()), ((2, 3), NoTangent(), NoTangent(), NoTangent()), @@ -939,10 +970,10 @@ function tangent_test_cases() (UnitRange{Int}(5, 7), NoTangent(), NoTangent(), NoTangent()), ], map([ - LowerTriangular{Float64, Matrix{Float64}}, - UpperTriangular{Float64, Matrix{Float64}}, - UnitLowerTriangular{Float64, Matrix{Float64}}, - UnitUpperTriangular{Float64, Matrix{Float64}}, + LowerTriangular{Float64,Matrix{Float64}}, + UpperTriangular{Float64,Matrix{Float64}}, + UnitLowerTriangular{Float64,Matrix{Float64}}, + UnitUpperTriangular{Float64,Matrix{Float64}}, ]) do T return ( T(randn(2, 2)), @@ -952,9 +983,8 @@ function tangent_test_cases() ) end, [ - (p, NoTangent(), NoTangent(), NoTangent()) for p in - [Array, Float64, Union{Float64, Float32}, Union, UnionAll, - typeof(<:)] + (p, NoTangent(), NoTangent(), NoTangent()) for + p in [Array, Float64, Union{Float64,Float32}, Union, UnionAll, typeof(<:)] ], ) rel_test_cases = Any[ diff --git a/src/test_resources.jl b/src/test_resources.jl index 38a223579..914232d86 100644 --- a/src/test_resources.jl +++ b/src/test_resources.jl @@ -10,12 +10,18 @@ module TestResources using ..Mooncake using ..Mooncake: - CoDual, Tangent, MutableTangent, NoTangent, PossiblyUninitTangent, ircode, - @is_primitive, MinimalCtx, val + CoDual, + Tangent, + MutableTangent, + NoTangent, + PossiblyUninitTangent, + ircode, + @is_primitive, + MinimalCtx, + val using DiffTests, LinearAlgebra, Random, Setfield - # # Types used for testing purposes # @@ -177,7 +183,7 @@ function bar(x, y) end function unused_expression(x, n) - y = getfield((Float64, ), n) + y = getfield((Float64,), n) return x end @@ -207,7 +213,7 @@ new_tester_2(x) = StableFoo(x, :symbol) @eval function new_tester_3(x::Ref{Any}) y = x[] - $(Expr(:new, :y, 5.0)) + return $(Expr(:new, :y, 5.0)) end @eval splatnew_tester(x::Ref{Tuple}) = $(Expr(:splatnew, StableFoo, :(x[]))) @@ -291,7 +297,7 @@ end simple_foreigncall_tester(s::String) = ccall(:jl_string_ptr, Ptr{UInt8}, (Any,), s) function simple_foreigncall_tester_2(a::TypeVar, b::Type) - ccall(:jl_type_unionall, Any, (Any, Any), a, b) + return ccall(:jl_type_unionall, Any, (Any, Any), a, b) end function no_primitive_inlining_tester(x) @@ -302,13 +308,13 @@ function no_primitive_inlining_tester(x) return X end -@noinline varargs_tester(x::Vararg{Any, N}) where {N} = x +@noinline varargs_tester(x::Vararg{Any,N}) where {N} = x varargs_tester_2(x) = varargs_tester(x) varargs_tester_2(x, y) = varargs_tester(x, y) varargs_tester_2(x, y, z) = varargs_tester(x, y, z) -@noinline varargs_tester_3(x, y::Vararg{Any, N}) where {N} = sin(x), y +@noinline varargs_tester_3(x, y::Vararg{Any,N}) where {N} = sin(x), y varargs_tester_4(x) = varargs_tester_3(x...) varargs_tester_4(x, y) = varargs_tester_3(x...) @@ -478,11 +484,11 @@ function test_union_of_arrays(x::Vector{Float64}, b::Bool) return 2z end -function test_union_of_types(x::Ref{Union{Type{Float64}, Type{Int}}}) +function test_union_of_types(x::Ref{Union{Type{Float64},Type{Int}}}) return x[] end -function test_small_union(x::Ref{Union{Float64, Vector{Float64}}}) +function test_small_union(x::Ref{Union{Float64,Vector{Float64}}}) v = x[] return v isa Float64 ? v : v[1] end @@ -494,7 +500,7 @@ end @noinline edge_case_tester(x::Float32) = 6.0 @noinline edge_case_tester(x::Int) = 10 @noinline edge_case_tester(x::String) = "hi" -@is_primitive MinimalCtx Tuple{typeof(edge_case_tester), Float64} +@is_primitive MinimalCtx Tuple{typeof(edge_case_tester),Float64} function Mooncake.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64}) edge_case_tester_pb!!(dy) = Mooncake.NoRData(), 5 * dy return Mooncake.zero_fcodual(5 * primal(x)), edge_case_tester_pb!! @@ -552,12 +558,12 @@ test_for_invoke(x) = 5x inlinable_invoke_call(x::Float64) = invoke(test_for_invoke, Tuple{Float64}, x) -vararg_test_for_invoke(n::Tuple{Int, Int}, x...) = sum(x) + n[1] +vararg_test_for_invoke(n::Tuple{Int,Int}, x...) = sum(x) + n[1] function inlinable_vararg_invoke_call( rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N} ) where {N} - return invoke(vararg_test_for_invoke, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...) + return invoke(vararg_test_for_invoke, Tuple{typeof(rows),Vararg{N}}, rows, n1, ns...) end # build_rrule should error for this function, because it references a non-const global ref. @@ -595,8 +601,12 @@ function generate_test_functions() (false, :allocs, nothing, unused_expression, 5.0, 1), (false, :none, nothing, type_unstable_argument_eval, sin, 5.0), ( - false, :none, nothing, - abstractly_typed_unused_container, StructFoo(5.0, [4.0]), 5.0, + false, + :none, + nothing, + abstractly_typed_unused_container, + StructFoo(5.0, [4.0]), + 5.0, ), (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5.0)), (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5)), @@ -641,8 +651,12 @@ function generate_test_functions() (false, :allocs, nothing, avoid_throwing_path_tester, 5.0), (true, :allocs, nothing, simple_foreigncall_tester, "hello"), ( - false, :none, nothing, - simple_foreigncall_tester_2, TypeVar(:T, Union{}, Any), Vector{T} where {T} + false, + :none, + nothing, + simple_foreigncall_tester_2, + TypeVar(:T, Union{}, Any), + Vector{T} where {T}, ), (false, :none, nothing, no_primitive_inlining_tester, 5.0), (false, :allocs, nothing, varargs_tester, 5.0), @@ -668,14 +682,19 @@ function generate_test_functions() (false, :none, (lb=1, ub=1_000), datatype_slot_tester, 2), (false, :none, (lb=1, ub=100_000_000), test_union_of_arrays, randn(5), true), ( - false, :none, nothing, - test_union_of_types, Ref{Union{Type{Float64}, Type{Int}}}(Float64), + false, + :none, + nothing, + test_union_of_types, + Ref{Union{Type{Float64},Type{Int}}}(Float64), ), (false, :allocs, nothing, test_self_reference, 1.1, 1.5), (false, :allocs, nothing, test_self_reference, 1.5, 1.1), (false, :none, nothing, test_recursive_sum, randn(2)), ( - false, :none, nothing, + false, + :none, + nothing, LinearAlgebra._modify!, LinearAlgebra.MulAddMul(5.0, 4.0), 5.0, @@ -685,12 +704,20 @@ function generate_test_functions() (false, :allocs, nothing, getfield_tester, (5.0, 5)), (false, :allocs, nothing, getfield_tester_2, (5.0, 5)), ( - false, :allocs, nothing, - setfield_tester_left!, FullyInitMutableStruct(5.0, randn(3)), 4.0, + false, + :allocs, + nothing, + setfield_tester_left!, + FullyInitMutableStruct(5.0, randn(3)), + 4.0, ), ( - false, :none, nothing, - setfield_tester_right!, FullyInitMutableStruct(5.0, randn(3)), randn(5), + false, + :none, + nothing, + setfield_tester_right!, + FullyInitMutableStruct(5.0, randn(3)), + randn(5), ), (false, :none, nothing, mul!, randn(3, 5)', randn(5, 5), randn(5, 3), 4.0, 3.0), (false, :none, nothing, Random.SHA.digest!, Random.SHA.SHA2_256_CTX()), @@ -719,29 +746,47 @@ function generate_test_functions() (false, :none, nothing, test_struct_partial_init, 3.5), (false, :none, nothing, test_mutable_partial_init, 3.3), ( - false, :allocs, nothing, - test_naive_mat_mul!, randn(100, 50), randn(100, 30), randn(30, 50), + false, + :allocs, + nothing, + test_naive_mat_mul!, + randn(100, 50), + randn(100, 30), + randn(30, 50), ), ( - false, :allocs, nothing, - (A, C) -> test_naive_mat_mul!(C, A, A), randn(25, 25), randn(25, 25), + false, + :allocs, + nothing, + (A, C) -> test_naive_mat_mul!(C, A, A), + randn(25, 25), + randn(25, 25), ), (false, :allocs, nothing, sum, randn(32)), (false, :none, nothing, test_diagonal_to_matrix, Diagonal(randn(30))), ( - false, :allocs, nothing, - ldiv!, randn(20, 20), Diagonal(rand(20) .+ 1), randn(20, 20), - ), - ( - false, :allocs, nothing, - LinearAlgebra._kron!, randn(25, 25), randn(5, 5), randn(5, 5), + false, + :allocs, + nothing, + ldiv!, + randn(20, 20), + Diagonal(rand(20) .+ 1), + randn(20, 20), ), ( - false, :allocs, nothing, - kron!, randn(25, 25), Diagonal(randn(5)), randn(5, 5), + false, + :allocs, + nothing, + LinearAlgebra._kron!, + randn(25, 25), + randn(5, 5), + randn(5, 5), ), + (false, :allocs, nothing, kron!, randn(25, 25), Diagonal(randn(5)), randn(5, 5)), ( - false, :none, nothing, + false, + :none, + nothing, test_mlp, randn(sr(1), 50, 20), randn(sr(2), 70, 50), @@ -756,8 +801,14 @@ function generate_test_functions() (false, :none, nothing, _broadcast_sin_cos_exp, randn(10, 10)), (false, :none, nothing, _map_sin_cos_exp, randn(10, 10)), (false, :none, nothing, ArgumentError, "hi"), - (false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}(5.0)), - (false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}([1.0])), + (false, :none, nothing, test_small_union, Ref{Union{Float64,Vector{Float64}}}(5.0)), + ( + false, + :none, + nothing, + test_small_union, + Ref{Union{Float64,Vector{Float64}}}([1.0]), + ), (false, :allocs, nothing, inlinable_invoke_call, 5.0), (false, :none, nothing, inlinable_vararg_invoke_call, (2, 2), 5.0, 4.0, 3.0, 2.0), (false, :none, nothing, hvcat, (2, 2), 3.0, 2.0, 0.0, 1.0), diff --git a/src/test_utils.jl b/src/test_utils.jl index 0568da6e1..6aa62d95d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -13,7 +13,7 @@ using Core: svec using ExprTools: combinedef using ..Mooncake: NoTangent, tangent_type, _typeof -const PRIMALS = Tuple{Bool, Any, Tuple}[] +const PRIMALS = Tuple{Bool,Any,Tuple}[] # Generate all of the composite types against which we might wish to test. function generate_primals() @@ -32,7 +32,6 @@ function generate_primals() ns_always_def = 0:n_fields for fields in field_combinations, n_always_def in ns_always_def - mutable_str = is_mutable ? "Mutable" : "" field_types = map(x -> x.type, fields) type_string = join(map(string, field_types), "_") @@ -52,12 +51,14 @@ function generate_primals() # Specify inner constructors. map(n_always_def:n_fields) do n - return combinedef(Dict( - :head => :function, - :name => name, - :args => field_names[1:n], - :body => Expr(:call, :new, field_names[1:n]...), - )) + return combinedef( + Dict( + :head => :function, + :name => name, + :args => field_names[1:n], + :body => Expr(:call, :new, field_names[1:n]...), + ), + ) end..., ), ) @@ -65,7 +66,7 @@ function generate_primals() t = @eval $name for n in n_always_def:n_fields - interface_only = any(x -> isbitstype(x.type), fields[n+1:end]) + interface_only = any(x -> isbitstype(x.type), fields[(n + 1):end]) fields_copies = map(x -> deepcopy(x.primal), fields[1:n]) push!(PRIMALS, (interface_only, t, fields_copies)) end @@ -88,13 +89,41 @@ module TestUtils using Random, Mooncake, Test, InteractiveUtils using Mooncake: - CoDual, NoTangent, rrule!!, is_init, zero_codual, DefaultCtx, @is_primitive, val, - is_always_fully_initialised, get_tangent_field, set_tangent_field!, MutableTangent, - Tangent, _typeof, rdata, NoFData, to_fwds, uninit_fdata, zero_rdata, - zero_rdata_from_type, CannotProduceZeroRDataFromType, lazy_zero_rdata, instantiate, - can_produce_zero_rdata_from_type, increment_rdata!!, fcodual_type, - verify_fdata_type, verify_rdata_type, verify_fdata_value, verify_rdata_value, - InvalidFDataException, InvalidRDataException, uninit_codual, lgetfield, lsetfield! + CoDual, + NoTangent, + rrule!!, + is_init, + zero_codual, + DefaultCtx, + @is_primitive, + val, + is_always_fully_initialised, + get_tangent_field, + set_tangent_field!, + MutableTangent, + Tangent, + _typeof, + rdata, + NoFData, + to_fwds, + uninit_fdata, + zero_rdata, + zero_rdata_from_type, + CannotProduceZeroRDataFromType, + lazy_zero_rdata, + instantiate, + can_produce_zero_rdata_from_type, + increment_rdata!!, + fcodual_type, + verify_fdata_type, + verify_rdata_type, + verify_fdata_value, + verify_rdata_value, + InvalidFDataException, + InvalidRDataException, + uninit_codual, + lgetfield, + lsetfield! struct Shim end @@ -118,20 +147,42 @@ that takes an additional `visited` dictionary to track visited objects and avoid recursion in cases of circular references. """ function has_equal_data(x, y; equal_undefs=true) - return has_equal_data_internal(x, y, equal_undefs, Dict{Tuple{UInt, UInt}, Bool}()) + return has_equal_data_internal(x, y, equal_undefs, Dict{Tuple{UInt,UInt},Bool}()) end -has_equal_data_internal(x::Type, y::Type, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) = x == y -has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T<:String} = x == y -has_equal_data_internal(x::Core.TypeName, y::Core.TypeName, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) = x == y -function has_equal_data_internal(x::Float64, y::Float64, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) +function has_equal_data_internal( + x::Type, y::Type, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) + return x == y +end +function has_equal_data_internal( + x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) where {T<:String} + return x == y +end +function has_equal_data_internal( + x::Core.TypeName, y::Core.TypeName, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) + return x == y +end +function has_equal_data_internal( + x::Float64, y::Float64, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) return (isapprox(x, y) && !isnan(x)) || (isnan(x) && isnan(y)) end -has_equal_data_internal(x::Module, y::Module, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) = x == y -function has_equal_data_internal(x::GlobalRef, y::GlobalRef; equal_undefs=true, d::Dict{Tuple{UInt, UInt}, Bool}) +function has_equal_data_internal( + x::Module, y::Module, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) + return x == y +end +function has_equal_data_internal( + x::GlobalRef, y::GlobalRef; equal_undefs=true, d::Dict{Tuple{UInt,UInt},Bool} +) return x.mod == y.mod && x.name == y.name end -function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T<:Array} +function has_equal_data_internal( + x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) where {T<:Array} size(x) != size(y) && return false # The dictionary is used to detect circular references in the data structures. @@ -163,10 +214,14 @@ function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{U end return all(equality) end -function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T<:Core.SimpleVector} +function has_equal_data_internal( + x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) where {T<:Core.SimpleVector} return all(map((a, b) -> has_equal_data_internal(a, b, equal_undefs, d), x, y)) end -function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T} +function has_equal_data_internal( + x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) where {T} isprimitivetype(T) && return isequal(x, y) id_pair = (objectid(x), objectid(y)) @@ -177,9 +232,17 @@ function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{U d[id_pair] = true if ismutabletype(x) - return all(map(fieldnames(T)) do n - isdefined(x, n) ? has_equal_data_internal(getfield(x, n), getfield(y, n), equal_undefs, d) : true - end) + return all( + map(fieldnames(T)) do n + if isdefined(x, n) + has_equal_data_internal( + getfield(x, n), getfield(y, n), equal_undefs, d + ) + else + true + end + end, + ) else for n in fieldnames(T) if !isdefined(x, n) && !isdefined(y, n) @@ -197,11 +260,15 @@ function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{U return true end end -has_equal_data_internal(x::T, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}) where {T,P} = false +function has_equal_data_internal( + x::T, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool} +) where {T,P} + return false +end has_equal_data_up_to_undefs(x::T, y::T) where {T} = has_equal_data(x, y; equal_undefs=false) -const AddressMap = Dict{Ptr{Nothing}, Ptr{Nothing}} +const AddressMap = Dict{Ptr{Nothing},Ptr{Nothing}} """ populate_address_map(primal, tangent) @@ -217,7 +284,7 @@ Fills `m` with pairs mapping from memory addresses in `primal` to corresponding addresses in `tangent`. If the same memory address appears multiple times in `primal`, throws an `AssertionError` if the same address is not mapped to in `tangent` each time. """ -function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P, T} +function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P,T} isprimitivetype(P) && return m T === NoTangent && return m T === NoFData && return m @@ -241,10 +308,10 @@ function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P, T return m end -__get_data_field(t::Union{Tangent, MutableTangent}, n) = getfield(t.fields, n) -__get_data_field(t::Union{Mooncake.FData, Mooncake.RData}, n) = getfield(t.data, n) +__get_data_field(t::Union{Tangent,MutableTangent}, n) = getfield(t.fields, n) +__get_data_field(t::Union{Mooncake.FData,Mooncake.RData}, n) = getfield(t.data, n) -function populate_address_map!(m::AddressMap, p::P, t) where {P<:Union{Tuple, NamedTuple}} +function populate_address_map!(m::AddressMap, p::P, t) where {P<:Union{Tuple,NamedTuple}} t isa NoFData && return m t isa NoTangent && return m foreach(n -> populate_address_map!(m, getfield(p, n), getfield(t, n)), fieldnames(P)) @@ -269,7 +336,7 @@ function populate_address_map!(m::AddressMap, p::Core.SimpleVector, t::Vector{An return m end -populate_address_map!(m::AddressMap, p::Union{Core.TypeName, Type, Symbol, String}, t) = m +populate_address_map!(m::AddressMap, p::Union{Core.TypeName,Type,Symbol,String}, t) = m """ address_maps_are_consistent(x::AddressMap, y::AddressMap) @@ -309,7 +376,9 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: x̄_zero = map(zero_tangent, x) x̄_fwds = map(Mooncake.fdata, x̄_zero) x_x̄_rule = map((x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f), x, x̄_fwds) - inputs_address_map = populate_address_map(map(primal, x_x̄_rule), map(tangent, x_x̄_rule)) + inputs_address_map = populate_address_map( + map(primal, x_x̄_rule), map(tangent, x_x̄_rule) + ) y_ȳ_rule, pb!! = rule(x_x̄_rule...) # Verify that inputs / outputs are the same under `f` and its rrule. @@ -339,7 +408,8 @@ function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb: @test all(map(has_equal_data_up_to_undefs, x, map(primal, x_x̄_rule))) # pullbacks increment, so have to compare to the incremented quantity. - @test _dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post) ≈ _dot(x̄, ẋ) rtol=1e-3 atol=1e-3 + @test _dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post) ≈ _dot(x̄, ẋ) rtol = 1e-3 atol = + 1e-3 end get_address(x) = ismutable(x) ? pointer_from_objref(x) : nothing @@ -347,7 +417,7 @@ get_address(x) = ismutable(x) ? pointer_from_objref(x) : nothing _deepcopy(x) = deepcopy(x) _deepcopy(x::Module) = x -rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty), Any} +rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty),Any} function test_rrule_interface(f_f̄, x_x̄...; rule) @nospecialize f_f̄ x_x̄ @@ -386,9 +456,11 @@ function test_rrule_interface(f_f̄, x_x̄...; rule) catch e display(e) println() - throw(ArgumentError( - "rule for $(_typeof(f_fwds)) with argument types $(_typeof(x_fwds)) does not run." - )) + throw( + ArgumentError( + "rule for $(_typeof(f_fwds)) with argument types $(_typeof(x_fwds)) does not run.", + ), + ) end @test rrule_ret isa rrule_output_type(_typeof(y)) y_ȳ, pb!! = rrule_ret @@ -404,9 +476,11 @@ function test_rrule_interface(f_f̄, x_x̄...; rule) catch e display(e) println() - throw(ArgumentError( - "pullback for $(_typeof(f_f̄)) with argument types $(_typeof(x_x̄)) does not run." - )) + throw( + ArgumentError( + "pullback for $(_typeof(f_f̄)) with argument types $(_typeof(x_x̄)) does not run.", + ), + ) end # Check that the pullback returns the correct number of things. @@ -422,21 +496,23 @@ function test_rrule_interface(f_f̄, x_x̄...; rule) @test all(map((a, b) -> _typeof(a) == _typeof(rdata(b)), x̄_new, x̄)) end -function __forwards_and_backwards(rule, x_x̄::Vararg{Any, N}) where {N} +function __forwards_and_backwards(rule, x_x̄::Vararg{Any,N}) where {N} out, pb!! = rule(x_x̄...) return pb!!(Mooncake.zero_rdata(primal(out))) end function test_rrule_performance( - performance_checks_flag::Symbol, rule::R, f_f̄::F, x_x̄::Vararg{Any, N} -) where {R, F, N} + performance_checks_flag::Symbol, rule::R, f_f̄::F, x_x̄::Vararg{Any,N} +) where {R,F,N} # Verify that a valid performance flag has been passed. valid_flags = (:none, :stability, :allocs, :stability_and_allocs) if !in(performance_checks_flag, valid_flags) - throw(ArgumentError( - "performance_checks=$performance_checks_flag. Must be one of $valid_flags" - )) + throw( + ArgumentError( + "performance_checks=$performance_checks_flag. Must be one of $valid_flags" + ), + ) end performance_checks_flag == :none && return nothing @@ -451,7 +527,7 @@ function test_rrule_performance( # Test reverse-pass stability. y_ȳ, pb!! = rule(to_fwds(f_f̄), map(to_fwds, _deepcopy(x_x̄))...) rvs_data = Mooncake.rdata(zero_tangent(primal(y_ȳ), tangent(y_ȳ))) - test_opt(Shim(), pb!!, (_typeof(rvs_data), )) + test_opt(Shim(), pb!!, (_typeof(rvs_data),)) end if performance_checks_flag in (:allocs, :stability_and_allocs) @@ -472,67 +548,68 @@ end __get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs) -@doc""" - test_rule( - rng, x...; - interface_only=false, - is_primitive::Bool=true, - perf_flag::Symbol=:none, - interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), - debug_mode::Bool=false, - unsafe_perturb::Bool=false, - ) - -Run standardised tests on the `rule` for `x`. -The first element of `x` should be the primal function to test, and each other element a -positional argument. -In most cases, elements of `x` can just be the primal values, and `randn_tangent` can be -relied upon to generate an appropriate tangent to test. Some notable exceptions exist -though, in partcular `Ptr`s. In this case, the argument for which `randn_tangent` cannot be -readily defined should be a `CoDual` containing the primal, and a _manually_ constructed -tangent field. - -This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will use an -`rrule!!` if one exists, and derive a rule otherwise. - -# Arguments -- `rng::AbstractRNG`: a random number generator -- `x...`: the function (first element) and its arguments (the remainder) - -# Keyword Arguments -- `interface_only::Bool=false`: test only that the interface is satisfied, without testing - correctness. This should generally be set to `false` (the default value), and only - enabled if the testing infrastructure is unable to test correctness for some reason - e.g. the returned value of the function is a `Ptr`, and appropriate tangents cannot, - therefore, be generated for it automatically. -- `is_primitive::Bool=true`: check whether the thing that you are testing has a hand-written - `rrule!!`. This option is helpful if you are testing a new `rrule!!`, as it enables you - to verify that your method of `is_primitive` has returned the correct value, and that - you are actually testing a method of the `rrule!!` function -- a common mistake when - authoring a new `rrule!!` is to implement `is_primitive` incorrectly and to accidentally - wind up testing a rule which Mooncake has derived, as opposed to the one that you have - written. If you are testing something for which you have not - hand-written an `rrule!!`, or which you do not care whether it has a hand-written - `rrule!!` or not, you should set it to `false`. -- `perf_flag::Symbol=:none`: the value of this symbol determines what kind of performance - tests should be performed. By default, none are performed. If you believe that a rule - should be allocation-free (iff the primal is allocation free), set this to `:allocs`. If - you hand-write an `rrule!!` and believe that your test case should be type stable, set - this to `:stability` (at present we cannot verify whether a derived rule is type stable - for technical reasons). If you believe that a hand-written rule should be _both_ - allocation-free and type-stable, set this to `:stability_and_allocs`. -- `interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter()`: the abstract - interpreter to be used when testing this rule. The default should generally be used. -- `debug_mode::Bool=false`: whether or not the rule should be tested in debug mode. - Typically this should be left at its default `false` value, but if you are finding that - the tests are failing for a given rule, you may wish to temporarily set it to `true` in - order to get access to additional information and automated testing. -- `unsafe_perturb::Bool=false`: value passed as the third argument to `_add_to_primal`. - Should usually be left `false` -- consult the docstring for `_add_to_primal` for more - info on when you might wish to set it to `true`. -""" +@doc """ + test_rule( + rng, x...; + interface_only=false, + is_primitive::Bool=true, + perf_flag::Symbol=:none, + interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), + debug_mode::Bool=false, + unsafe_perturb::Bool=false, + ) + + Run standardised tests on the `rule` for `x`. + The first element of `x` should be the primal function to test, and each other element a + positional argument. + In most cases, elements of `x` can just be the primal values, and `randn_tangent` can be + relied upon to generate an appropriate tangent to test. Some notable exceptions exist + though, in partcular `Ptr`s. In this case, the argument for which `randn_tangent` cannot be + readily defined should be a `CoDual` containing the primal, and a _manually_ constructed + tangent field. + + This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will use an + `rrule!!` if one exists, and derive a rule otherwise. + + # Arguments + - `rng::AbstractRNG`: a random number generator + - `x...`: the function (first element) and its arguments (the remainder) + + # Keyword Arguments + - `interface_only::Bool=false`: test only that the interface is satisfied, without testing + correctness. This should generally be set to `false` (the default value), and only + enabled if the testing infrastructure is unable to test correctness for some reason + e.g. the returned value of the function is a `Ptr`, and appropriate tangents cannot, + therefore, be generated for it automatically. + - `is_primitive::Bool=true`: check whether the thing that you are testing has a hand-written + `rrule!!`. This option is helpful if you are testing a new `rrule!!`, as it enables you + to verify that your method of `is_primitive` has returned the correct value, and that + you are actually testing a method of the `rrule!!` function -- a common mistake when + authoring a new `rrule!!` is to implement `is_primitive` incorrectly and to accidentally + wind up testing a rule which Mooncake has derived, as opposed to the one that you have + written. If you are testing something for which you have not + hand-written an `rrule!!`, or which you do not care whether it has a hand-written + `rrule!!` or not, you should set it to `false`. + - `perf_flag::Symbol=:none`: the value of this symbol determines what kind of performance + tests should be performed. By default, none are performed. If you believe that a rule + should be allocation-free (iff the primal is allocation free), set this to `:allocs`. If + you hand-write an `rrule!!` and believe that your test case should be type stable, set + this to `:stability` (at present we cannot verify whether a derived rule is type stable + for technical reasons). If you believe that a hand-written rule should be _both_ + allocation-free and type-stable, set this to `:stability_and_allocs`. + - `interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter()`: the abstract + interpreter to be used when testing this rule. The default should generally be used. + - `debug_mode::Bool=false`: whether or not the rule should be tested in debug mode. + Typically this should be left at its default `false` value, but if you are finding that + the tests are failing for a given rule, you may wish to temporarily set it to `true` in + order to get access to additional information and automated testing. + - `unsafe_perturb::Bool=false`: value passed as the third argument to `_add_to_primal`. + Should usually be left `false` -- consult the docstring for `_add_to_primal` for more + info on when you might wish to set it to `true`. + """ function test_rule( - rng::AbstractRNG, x...; + rng::AbstractRNG, + x...; interface_only::Bool=false, is_primitive::Bool=true, perf_flag::Symbol=:none, @@ -550,7 +627,13 @@ function test_rule( is_primitive && @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) # Generate random tangents for anything that is not already a CoDual. - x_x̄ = map(x -> x isa CoDual ? x : interface_only ? uninit_codual(x) : zero_codual(x), x) + x_x̄ = map(x -> if x isa CoDual + x + elseif interface_only + uninit_codual(x) + else + zero_codual(x) + end, x) # Test that the interface is basically satisfied (checks types / memory addresses). test_rrule_interface(x_x̄...; rule) @@ -562,28 +645,30 @@ function test_rule( test_rrule_performance(perf_flag, rule, x_x̄...) # Test the interface again, in order to verify that caching is working correctly. - test_rrule_interface(x_x̄..., rule=Mooncake.build_rrule(interp, sig; debug_mode)) + return test_rrule_interface(x_x̄...; rule=Mooncake.build_rrule(interp, sig; debug_mode)) end - function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) test_cases, memory = Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, v) - GC.@preserve memory @testset "$f, $(_typeof(x))" for (interface_only, perf_flag, _, f, x...) in test_cases + GC.@preserve memory @testset "$f, $(_typeof(x))" for ( + interface_only, perf_flag, _, f, x... + ) in test_cases test_rule(rng_ctor(123), f, x...; interface_only, perf_flag) end end function run_derived_rrule!!_test_cases(rng_ctor, v::Val) test_cases, memory = Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, v) - GC.@preserve memory @testset "$f, $(typeof(x))" for - (interface_only, perf_flag, _, f, x...) in test_cases + GC.@preserve memory @testset "$f, $(typeof(x))" for ( + interface_only, perf_flag, _, f, x... + ) in test_cases test_rule(rng_ctor(123), f, x...; interface_only, perf_flag, is_primitive=false) end end function run_rrule!!_test_cases(rng_ctor, v::Val) run_hand_written_rrule!!_test_cases(rng_ctor, v) - run_derived_rrule!!_test_cases(rng_ctor, v) + return run_derived_rrule!!_test_cases(rng_ctor, v) end # @@ -594,8 +679,8 @@ generate_args(::typeof(===), x) = [(x, 0.0), (1.0, x)] function generate_args(::typeof(Core.ifelse), x) return [(true, x, 0.0), (false, x, 0.0), (true, 0.0, x), (false, 0.0, x)] end -generate_args(::typeof(Core.sizeof), x) = [(x, )] -generate_args(::typeof(Core.svec), x) = [(x, ), (x, x)] +generate_args(::typeof(Core.sizeof), x) = [(x,)] +generate_args(::typeof(Core.svec), x) = [(x,), (x, x)] function generate_args(::typeof(getfield), x) syms = filter(f -> isdefined(x, f), fieldnames(_typeof(x))) fs = vcat(syms..., eachindex(syms)...) @@ -621,7 +706,7 @@ else # Consequently, it does not make sense to call `_new_` on them -- while this _can_ be # made to work, it typically yields segfaults in very short order, and I _believe_ it # should never occur in practice. - _new_excluded(::Type{<:Union{Memory, MemoryRef}}) = true + _new_excluded(::Type{<:Union{Memory,MemoryRef}}) = true end function generate_args(::typeof(Mooncake._new_), x) @@ -643,9 +728,9 @@ function generate_args(::typeof(lsetfield!), x) end return map(n -> (x, Val(n), getfield(x, n)), vcat(names..., eachindex(names)...)) end -generate_args(::typeof(tuple), x) = [(x, ), (x, x), (x, x, x)] +generate_args(::typeof(tuple), x) = [(x,), (x, x), (x, x, x)] generate_args(::typeof(typeassert), x) = [(x, _typeof(x))] -generate_args(::typeof(typeof), x) = [(x, )] +generate_args(::typeof(typeof), x) = [(x,)] function functions_for_all_types() return [===, Core.ifelse, Core.sizeof, isa, tuple, typeassert, typeof] @@ -690,7 +775,9 @@ function test_rule_and_type_interactions(rng::AbstractRNG, p::P) where {P} arg_sets = generate_args(f, p) @testset for args in arg_sets test_rule( - rng, f, args...; + rng, + f, + args...; interface_only=true, is_primitive=true, perf_flag=:none, @@ -708,7 +795,7 @@ infers / optimises away. """ function test_tangent_type(primal_type::Type, expected_tangent_type::Type) @test tangent_type(primal_type) == expected_tangent_type - test_opt(Shim(), tangent_type, Tuple{_typeof(primal_type)}) + return test_opt(Shim(), tangent_type, Tuple{_typeof(primal_type)}) end """ @@ -832,7 +919,7 @@ function test_set_tangent_field!_correctness(t1::T, t2::T) where {T<:MutableTang end end -function check_allocs(::Any, f::F, x::Tuple{Vararg{Any, N}}) where {F, N} +function check_allocs(::Any, f::F, x::Tuple{Vararg{Any,N}}) where {F,N} throw(error("Load AllocCheck.jl to use this functionality.")) end @@ -863,8 +950,8 @@ function test_tangent_performance(rng::AbstractRNG, p::P) where {P} # Check there are no allocations when there ought not to be. if !__tangent_generation_should_allocate(P) - test_opt(Shim(), Tuple{typeof(zero_tangent), P}) - test_opt(Shim(), Tuple{typeof(randn_tangent), Xoshiro, P}) + test_opt(Shim(), Tuple{typeof(zero_tangent),P}) + test_opt(Shim(), Tuple{typeof(randn_tangent),Xoshiro,P}) end # `increment!!` should always infer. @@ -879,30 +966,27 @@ function test_tangent_performance(rng::AbstractRNG, p::P) where {P} # set_tangent_field! should never allocate. t isa MutableTangent && test_set_tangent_field!_performance(t, z) - t isa Union{MutableTangent, Tangent} && test_get_tangent_field_performance(t) + return t isa Union{MutableTangent,Tangent} && test_get_tangent_field_performance(t) end function test_allocations(t::T, z::T) where {T} check_allocs(Shim(), increment!!, (t, t)) check_allocs(Shim(), increment!!, (t, z)) check_allocs(Shim(), increment!!, (z, t)) - check_allocs(Shim(), increment!!, (z, z)) + return check_allocs(Shim(), increment!!, (z, z)) end _set_tangent_field!(x, ::Val{i}, v) where {i} = set_tangent_field!(x, i, v) _get_tangent_field(x, ::Val{i}) where {i} = get_tangent_field(x, i) -function test_set_tangent_field!_performance(t1::T, t2::T) where {V, T<:MutableTangent{V}} +function test_set_tangent_field!_performance(t1::T, t2::T) where {V,T<:MutableTangent{V}} for n in 1:fieldcount(V) !is_init(t2.fields[n]) && continue v = get_tangent_field(t2, n) # Int mode. _set_tangent_field!(t1, Val(n), v) - report_opt( - Shim(), - Tuple{typeof(_set_tangent_field!), typeof(t1), Val{n}, typeof(v)}, - ) + report_opt(Shim(), Tuple{typeof(_set_tangent_field!),typeof(t1),Val{n},typeof(v)}) if all(n -> !(fieldtype(V, n) <: Mooncake.PossiblyUninitTangent), 1:fieldcount(V)) i = Val(n) @@ -914,8 +998,7 @@ function test_set_tangent_field!_performance(t1::T, t2::T) where {V, T<:MutableT s = Val(fieldname(V, n)) @inferred _set_tangent_field!(t1, s, v) report_opt( - Shim(), - Tuple{typeof(_set_tangent_field!), typeof(t1), typeof(s), typeof(v)}, + Shim(), Tuple{typeof(_set_tangent_field!),typeof(t1),typeof(s),typeof(v)} ) if all(n -> !(fieldtype(V, n) <: Mooncake.PossiblyUninitTangent), 1:fieldcount(V)) @@ -925,7 +1008,7 @@ function test_set_tangent_field!_performance(t1::T, t2::T) where {V, T<:MutableT end end -function test_get_tangent_field_performance(t::Union{MutableTangent, Tangent}) +function test_get_tangent_field_performance(t::Union{MutableTangent,Tangent}) V = Mooncake._typeof(t.fields) for n in 1:fieldcount(V) !is_init(t.fields[n]) && continue @@ -934,20 +1017,20 @@ function test_get_tangent_field_performance(t::Union{MutableTangent, Tangent}) # Int mode. i = Val(n) - report_opt(Shim(), Tuple{typeof(_get_tangent_field), typeof(t), typeof(i)}) + report_opt(Shim(), Tuple{typeof(_get_tangent_field),typeof(t),typeof(i)}) @inferred _get_tangent_field(t, i) @test count_allocs(_get_tangent_field, t, i) == 0 # Symbol mode. s = Val(fieldname(V, n)) - report_opt(Shim(), Tuple{typeof(_get_tangent_field), typeof(t), typeof(s)}) + report_opt(Shim(), Tuple{typeof(_get_tangent_field),typeof(t),typeof(s)}) @inferred _get_tangent_field(t, s) @test count_allocs(_get_tangent_field, t, s) == 0 end end # Function barrier to ensure inference in value types. -function count_allocs(f::F, x::Vararg{Any, N}) where {F, N} +function count_allocs(f::F, x::Vararg{Any,N}) where {F,N} @allocations f(x...) end @@ -974,21 +1057,21 @@ function __is_completely_stable_type(::Type{P}) where {P} return all(__is_completely_stable_type, fieldtypes(P)) end -@doc""" - test_tangent(rng::AbstractRNG, p::P, x::T, y::T, z_target::T) where {P, T} +@doc """ + test_tangent(rng::AbstractRNG, p::P, x::T, y::T, z_target::T) where {P, T} -Verify that primal `p` with tangents `z_target`, `x`, and `y`, satisfies the tangent -interface. If these tests pass, then it should be possible to write rules for primals -of type `P`, and to test them using [`test_rule`](@ref). + Verify that primal `p` with tangents `z_target`, `x`, and `y`, satisfies the tangent + interface. If these tests pass, then it should be possible to write rules for primals + of type `P`, and to test them using [`test_rule`](@ref). -It should be the case that `z_target` == `increment!!(x, y)`. + It should be the case that `z_target` == `increment!!(x, y)`. -As always, there are limits to the errors that these tests can identify -- they form -necessary but not sufficient conditions for the correctness of your code. -""" + As always, there are limits to the errors that these tests can identify -- they form + necessary but not sufficient conditions for the correctness of your code. + """ function test_tangent( rng::AbstractRNG, p::P, x::T, y::T, z_target::T; interface_only, perf=true -) where {P, T} +) where {P,T} @nospecialize rng p x y z_target # Check the interface. @@ -1008,12 +1091,12 @@ function test_tangent( end # Check performance is as expected. - perf && test_tangent_performance(rng, p) + return perf && test_tangent_performance(rng, p) end function test_tangent(rng::AbstractRNG, p::P; interface_only=false, perf=true) where {P} test_tangent_consistency(rng, p; interface_only) - perf && test_tangent_performance(rng, p) + return perf && test_tangent_performance(rng, p) end function test_equality_comparison(x) @@ -1115,7 +1198,7 @@ written in Mooncake itself. function test_data(rng::AbstractRNG, p::P; interface_only=false) where {P} test_tangent_consistency(rng, p; interface_only) test_fwds_rvs_data(rng, p) - test_rule_and_type_interactions(rng, p) + return test_rule_and_type_interactions(rng, p) end end diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index b51ffa423..65ee34368 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -134,7 +134,7 @@ anything in `f` or `x`. This is always the case if the result is a bits type, bu may be required if it is not. ``` """ -@inline function zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N} +@inline function zero_adjoint(f::CoDual, x::Vararg{CoDual,N}) where {N} return zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...) end @@ -203,13 +203,11 @@ macro zero_adjoint(ctx, sig) is_vararg = arg_type_symbols[end] === :Vararg if is_vararg arg_types = vcat( - map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols[1:end-1]), + map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols[1:(end - 1)]), :(Vararg{Mooncake.CoDual}), ) splat_symbol = Expr(Symbol("..."), arg_names[end]) - body = Expr( - :call, Mooncake.zero_adjoint, arg_names[1:end-1]..., splat_symbol, - ) + body = Expr(:call, Mooncake.zero_adjoint, arg_names[1:(end - 1)]..., splat_symbol) else arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) body = Expr(:call, Mooncake.zero_adjoint, arg_names...) @@ -256,31 +254,31 @@ function increment_and_get_rdata!(f, r, t::CRC.Thunk) return increment_and_get_rdata!(f, r, CRC.unthunk(t)) end -@doc""" - rrule_wrapper(f::CoDual, args::CoDual...) +@doc """ + rrule_wrapper(f::CoDual, args::CoDual...) -Used to implement `rrule!!`s via `ChainRulesCore.rrule`. + Used to implement `rrule!!`s via `ChainRulesCore.rrule`. -Given a function `foo`, argument types `arg_types`, and a method of `ChainRulesCore.rrule` -which applies to these, you can make use of this function as follows: -```julia -Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} -function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) - return rrule_wrapper(f, args...) -end -``` -Assumes that methods of `to_cr_tangent` and `to_mooncake_tangent` are defined such that you -can convert between the different representations of tangents that Mooncake and -ChainRulesCore expect. + Given a function `foo`, argument types `arg_types`, and a method of `ChainRulesCore.rrule` + which applies to these, you can make use of this function as follows: + ```julia + Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} + function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) + return rrule_wrapper(f, args...) + end + ``` + Assumes that methods of `to_cr_tangent` and `to_mooncake_tangent` are defined such that you + can convert between the different representations of tangents that Mooncake and + ChainRulesCore expect. -Furthermore, it is _essential_ that -1. `f(args)` does not mutate `f` or `args`, and -2. the result of `f(args)` does not alias any data stored in `f` or `args`. + Furthermore, it is _essential_ that + 1. `f(args)` does not mutate `f` or `args`, and + 2. the result of `f(args)` does not alias any data stored in `f` or `args`. -Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the -amount of boilerplate code that you are required to write even further. -""" -function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} + Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the + amount of boilerplate code that you are required to write even further. + """ +function rrule_wrapper(fargs::Vararg{CoDual,N}) where {N} # Run forwards-pass. primals = tuple_map(primal, fargs) @@ -304,7 +302,7 @@ function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} return CoDual(y_primal, y_fdata), pb!! end -function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) where {N} +function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual,N}) where {N} # Run forwards-pass. primals = tuple_map(primal, fargs) @@ -335,121 +333,120 @@ function construct_rrule_wrapper_def(arg_names, arg_types, where_params) return construct_def(arg_names, arg_types, where_params, body) end -@doc""" - @from_rrule ctx sig [has_kwargs=false] +@doc """ + @from_rrule ctx sig [has_kwargs=false] -Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. + Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. -# Arguments + # Arguments -- `ctx`: A Mooncake context type -- `sig`: the signature which you wish to assert should be a primitive in `Mooncake.jl`, and - use an existing `ChainRulesCore.rrule` to implement this functionality. -- `has_kwargs`: a `Bool` state whether or not the function has keyword arguments. This - feature has the same limitations as `ChainRulesCore.rrule` -- the derivative w.r.t. all - kwargs must be zero. + - `ctx`: A Mooncake context type + - `sig`: the signature which you wish to assert should be a primitive in `Mooncake.jl`, and + use an existing `ChainRulesCore.rrule` to implement this functionality. + - `has_kwargs`: a `Bool` state whether or not the function has keyword arguments. This + feature has the same limitations as `ChainRulesCore.rrule` -- the derivative w.r.t. all + kwargs must be zero. -# Example Usage + # Example Usage -## A Basic Example + ## A Basic Example -```jldoctest -julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils + ```jldoctest + julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils -julia> using ChainRulesCore + julia> using ChainRulesCore -julia> foo(x::Real) = 5x; + julia> foo(x::Real) = 5x; -julia> function ChainRulesCore.rrule(::typeof(foo), x::Real) - foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω - return foo(x), foo_pb - end; + julia> function ChainRulesCore.rrule(::typeof(foo), x::Real) + foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω + return foo(x), foo_pb + end; -julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} + julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} -julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0) -(NoRData(), 5.0) + julia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0) + (NoRData(), 5.0) -julia> # Check that the rule works as intended. - TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true) -Test Passed -``` - -## An Example with Keyword Arguments + julia> # Check that the rule works as intended. + TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true) + Test Passed + ``` -```jldoctest -julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils + ## An Example with Keyword Arguments -julia> using ChainRulesCore + ```jldoctest + julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils -julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x; + julia> using ChainRulesCore -julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool) - foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω - return foo(x; cond), foo_pb - end; + julia> foo(x::Real; cond::Bool) = cond ? 5x : 4x; -julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true + julia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool) + foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω + return foo(x; cond), foo_pb + end; -julia> _, pb = rrule!!( - zero_fcodual(Core.kwcall), - zero_fcodual((cond=false, )), - zero_fcodual(foo), - zero_fcodual(5.0), - ); + julia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true -julia> pb(3.0) -(NoRData(), NoRData(), NoRData(), 12.0) + julia> _, pb = rrule!!( + zero_fcodual(Core.kwcall), + zero_fcodual((cond=false, )), + zero_fcodual(foo), + zero_fcodual(5.0), + ); -julia> # Check that the rule works as intended. - TestUtils.test_rule( - Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true - ) -Test Passed -``` -Notice that, in order to access the kwarg method we must call the method of `Core.kwcall`, -as Mooncake's `rrule!!` does not itself permit the use of kwargs. + julia> pb(3.0) + (NoRData(), NoRData(), NoRData(), 12.0) -# Limitations - -It is your responsibility to ensure that -1. calls with signature `sig` do not mutate their arguments, -2. the output of calls with signature `sig` does not alias any of the inputs. - -As with all hand-written rules, you should definitely make use of -[`TestUtils.test_rule`](@ref) to verify correctness on some test cases. - -# Argument Type Constraints - -Many methods of `ChainRuleCore.rrule` are implemented with very loose type constraints. -For example, it would not be surprising to see a method of rrule with the signature -```julia -Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} -``` -There are a variety of reasons for this way of doing things, and whether it is a good idea -to write rules for such generic objects has been debated at length. - -Suffice it to say, you should not write rules for _this_ package which are so generically -typed. -Rather, you should create rules for the subset of types for which you believe that the -`ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the -rest. -For example, it is quite common to be confident that a given rule will work correctly for -any `Base.IEEEFloat` argument, i.e. `Union{Float16, Float32, Float64}`, but it is usually -not possible to know that the rule is correct for all possible subtypes of `Real` that -someone might define. - -# Conversions Between Different Tangent Type Systems - -Under the hood, this functionality relies on two functions: `Mooncake.to_cr_tangent`, and -`Mooncake.increment_and_get_rdata!`. These two functions handle conversion to / from -`Mooncake` tangent types and `ChainRulesCore` tangent types. This functionality is known to -work well for simple types, but has not been tested to a great extent on complicated -composite types. If `@from_rrule` does not work in your case because the required method of -either of these functions does not exist, please open an issue. -""" + julia> # Check that the rule works as intended. + TestUtils.test_rule( + Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true + ) + Test Passed + ``` + Notice that, in order to access the kwarg method we must call the method of `Core.kwcall`, + as Mooncake's `rrule!!` does not itself permit the use of kwargs. + + # Limitations + + It is your responsibility to ensure that + 1. calls with signature `sig` do not mutate their arguments, + 2. the output of calls with signature `sig` does not alias any of the inputs. + + As with all hand-written rules, you should definitely make use of + [`TestUtils.test_rule`](@ref) to verify correctness on some test cases. + + # Argument Type Constraints + + Many methods of `ChainRuleCore.rrule` are implemented with very loose type constraints. + For example, it would not be surprising to see a method of rrule with the signature + ```julia + Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} + ``` + There are a variety of reasons for this way of doing things, and whether it is a good idea + to write rules for such generic objects has been debated at length. + + Suffice it to say, you should not write rules for _this_ package which are so generically + typed. + Rather, you should create rules for the subset of types for which you believe that the + `ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the + rest. + For example, it is quite common to be confident that a given rule will work correctly for + any `Base.IEEEFloat` argument, i.e. `Union{Float16, Float32, Float64}`, but it is usually + not possible to know that the rule is correct for all possible subtypes of `Real` that + someone might define. + + # Conversions Between Different Tangent Type Systems + + Under the hood, this functionality relies on two functions: `Mooncake.to_cr_tangent`, and + `Mooncake.increment_and_get_rdata!`. These two functions handle conversion to / from + `Mooncake` tangent types and `ChainRulesCore` tangent types. This functionality is known to + work well for simple types, but has not been tested to a great extent on complicated + composite types. If `@from_rrule` does not work in your case because the required method of + either of these functions does not exist, please open an issue. + """ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) - arg_type_syms, where_params = parse_signature_expr(sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_syms)) arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_syms) diff --git a/src/utils.jl b/src/utils.jl index d52f6f9f3..6374a2ae9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ Central definition of typeof, which is specific to the use-required in this pack """ _typeof(x) = Base._stable_typeof(x) _typeof(x::Tuple) = Tuple{tuple_map(_typeof, x)...} -_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names, _typeof(Tuple(x))} +_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names,_typeof(Tuple(x))} """ tuple_map(f::F, x::Tuple) where {F} @@ -37,28 +37,34 @@ end end for N in 1:128 - @eval @inline function tuple_map(f::F, x::Tuple{Vararg{Any, $N}}) where {F} + @eval @inline function tuple_map(f::F, x::Tuple{Vararg{Any,$N}}) where {F} return $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...)) end - @eval @inline function tuple_map(f::F, x::NamedTuple{names, <:Tuple{Vararg{Any, $N}}}) where {F, names} - return NamedTuple{names}($(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...))) - end @eval @inline function tuple_map( - f, x::Tuple{Vararg{Any, $N}}, y::Tuple{Vararg{Any, $N}} - ) - return $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...)) + f::F, x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}} + ) where {F,names} + return NamedTuple{names}( + $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...)) + ) + end + @eval @inline function tuple_map(f, x::Tuple{Vararg{Any,$N}}, y::Tuple{Vararg{Any,$N}}) + return $(Expr( + :call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)... + )) end @eval @inline function tuple_map( f::F, - x::NamedTuple{names, <:Tuple{Vararg{Any, $N}}}, - y::NamedTuple{names, <:Tuple{Vararg{Any, $N}}}, - ) where {F, names} - return NamedTuple{names}($(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...))) + x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}}, + y::NamedTuple{names,<:Tuple{Vararg{Any,$N}}}, + ) where {F,names} + return NamedTuple{names}( + $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...)) + ) end end for N in 1:256 - @eval @inline function tuple_splat(f, x::Tuple{Vararg{Any, $N}}) + @eval @inline function tuple_splat(f, x::Tuple{Vararg{Any,$N}}) return $(Expr(:call, :f, map(n -> :(getfield(x, $n)), 1:N)...)) end end @@ -71,11 +77,10 @@ end return Expr(:call, :f, :v, map(n -> :(x[$n]), 1:length(x.parameters))...) end -@inline @generated function tuple_fill(val ,::Val{N}) where {N} +@inline @generated function tuple_fill(val, ::Val{N}) where {N} return Expr(:call, :tuple, map(_ -> :val, 1:N)...) end - """ _map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P} @@ -86,7 +91,7 @@ Equivalent to `map!(f, y, x)` if `P` is a bits type as element will always be as Requires that `y` and `x` have the same size. """ -function _map_if_assigned!(f::F, y::DenseArray, x::DenseArray{P}) where {F, P} +function _map_if_assigned!(f::F, y::DenseArray, x::DenseArray{P}) where {F,P} @assert size(y) == size(x) @inbounds for n in eachindex(y) if isbitstype(P) || isassigned(x, n) @@ -104,7 +109,9 @@ writes `f(x1[n], x2[n])` to `y[n]`, otherwise leaves `y[n]` unchanged. Requires that `y`, `x1`, and `x2` have the same size. """ -function _map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray) where {F, P} +function _map_if_assigned!( + f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray +) where {F,P} @assert size(y) == size(x1) @assert size(y) == size(x2) @inbounds for n in eachindex(y) @@ -121,7 +128,7 @@ end Same as `map` but requires all elements of `x` to have equal length. The usual function `map` doesn't enforce this for `Array`s. """ -@inline function _map(f::F, x::Vararg{Any, N}) where {F, N} +@inline function _map(f::F, x::Vararg{Any,N}) where {F,N} @assert allequal(map(length, x)) return map(f, x...) end @@ -139,11 +146,13 @@ is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m) Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it. """ -function is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}} +function is_vararg_and_sparam_names(sig)::Tuple{Bool,Vector{Symbol}} world = Base.get_world_counter() min = Base.RefValue{UInt}(typemin(UInt)) max = Base.RefValue{UInt}(typemax(UInt)) - ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector + ms = Base._methods_by_ftype( + sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL) + )::Vector return is_vararg_and_sparam_names(only(ms).method) end @@ -152,7 +161,7 @@ end Calls `is_vararg_and_sparam_names` on `mi.def::Method`. """ -function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool, Vector{Symbol}} +function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool,Vector{Symbol}} return is_vararg_and_sparam_names(mi.def) end @@ -198,7 +207,7 @@ end Like `getfield`, but with the field and access order encoded as types. """ -lgetfield(x, ::Val{f}, ::Val{order}) where {f, order} = getfield(x, f, order) +lgetfield(x, ::Val{f}, ::Val{order}) where {f,order} = getfield(x, f, order) """ lsetfield!(value, name::Val, x, [order::Val]) @@ -216,6 +225,6 @@ lsetfield!(value, ::Val{name}, x) where {name} = setfield!(value, name, x) One-liner which calls the `:new` instruction with type `T` with arguments `x`. """ -@inline @generated function _new_(::Type{T}, x::Vararg{Any, N}) where {T, N} +@inline @generated function _new_(::Type{T}, x::Vararg{Any,N}) where {T,N} return Expr(:new, :T, map(n -> :(x[$n]), 1:N)...) end diff --git a/test/codual.jl b/test/codual.jl index 4ae349298..9383ab41e 100644 --- a/test/codual.jl +++ b/test/codual.jl @@ -1,21 +1,23 @@ @testset "codual" begin - @test CoDual(5.0, 4.0) isa CoDual{Float64, Float64} - @test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64}, NoTangent} + @test CoDual(5.0, 4.0) isa CoDual{Float64,Float64} + @test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64},NoTangent} @test zero_codual(5.0) == CoDual(5.0, 0.0) - @test codual_type(Float64) == CoDual{Float64, Float64} - @test codual_type(Int) == CoDual{Int, NoTangent} + @test codual_type(Float64) == CoDual{Float64,Float64} + @test codual_type(Int) == CoDual{Int,NoTangent} @test codual_type(Real) == CoDual @test codual_type(Any) == CoDual - @test codual_type(Type{UnitRange{Int}}) == CoDual{Type{UnitRange{Int}}, NoTangent} + @test codual_type(Type{UnitRange{Int}}) == CoDual{Type{UnitRange{Int}},NoTangent} @test codual_type(Type{Tuple{T}} where {T}) <: CoDual @test Mooncake.fcodual_type(Type{Tuple{T}} where {T}) <: CoDual - @test(==( - codual_type(Union{Float64, Int}), - Union{CoDual{Float64, Float64}, CoDual{Int, NoTangent}}, - )) + @test( + ==( + codual_type(Union{Float64,Int}), + Union{CoDual{Float64,Float64},CoDual{Int,NoTangent}}, + ) + ) @test codual_type(UnionAll) == CoDual @testset "NoPullback" begin @test Base.issingletontype(typeof(NoPullback(zero_fcodual(5.0)))) - @test NoPullback(zero_codual(5.0))(4.0) == (0.0, ) + @test NoPullback(zero_codual(5.0))(4.0) == (0.0,) end end diff --git a/test/debug_mode.jl b/test/debug_mode.jl index 3e5f25245..7d352d64b 100644 --- a/test/debug_mode.jl +++ b/test/debug_mode.jl @@ -5,7 +5,7 @@ @testset "argument checking" begin f = x -> 5x rule = build_rrule(f, 5.0; debug_mode=true) - @test_throws ErrorException rule(zero_fcodual(f), CoDual(0f0, 1f0)) + @test_throws ErrorException rule(zero_fcodual(f), CoDual(0.0f0, 1.0f0)) end # Forwards-pass tests. @@ -13,7 +13,7 @@ @test_throws(ErrorException, Mooncake.DebugRRule(rrule!!)(x...)) x = (CoDual(sin, NoFData()), CoDual(5.0, NoFData())) @test_throws( - ErrorException, Mooncake.DebugRRule((x..., ) -> (CoDual(1.0, 0.0), nothing))(x...) + ErrorException, Mooncake.DebugRRule((x...,) -> (CoDual(1.0, 0.0), nothing))(x...) ) # Basic type checking. @@ -24,7 +24,7 @@ # just by looking at the array. x = ( CoDual(size, NoFData()), - CoDual(Any[rand() for _ in 1:10], Any[rand(Float16) for _ in 1:10]) + CoDual(Any[rand() for _ in 1:10], Any[rand(Float16) for _ in 1:10]), ) @test_throws ErrorException Mooncake.DebugRRule(rrule!!)(x...) @@ -33,7 +33,7 @@ @test_throws(InvalidRDataException, pb!!(5)) # Test that bad rdata is caught as a post-condition. - rule_with_bad_pb(x::CoDual{Float64}) = x, dy -> (5, ) # returns the wrong type + rule_with_bad_pb(x::CoDual{Float64}) = x, dy -> (5,) # returns the wrong type y, pb!! = Mooncake.DebugRRule(rule_with_bad_pb)(zero_fcodual(5.0)) @test_throws InvalidRDataException pb!!(1.0) diff --git a/test/developer_tools.jl b/test/developer_tools.jl index f3b63fac0..7c65bfd67 100644 --- a/test/developer_tools.jl +++ b/test/developer_tools.jl @@ -1,5 +1,5 @@ @testset "developer_tools" begin - sig = Tuple{typeof(sin), Float64} + sig = Tuple{typeof(sin),Float64} @test Mooncake.primal_ir(sig) isa CC.IRCode @test Mooncake.fwd_ir(sig) isa CC.IRCode @test Mooncake.rvs_ir(sig) isa CC.IRCode diff --git a/test/ext/cuda/cuda.jl b/test/ext/cuda/cuda.jl index 82e4fc9a5..e26933332 100644 --- a/test/ext/cuda/cuda.jl +++ b/test/ext/cuda/cuda.jl @@ -9,13 +9,19 @@ using Mooncake.TestUtils: test_tangent, test_rule # Check we can operate on CuArrays. test_tangent( - StableRNG(123456), CuArray{Float32, 2, CUDA.DeviceMemory}(undef, 8, 8); + StableRNG(123456), + CuArray{Float32,2,CUDA.DeviceMemory}(undef, 8, 8); interface_only=false, ) # Check we can instantiate a CuArray. test_rule( - StableRNG(123456), CuArray{Float32, 1, CUDA.DeviceMemory}, undef, 256; - interface_only=true, is_primitive=true, debug_mode=true, + StableRNG(123456), + CuArray{Float32,1,CUDA.DeviceMemory}, + undef, + 256; + interface_only=true, + is_primitive=true, + debug_mode=true, ) end diff --git a/test/ext/differentiation_interface/differentiation_interface.jl b/test/ext/differentiation_interface/differentiation_interface.jl index 6b1874c22..ea74e84b6 100644 --- a/test/ext/differentiation_interface/differentiation_interface.jl +++ b/test/ext/differentiation_interface/differentiation_interface.jl @@ -3,9 +3,10 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using DifferentiationInterface, DifferentiationInterfaceTest -import Mooncake +using Mooncake: Mooncake test_differentiation( [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())]; - excluded=SECOND_ORDER, logging=true, + excluded=SECOND_ORDER, + logging=true, ) diff --git a/test/ext/luxlib/luxlib.jl b/test/ext/luxlib/luxlib.jl index 304c0a63f..e725d3e6b 100644 --- a/test/ext/luxlib/luxlib.jl +++ b/test/ext/luxlib/luxlib.jl @@ -7,36 +7,50 @@ using LuxLib.Impl: sleefpirates_fast_act using Mooncake.TestUtils: test_rule @testset "luxlib" begin - @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in vcat( + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in + vcat( Any[ (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), - (false, :none, true, LuxLib.Impl.batched_matmul, randn(5, 4, 3), randn(4, 3, 3)), + ( + false, + :none, + true, + LuxLib.Impl.batched_matmul, + randn(5, 4, 3), + randn(4, 3, 3), + ), (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), ], - map(Any[ - LuxLib.NNlib.sigmoid_fast, - LuxLib.NNlib.softplus, - LuxLib.NNlib.logsigmoid, - LuxLib.NNlib.swish, - LuxLib.NNlib.lisht, - Base.tanh, - LuxLib.NNlib.tanh_fast, - ]) do f + map( + Any[ + LuxLib.NNlib.sigmoid_fast, + LuxLib.NNlib.softplus, + LuxLib.NNlib.logsigmoid, + LuxLib.NNlib.swish, + LuxLib.NNlib.lisht, + Base.tanh, + LuxLib.NNlib.tanh_fast, + ], + ) do f return (false, :stability_and_allocs, true, sleefpirates_fast_act(f), randn()) end, Any[ ( - false, :stability_and_allocs, true, + false, + :stability_and_allocs, + true, LuxLib.Utils.static_training_mode_check, nothing, LuxLib.Utils.True(), LuxLib.Utils.True(), ), ( - false, :none, false, - function(opmode, act, x, m, sigma2, gamma, beta) - LuxLib.Impl.batchnorm_affine_normalize_internal( + false, + :none, + false, + function (opmode, act, x, m, sigma2, gamma, beta) + return LuxLib.Impl.batchnorm_affine_normalize_internal( opmode, act, x, m, sigma2, gamma, beta, 1e-3 ) end, @@ -49,16 +63,27 @@ using Mooncake.TestUtils: test_rule nothing, ), ], - vec(map(Iterators.product( - [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], - [randn(5), nothing], - [Lux.relu, tanh, NNlib.gelu], - )) do (opmode, bias, activation) - ( - false, :none, false, - LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, - ) - end), + vec( + map( + Iterators.product( + [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + [randn(5), nothing], + [Lux.relu, tanh, NNlib.gelu], + ), + ) do (opmode, bias, activation) + ( + false, + :none, + false, + LuxLib.Impl.fused_dense, + opmode, + activation, + randn(5, 4), + randn(4, 2), + bias, + ) + end, + ), ) test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only) end diff --git a/test/ext/nnlib/nnlib.jl b/test/ext/nnlib/nnlib.jl index 8ee065e5a..c43c4fe53 100644 --- a/test/ext/nnlib/nnlib.jl +++ b/test/ext/nnlib/nnlib.jl @@ -22,33 +22,46 @@ using NNlib: dropout grid[:, 1, 2, 1] .= (-1, 1) grid[:, 2, 2, 1] .= (1, 1) - @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in + Any[ # batched_mul (false, :none, true, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), # dropout ( - true, :none, false, - (x, p) -> dropout(StableRNG(1), x, p; dims=1), randn(2, 2), 0.5, + true, + :none, + false, + (x, p) -> dropout(StableRNG(1), x, p; dims=1), + randn(2, 2), + 0.5, ), ( - true, :none, false, - (x, p) -> dropout(StableRNG(1), x, p; dims=2), randn(2, 2), 0.1, + true, + :none, + false, + (x, p) -> dropout(StableRNG(1), x, p; dims=2), + randn(2, 2), + 0.1, ), ( - true, :none, false, - (x, p) -> dropout(StableRNG(1), x, p; dims=(1, 2)), randn(2, 2), 0.4, + true, + :none, + false, + (x, p) -> dropout(StableRNG(1), x, p; dims=(1, 2)), + randn(2, 2), + 0.4, ), # softmax (false, :stability, true, softmax, randn(2)), (false, :stability, true, softmax, randn(2, 2)), - (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(2,)), - (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1,), softmax, randn(2)), + (false, :stability, true, Core.kwcall, (dims=1,), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2,), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2),), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2),), softmax, randn(3, 3, 2)), (false, :none, false, x -> softmax(5x), randn(3, 2)), (false, :none, false, x -> softmax(x; dims=1), randn(3, 2)), (false, :none, false, x -> softmax(x; dims=2), randn(3, 2)), @@ -58,21 +71,21 @@ using NNlib: dropout (false, :stability, true, logsoftmax, randn(2)), (false, :stability, true, logsoftmax, randn(2, 3)), (false, :stability, true, logsoftmax, randn(2, 3, 2)), - (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), - (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1,), logsoftmax, randn(2)), + (false, :stability, true, Core.kwcall, (dims=1,), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2,), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2),), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2),), logsoftmax, randn(3, 3, 2)), # logsumexp - (false, :stability, true, logsumexp, randn(2,)), + (false, :stability, true, logsumexp, randn(2)), (false, :stability, true, logsumexp, randn(3, 3)), (false, :stability, true, logsumexp, randn(3, 3, 2)), - (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), - (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), - (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1,), logsumexp, randn(2)), + (false, :stability, true, Core.kwcall, (dims=1,), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2,), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2),), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2),), logsumexp, randn(3, 3, 2)), # upsample_nearest (false, :stability, true, upsample_nearest, randn(3), (2,)), diff --git a/test/front_matter.jl b/test/front_matter.jl index ada88403e..be56bdfe9 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -1,5 +1,4 @@ -using - AllocCheck, +using AllocCheck, Aqua, BenchmarkTools, DiffRules, @@ -11,7 +10,7 @@ using Mooncake, Test -import ChainRulesCore +using ChainRulesCore: ChainRulesCore using Base: unsafe_load, pointer_from_objref, IEEEFloat, TwicePrecision using Base.Iterators: product diff --git a/test/fwds_rvs_data.jl b/test/fwds_rvs_data.jl index ce3126524..660748c06 100644 --- a/test/fwds_rvs_data.jl +++ b/test/fwds_rvs_data.jl @@ -1,19 +1,15 @@ module FwdsRvsDataTestResources - struct Foo{A} end +struct Foo{A} end end @testset "fwds_rvs_data" begin @testset "fdata_type / rdata_type($P)" for (P, F, R) in Any[ ( - Tuple{Any, Vector{Float64}}, - Tuple{Any, Vector{Float64}}, - Union{NoRData, Tuple{Any, NoRData}}, - ), - ( - Tuple{Any, Float64}, - Union{NoFData, Tuple{Any, NoFData}}, - Tuple{Any, Float64}, + Tuple{Any,Vector{Float64}}, + Tuple{Any,Vector{Float64}}, + Union{NoRData,Tuple{Any,NoRData}}, ), + (Tuple{Any,Float64}, Union{NoFData,Tuple{Any,NoFData}}, Tuple{Any,Float64}), ] @test fdata_type(tangent_type(P)) == F @test rdata_type(tangent_type(P)) == R @@ -25,14 +21,14 @@ end @test Mooncake.can_produce_zero_rdata_from_type(Vector) == true @test Mooncake.zero_rdata_from_type(Vector) == NoRData() @test !Mooncake.can_produce_zero_rdata_from_type(FwdsRvsDataTestResources.Foo) - @test Mooncake.can_produce_zero_rdata_from_type(Tuple{Float64, Type{Float64}}) + @test Mooncake.can_produce_zero_rdata_from_type(Tuple{Float64,Type{Float64}}) @test ==( Mooncake.zero_rdata_from_type(FwdsRvsDataTestResources.Foo), Mooncake.CannotProduceZeroRDataFromType(), ) @test !Mooncake.can_produce_zero_rdata_from_type(Tuple) - @test !Mooncake.can_produce_zero_rdata_from_type(Union{Tuple{Float64}, Tuple{Int}}) - @test !Mooncake.can_produce_zero_rdata_from_type(Tuple{T, T} where {T<:Integer}) + @test !Mooncake.can_produce_zero_rdata_from_type(Union{Tuple{Float64},Tuple{Int}}) + @test !Mooncake.can_produce_zero_rdata_from_type(Tuple{T,T} where {T<:Integer}) @test Mooncake.can_produce_zero_rdata_from_type(Type{Float64}) # Edge case: Types with unbound type parameters. @@ -51,14 +47,18 @@ end (Int, 5, true), (Int32, Int32(5), true), (Float64, 5.0, true), - (Float32, 5f0, true), + (Float32, 5.0f0, true), (Float16, Float16(5.0), true), (StructFoo, StructFoo(5.0), false), (StructFoo, StructFoo(5.0, randn(4)), false), (Type{Bool}, Bool, true), - (Type{Mooncake.TestResources.StableFoo}, Mooncake.TestResources.StableFoo, true), - (Tuple{Float64, Float64}, (5.0, 4.0), true), - (Tuple{Float64, Vararg{Float64}}, (5.0, 4.0, 3.0), false), + ( + Type{Mooncake.TestResources.StableFoo}, + Mooncake.TestResources.StableFoo, + true, + ), + (Tuple{Float64,Float64}, (5.0, 4.0), true), + (Tuple{Float64,Vararg{Float64}}, (5.0, 4.0, 3.0), false), (Type{Type{Tuple{T}} where {T}}, Type{Tuple{T}} where {T}, true), ] L = Mooncake.lazy_zero_rdata_type(P) @@ -75,19 +75,21 @@ end ) end @testset "misc fdata / rdata type checking" begin - @test(==( - Mooncake.rdata_type(tangent_type(Tuple{Union{Float32, Float64}})), - Tuple{Union{Float32, Float64}}, - )) - @test(==( - Mooncake.rdata_type(tangent_type(Tuple{Union{Int32, Int}})), NoRData - )) - @test(==( - Mooncake.rdata_type(tangent_type( - Tuple{Union{Vector{Float32}, Vector{Float64}}} - )), - NoRData, - )) + @test( + ==( + Mooncake.rdata_type(tangent_type(Tuple{Union{Float32,Float64}})), + Tuple{Union{Float32,Float64}}, + ) + ) + @test(==(Mooncake.rdata_type(tangent_type(Tuple{Union{Int32,Int}})), NoRData)) + @test( + ==( + Mooncake.rdata_type( + tangent_type(Tuple{Union{Vector{Float32},Vector{Float64}}}) + ), + NoRData, + ) + ) end # Tests that the static type of an fdata / rdata is correct happen in @@ -100,9 +102,9 @@ end end @testset "Tuple" begin @test_throws InvalidFDataException verify_fdata_value((), ()) - @test_throws InvalidFDataException verify_fdata_value((5,), (NoFData(), )) + @test_throws InvalidFDataException verify_fdata_value((5,), (NoFData(),)) @test_throws InvalidRDataException verify_rdata_value((), ()) - @test_throws InvalidRDataException verify_rdata_value((5,), (NoRData(), )) + @test_throws InvalidRDataException verify_rdata_value((5,), (NoRData(),)) end @testset "Ptr" begin @test verify_fdata_value(Ptr{Float64}(), Ptr{Float64}()) === nothing diff --git a/test/integration_testing/array/array.jl b/test/integration_testing/array/array.jl index 84c7fc4a9..d3f47bb01 100644 --- a/test/integration_testing/array/array.jl +++ b/test/integration_testing/array/array.jl @@ -1,6 +1,6 @@ using Pkg Pkg.activate(@__DIR__) -Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using LinearAlgebra, Mooncake, StableRNGs, Test using Mooncake.TestUtils: test_rule @@ -119,7 +119,7 @@ _getter() = 5.0 ( false, *, - randn(sr(0), ), + randn(sr(0)), adjoint(randn(sr(1), 2, 1)), randn(sr(2), 2, 3), randn(sr(3), 3), @@ -400,16 +400,20 @@ _getter() = 5.0 (false, x -> minimum(cos, x; dims=2), randn(sr(7), 3, 2)), (false, minimum!, sin, randn(sr(9), 2), randn(sr(8), 2, 3)), ], - vec(reduce( - vcat, - map(Iterators.product( - [adjoint(randn(sr(0), 2, 3)), transpose(randn(sr(1), 2, 3))], - [randn(sr(3), 2), randn(sr(2), 2, 3)], - [randn(sr(4)), randn(sr(5), 1), randn(sr(6), 3)], - )) do (A, b, z) - (false, muladd, A, b, z) - end, - )), + vec( + reduce( + vcat, + map( + Iterators.product( + [adjoint(randn(sr(0), 2, 3)), transpose(randn(sr(1), 2, 3))], + [randn(sr(3), 2), randn(sr(2), 2, 3)], + [randn(sr(4)), randn(sr(5), 1), randn(sr(6), 3)], + ), + ) do (A, b, z) + (false, muladd, A, b, z) + end, + ), + ), Any[ (false, ndims, randn(sr(7), 2)), (false, ndims, randn(sr(8), 1, 2, 1, 1, 1)), @@ -507,7 +511,7 @@ _getter() = 5.0 (false, view, randn(sr(0), 3, 2), :, :), (false, zero, randn(sr(1), 3)), (false, zero, randn(sr(2), 2, 3)), - ] + ], ) @testset for (interface_only, f, x...) in test_cases @info typeof((f, x...)) diff --git a/test/integration_testing/battery_tests/battery_tests.jl b/test/integration_testing/battery_tests/battery_tests.jl index caafbcdf8..266d164f9 100644 --- a/test/integration_testing/battery_tests/battery_tests.jl +++ b/test/integration_testing/battery_tests/battery_tests.jl @@ -1,6 +1,6 @@ using Pkg Pkg.activate(@__DIR__) -Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using JET, LinearAlgebra, Mooncake, Random, StableRNGs, Test using Mooncake: TestResources @@ -8,17 +8,28 @@ using Mooncake: TestResources @testset "battery_tests" begin @testset "$(typeof(p))" for p in vcat( [ - true, false, - UInt8(0), UInt8(3), - UInt16(0), UInt16(5), - UInt32(0), UInt32(7), - UInt64(0), UInt64(9), - UInt128(0), UInt128(3), - Int8(0), Int8(3), - Int16(0), Int16(-1), - Int32(0), Int32(-3), - Int64(0), Int64(5), - Int128(0), Int128(24), + true, + false, + UInt8(0), + UInt8(3), + UInt16(0), + UInt16(5), + UInt32(0), + UInt32(7), + UInt64(0), + UInt64(9), + UInt128(0), + UInt128(3), + Int8(0), + Int8(3), + Int16(0), + Int16(-1), + Int32(0), + Int32(-3), + Int64(0), + Int64(5), + Int128(0), + Int128(24), "hello", ], randn(Float64, 5), @@ -46,7 +57,7 @@ using Mooncake: TestResources UpperTriangular(randn(3, 3)), UnitLowerTriangular(randn(3, 3)), UnitUpperTriangular(randn(2, 2)), - ] + ], ) Mooncake.TestUtils.test_data(StableRNG(123), p) end diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl index 54d67bc44..50ee4f640 100644 --- a/test/integration_testing/bijectors/bijectors.jl +++ b/test/integration_testing/bijectors/bijectors.jl @@ -1,6 +1,6 @@ using Pkg Pkg.activate(@__DIR__) -Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using Bijectors, LinearAlgebra, Mooncake, StableRNGs, Test using Mooncake.TestUtils: test_rule @@ -15,16 +15,16 @@ struct TestCase broken::Bool end -TestCase(f, arg; name = nothing, broken=false) = TestCase(f, arg, name, broken) +TestCase(f, arg; name=nothing, broken=false) = TestCase(f, arg, name, broken) """ A helper function that returns a TestCase that evaluates bijector(inverse(bijector)(x)) """ -function b_binv_test_case(bijector, dim; name = nothing, rng = StableRNG(23)) +function b_binv_test_case(bijector, dim; name=nothing, rng=StableRNG(23)) if name === nothing name = string(bijector) end - return TestCase(x -> bijector(inverse(bijector)(x)), randn(rng, dim); name = name) + return TestCase(x -> bijector(inverse(bijector)(x)), randn(rng, dim); name=name) end @testset "Bijectors integration tests" begin @@ -38,8 +38,7 @@ end b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), b_binv_test_case( - Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), - 3, + Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), 3 ), b_binv_test_case(Bijectors.InvertibleBatchNorm(3; eps=1e-5, mtm=1e-1), (3, 3)), b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), @@ -65,15 +64,13 @@ end TestCase( function (x) b = Bijectors.RationalQuadraticSpline( - [-0.2, 0.1, 0.5], - [-0.3, 0.3, 0.9], - [1.0, 0.2, 1.0], + [-0.2, 0.1, 0.5], [-0.3, 0.3, 0.9], [1.0, 0.2, 1.0] ) binv = Bijectors.inverse(b) return binv(b(x)) end, randn(StableRNG(23)); - name = "RationalQuadraticSpline on scalar", + name="RationalQuadraticSpline on scalar", ), TestCase( function (x) @@ -82,21 +79,20 @@ end return binv(b(x)) end, randn(StableRNG(23), 7); - name = "OrderedBijector", + name="OrderedBijector", ), TestCase( function (x) layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) flow = Bijectors.transformed( - Bijectors.MvNormal(zeros(2), LinearAlgebra.I), - layer, + Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer ) x = x[6:7] return Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x) end, randn(StableRNG(23), 7); - name = "PlanarLayer7", + name="PlanarLayer7", # TODO(mhauru) Broken on v1.11 due to # https://github.com/compintell/Mooncake.jl/issues/319 broken=(VERSION >= v"1.11"), @@ -105,8 +101,7 @@ end function (x) layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) flow = Bijectors.transformed( - Bijectors.MvNormal(zeros(2), LinearAlgebra.I), - layer, + Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer ) x = reshape(x[6:end], 2, :) return sum( @@ -115,7 +110,7 @@ end ) end, randn(StableRNG(23), 11); - name = "PlanarLayer11", + name="PlanarLayer11", ), ] diff --git a/test/integration_testing/diff_tests/diff_tests.jl b/test/integration_testing/diff_tests/diff_tests.jl index 228352112..8c0b1e0b2 100644 --- a/test/integration_testing/diff_tests/diff_tests.jl +++ b/test/integration_testing/diff_tests/diff_tests.jl @@ -1,17 +1,19 @@ using Pkg Pkg.activate(@__DIR__) -Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using Mooncake, Random, StableRNGs, Test using Mooncake.TestUtils: test_rule @testset "diff_tests" begin - @testset "$f, $(typeof(x))" for (n, (interface_only, f, x...)) in enumerate(vcat( - Mooncake.TestResources.DIFFTESTS_FUNCTIONS[1:6], # skipping DiffTests.num2arr_1. See https://github.com/JuliaLang/julia/issues/56193 - Mooncake.TestResources.DIFFTESTS_FUNCTIONS[8:66], # skipping sparse_ldiv - Mooncake.TestResources.DIFFTESTS_FUNCTIONS[68:89], # skipping sparse_ldiv - Mooncake.TestResources.DIFFTESTS_FUNCTIONS[91:end], # skipping sparse_ldiv - )) + @testset "$f, $(typeof(x))" for (n, (interface_only, f, x...)) in enumerate( + vcat( + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[1:6], # skipping DiffTests.num2arr_1. See https://github.com/JuliaLang/julia/issues/56193 + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[8:66], # skipping sparse_ldiv + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[68:89], # skipping sparse_ldiv + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[91:end], # skipping sparse_ldiv + ), + ) @info "$n: $(typeof((f, x...)))" test_rule(StableRNG(123456), f, x...; is_primitive=false) end diff --git a/test/integration_testing/distributions/distributions.jl b/test/integration_testing/distributions/distributions.jl index 19d72fa37..ebf96fe80 100644 --- a/test/integration_testing/distributions/distributions.jl +++ b/test/integration_testing/distributions/distributions.jl @@ -2,16 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using - AllocCheck, - JET, - Distributions, - FillArrays, - Mooncake, - LinearAlgebra, - PDMats, - StableRNGs, - Test +using AllocCheck, + JET, Distributions, FillArrays, Mooncake, LinearAlgebra, PDMats, StableRNGs, Test using Mooncake.TestUtils: test_rule @@ -224,8 +216,18 @@ sr(n::Int) = StableRNG(n) (:none, LKJ(5, 1.1), rand(sr(123456), LKJ(5, 1.1))), ] work_around_test_cases = Any[ - (:allocs, "InverseGamma", (a, b, x) -> logpdf(InverseGamma(a, b), x), (1.5, 1.4, 0.4)), - (:allocs, "NormalCanon", (m, s, x) -> logpdf(NormalCanon(m, s), x), (0.1, 1.0, -0.5)), + ( + :allocs, + "InverseGamma", + (a, b, x) -> logpdf(InverseGamma(a, b), x), + (1.5, 1.4, 0.4), + ), + ( + :allocs, + "NormalCanon", + (m, s, x) -> logpdf(NormalCanon(m, s), x), + (0.1, 1.0, -0.5), + ), (:none, "Categorical", x -> logpdf(Categorical(x, 1 - x), 1), 0.3), ( :none, @@ -254,10 +256,10 @@ sr(n::Int) = StableRNG(n) ( :none, "left-truncated Beta", - (a, α, β, x) -> logpdf(truncated(Beta(α, β), lower=a), x), + (a, α, β, x) -> logpdf(truncated(Beta(α, β); lower=a), x), (0.1, 1.1, 1.3, 0.4), ), - (:none, "Dirichlet", (a, x) -> logpdf(Dirichlet(a), [x, 1-x]), ([1.5, 1.1], 0.6)), + (:none, "Dirichlet", (a, x) -> logpdf(Dirichlet(a), [x, 1 - x]), ([1.5, 1.1], 0.6)), ( :none, "reshape", @@ -266,8 +268,9 @@ sr(n::Int) = StableRNG(n) ), (:none, "vec", x -> logpdf(vec(LKJ(2, 1.1)), x), ([1.0, 0.489, 0.489, 1.0],)), ( - :none, "LKJCholesky", - function(X, v) + :none, + "LKJCholesky", + function (X, v) # LKJCholesky distributes over the Cholesky factorisation of correlation # matrices, so the argument to `logpdf` must be such a matrix. S = X'X diff --git a/test/integration_testing/gp/gp.jl b/test/integration_testing/gp/gp.jl index b7b70bf6e..01358badf 100644 --- a/test/integration_testing/gp/gp.jl +++ b/test/integration_testing/gp/gp.jl @@ -32,12 +32,12 @@ using Mooncake.TestUtils: test_rule Any[(with_lengthscale(k, 1.1), x1, x2) for k in ks for (x1, x2) in xs], Any[(with_lengthscale(k, rand(rng, 2)), x1, x2) for k in ks for (x1, x2) in d_2_xs], Any[ - (k ∘ LinearTransform(randn(rng, 2, 2)), x1, x2) for - k in ks for (x1, x2) in d_2_xs + (k ∘ LinearTransform(randn(rng, 2, 2)), x1, x2) for k in ks for + (x1, x2) in d_2_xs ], Any[ - (k ∘ LinearTransform(Diagonal(randn(rng, 2))), x1, x2) for - k in ks for (x1, x2) in d_2_xs + (k ∘ LinearTransform(Diagonal(randn(rng, 2))), x1, x2) for k in ks for + (x1, x2) in d_2_xs ], ) fx = GP(k)(x1, 1.1) diff --git a/test/integration_testing/lux/lux.jl b/test/integration_testing/lux/lux.jl index 27d6e0b7c..74078f4a7 100644 --- a/test/integration_testing/lux/lux.jl +++ b/test/integration_testing/lux/lux.jl @@ -14,10 +14,22 @@ using Mooncake.TestUtils: test_rule (Scale(2), randn(Float32, 2, 3)), (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), - (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), - (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), - (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), - (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), + ( + Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), + randn(Float32, 3, 3, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), + rand(Float32, 5, 5, 2, 2), + ), + ( + Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), + rand(Float32, 5, 5, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), + rand(Float32, 5, 5, 2, 2), + ), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), @@ -25,11 +37,29 @@ using Mooncake.TestUtils: test_rule (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), - (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), + ( + Chain( + StatefulRecurrentCell(RNNCell(3 => 5)), + StatefulRecurrentCell(RNNCell(5 => 3)), + ), + rand(Float32, 3, 2), + ), (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), - (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), + ( + Chain( + StatefulRecurrentCell(LSTMCell(3 => 5)), + StatefulRecurrentCell(LSTMCell(5 => 3)), + ), + rand(Float32, 3, 2), + ), (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), - (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + ( + Chain( + StatefulRecurrentCell(GRUCell(3 => 5)), + StatefulRecurrentCell(GRUCell(5 => 3)), + ), + rand(Float32, 3, 10), + ), (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), @@ -39,7 +69,10 @@ using Mooncake.TestUtils: test_rule (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), - (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + ( + Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), + randn(Float32, 4, 4, 2, 2), + ), (InstanceNorm(6), randn(Float32, 6, 6, 2, 2)), (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), diff --git a/test/integration_testing/misc_abstract_array/misc_abstract_array.jl b/test/integration_testing/misc_abstract_array/misc_abstract_array.jl index 61fd1d6cb..2e04d7a21 100644 --- a/test/integration_testing/misc_abstract_array/misc_abstract_array.jl +++ b/test/integration_testing/misc_abstract_array/misc_abstract_array.jl @@ -1,6 +1,6 @@ using Pkg Pkg.activate(@__DIR__) -Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) using LinearAlgebra, Mooncake, Random, StableRNGs, Test using Mooncake.TestUtils: test_rule @@ -13,82 +13,100 @@ using Mooncake.TestUtils: test_rule (false, setindex!, randn(5), 4.0, 3), (false, setindex!, randn(5, 4), 3.0, 1, 3), (false, x -> getglobal(Main, :sin)(x), 5.0), - (false, x -> Base.pointerref(Base.bitcast(Ptr{Float64}, pointer_from_objref(Ref(x))), 1, 1), 5.0), + ( + false, + x -> Base.pointerref( + Base.bitcast(Ptr{Float64}, pointer_from_objref(Ref(x))), 1, 1 + ), + 5.0, + ), (false, (v, x) -> (Base.pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5)), (false, x -> (Base.pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5)), (false, x -> Ref(x)[], 5.0), - (false, x -> unsafe_load(Base.bitcast(Ptr{Float64}, pointer_from_objref(Ref(x)))), 5.0), + ( + false, + x -> unsafe_load(Base.bitcast(Ptr{Float64}, pointer_from_objref(Ref(x)))), + 5.0, + ), (false, x -> unsafe_load(Base.unsafe_convert(Ptr{Float64}, x)), randn(5)), (false, view, randn(5, 4), 1, 1), (false, view, randn(5, 4), 2:3, 1), (false, view, randn(5, 4), 1, 2:3), (false, view, randn(5, 4), 2:3, 2:4), - (true, Array{Float64, 1}, undef, (1, )), - (true, Array{Float64, 2}, undef, (2, 3)), - (true, Array{Float64, 3}, undef, (2, 3, 4)), - (false, Array{Vector{Float64}, 1}, undef, (1, )), - (false, Array{Vector{Float64}, 2}, undef, (2, 3)), - (false, Array{Vector{Float64}, 3}, undef, (2, 3, 4)), + (true, Array{Float64,1}, undef, (1,)), + (true, Array{Float64,2}, undef, (2, 3)), + (true, Array{Float64,3}, undef, (2, 3, 4)), + (false, Array{Vector{Float64},1}, undef, (1,)), + (false, Array{Vector{Float64},2}, undef, (2, 3)), + (false, Array{Vector{Float64},3}, undef, (2, 3, 4)), (false, push!, randn(5), 3.0), (false, x -> (a=x, b=x), 5.0), ], - map(n -> (false, map, sin, (randn(n)..., )), 1:7), + map(n -> (false, map, sin, (randn(n)...,)), 1:7), map(n -> (false, map, sin, randn(n)), 1:7), - map(n -> (false, x -> sin.(x), (randn(n)..., )), 1:7), + map(n -> (false, x -> sin.(x), (randn(n)...,)), 1:7), map(n -> (false, x -> sin.(x), randn(n)), 1:7), - vec(map(Iterators.product( - Any[ - randn(3, 5), - transpose(randn(5, 3)), - adjoint(randn(5, 3)), - view(randn(5, 5), 1:3, 1:5), - transpose(view(randn(5, 5), 1:5, 1:3)), - adjoint(view(randn(5, 5), 1:5, 1:3)), - ], - Any[ - randn(3, 4), - transpose(randn(4, 3)), - adjoint(randn(4, 3)), - view(randn(5, 5), 1:3, 1:4), - transpose(view(randn(5, 5), 1:4, 1:3)), - adjoint(view(randn(5, 5), 1:4, 1:3)), - ], - Any[ - randn(4, 5), - transpose(randn(5, 4)), - adjoint(randn(5, 4)), - view(randn(5, 5), 1:4, 1:5), - transpose(view(randn(5, 5), 1:5, 1:4)), - adjoint(view(randn(5, 5), 1:5, 1:4)), - ], - )) do (A, B, C) - (false, mul!, A, B, C, randn(), randn()) - end), - vec(map(Iterators.product( - Any[ - LowerTriangular(randn(3, 3)), - UpperTriangular(randn(3, 3)), - UnitLowerTriangular(randn(3, 3)), - UnitUpperTriangular(randn(3, 3)), - LowerTriangular(view(randn(5, 5), 2:4, 2:4)), - UpperTriangular(view(randn(5, 5), 2:4, 2:4)), - UnitLowerTriangular(view(randn(5, 5), 2:4, 2:4)), - UnitUpperTriangular(view(randn(5, 5), 2:4, 2:4)), - ], - Any[ - LowerTriangular(randn(3, 3)), - UpperTriangular(randn(3, 3)), - UnitLowerTriangular(randn(3, 3)), - UnitUpperTriangular(randn(3, 3)), - LowerTriangular(view(randn(5, 5), 2:4, 2:4)), - UpperTriangular(view(randn(5, 5), 2:4, 2:4)), - UnitLowerTriangular(view(randn(5, 5), 2:4, 2:4)), - UnitUpperTriangular(view(randn(5, 5), 2:4, 2:4)), - ], - )) do (B, C) - A = randn(3, 3) - (false, mul!, A, B, C, randn(), randn()) - end), + vec( + map( + Iterators.product( + Any[ + randn(3, 5), + transpose(randn(5, 3)), + adjoint(randn(5, 3)), + view(randn(5, 5), 1:3, 1:5), + transpose(view(randn(5, 5), 1:5, 1:3)), + adjoint(view(randn(5, 5), 1:5, 1:3)), + ], + Any[ + randn(3, 4), + transpose(randn(4, 3)), + adjoint(randn(4, 3)), + view(randn(5, 5), 1:3, 1:4), + transpose(view(randn(5, 5), 1:4, 1:3)), + adjoint(view(randn(5, 5), 1:4, 1:3)), + ], + Any[ + randn(4, 5), + transpose(randn(5, 4)), + adjoint(randn(5, 4)), + view(randn(5, 5), 1:4, 1:5), + transpose(view(randn(5, 5), 1:5, 1:4)), + adjoint(view(randn(5, 5), 1:5, 1:4)), + ], + ), + ) do (A, B, C) + (false, mul!, A, B, C, randn(), randn()) + end, + ), + vec( + map( + Iterators.product( + Any[ + LowerTriangular(randn(3, 3)), + UpperTriangular(randn(3, 3)), + UnitLowerTriangular(randn(3, 3)), + UnitUpperTriangular(randn(3, 3)), + LowerTriangular(view(randn(5, 5), 2:4, 2:4)), + UpperTriangular(view(randn(5, 5), 2:4, 2:4)), + UnitLowerTriangular(view(randn(5, 5), 2:4, 2:4)), + UnitUpperTriangular(view(randn(5, 5), 2:4, 2:4)), + ], + Any[ + LowerTriangular(randn(3, 3)), + UpperTriangular(randn(3, 3)), + UnitLowerTriangular(randn(3, 3)), + UnitUpperTriangular(randn(3, 3)), + LowerTriangular(view(randn(5, 5), 2:4, 2:4)), + UpperTriangular(view(randn(5, 5), 2:4, 2:4)), + UnitLowerTriangular(view(randn(5, 5), 2:4, 2:4)), + UnitUpperTriangular(view(randn(5, 5), 2:4, 2:4)), + ], + ), + ) do (B, C) + A = randn(3, 3) + (false, mul!, A, B, C, randn(), randn()) + end, + ), ) @info "$(typeof((f, x...)))" test_rule(StableRNG(123456), f, x...; interface_only, is_primitive=false) diff --git a/test/integration_testing/temporalgps/temporalgps.jl b/test/integration_testing/temporalgps/temporalgps.jl index b9e244b58..a9d7e9122 100644 --- a/test/integration_testing/temporalgps/temporalgps.jl +++ b/test/integration_testing/temporalgps/temporalgps.jl @@ -10,7 +10,6 @@ build_gp(k) = to_sde(GP(k), SArrayStorage(Float64)) temporalgps_logpdf_tester(k, x, y, s) = logpdf(build_gp(k)(x, s), y) @testset "temporalgps" begin - xs = Any[ collect(range(-5.0; step=0.1, length=1_000)), RegularSpacing(0.0, 0.1, 1_000), diff --git a/test/integration_testing/turing/turing.jl b/test/integration_testing/turing/turing.jl index 227150954..02d9505ae 100644 --- a/test/integration_testing/turing/turing.jl +++ b/test/integration_testing/turing/turing.jl @@ -6,7 +6,7 @@ using Distributions, DynamicPPL, Mooncake, StableRNGs, Test using Mooncake.TestUtils: test_rule @model function simple_model() - y ~ Normal() + return y ~ Normal() end @model function demo() @@ -14,16 +14,16 @@ end σ2 ~ LogNormal() # tweaked from InverseGamma due to control flow issues. σ = sqrt(σ2 + 1e-3) μ ~ Normal(0.0, σ) - + # Observations x ~ Normal(μ, σ) - y ~ Normal(μ, σ) + return y ~ Normal(μ, σ) end @model broadcast_demo(x) = begin μ ~ truncated(Normal(1, 2), 0.1, 10) σ ~ truncated(Normal(1, 2), 0.1, 10) - x .~ LogNormal(μ, σ) + x .~ LogNormal(μ, σ) end # LDA example -- copied over from @@ -31,7 +31,7 @@ end function _make_data(D, K, V, N, α, η) β = Matrix{Float64}(undef, V, K) for k in 1:K - β[:,k] .= rand(Dirichlet(η)) + β[:, k] .= rand(Dirichlet(η)) end θ = Matrix{Float64}(undef, K, D) @@ -40,7 +40,7 @@ function _make_data(D, K, V, N, α, η) doc = Vector{Int}(undef, D * N) i = 0 for d in 1:D - θ[:,d] .= rand(Dirichlet(α)) + θ[:, d] .= rand(Dirichlet(α)) for n in 1:N i += 1 z[i] = rand(Categorical(θ[:, d])) @@ -56,9 +56,7 @@ data = let D = 2, K = 2, V = 160, N = 290 end # LDA with vectorization and manual log-density accumulation -@model function LatentDirichletAllocationVectorizedCollapsedMannual( - D, K, V, α, η, w, doc -) +@model function LatentDirichletAllocationVectorizedCollapsedMannual(D, K, V, α, η, w, doc) β ~ filldist(Dirichlet(η), K) θ ~ filldist(Dirichlet(α), D) @@ -72,9 +70,9 @@ end function make_large_model() num_tildes = 50 - expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums! + expr = Base.remove_linenums!(:(function $(Symbol(:demo, num_tildes))() end)) mainbody = last(expr.args) - append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]) + append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j in 1:num_tildes]) f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false)) return invokelatest(f) end @@ -114,7 +112,7 @@ end ], Any[ (false, "demo_$n", m, DynamicPPL.TestUtils.rand_prior_true(m)) for - (n, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) + (n, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS) ], ) @info name diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index e756f2788..59a2a7792 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -1,14 +1,17 @@ a_primitive(x) = sin(x) non_primitive(x) = sin(x) -Mooncake.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(a_primitive), Any}}) = true -Mooncake.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(non_primitive), Any}}) = false +Mooncake.is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(a_primitive),Any}}) = true +function Mooncake.is_primitive( + ::Type{DefaultCtx}, ::Type{<:Tuple{typeof(non_primitive),Any}} +) + return false +end contains_primitive(x) = @inline a_primitive(x) contains_non_primitive(x) = @inline non_primitive(x) contains_primitive_behind_call(x) = @inline contains_primitive(x) - @testset "abstract_interpretation" begin # Check that inlining doesn't / does happen as expected. @testset "MooncakeInterpreter" begin @@ -17,7 +20,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) # A non-primitive is present in the IR for contains_non_primitive. It is # inlined away under usual interpretation, and should also be inlined away # when doing AD. - sig = Tuple{typeof(contains_non_primitive), Float64} + sig = Tuple{typeof(contains_non_primitive),Float64} # Pre-condition: must inline away under usual compilation. usual_ir = Base.code_ircode_by_type(sig)[1][1] @@ -34,7 +37,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) # A primitive is present in the IR for contains_primitive. It is inlined away # under usual interpretation, but should not be when doing AD. - sig = Tuple{typeof(contains_primitive), Float64} + sig = Tuple{typeof(contains_primitive),Float64} # Pre-condition: must inline away under usual compilation. usual_ir = Base.code_ircode_by_type(sig)[1][1] @@ -53,7 +56,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) # usually inlined away to reveal a primitive. This primitive is _also_ usually # inlined away, but should not be when doing AD. This case is not handled if # various bits of information are not properly propagated in the compiler. - sig = Tuple{typeof(contains_primitive_behind_call), Float64} + sig = Tuple{typeof(contains_primitive_behind_call),Float64} # Pre-condition: both functions should be inlined away under usual conditions. usual_ir = Base.code_ircode_by_type(sig)[1][1] @@ -71,4 +74,4 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @test _type(CC.Const(5.0)) === Float64 @test _type(CC.PartialTypeVar(TypeVar(:a, Union{}, Any), true, true)) === TypeVar end -end \ No newline at end of file +end diff --git a/test/interpreter/bbcode.jl b/test/interpreter/bbcode.jl index e844cc57b..492df0c3b 100644 --- a/test/interpreter/bbcode.jl +++ b/test/interpreter/bbcode.jl @@ -1,5 +1,5 @@ module BBCodeTestCases - test_phi_node(x::Ref{Union{Float32, Float64}}) = sin(x[]) +test_phi_node(x::Ref{Union{Float32,Float64}}) = sin(x[]) end @testset "bbcode" begin @@ -36,18 +36,19 @@ end # Final statment is regular instruction, so newly inserted instruction should go at # the end of the block. - @test Mooncake.insert_before_terminator!(bb, ID(), new_inst(ReturnNode(5))) === nothing + @test Mooncake.insert_before_terminator!(bb, ID(), new_inst(ReturnNode(5))) === + nothing @test bb.insts[end].stmt === ReturnNode(5) # Final statement is now a Terminator, so insertion should happen before it. @test Mooncake.insert_before_terminator!(bb, ID(), new_inst(nothing)) === nothing @test bb.insts[end].stmt === ReturnNode(5) - @test bb.insts[end-1].stmt === nothing + @test bb.insts[end - 1].stmt === nothing end @testset "BBCode $f" for (f, P) in [ (TestResources.test_while_loop, Tuple{Float64}), (sin, Tuple{Float64}), - (BBCodeTestCases.test_phi_node, Tuple{Ref{Union{Float32, Float64}}}), + (BBCodeTestCases.test_phi_node, Tuple{Ref{Union{Float32,Float64}}}), ] ir = Base.code_ircode(f, P)[1][1] bb_code = BBCode(ir) @@ -61,10 +62,10 @@ end @test all(map(==, ir.stmts.line, new_ir.stmts.line)) @test all(map(==, ir.stmts.flag, new_ir.stmts.flag)) @test length(Mooncake.collect_stmts(bb_code)) == length(stmt(ir.stmts)) - @test Mooncake.id_to_line_map(bb_code) isa Dict{ID, Int} + @test Mooncake.id_to_line_map(bb_code) isa Dict{ID,Int} end @testset "control_flow_graph" begin - ir = Base.code_ircode_by_type(Tuple{typeof(sin), Float64})[1][1] + ir = Base.code_ircode_by_type(Tuple{typeof(sin),Float64})[1][1] bb = BBCode(ir) new_ir = Core.Compiler.IRCode(bb) cfg = Mooncake.control_flow_graph(bb) @@ -150,7 +151,7 @@ end @testset "_find_id_uses!" begin @testset "Expr" begin id = ID() - d = Dict{ID, Bool}(id => false) + d = Dict{ID,Bool}(id => false) Mooncake._find_id_uses!(d, Expr(:call, sin, 5)) @test d[id] == false Mooncake._find_id_uses!(d, Expr(:call, sin, id)) @@ -158,7 +159,7 @@ end end @testset "IDGotoIfNot" begin id = ID() - d = Dict{ID, Bool}(id => false) + d = Dict{ID,Bool}(id => false) Mooncake._find_id_uses!(d, IDGotoIfNot(ID(), ID())) @test d[id] == false Mooncake._find_id_uses!(d, IDGotoIfNot(true, ID())) @@ -168,13 +169,13 @@ end end @testset "IDGotoNode" begin id = ID() - d = Dict{ID, Bool}(id => false) + d = Dict{ID,Bool}(id => false) Mooncake._find_id_uses!(d, IDGotoNode(ID())) @test d[id] == false end @testset "IDPhiNode" begin id = ID() - d = Dict{ID, Bool}(id => false) + d = Dict{ID,Bool}(id => false) Mooncake._find_id_uses!(d, IDPhiNode([ID()], Vector{Any}(undef, 1))) @test d[id] == false Mooncake._find_id_uses!(d, IDPhiNode([ID()], Any[id])) @@ -182,7 +183,7 @@ end end @testset "PiNode" begin id = ID() - d = Dict{ID, Bool}(id => false) + d = Dict{ID,Bool}(id => false) Mooncake._find_id_uses!(d, PiNode(false, Bool)) @test d[id] == false Mooncake._find_id_uses!(d, PiNode(id, Bool)) @@ -190,7 +191,7 @@ end end @testset "ReturnNode" begin id = ID() - d = Dict{ID, Bool}(id => false) + d = Dict{ID,Bool}(id => false) Mooncake._find_id_uses!(d, ReturnNode()) @test d[id] == false Mooncake._find_id_uses!(d, ReturnNode(5)) @@ -203,7 +204,7 @@ end id_1 = ID() id_2 = ID() id_3 = ID() - stmts = Tuple{ID, Core.Compiler.NewInstruction}[ + stmts = Tuple{ID,Core.Compiler.NewInstruction}[ (id_1, new_inst(Expr(:call, sin, Argument(1)))), (id_2, new_inst(Expr(:call, cos, id_1))), (id_3, new_inst(ReturnNode(id_2))), @@ -261,17 +262,17 @@ end # Get the IRCode, and ensure that the statements in it agree with what is expected. new_ir = CC.IRCode(new_bb_ir) expected_stmts = Any[ - GotoNode(2), - PhiNode(Int32[1], Any[true]), - ReturnNode(SSAValue(2)), + GotoNode(2), PhiNode(Int32[1], Any[true]), ReturnNode(SSAValue(2)) ] @test Mooncake.stmt(new_ir.stmts) == expected_stmts end @testset "inc_args" begin - @test Mooncake.inc_args(Expr(:call, sin, Argument(4))) == Expr(:call, sin, Argument(5)) + @test Mooncake.inc_args(Expr(:call, sin, Argument(4))) == + Expr(:call, sin, Argument(5)) @test Mooncake.inc_args(ReturnNode(Argument(2))) == ReturnNode(Argument(3)) id = ID() - @test Mooncake.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id) + @test Mooncake.inc_args(IDGotoIfNot(Argument(1), id)) == + IDGotoIfNot(Argument(2), id) @test Mooncake.inc_args(IDGotoNode(id)) == IDGotoNode(id) ids = [id, ID()] @test ==( diff --git a/test/interpreter/contexts.jl b/test/interpreter/contexts.jl index 5f332d425..43fab68e9 100644 --- a/test/interpreter/contexts.jl +++ b/test/interpreter/contexts.jl @@ -1,3 +1 @@ -@testset "contexts" begin - -end +@testset "contexts" begin end diff --git a/test/interpreter/ir_normalisation.jl b/test/interpreter/ir_normalisation.jl index 559d75558..f78170660 100644 --- a/test/interpreter/ir_normalisation.jl +++ b/test/interpreter/ir_normalisation.jl @@ -1,9 +1,6 @@ @testset "ir_normalisation" begin @testset "interpolate_boundschecks" begin - statements = Any[ - Expr(:boundscheck, true), - Expr(:call, sin, SSAValue(1)), - ] + statements = Any[Expr(:boundscheck, true), Expr(:call, sin, SSAValue(1))] Mooncake._interpolate_boundschecks!(statements) @test statements[2].args[2] == true end @@ -19,7 +16,7 @@ 0x0000000000000001, 0x0000000000000001, ) - sp_map = Dict{Symbol, CC.VarState}() + sp_map = Dict{Symbol,CC.VarState}() call = Mooncake.foreigncall_to_call(foreigncall, sp_map) @test Meta.isexpr(call, :call) @test call.args[1] == Mooncake._foreigncall_ @@ -89,10 +86,7 @@ Expr(:call, setfield!, SSAValue(1), QuoteNode(:a), SSAValue(3)), Expr(:call, lsetfield!, SSAValue(1), Val(:a), SSAValue(3)), ), - ( - Expr(:call, sin, SSAValue(1)), - Expr(:call, sin, SSAValue(1)), - ), + (Expr(:call, sin, SSAValue(1)), Expr(:call, sin, SSAValue(1))), ] @test Mooncake.lift_getfield_and_others(ex) == target end @@ -117,14 +111,16 @@ x = FinalizerObject() ptr = convert(Ptr{Bool}, Base.pointer_from_objref(x)) GC.gc(true) - unsafe_load(ptr) + return unsafe_load(ptr) end @test test_no_preserve() # Check that if you insert a call to `gc_preserve`, the object is not finalised. function test_preserved() x = FinalizerObject() - _, pb!! = Mooncake.rrule!!(zero_fcodual(Mooncake.gc_preserve), Mooncake.zero_fcodual(x)) + _, pb!! = Mooncake.rrule!!( + zero_fcodual(Mooncake.gc_preserve), Mooncake.zero_fcodual(x) + ) ptr = convert(Ptr{Bool}, Base.pointer_from_objref(x)) GC.gc(true) return unsafe_load(ptr), pb!! @@ -134,7 +130,9 @@ # Check that translation of expressions happens correctly. @test ==( - Mooncake.lift_gc_preservation(Expr(:gc_preserve_begin, Argument(1), SSAValue(2))), + Mooncake.lift_gc_preservation( + Expr(:gc_preserve_begin, Argument(1), SSAValue(2)) + ), Expr(:call, Mooncake.gc_preserve, Argument(1), SSAValue(2)), ) @test Mooncake.lift_gc_preservation(Expr(:gc_preserve_end, SSAValue(2))) === nothing diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index ea420ce8d..5bbc382eb 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -1,18 +1,16 @@ module IRUtilsGlobalRefs - __x_1 = 5.0 - const __x_2 = 5.0 - __x_3::Float64 = 5.0 - const __x_4::Float64 = 5.0 +__x_1 = 5.0 +const __x_2 = 5.0 +__x_3::Float64 = 5.0 +const __x_4::Float64 = 5.0 end @testset "ir_utils" begin - @testset "ircode $(typeof(fargs))" for fargs in Any[ - (sin, 5.0), (cos, 1.0), - ] + @testset "ircode $(typeof(fargs))" for fargs in Any[(sin, 5.0), (cos, 1.0)] # Construct a vector of instructions from known function. f, args... = fargs insts = only(code_typed(f, _typeof(args)))[1].code - + # Use Mooncake.ircode to build an `IRCode`. argtypes = Any[map(_typeof, fargs)...] ir = Mooncake.ircode(insts, argtypes) @@ -43,7 +41,7 @@ end @test Core.OpaqueClosure(ir)(5.0) == cos(sin(5.0)) end @testset "lookup_ir" begin - tt = Tuple{typeof(sin), Float64} + tt = Tuple{typeof(sin),Float64} @test isa( Mooncake.lookup_ir(CC.NativeInterpreter(), tt; optimize_until=nothing)[1], CC.IRCode, diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 40b8be7d5..c3514d6dc 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -1,23 +1,23 @@ module S2SGlobals - using LinearAlgebra, Mooncake +using LinearAlgebra, Mooncake - non_const_global = 5.0 - const const_float = 5.0 - const const_int = 5 - const const_bool = true +non_const_global = 5.0 +const const_float = 5.0 +const const_int = 5 +const const_bool = true - # used for regression test for issue 184 - struct A - data - end - f(a, x) = dot(a.data, x) +# used for regression test for issue 184 +struct A + data +end +f(a, x) = dot(a.data, x) - # Test cases designed to cause `LazyDerivedRule` to throw an error when attempting to - # construct a rule for `bar`. - foo(x) = x - @noinline bar(x) = foo(x) - baz(x) = bar(x) - Mooncake.@is_primitive Mooncake.MinimalCtx Tuple{typeof(foo), Any} +# Test cases designed to cause `LazyDerivedRule` to throw an error when attempting to +# construct a rule for `bar`. +foo(x) = x +@noinline bar(x) = foo(x) +baz(x) = bar(x) +Mooncake.@is_primitive Mooncake.MinimalCtx Tuple{typeof(foo),Any} end @testset "s2s_reverse_mode_ad" begin @@ -29,16 +29,18 @@ end @test m.pairs[1][2] == 5.0 end @testset "ADInfo" begin - arg_types = Dict{Argument, Any}(Argument(1) => Float64, Argument(2) => Int) + arg_types = Dict{Argument,Any}(Argument(1) => Float64, Argument(2) => Int) id_ssa_1 = ID() id_ssa_2 = ID() - ssa_insts = Dict{ID, CC.NewInstruction}( + ssa_insts = Dict{ID,CC.NewInstruction}( id_ssa_1 => CC.NewInstruction(nothing, Float64), id_ssa_2 => CC.NewInstruction(nothing, Any), ) - is_used_dict = Dict{ID, Bool}(id_ssa_1 => true, id_ssa_2 => true) + is_used_dict = Dict{ID,Bool}(id_ssa_1 => true, id_ssa_2 => true) rdata_ref = Ref{Tuple{map(Mooncake.lazy_zero_rdata_type, (Float64, Int))...}}() - info = ADInfo(get_interpreter(), arg_types, ssa_insts, is_used_dict, false, rdata_ref) + info = ADInfo( + get_interpreter(), arg_types, ssa_insts, is_used_dict, false, rdata_ref + ) # Verify that we can access the interpreter and terminator block ID. @test info.interp isa Mooncake.MooncakeInterpreter @@ -70,12 +72,12 @@ end id_line_2 = ID() info = ADInfo( get_interpreter(), - Dict{Argument, Any}(Argument(1) => typeof(sin), Argument(2) => Float64), - Dict{ID, CC.NewInstruction}( + Dict{Argument,Any}(Argument(1) => typeof(sin), Argument(2) => Float64), + Dict{ID,CC.NewInstruction}( id_line_1 => new_inst(Expr(:invoke, nothing, cos, Argument(2)), Float64), id_line_2 => new_inst(nothing, Any), ), - Dict{ID, Bool}(id_line_1=>true, id_line_2=>true), + Dict{ID,Bool}(id_line_1 => true, id_line_2 => true), false, Ref{Tuple{map(Mooncake.lazy_zero_rdata_type, (typeof(sin), Float64))...}}(), ) @@ -153,7 +155,7 @@ end @test last(stmt_info.fwds)[1] == id_line_1 fwds_stmt = last(stmt_info.fwds)[2].stmt @test fwds_stmt isa PiNode - @test fwds_stmt.typ == CoDual{Nothing, NoFData} + @test fwds_stmt.typ == CoDual{Nothing,NoFData} @test only(stmt_info.rvs)[2].stmt === nothing end @testset "π (nothing, CC.Const(nothing))" begin @@ -163,7 +165,7 @@ end @test last(stmt_info.fwds)[1] == id_line_1 fwds_stmt = last(stmt_info.fwds)[2].stmt @test fwds_stmt isa PiNode - @test fwds_stmt.typ == CoDual{Nothing, NoFData} + @test fwds_stmt.typ == CoDual{Nothing,NoFData} @test only(stmt_info.rvs)[2].stmt === nothing end @testset "π (GlobalRef, Type)" begin @@ -233,9 +235,7 @@ end ad_stmt_info(line, nothing, fwds, nothing), ) end - @testset "$stmt" for stmt in [ - Expr(:gc_preserve_begin), - ] + @testset "$stmt" for stmt in [Expr(:gc_preserve_begin)] line = ID() @test TestUtils.has_equal_data( make_ad_stmts!(stmt, line, info), @@ -244,12 +244,11 @@ end end end end - @testset "rule_type $sig, $debug_mode" for - sig in Any[ - Tuple{typeof(getfield), Tuple{Float64}, 1}, - Tuple{typeof(TestResources.foo), Float64}, - Tuple{typeof(TestResources.type_unstable_tester_0), Ref{Any}}, - Tuple{typeof(TestResources.tuple_with_union), Bool}, + @testset "rule_type $sig, $debug_mode" for sig in Any[ + Tuple{typeof(getfield),Tuple{Float64},1}, + Tuple{typeof(TestResources.foo),Float64}, + Tuple{typeof(TestResources.type_unstable_tester_0),Ref{Any}}, + Tuple{typeof(TestResources.tuple_with_union),Bool}, ], debug_mode in [true, false] @@ -270,16 +269,22 @@ end @test_throws(Mooncake.MooncakeRuleCompilationError, Mooncake.build_rrule(sin)) end @testset "$(_typeof((f, x...)))" for (n, (interface_only, perf_flag, bnds, f, x...)) in - collect(enumerate(TestResources.generate_test_functions())) - + collect( + enumerate(TestResources.generate_test_functions()) + ) sig = _typeof((f, x...)) @info "$n: $sig" TestUtils.test_rule( Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false ) TestUtils.test_rule( - Xoshiro(123456), f, x...; - perf_flag=:none, interface_only, is_primitive=false, debug_mode=true, + Xoshiro(123456), + f, + x...; + perf_flag=:none, + interface_only, + is_primitive=false, + debug_mode=true, ) # interp = Mooncake.get_interpreter() @@ -307,7 +312,7 @@ end @test_throws( Mooncake.Mooncake.MooncakeRuleCompilationError, Mooncake.build_rrule( - Tuple{typeof(Mooncake.TestResources.non_const_global_ref), Float64}, + Tuple{typeof(Mooncake.TestResources.non_const_global_ref),Float64} ) ) end @@ -317,8 +322,12 @@ end # 184 TestUtils.test_rule( - Xoshiro(123456), S2SGlobals.f, S2SGlobals.A(2 * ones(3)), ones(3); - interface_only=false, is_primitive=false, + Xoshiro(123456), + S2SGlobals.f, + S2SGlobals.A(2 * ones(3)), + ones(3); + interface_only=false, + is_primitive=false, ) # BenchmarkTools not working due to world age problems. Provided that this code diff --git a/test/rrules/foreigncall.jl b/test/rrules/foreigncall.jl index 62196d90c..73ad7e499 100644 --- a/test/rrules/foreigncall.jl +++ b/test/rrules/foreigncall.jl @@ -2,11 +2,28 @@ TestUtils.run_rrule!!_test_cases(StableRNG, Val(:foreigncall)) @testset "foreigncalls that should never be hit: $name" for name in [ - :jl_alloc_array_1d, :jl_alloc_array_2d, :jl_alloc_array_3d, :jl_new_array, - :jl_array_copy, :jl_type_intersection, :memset, :jl_get_tls_world_age, :memmove, - :jl_object_id, :jl_array_sizehint, :jl_array_grow_beg, :jl_array_grow_end, - :jl_array_grow_at, :jl_array_del_beg, :jl_array_del_end, :jl_array_del_at, - :jl_value_ptr, :jl_threadid, :memhash_seed, :memhash32_seed, :jl_get_field_offset, + :jl_alloc_array_1d, + :jl_alloc_array_2d, + :jl_alloc_array_3d, + :jl_new_array, + :jl_array_copy, + :jl_type_intersection, + :memset, + :jl_get_tls_world_age, + :memmove, + :jl_object_id, + :jl_array_sizehint, + :jl_array_grow_beg, + :jl_array_grow_end, + :jl_array_grow_at, + :jl_array_del_beg, + :jl_array_del_end, + :jl_array_del_at, + :jl_value_ptr, + :jl_threadid, + :memhash_seed, + :memhash32_seed, + :jl_get_field_offset, ] @test_throws( ErrorException, diff --git a/test/rrules/function_wrappers.jl b/test/rrules/function_wrappers.jl index 8289a649b..3946e706d 100644 --- a/test/rrules/function_wrappers.jl +++ b/test/rrules/function_wrappers.jl @@ -2,8 +2,8 @@ rng = Xoshiro(123) _data = Ref{Float64}(5.0) @testset "$p" for p in Any[ - FunctionWrapper{Float64, Tuple{Float64}}(sin), - FunctionWrapper{Float64, Tuple{Float64}}(x -> x * _data[]), + FunctionWrapper{Float64,Tuple{Float64}}(sin), + FunctionWrapper{Float64,Tuple{Float64}}(x -> x * _data[]), ] TestUtils.test_tangent_consistency(rng, p) TestUtils.test_fwds_rvs_data(rng, p) diff --git a/test/rrules/low_level_maths.jl b/test/rrules/low_level_maths.jl index ce6834a41..1b5a13f95 100644 --- a/test/rrules/low_level_maths.jl +++ b/test/rrules/low_level_maths.jl @@ -5,16 +5,16 @@ # because they are very shallow wrappers around lower-level primitives for which we # already have rules. @testset "$T, $C" for T in [Float16, Float32, Float64], C in [DefaultCtx, MinimalCtx] - @test !is_primitive(C, Tuple{typeof(+), T}) - @test !is_primitive(C, Tuple{typeof(-), T}) - @test !is_primitive(C, Tuple{typeof(abs2), T}) - @test !is_primitive(C, Tuple{typeof(inv), T}) - @test !is_primitive(C, Tuple{typeof(abs), T}) + @test !is_primitive(C, Tuple{typeof(+),T}) + @test !is_primitive(C, Tuple{typeof(-),T}) + @test !is_primitive(C, Tuple{typeof(abs2),T}) + @test !is_primitive(C, Tuple{typeof(inv),T}) + @test !is_primitive(C, Tuple{typeof(abs),T}) - @test !is_primitive(C, Tuple{typeof(+), T, T}) - @test !is_primitive(C, Tuple{typeof(-), T, T}) - @test !is_primitive(C, Tuple{typeof(*), T, T}) - @test !is_primitive(C, Tuple{typeof(/), T, T}) - @test !is_primitive(C, Tuple{typeof(\), T, T}) + @test !is_primitive(C, Tuple{typeof(+),T,T}) + @test !is_primitive(C, Tuple{typeof(-),T,T}) + @test !is_primitive(C, Tuple{typeof(*),T,T}) + @test !is_primitive(C, Tuple{typeof(/),T,T}) + @test !is_primitive(C, Tuple{typeof(\),T,T}) end -end +end diff --git a/test/rrules/misc.jl b/test/rrules/misc.jl index 1257b389f..8f2c4a88a 100644 --- a/test/rrules/misc.jl +++ b/test/rrules/misc.jl @@ -1,5 +1,4 @@ @testset "misc" begin - @testset "misc utility" begin x = randn(4, 5) p = Base.unsafe_convert(Ptr{Float64}, x) diff --git a/test/runtests.jl b/test/runtests.jl index 7a2bd76d5..f28eadd7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,7 +29,7 @@ include("front_matter.jl") include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) elseif test_group == "rrules/blas" include(joinpath("rrules", "blas.jl")) - elseif test_group == "rrules/builtins" + elseif test_group == "rrules/builtins" include(joinpath("rrules", "builtins.jl")) elseif test_group == "rrules/fastmath" include(joinpath("rrules", "fastmath.jl")) diff --git a/test/tangents.jl b/test/tangents.jl index c097cf771..8803106dd 100644 --- a/test/tangents.jl +++ b/test/tangents.jl @@ -1,71 +1,73 @@ @testset "tangents" begin - @testset "tangent_type($primal_type)" for (primal_type, expected_tangent_type) in Any[ ## Tuples # Unions of Tuples. - (Union{Tuple{Float64}, Tuple{Float32}}, Union{Tuple{Float64}, Tuple{Float32}}), - (Union{Tuple{Float64}, Tuple{Int}}, Union{Tuple{Float64}, NoTangent}), - (Union{Tuple{}, Tuple{Int}}, NoTangent), + (Union{Tuple{Float64},Tuple{Float32}}, Union{Tuple{Float64},Tuple{Float32}}), + (Union{Tuple{Float64},Tuple{Int}}, Union{Tuple{Float64},NoTangent}), + (Union{Tuple{},Tuple{Int}}, NoTangent), ( - Union{Tuple{Float64}, Tuple{Int}, Tuple{Float64, Int}}, - Union{Tuple{Float64}, NoTangent, Tuple{Float64, NoTangent}}, + Union{Tuple{Float64},Tuple{Int},Tuple{Float64,Int}}, + Union{Tuple{Float64},NoTangent,Tuple{Float64,NoTangent}}, ), - (Union{Tuple{Float64}, Tuple{Any}}, Union{NoTangent, Tuple{Any}}), + (Union{Tuple{Float64},Tuple{Any}}, Union{NoTangent,Tuple{Any}}), # UnionAlls of Tuples. - (Tuple{T} where {T}, Union{NoTangent, Tuple{Any}}), - (Tuple{T, T} where {T<:Real}, Union{NoTangent, Tuple{Any, Any}}), - (Tuple{Float64, T} where {T<:Int}, Tuple{Float64, NoTangent}), - (Union{Tuple{T}, Tuple{T, T}} where {T<:Real}, Any), + (Tuple{T} where {T}, Union{NoTangent,Tuple{Any}}), + (Tuple{T,T} where {T<:Real}, Union{NoTangent,Tuple{Any,Any}}), + (Tuple{Float64,T} where {T<:Int}, Tuple{Float64,NoTangent}), + (Union{Tuple{T},Tuple{T,T}} where {T<:Real}, Any), # Edge case: (provably) empty Tuple. (Tuple{}, NoTangent), # Vararg Tuples (Tuple, Any), - (Tuple{Float64, Vararg}, Any), - (Tuple{Float64, Vararg{Int}}, Any), + (Tuple{Float64,Vararg}, Any), + (Tuple{Float64,Vararg{Int}}, Any), (Tuple{Vararg{Int}}, Any), - (Tuple{Int, Vararg{Int}}, Any), + (Tuple{Int,Vararg{Int}}, Any), # Simple Tuples. (Tuple{Int}, NoTangent), - (Tuple{Vararg{Int, 250}}, NoTangent), - (Tuple{Int, Int}, NoTangent), - (Tuple{DataType, Int}, NoTangent), - (Tuple{DataType, Vararg{Int, 100}}, NoTangent), - (Tuple{DataType, Type{Float64}}, NoTangent), - (Tuple{DataType, Vararg{Type{Float64}, 100}}, NoTangent), - (Tuple{Any}, Union{NoTangent, Tuple{Any}}), - (Tuple{Any, Any}, Union{NoTangent, Tuple{Any, Any}}), - (Tuple{Int, Any}, Union{NoTangent, Tuple{NoTangent, Any}}), - (Tuple{Int, Float64}, Tuple{NoTangent, Float64}), - (Tuple{Int, Vararg{Float64, 100}}, Tuple{NoTangent, Vararg{Float64, 100}}), - (Tuple{Type{Float64}, Float64}, Tuple{NoTangent, Float64}), - (Tuple{DataType, Vararg{Float32, 100}}, Tuple{NoTangent, Vararg{Float32, 100}}), - (Tuple{Tuple{Type{Int}}, Float64}, Tuple{NoTangent, Float64}), + (Tuple{Vararg{Int,250}}, NoTangent), + (Tuple{Int,Int}, NoTangent), + (Tuple{DataType,Int}, NoTangent), + (Tuple{DataType,Vararg{Int,100}}, NoTangent), + (Tuple{DataType,Type{Float64}}, NoTangent), + (Tuple{DataType,Vararg{Type{Float64},100}}, NoTangent), + (Tuple{Any}, Union{NoTangent,Tuple{Any}}), + (Tuple{Any,Any}, Union{NoTangent,Tuple{Any,Any}}), + (Tuple{Int,Any}, Union{NoTangent,Tuple{NoTangent,Any}}), + (Tuple{Int,Float64}, Tuple{NoTangent,Float64}), + (Tuple{Int,Vararg{Float64,100}}, Tuple{NoTangent,Vararg{Float64,100}}), + (Tuple{Type{Float64},Float64}, Tuple{NoTangent,Float64}), + (Tuple{DataType,Vararg{Float32,100}}, Tuple{NoTangent,Vararg{Float32,100}}), + (Tuple{Tuple{Type{Int}},Float64}, Tuple{NoTangent,Float64}), ## NamedTuple # Unions of NamedTuples. ( - Union{@NamedTuple{a::Float64}, @NamedTuple{b::Float64}}, - Union{@NamedTuple{a::Float64}, @NamedTuple{b::Float64}}, + Union{@NamedTuple{a::Float64},@NamedTuple{b::Float64}}, + Union{@NamedTuple{a::Float64},@NamedTuple{b::Float64}}, ), ( - Union{@NamedTuple{a::Float64}, @NamedTuple{}}, - Union{@NamedTuple{a::Float64}, NoTangent}, + Union{@NamedTuple{a::Float64},@NamedTuple{}}, + Union{@NamedTuple{a::Float64},NoTangent}, ), - (Union{@NamedTuple{a::Float64}, @NamedTuple{a::Any}}, Any), + (Union{@NamedTuple{a::Float64},@NamedTuple{a::Any}}, Any), # UnionAlls of NamedTuples. - (@NamedTuple{a::T} where {T}, Union{NoTangent, NamedTuple{(:a,)}}), - (@NamedTuple{a::T, b::T} where {T<:Real}, Union{NoTangent, NamedTuple{(:a, :b)}}), - (@NamedTuple{a::Float64, b::T} where {T<:Int}, Union{NoTangent, NamedTuple{(:a, :b)}}), - (Union{@NamedTuple{a::T}, @NamedTuple{b::T, c::T}} where {T<:Any}, Any), - (Union{@NamedTuple{T, Float64}, @NamedTuple{T, Float64, Int}} where {T}, Any), + (@NamedTuple{a::T} where {T}, Union{NoTangent,NamedTuple{(:a,)}}), + (@NamedTuple{a::T, b::T} where {T<:Real}, Union{NoTangent,NamedTuple{(:a, :b)}}), + ( + @NamedTuple{a::Float64, b::T} where {T<:Int}, + Union{NoTangent,NamedTuple{(:a, :b)}}, + ), + (Union{@NamedTuple{a::T},@NamedTuple{b::T, c::T}} where {T<:Any}, Any), + (Union{@NamedTuple{T, Float64},@NamedTuple{T, Float64, Int}} where {T}, Any), # Edge case (@NamedTuple{}, NoTangent), @@ -80,12 +82,16 @@ (@NamedTuple{a::Int, b::Any}, Any), (@NamedTuple{b::Int, a::Float64}, @NamedTuple{b::NoTangent, a::Float64}), (@NamedTuple{a::Type{Float64}, b::Float64}, @NamedTuple{a::NoTangent, b::Float64}), - (@NamedTuple{a::Tuple{Type{Int}}, b::Float64}, @NamedTuple{a::NoTangent, b::Float64}), + ( + @NamedTuple{a::Tuple{Type{Int}}, b::Float64}, + @NamedTuple{a::NoTangent, b::Float64} + ), ] TestUtils.test_tangent_type(primal_type, expected_tangent_type) end - @testset "$(typeof(data))" for (interface_only, data...) in Mooncake.tangent_test_cases() + @testset "$(typeof(data))" for (interface_only, data...) in + Mooncake.tangent_test_cases() test_tangent(Xoshiro(123456), data...; interface_only) end @@ -166,7 +172,7 @@ end @testset "restricted inner constructor" begin p = TestResources.NoDefaultCtor(5.0) - t = Mooncake.Tangent((x=5.0, )) + t = Mooncake.Tangent((x=5.0,)) @test_throws Mooncake.AddToPrimalException Mooncake._add_to_primal(p, t) @test Mooncake._add_to_primal(p, t, true) isa typeof(p) end @@ -186,7 +192,9 @@ end bar = Mooncake.TestResources.TypeUnstableStruct(5.0, 1.0) @test ==( Mooncake.zero_tangent(bar), - Tangent{@NamedTuple{a::Float64, b::PossiblyUninitTangent{Any}}}((a=0.0, b=PossiblyUninitTangent{Any}(0.0))) + Tangent{@NamedTuple{a::Float64, b::PossiblyUninitTangent{Any}}}(( + a=0.0, b=PossiblyUninitTangent{Any}(0.0) + )), ) end @@ -269,7 +277,6 @@ end # return filter(!isabstracttype, types_in(m)) # end - # # Primitives are required to explicitly declare a method of `zero_tangent` which applies # # to them. They must not hit the generic fallback. This function checks that there are no # # primitives within the specified module which don't hit a generic fallback. @@ -277,7 +284,6 @@ end # return filter(t -> isprimitivetype(t), types_in(m)) # end - # # A toy type on which to test tangent stuff in a variety of situations. # struct Foo{T, V} # x::T diff --git a/test/test_utils.jl b/test/test_utils.jl index b0e4c54d7..688a3dc96 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -22,18 +22,44 @@ @test has_equal_data(Diagonal(ones(5)), Diagonal(ones(5))) @test has_equal_data("hello", "hello") @test !has_equal_data("hello", "goodbye") - @test has_equal_data(TypeUnstableMutableStruct(4.0, 5), TypeUnstableMutableStruct(4.0, 5)) - @test !has_equal_data(TypeUnstableMutableStruct(4.0, 5), TypeUnstableMutableStruct(4.0, 6)) + @test has_equal_data( + TypeUnstableMutableStruct(4.0, 5), TypeUnstableMutableStruct(4.0, 5) + ) + @test !has_equal_data( + TypeUnstableMutableStruct(4.0, 5), TypeUnstableMutableStruct(4.0, 6) + ) @test has_equal_data(TypeUnstableStruct(4.0, 5), TypeUnstableStruct(4.0, 5)) @test !has_equal_data(TypeUnstableStruct(0.0), TypeUnstableStruct(4.0)) - @test has_equal_data(make_circular_reference_struct(), make_circular_reference_struct()) - @test has_equal_data(make_indirect_circular_reference_struct(), make_indirect_circular_reference_struct()) - @test has_equal_data(make_circular_reference_array(), make_circular_reference_array()) - @test has_equal_data(make_indirect_circular_reference_array(), make_indirect_circular_reference_array()) - @test !has_equal_data((s = make_circular_reference_struct(); s.a = 1.0; s), (t = make_circular_reference_struct(); t.a = 2.0; t)) - @test !has_equal_data((a = make_indirect_circular_reference_array(); a[1][1] = 1.0; a), (b = make_indirect_circular_reference_array(); b[1][1] = 2.0; b)) - @test !has_equal_data((s = make_indirect_circular_reference_struct(); s.b.a = 1.0; s), (t = make_indirect_circular_reference_struct(); t.b.a = 2.0; t)) - @test !has_equal_data((a = make_indirect_circular_reference_array(); a[1][1] = 1.0; a), (b = make_indirect_circular_reference_array(); b[1][1] = 2.0; b)) + @test has_equal_data( + make_circular_reference_struct(), make_circular_reference_struct() + ) + @test has_equal_data( + make_indirect_circular_reference_struct(), + make_indirect_circular_reference_struct(), + ) + @test has_equal_data( + make_circular_reference_array(), make_circular_reference_array() + ) + @test has_equal_data( + make_indirect_circular_reference_array(), + make_indirect_circular_reference_array(), + ) + @test !has_equal_data( + (s = make_circular_reference_struct(); s.a = 1.0; s), + (t = make_circular_reference_struct(); t.a = 2.0; t), + ) + @test !has_equal_data( + (a = make_indirect_circular_reference_array(); a[1][1] = 1.0; a), + (b = make_indirect_circular_reference_array(); b[1][1] = 2.0; b), + ) + @test !has_equal_data( + (s = make_indirect_circular_reference_struct(); s.b.a = 1.0; s), + (t = make_indirect_circular_reference_struct(); t.b.a = 2.0; t), + ) + @test !has_equal_data( + (a = make_indirect_circular_reference_array(); a[1][1] = 1.0; a), + (b = make_indirect_circular_reference_array(); b[1][1] = 2.0; b), + ) end @testset "populate_address_map" begin @testset "primitive types" begin @@ -56,7 +82,9 @@ p2 = [p, p] t2 = [t, t] @test length(populate_address_map(p2, t2)) == 4 - @test_throws AssertionError populate_address_map(p2, [zero_tangent(p), zero_tangent(p)]) + @test_throws AssertionError populate_address_map( + p2, [zero_tangent(p), zero_tangent(p)] + ) end @testset "immutable type" begin p = TestResources.StructFoo(5.0, randn(2)) @@ -75,11 +103,15 @@ @test length(populate_address_map(p, t)) == 1 p2 = (p, p) - @test_throws AssertionError populate_address_map(p2, (zero_tangent(p), zero_tangent(p))) + @test_throws AssertionError populate_address_map( + p2, (zero_tangent(p), zero_tangent(p)) + ) p = TestResources.MutableFoo(5.0, randn(2)) p2 = (p, p) - @test_throws AssertionError populate_address_map(p2, (zero_tangent(p), zero_tangent(p))) + @test_throws AssertionError populate_address_map( + p2, (zero_tangent(p), zero_tangent(p)) + ) end @testset "views" begin p = view(randn(5, 4), 1:2, 1:3) diff --git a/test/tools_for_rules.jl b/test/tools_for_rules.jl index 9c18e1a1f..2b7a2af9b 100644 --- a/test/tools_for_rules.jl +++ b/test/tools_for_rules.jl @@ -8,10 +8,10 @@ overlay_tester(x) = 2x @mooncake_overlay overlay_tester(x) = 3x zero_tester(x) = 0 -@zero_adjoint MinimalCtx Tuple{typeof(zero_tester), Float64} +@zero_adjoint MinimalCtx Tuple{typeof(zero_tester),Float64} vararg_zero_tester(x...) = 0 -@zero_adjoint MinimalCtx Tuple{typeof(vararg_zero_tester), Vararg} +@zero_adjoint MinimalCtx Tuple{typeof(vararg_zero_tester),Vararg} # Test case with isbits data. @@ -21,7 +21,7 @@ function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) end -@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} false +@from_rrule DefaultCtx Tuple{typeof(bleh),Float64,Int} false # Test case with heap-allocated input. @@ -32,7 +32,7 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} false +@from_rrule DefaultCtx Tuple{typeof(test_sum),Array{<:Base.IEEEFloat}} false # Test case with heap-allocated output. @@ -46,7 +46,7 @@ function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{< end @from_rrule( - DefaultCtx, Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}}, false + DefaultCtx, Tuple{typeof(test_scale),Base.IEEEFloat,Vector{<:Base.IEEEFloat}}, false ) # Test case with non-differentiable type as output. @@ -71,7 +71,7 @@ function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) return 5x, test_bad_rdata_pb end -@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} false +@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata),Float64} false # Test case for rule with diagonal dispatch. test_add(x, y) = x + y @@ -79,7 +79,7 @@ function ChainRulesCore.rrule(::typeof(test_add), x, y) test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout return x + y, test_add_pb end -@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} false +@from_rrule DefaultCtx Tuple{typeof(test_add),T,T} where {T<:IEEEFloat} false # Test case for rule with non-differentiable kwargs. test_kwargs(x; y::Bool=false) = y ? x : 2x @@ -89,7 +89,7 @@ function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool=false) return y ? x : 2x, test_kwargs_pb end -@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs), Float64}, true) +@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs),Float64}, true) # Test case for rule with differentiable types used in a non-differentiable way. test_kwargs_conditional(x; y::Float64=1.0) = y > 0 ? x : 2x @@ -99,27 +99,34 @@ function ChainRulesCore.rrule(::typeof(test_kwargs_conditional), x::Float64; y:: return y > 0 ? x : 2x, test_kwargs_cond_pb end -@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs_conditional), Float64}, true) +@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs_conditional),Float64}, true) end @testset "tools_for_rules" begin @testset "mooncake_overlay" begin f = ToolsForRulesResources.overlay_tester - rule = Mooncake.build_rrule(Tuple{typeof(f), Float64}) + rule = Mooncake.build_rrule(Tuple{typeof(f),Float64}) @test value_and_gradient!!(rule, f, 5.0) == (15.0, (NoTangent(), 3.0)) end @testset "zero_adjoint" begin f_zero = ToolsForRulesResources test_rule( - sr(123), ToolsForRulesResources.zero_tester, 5.0; - is_primitive=true, perf_flag=:stability_and_allocs, + sr(123), + ToolsForRulesResources.zero_tester, + 5.0; + is_primitive=true, + perf_flag=:stability_and_allocs, ) test_rule( - sr(123), ToolsForRulesResources.vararg_zero_tester, 5.0, 4.0; - is_primitive=true, perf_flag=:stability_and_allocs, + sr(123), + ToolsForRulesResources.vararg_zero_tester, + 5.0, + 4.0; + is_primitive=true, + perf_flag=:stability_and_allocs, ) - end + end @testset "chain_rules_macro" begin @testset "to_cr_tangent" for (t, t_cr) in Any[ (5.0, 5.0), @@ -143,11 +150,11 @@ end (ToolsForRulesResources.test_sum, ones(5)), (ToolsForRulesResources.test_scale, 5.0, randn(3)), (ToolsForRulesResources.test_nothing,), - (Core.kwcall, (y=true, ), ToolsForRulesResources.test_kwargs, 5.0), - (Core.kwcall, (y=false, ), ToolsForRulesResources.test_kwargs, 5.0), + (Core.kwcall, (y=true,), ToolsForRulesResources.test_kwargs, 5.0), + (Core.kwcall, (y=false,), ToolsForRulesResources.test_kwargs, 5.0), (ToolsForRulesResources.test_kwargs, 5.0), - (Core.kwcall, (y=-1.0, ), ToolsForRulesResources.test_kwargs_conditional, 5.0), - (Core.kwcall, (y=1.0, ), ToolsForRulesResources.test_kwargs_conditional, 5.0), + (Core.kwcall, (y=-1.0,), ToolsForRulesResources.test_kwargs_conditional, 5.0), + (Core.kwcall, (y=1.0,), ToolsForRulesResources.test_kwargs_conditional, 5.0), (ToolsForRulesResources.test_kwargs_conditional, 5.0), ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) diff --git a/test/utils.jl b/test/utils.jl index 835386a39..7305d8e0b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -5,7 +5,7 @@ @test _typeof(Float64) == Type{Float64} @test _typeof(Vector{Int}) == Type{Vector{Int}} @test _typeof(Vector{T} where {T}) == Type{Vector} - @test _typeof((5.0, Float64)) == Tuple{Float64, Type{Float64}} + @test _typeof((5.0, Float64)) == Tuple{Float64,Type{Float64}} @test _typeof((a=5.0, b=Float64)) == @NamedTuple{a::Float64, b::Type{Float64}} end @testset "tuple_map" begin @@ -22,8 +22,8 @@ ) # Require that length of arguments are equal. - @test_throws ArgumentError Mooncake.tuple_map(*, (5.0, 4.0), (4.0, )) - @test_throws ArgumentError Mooncake.tuple_map(*, (4.0, ), (5.0, 4.0)) + @test_throws ArgumentError Mooncake.tuple_map(*, (5.0, 4.0), (4.0,)) + @test_throws ArgumentError Mooncake.tuple_map(*, (4.0,), (5.0, 4.0)) end @testset "_map_if_assigned!" begin @testset "unary bits type" begin @@ -74,7 +74,7 @@ y = randn(10) @test Mooncake._map(*, x, y) == map(*, x, y) @assert length(map(*, x, randn(11))) == 10 - @test_throws AssertionError Mooncake._map(*, x, randn(11)) + @test_throws AssertionError Mooncake._map(*, x, randn(11)) end @testset "is_always_initialised" begin @test Mooncake.is_always_initialised(TestResources.StructFoo, 1)