Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup extensions #127

Closed
wants to merge 14 commits into from
22 changes: 22 additions & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,35 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
BenchmarkTools = "1"
Bijectors = "0.13"
Distributions = "0.25.111"
DistributionsAD = "0.6"
Enzyme = "0.13.7"
FillArrays = "1"
ForwardDiff = "0.10"
InteractiveUtils = "1"
LogDensityProblems = "2"
Mooncake = "0.4.5"
Optimisers = "0.3"
Random = "1"
ReverseDiff = "1"
SimpleUnPack = "1"
StableRNGs = "1"
Zygote = "0.6"
julia = "1.10"
10 changes: 10 additions & 0 deletions bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,13 @@
This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`.
The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main).
The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl).

To run the benchmarks locally, follow the following steps:

```julia
using Pkg
Pkg.activate(".")
Pkg.instantiate()
Pkg.develop("AdvancedVI")
include("benchmarks.jl")
```
87 changes: 59 additions & 28 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

using ADTypes, ForwardDiff, ReverseDiff, Zygote
using ADTypes
using AdvancedVI
using BenchmarkTools
using Bijectors
using Distributions
using DistributionsAD
using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake
using FillArrays
using InteractiveUtils
using LinearAlgebra
Expand All @@ -17,37 +18,67 @@ BLAS.set_num_threads(min(4, Threads.nthreads()))
@info sprint(versioninfo)
@info "BLAS threads: $(BLAS.get_num_threads())"

include("utils.jl")
include("normallognormal.jl")
include("unconstrdist.jl")

const SUITES = BenchmarkGroup()

# Comment until https://github.com/TuringLang/Bijectors.jl/pull/315 is merged
# SUITES["normal + bijector"]["meanfield"]["Zygote"] =
# @benchmarkable normallognormal(
# ;
# fptype = Float64,
# adtype = AutoZygote(),
# family = :meanfield,
# objective = :RepGradELBO,
# n_montecarlo = 4,
# )

SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoReverseDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)

SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoForwardDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)
function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol)
if family == :meanfield
MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims)))
else
FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims))
end
end

begin
T = Float64

for (probname, prob) in [
("normal + bijector", normallognormal(; n_dims=10, realtype=T))
("normal", normal(; n_dims=10, realtype=T))
]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(T(1e-3))

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
("RepGradELBO + STL", RepGradELBO(10; entropy=StickingTheLandingEntropy())),
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
#("Mooncake", AutoMooncake(; config=Mooncake.Config())),
#("Enzyme", AutoEnzyme()),
],
(familyname, family) in [
("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))),
(
"fullrank",
FullRankGaussian(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d))),
),
]

b = Bijectors.bijector(prob)
binv = inverse(b)
q = Bijectors.TransformedDistribution(family, binv)

SUITES[probname][objname][familyname][adname] = begin
@benchmarkable AdvancedVI.optimize(
$prob,
$obj,
$q,
$max_iter;
adtype=$adtype,
optimizer=$optimizer,
show_progress=false,
)
end
end
end
end

BenchmarkTools.tune!(SUITES; verbose=true)
results = BenchmarkTools.run(SUITES; verbose=true)
Expand Down
32 changes: 6 additions & 26 deletions bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,10 @@ function Bijectors.bijector(model::NormalLogNormal)
)
end

function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...)
n_dims = 10
μ_x = fptype(5.0)
σ_x = fptype(0.3)
μ_y = Fill(fptype(5.0), n_dims)
σ_y = Fill(fptype(0.3), n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))

obj = variational_objective(objective; kwargs...)

d = LogDensityProblems.dimension(model)
q = variational_standard_mvnormal(fptype, d, family)

b = Bijectors.bijector(model)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

return AdvancedVI.optimize(
model,
obj,
q_transformed,
max_iter;
adtype,
optimizer=Optimisers.Adam(fptype(1e-3)),
show_progress=false,
)
function normallognormal(; n_dims=10, realtype=Float64)
μ_x = realtype(5.0)
σ_x = realtype(0.3)
μ_y = Fill(realtype(5.0), n_dims)
σ_y = Fill(realtype(0.3), n_dims)
return model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))
end
26 changes: 26 additions & 0 deletions bench/unconstrdist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

struct UnconstrDist{D<:ContinuousMultivariateDistribution}
dist::D
end

function LogDensityProblems.logdensity(model::UnconstrDist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.dimension(model::UnconstrDist)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:UnconstrDist})
return LogDensityProblems.LogDensityOrder{0}()
end

function Bijectors.bijector(model::UnconstrDist)
return identity
end

function normal(; n_dims=10, realtype=Float64)
μ = fill(realtype(5), n_dims)
Σ = Diagonal(ones(realtype, n_dims))
return UnconstrDist(MvNormal(μ, Σ))
end
20 changes: 0 additions & 20 deletions bench/utils.jl

This file was deleted.

Empty file removed ext/AdvancedVIForwardDiffExt.jl
Empty file.
Empty file removed ext/AdvancedVIMooncakeExt.jl
Empty file.
Empty file removed ext/AdvancedVIReverseDiffExt.jl
Empty file.
Empty file removed ext/AdvancedVIZygoteExt.jl
Empty file.
9 changes: 0 additions & 9 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,6 @@ end
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/AdvancedVIEnzymeExt.jl")
end
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/AdvancedVIForwardDiffExt.jl")
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("../ext/AdvancedVIReverseDiffExt.jl")
end
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/AdvancedVIZygoteExt.jl")
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
@test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2)

samples_ref = rand(StableRNG(1), q, n_montecarlo)
@test samples_ref == rand(StableRNG(1), q, n_montecarlo)
@test samples_ref rand(StableRNG(1), q, n_montecarlo)
end

@testset "rand! AbstractVector" begin
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_distributionsad = Dict(
)

if @isdefined(Mooncake)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_locationscale = Dict(
)

if @isdefined(Mooncake)
AD_locationscale[:Mooncake] = AutoMooncake(; config=nothing)
AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_locationscale_bijectors = Dict(
)

if @isdefined(Mooncake)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing)
AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
4 changes: 2 additions & 2 deletions test/inference/scoregradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ AD_scoregradelbo_distributionsad = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_scoregradelbo_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

#if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/inference/scoregradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AD_scoregradelbo_locationscale = Dict(
)

if @isdefined(Mooncake)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=nothing)
AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
4 changes: 2 additions & 2 deletions test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ const interface_ad_backends = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.Config())
end

if @isdefined(Enzyme)
Expand Down
2 changes: 1 addition & 1 deletion test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote()
]
if @isdefined(Mooncake)
push!(ad_backends, AutoMooncake(; config=nothing))
push!(ad_backends, AutoMooncake(; config=Mooncake.Config()))
end
if @isdefined(Enzyme)
push!(
Expand Down
Loading