Skip to content

Commit

Permalink
Apply formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Nov 24, 2024
1 parent a47709d commit e89a038
Show file tree
Hide file tree
Showing 77 changed files with 3,100 additions and 2,237 deletions.
65 changes: 38 additions & 27 deletions bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using Pkg
Pkg.develop(path=joinpath(@__DIR__, ".."))
Pkg.develop(; path=joinpath(@__DIR__, ".."))

using
AbstractGPs,
using AbstractGPs,
Chairmarks,
CSV,
DataFrames,
Expand All @@ -28,13 +27,13 @@ using Mooncake:

using Mooncake.TestUtils: _deepcopy

function to_benchmark(__rrule!!::R, dx::Vararg{CoDual, N}) where {R, N}
function to_benchmark(__rrule!!::R, dx::Vararg{CoDual,N}) where {R,N}
dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx)
out, pb!! = __rrule!!(dx_f...)
return pb!!(Mooncake.zero_rdata(primal(out)))
end

function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N}
function zygote_to_benchmark(ctx, x::Vararg{Any,N}) where {N}
out, pb = Zygote._pullback(ctx, x...)
return pb(out)
end
Expand Down Expand Up @@ -107,7 +106,7 @@ end
@model broadcast_demo(x) = begin
μ ~ truncated(Normal(1, 2), 0.1, 10)
σ ~ truncated(Normal(1, 2), 0.1, 10)
x .~ LogNormal(μ, σ)
x .~ LogNormal(μ, σ)
end

function build_turing_problem()
Expand All @@ -122,17 +121,21 @@ function build_turing_problem()
return test_function, randn(rng, d)
end

run_turing_problem(f::F, x::X) where {F, X} = f(x)
run_turing_problem(f::F, x::X) where {F,X} = f(x)

should_run_benchmark(
function should_run_benchmark(
::Val{:zygote}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x...
) = false
should_run_benchmark(
)
return false
end
function should_run_benchmark(
::Val{:enzyme}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x...
) = false
)
return false
end
should_run_benchmark(::Val{:enzyme}, x...) = false

@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N-1)) : x
@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N - 1)) : x

large_single_block(x::AbstractVector{<:Real}) = g(x[1], x[2], Val(400))

Expand Down Expand Up @@ -168,14 +171,12 @@ function generate_inter_framework_tests()
end

function benchmark_rules!!(test_case_data, default_ratios, include_other_frameworks::Bool)

test_cases = reduce(vcat, map(first, test_case_data))
memory = map(x -> x[2], test_case_data)
ranges = reduce(vcat, map(x -> x[3], test_case_data))
tags = reduce(vcat, map(x -> x[4], test_case_data))
GC.@preserve memory begin
return map(enumerate(test_cases)) do (n, args)

@info "$n / $(length(test_cases))", _typeof(args)
suite = Dict()

Expand All @@ -186,7 +187,7 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
() -> primals,
primals -> (primals[1], _deepcopy(primals[2:end])),
(a -> a[1]((a[2]...))),
_ -> true,
_ -> true;
evals=1,
)

Expand All @@ -199,17 +200,19 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
() -> (rule, coduals),
identity,
a -> to_benchmark(a[1], a[2]...),
_ -> true,
_ -> true;
evals=1,
)

if include_other_frameworks

if should_run_benchmark(Val(:zygote), args...)
@info "Zygote"
suite["zygote"] = @be(
_, _, zygote_to_benchmark($(Zygote.Context()), $primals...), _,
evals=1,
_,
_,
zygote_to_benchmark($(Zygote.Context()), $primals...),
_,
evals = 1,
)
end

Expand All @@ -219,21 +222,27 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
compiled_tape = ReverseDiff.compile(tape)
result = map(x -> randn(size(x)), primals[2:end])
suite["rd"] = @be(
_, _, rd_to_benchmark!($result, $compiled_tape, $primals[2:end]), _,
evals=1,
_,
_,
rd_to_benchmark!($result, $compiled_tape, $primals[2:end]),
_,
evals = 1,
)
end

if should_run_benchmark(Val(:enzyme), args...)
@info "Enzyme"
dup_args = map(x -> Duplicated(x, randn(size(x))), primals[2:end])
suite["enzyme"] = @be(
_, _, autodiff(Reverse, $primals[1], Active, $dup_args...), _,
evals=1,
_,
_,
autodiff(Reverse, $primals[1], Active, $dup_args...),
_,
evals = 1,
)
end
end

return combine_results((args, suite), tags[n], ranges[n], default_ratios)
end
end
Expand Down Expand Up @@ -319,7 +328,7 @@ well-suited to the numbers typically found in this field.
function plot_ratio_histogram!(df::DataFrame)
bin = 10.0 .^ (-1.0:0.05:4.0)
xlim = extrema(bin)
histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="")
return histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="")
end

function create_inter_ad_benchmarks()
Expand All @@ -328,7 +337,7 @@ function create_inter_ad_benchmarks()
df = DataFrame(results)[:, [:tag, tools...]]

# Plot graph of results.
plt = plot(yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)")
plt = plot(; yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)")
for label in string.(tools)
plot!(plt, df.tag, df[:, label]; label, marker=:circle, xrotation=45)
end
Expand All @@ -337,7 +346,9 @@ function create_inter_ad_benchmarks()
# Write table of results.
formatted_cols = map(t -> t => string.(round.(df[:, t]; sigdigits=3)), tools)
df_formatted = DataFrame(:Label => df.tag, formatted_cols...)
open(io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true)
return open(
io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true
)
end

function main()
Expand Down
22 changes: 7 additions & 15 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,18 @@ DocMeta.setdocmeta!(
recursive=true,
)

makedocs(
makedocs(;
sitename="Mooncake.jl",
format=Documenter.HTML(;
mathengine = Documenter.KaTeX(
Dict(
:macros => Dict(
"\\RR" => "\\mathbb{R}",
),
)
),
mathengine=Documenter.KaTeX(Dict(:macros => Dict("\\RR" => "\\mathbb{R}"))),
size_threshold_ignore=[
joinpath("developer_documentation", "internal_docstrings.md"),
joinpath("developer_documentation", "internal_docstrings.md")
],
),
modules=[Mooncake],
checkdocs=:none,
plugins=[
CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"); style=:numeric),
],
pages = [
plugins=[CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"); style=:numeric)],
pages=[
"Mooncake.jl" => "index.md",
"Understanding Mooncake.jl" => [
joinpath("understanding_mooncake", "introduction.md"),
Expand All @@ -46,7 +38,7 @@ makedocs(
joinpath("developer_documentation", "internal_docstrings.md"),
],
"known_limitations.md",
]
],
)

deploydocs(repo="github.com/compintell/Mooncake.jl.git", push_preview=true)
deploydocs(; repo="github.com/compintell/Mooncake.jl.git", push_preview=true)
2 changes: 1 addition & 1 deletion ext/MooncakeAllocCheckExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module MooncakeAllocCheckExt
using AllocCheck, Mooncake
import Mooncake.TestUtils: check_allocs, Shim

@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any, N}}) where {F, N} = f(x...)
@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any,N}}) where {F,N} = f(x...)

end
6 changes: 2 additions & 4 deletions ext/MooncakeCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0
_add_to_primal(x::P, y::P, ::Bool) where {P<:CuArray{<:IEEEFloat}} = x + y
_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y
_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y))
_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y
_scale(x::Float64, y::P) where {T<:IEEEFloat,P<:CuArray{T}} = T(x) * y

Check warning on line 41 in ext/MooncakeCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeCUDAExt.jl#L41

Added line #L41 was not covered by tests
function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray)
k = pointer_from_objref(p)
v = pointer_from_objref(t)
Expand All @@ -55,9 +55,7 @@ end

# Basic rules for operating on CuArrays.

@is_primitive(
MinimalCtx, Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N},
)
@is_primitive(MinimalCtx, Tuple{Type{<:CuArray},UndefInitializer,Vararg{Int,N}} where {N},)
function rrule!!(
p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}...
) where {P<:CuArray{<:Base.IEEEFloat}}
Expand Down
2 changes: 1 addition & 1 deletion ext/MooncakeDynamicPPLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg}
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}

end # module
45 changes: 21 additions & 24 deletions ext/MooncakeLuxLibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,16 @@ using Base: IEEEFloat

import LuxLib: Impl
import LuxLib.Utils: static_training_mode_check
import Mooncake:
@from_rrule,
DefaultCtx,
@mooncake_overlay,
CoDual
import Mooncake: @from_rrule, DefaultCtx, @mooncake_overlay, CoDual

@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat})
@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul),Array{P},Array{P}} where {P<:IEEEFloat})
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat},
Tuple{typeof(Impl.matmuladd),Array{P},Array{P},Vector{P}} where {P<:IEEEFloat},
)
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat},
Tuple{typeof(Impl.batched_matmul),Array{P,3},Array{P,3}} where {P<:IEEEFloat},
)

# Re-implement a bunch of methods to ensure that Mooncake can differentiate them.
Expand All @@ -35,15 +31,15 @@ end
@mooncake_overlay function LuxLib.Impl.fused_conv(
::LuxLib.Impl.AbstractInternalArrayOpMode,
act::F,
weight::AbstractArray{wT, N},
x::AbstractArray{xT, N},
weight::AbstractArray{wT,N},
x::AbstractArray{xT,N},
bias::LuxLib.Optional{<:AbstractVector},
cdims::LuxLib.Impl.ConvDims,
) where {F, wT, xT, N}
) where {F,wT,xT,N}
return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias)
end

Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check), Vararg}
Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check),Vararg}

# This is a really horrible hack that we need to do until Mooncake is able to support the
# call-back-into-ad interface that ChainRules exposes.
Expand All @@ -61,18 +57,18 @@ function CRC.rrule(
::typeof(batchnorm_affine_normalize_internal),
opmode::AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{T, N},
x::AbstractArray{T,N},
μ::AbstractVector,
σ²::AbstractVector,
γ::LuxLib.Optional{<:AbstractVector},
β::LuxLib.Optional{<:AbstractVector},
ϵ::Real,
) where {T, N}
) where {T,N}
y = similar(
x,
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
),
)
γ′ = similar(
x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1)
Expand Down Expand Up @@ -111,13 +107,13 @@ end
@mooncake_overlay function batchnorm_affine_normalize_internal(
opmode::LuxLib.AbstractInternalArrayOpMode,
act::F,
x::AbstractArray{xT, 3},
x::AbstractArray{xT,3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
γ::Union{Nothing,AbstractVector},
β::Union{Nothing,AbstractVector},
ϵ::Real,
) where {F, xT}
) where {F,xT}
y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ)
LuxLib.Impl.activation!(y, opmode, act, y)
return y
Expand All @@ -126,17 +122,18 @@ end
@mooncake_overlay function batchnorm_affine_normalize_internal(
opmode::LuxLib.AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{xT, 3},
x::AbstractArray{xT,3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
γ::Union{Nothing,AbstractVector},
β::Union{Nothing,AbstractVector},
ϵ::Real,
) where {xT}
y = similar(x,
y = similar(

Check warning on line 132 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L132

Added line #L132 was not covered by tests
x,
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
),
)
batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ)
return y
Expand Down
Loading

0 comments on commit e89a038

Please sign in to comment.