From 8956c5969e3d7bff6b7c2ff48adf7542ef8a7ce5 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 3 Oct 2024 20:59:06 -0700 Subject: [PATCH 01/13] add subdirectories to compathelper --- .github/workflows/CompatHelper.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 81c8c3fc1..8e6f5716e 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -13,4 +13,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }} - run: julia -e 'using CompatHelper; CompatHelper.main()' + run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs=["", "bench", "test", "docs"])' From 624c82ac76fdde0ba289c5538f9fae21f7c39548 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 3 Oct 2024 22:16:38 -0700 Subject: [PATCH 02/13] add benchmarks --- bench/Project.toml | 2 + bench/benchmarks.jl | 85 +++++++++++++++++++++++++++------------- bench/normallognormal.jl | 22 +---------- bench/unconstrdist.jl | 54 +++++++++++++++++++++++++ bench/utils.jl | 20 ---------- 5 files changed, 114 insertions(+), 69 deletions(-) create mode 100644 bench/unconstrdist.jl delete mode 100644 bench/utils.jl diff --git a/bench/Project.toml b/bench/Project.toml index 3cadd8200..a63fea9bd 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -5,10 +5,12 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 551e12b2d..9c18f1b42 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -1,10 +1,11 @@ -using ADTypes, ForwardDiff, ReverseDiff, Zygote +using ADTypes using AdvancedVI using BenchmarkTools using Bijectors using Distributions using DistributionsAD +using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake using FillArrays using InteractiveUtils using LinearAlgebra @@ -17,37 +18,65 @@ BLAS.set_num_threads(min(4, Threads.nthreads())) @info sprint(versioninfo) @info "BLAS threads: $(BLAS.get_num_threads())" -include("utils.jl") include("normallognormal.jl") +include("unconstrdist.jl") const SUITES = BenchmarkGroup() -# Comment until https://github.com/TuringLang/Bijectors.jl/pull/315 is merged -# SUITES["normal + bijector"]["meanfield"]["Zygote"] = -# @benchmarkable normallognormal( -# ; -# fptype = Float64, -# adtype = AutoZygote(), -# family = :meanfield, -# objective = :RepGradELBO, -# n_montecarlo = 4, -# ) - -SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(; - fptype=Float64, - adtype=AutoReverseDiff(), - family=:meanfield, - objective=:RepGradELBO, - n_montecarlo=4, -) - -SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(; - fptype=Float64, - adtype=AutoForwardDiff(), - family=:meanfield, - objective=:RepGradELBO, - n_montecarlo=4, -) +function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol) + if family == :meanfield + MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims))) + else + FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims)) + end +end + +begin + fptype = Float64 + + for (probname, prob) in [ + ("normal + bijector", normallognormal(; n_dims=10, fptype)) + ("normal", normal(; n_dims=10, fptype)) + ] + max_iter = 10^4 + d = LogDensityProblems.dimension(prob) + optimizer = Optimisers.Adam(fptype(1e-3)) + + for (objname, obj) in [ + ("RepGradELBO", RepGradELBO(10)), + ("RepGradELBO + STL", RepGradELBO(10; entropy=StickingTheLandingEntropy())), + ], + (adname, adtype) in [ + ("Zygote", AutoZygote()), + ("ForwardDiff", AutoForwardDiff()), + ("ReverseDiff", AutoReverseDiff()), + #("Mooncake", AutoMooncake(; config=nothing)), + #("Enzyme", AutoEnzyme()), + ], + (familyname, family) in [ + ("meanfield", MeanFieldGaussian(zeros(d), Diagonal(ones(d)))), + ( + "fullrank", + FullRankGaussian(zeros(d), LowerTriangular(Matrix{fptype}(I, d, d))), + ), + ] + + b = Bijectors.bijector(prob) + binv = inverse(b) + q = Bijectors.TransformedDistribution(family, binv) + + SUITES[probname][objname][familyname][adname] = @benchmarkable AdvancedVI.optimize( + $prob, + $obj, + $q, + $max_iter; + adtype=$adtype, + optimizer=$optimizer, + show_progress=false, + ) + end + end +end BenchmarkTools.tune!(SUITES; verbose=true) results = BenchmarkTools.run(SUITES; verbose=true) diff --git a/bench/normallognormal.jl b/bench/normallognormal.jl index 075bf3dc8..3cc7f4e57 100644 --- a/bench/normallognormal.jl +++ b/bench/normallognormal.jl @@ -27,30 +27,10 @@ function Bijectors.bijector(model::NormalLogNormal) ) end -function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...) - n_dims = 10 +function normallognormal(; n_dims=10, fptype=Float64) μ_x = fptype(5.0) σ_x = fptype(0.3) μ_y = Fill(fptype(5.0), n_dims) σ_y = Fill(fptype(0.3), n_dims) model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) - - obj = variational_objective(objective; kwargs...) - - d = LogDensityProblems.dimension(model) - q = variational_standard_mvnormal(fptype, d, family) - - b = Bijectors.bijector(model) - binv = inverse(b) - q_transformed = Bijectors.TransformedDistribution(q, binv) - - return AdvancedVI.optimize( - model, - obj, - q_transformed, - max_iter; - adtype, - optimizer=Optimisers.Adam(fptype(1e-3)), - show_progress=false, - ) end diff --git a/bench/unconstrdist.jl b/bench/unconstrdist.jl new file mode 100644 index 000000000..7eef8aef3 --- /dev/null +++ b/bench/unconstrdist.jl @@ -0,0 +1,54 @@ + +using Distributions, Enzyme, DifferentiationInterface, LinearAlgebra, LogDensityProblems + +struct UnconstrDist{D <: ContinuousMultivariateDistribution} + dist::D +end + +function LogDensityProblems.logdensity(model::UnconstrDist, x) + return logpdf(model.dist, x) +end + +function LogDensityProblems.dimension(model::UnconstrDist) + return length(model.dist) +end + +function LogDensityProblems.capabilities(::Type{<:UnconstrDist}) + return LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::UnconstrDist) + return identity +end + +function normal(; n_dims=10, fptype=Float64) + μ = fill(fptype(5), n_dims) + Σ = Diagonal(ones(fptype, n_dims)) + UnconstrDist(MvNormal(μ, Σ)) +end + +function f(x, aux) + LogDensityProblems.logdensity(aux , x) +end + +function main() + n_dims = 10 + x = randn(10) + + for fptype in [Float32, Float64], + aux in [ + UnconstrDist(MvNormal(fill(fptype(5), n_dims), Diagonal(ones(fptype, n_dims)))), + UnconstrDist(MvNormal(fill(fptype(5), n_dims), ones(fptype, n_dims))), + UnconstrDist(MvNormal(fill(fptype(5), n_dims), I)), + ] + ∇x = zeros(n_dims) + _, y = Enzyme.autodiff( + Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), + Enzyme.Const(f), + Enzyme.Active, + Enzyme.Duplicated(x, ∇x), + Enzyme.Const(aux), + ) + println(y) + end +end diff --git a/bench/utils.jl b/bench/utils.jl deleted file mode 100644 index d95741cd4..000000000 --- a/bench/utils.jl +++ /dev/null @@ -1,20 +0,0 @@ - -function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol) - if family == :meanfield - AdvancedVI.MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims))) - else - AdvancedVI.FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims)) - end -end - -function variational_objective(objective::Symbol; kwargs...) - if objective == :RepGradELBO - AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]) - elseif objective == :RepGradELBOSTL - AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy()) - elseif objective == :ScoreGradELBO - throw("ScoreGradELBO not supported yet. Please use ScoreGradELBOSTL instead.") - elseif objective == :ScoreGradELBOSTL - AdvancedVI.ScoreGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy()) - end -end From 66ee6c104f69ef2e7fb0dabab7e93a8fd3ebef89 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 3 Oct 2024 22:26:43 -0700 Subject: [PATCH 03/13] add compat bound to benchmark --- bench/Project.toml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/bench/Project.toml b/bench/Project.toml index a63fea9bd..e748ff059 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -17,3 +17,23 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +ADTypes = "1" +BenchmarkTools = "1" +Bijectors = "0.13" +Distributions = "0.25.111" +DistributionsAD = "0.6" +Enzyme = "0.13.7" +FillArrays = "1" +ForwardDiff = "0.10" +InteractiveUtils = "1" +LogDensityProblems = "2" +Mooncake = "0.4.5" +Optimisers = "0.3" +Random = "1" +ReverseDiff = "1" +SimpleUnPack = "1" +StableRNGs = "1" +Zygote = "0.6" +julia = "1.10" From 02a831cd213f50b45b311cc5ccbd1a6a19f72dbc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 3 Oct 2024 22:31:22 -0700 Subject: [PATCH 04/13] remove unused code --- bench/unconstrdist.jl | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/bench/unconstrdist.jl b/bench/unconstrdist.jl index 7eef8aef3..b2755c032 100644 --- a/bench/unconstrdist.jl +++ b/bench/unconstrdist.jl @@ -1,6 +1,4 @@ -using Distributions, Enzyme, DifferentiationInterface, LinearAlgebra, LogDensityProblems - struct UnconstrDist{D <: ContinuousMultivariateDistribution} dist::D end @@ -26,29 +24,3 @@ function normal(; n_dims=10, fptype=Float64) Σ = Diagonal(ones(fptype, n_dims)) UnconstrDist(MvNormal(μ, Σ)) end - -function f(x, aux) - LogDensityProblems.logdensity(aux , x) -end - -function main() - n_dims = 10 - x = randn(10) - - for fptype in [Float32, Float64], - aux in [ - UnconstrDist(MvNormal(fill(fptype(5), n_dims), Diagonal(ones(fptype, n_dims)))), - UnconstrDist(MvNormal(fill(fptype(5), n_dims), ones(fptype, n_dims))), - UnconstrDist(MvNormal(fill(fptype(5), n_dims), I)), - ] - ∇x = zeros(n_dims) - _, y = Enzyme.autodiff( - Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true), - Enzyme.Const(f), - Enzyme.Active, - Enzyme.Duplicated(x, ∇x), - Enzyme.Const(aux), - ) - println(y) - end -end From 7f5c7d129d38186931166223eb26c8a5c039f51d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 3 Oct 2024 23:20:22 -0700 Subject: [PATCH 05/13] refactor benchmark code, run formatter --- bench/benchmarks.jl | 34 ++++++++++++++++++---------------- bench/normallognormal.jl | 12 ++++++------ bench/unconstrdist.jl | 10 +++++----- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 9c18f1b42..98fbd38c0 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -32,15 +32,15 @@ function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol) end begin - fptype = Float64 + T = Float64 for (probname, prob) in [ - ("normal + bijector", normallognormal(; n_dims=10, fptype)) - ("normal", normal(; n_dims=10, fptype)) + ("normal + bijector", normallognormal(; n_dims=10, realtype=T)) + ("normal", normal(; n_dims=10, realtype=T)) ] max_iter = 10^4 d = LogDensityProblems.dimension(prob) - optimizer = Optimisers.Adam(fptype(1e-3)) + optimizer = Optimisers.Adam(T(1e-3)) for (objname, obj) in [ ("RepGradELBO", RepGradELBO(10)), @@ -51,13 +51,13 @@ begin ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), #("Mooncake", AutoMooncake(; config=nothing)), - #("Enzyme", AutoEnzyme()), + ("Enzyme", AutoEnzyme()), ], (familyname, family) in [ - ("meanfield", MeanFieldGaussian(zeros(d), Diagonal(ones(d)))), + ("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))), ( "fullrank", - FullRankGaussian(zeros(d), LowerTriangular(Matrix{fptype}(I, d, d))), + FullRankGaussian(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d))), ), ] @@ -65,15 +65,17 @@ begin binv = inverse(b) q = Bijectors.TransformedDistribution(family, binv) - SUITES[probname][objname][familyname][adname] = @benchmarkable AdvancedVI.optimize( - $prob, - $obj, - $q, - $max_iter; - adtype=$adtype, - optimizer=$optimizer, - show_progress=false, - ) + SUITES[probname][objname][familyname][adname] = begin + @benchmarkable AdvancedVI.optimize( + $prob, + $obj, + $q, + $max_iter; + adtype=$adtype, + optimizer=$optimizer, + show_progress=false, + ) + end end end end diff --git a/bench/normallognormal.jl b/bench/normallognormal.jl index 3cc7f4e57..181996960 100644 --- a/bench/normallognormal.jl +++ b/bench/normallognormal.jl @@ -27,10 +27,10 @@ function Bijectors.bijector(model::NormalLogNormal) ) end -function normallognormal(; n_dims=10, fptype=Float64) - μ_x = fptype(5.0) - σ_x = fptype(0.3) - μ_y = Fill(fptype(5.0), n_dims) - σ_y = Fill(fptype(0.3), n_dims) - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) +function normallognormal(; n_dims=10, realtype=Float64) + μ_x = realtype(5.0) + σ_x = realtype(0.3) + μ_y = Fill(realtype(5.0), n_dims) + σ_y = Fill(realtype(0.3), n_dims) + return model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)) end diff --git a/bench/unconstrdist.jl b/bench/unconstrdist.jl index b2755c032..04223757e 100644 --- a/bench/unconstrdist.jl +++ b/bench/unconstrdist.jl @@ -1,5 +1,5 @@ -struct UnconstrDist{D <: ContinuousMultivariateDistribution} +struct UnconstrDist{D<:ContinuousMultivariateDistribution} dist::D end @@ -19,8 +19,8 @@ function Bijectors.bijector(model::UnconstrDist) return identity end -function normal(; n_dims=10, fptype=Float64) - μ = fill(fptype(5), n_dims) - Σ = Diagonal(ones(fptype, n_dims)) - UnconstrDist(MvNormal(μ, Σ)) +function normal(; n_dims=10, realtype=Float64) + μ = fill(realtype(5), n_dims) + Σ = Diagonal(ones(realtype, n_dims)) + return UnconstrDist(MvNormal(μ, Σ)) end From 95a8b558c86c7d78eaa32d491455ee74acd59e0c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 3 Oct 2024 23:26:29 -0700 Subject: [PATCH 06/13] disable enzyme for now --- bench/benchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 98fbd38c0..8c5266c1b 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -51,7 +51,7 @@ begin ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), #("Mooncake", AutoMooncake(; config=nothing)), - ("Enzyme", AutoEnzyme()), + #("Enzyme", AutoEnzyme()), ], (familyname, family) in [ ("meanfield", MeanFieldGaussian(zeros(T, d), Diagonal(ones(T, d)))), From 71dfa33198f57417aade5b49f002210922ba9fbe Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Oct 2024 12:18:07 -0700 Subject: [PATCH 07/13] update benchmark README --- bench/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bench/README.md b/bench/README.md index 8a8f5163b..ee7d5076a 100644 --- a/bench/README.md +++ b/bench/README.md @@ -3,3 +3,12 @@ This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`. The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main). The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl). + +To run the benchmarks locally, follow the following steps: +```julia +using Pkg +Pkg.activate(".") +Pkg.instantiate() +Pkg.develop("AdvancedVI") +include("benchmarks.jl") +``` From 75c15912e34c653f106a221e39bf22b4b9deea5b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Oct 2024 12:18:14 -0700 Subject: [PATCH 08/13] fix mooncake options --- bench/benchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 8c5266c1b..75807cb60 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -50,7 +50,7 @@ begin ("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), - #("Mooncake", AutoMooncake(; config=nothing)), + #("Mooncake", AutoMooncake(; config=Mooncake.config())), #("Enzyme", AutoEnzyme()), ], (familyname, family) in [ From fd606c997bd27a9bb5f6d343f4259903dd676995 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Oct 2024 12:18:43 -0700 Subject: [PATCH 09/13] try enable mooncake benchmarks --- bench/benchmarks.jl | 2 +- test/inference/repgradelbo_distributionsad.jl | 2 +- test/inference/repgradelbo_locationscale.jl | 2 +- test/inference/repgradelbo_locationscale_bijectors.jl | 2 +- test/inference/scoregradelbo_distributionsad.jl | 4 ++-- test/inference/scoregradelbo_locationscale.jl | 2 +- test/interface/ad.jl | 4 ++-- test/interface/repgradelbo.jl | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 75807cb60..6b06d8d8d 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -50,7 +50,7 @@ begin ("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), - #("Mooncake", AutoMooncake(; config=Mooncake.config())), + ("Mooncake", AutoMooncake(; config=Mooncake.config())), #("Enzyme", AutoEnzyme()), ], (familyname, family) in [ diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 4086a2052..753e3cf36 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -6,7 +6,7 @@ AD_distributionsad = Dict( ) if @isdefined(Mooncake) - AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing) + AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.config()) end if @isdefined(Enzyme) diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 1ca318851..4802f3d29 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -6,7 +6,7 @@ AD_locationscale = Dict( ) if @isdefined(Mooncake) - AD_locationscale[:Mooncake] = AutoMooncake(; config=nothing) + AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.config()) end if @isdefined(Enzyme) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 167fe3892..3135501e5 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -6,7 +6,7 @@ AD_locationscale_bijectors = Dict( ) if @isdefined(Mooncake) - AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=nothing) + AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.config()) end if @isdefined(Enzyme) diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 1de7af1dc..7903852a8 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -5,8 +5,8 @@ AD_scoregradelbo_distributionsad = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - AD_scoregradelbo_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.config()) end #if @isdefined(Enzyme) diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index f0073d7cc..905542a68 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -6,7 +6,7 @@ AD_scoregradelbo_locationscale = Dict( ) if @isdefined(Mooncake) - AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=nothing) + AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.config()) end if @isdefined(Enzyme) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 713a0f56d..0be749f17 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -7,8 +7,8 @@ const interface_ad_backends = Dict( :Zygote => AutoZygote(), ) -if @isdefined(Tapir) - interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false) +if @isdefined(Mooncake) + interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.config()) end if @isdefined(Enzyme) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index afd6249e9..3614dedfa 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -38,7 +38,7 @@ end ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() ] if @isdefined(Mooncake) - push!(ad_backends, AutoMooncake(; config=nothing)) + push!(ad_backends, AutoMooncake(; config=Mooncake.config())) end if @isdefined(Enzyme) push!( From 930cd6d49e8d8099fb163ed2983c1e5941acc89a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Oct 2024 12:33:11 -0700 Subject: [PATCH 10/13] fix wrong mooncake API --- bench/benchmarks.jl | 2 +- test/inference/repgradelbo_distributionsad.jl | 2 +- test/inference/repgradelbo_locationscale.jl | 2 +- test/inference/repgradelbo_locationscale_bijectors.jl | 2 +- test/inference/scoregradelbo_distributionsad.jl | 2 +- test/inference/scoregradelbo_locationscale.jl | 2 +- test/interface/ad.jl | 2 +- test/interface/repgradelbo.jl | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 6b06d8d8d..22daaaffe 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -50,7 +50,7 @@ begin ("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), - ("Mooncake", AutoMooncake(; config=Mooncake.config())), + ("Mooncake", AutoMooncake(; config=Mooncake.Config())), #("Enzyme", AutoEnzyme()), ], (familyname, family) in [ diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 753e3cf36..fbe70ae9d 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -6,7 +6,7 @@ AD_distributionsad = Dict( ) if @isdefined(Mooncake) - AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.config()) + AD_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) end if @isdefined(Enzyme) diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 4802f3d29..d1f0d7e41 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -6,7 +6,7 @@ AD_locationscale = Dict( ) if @isdefined(Mooncake) - AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.config()) + AD_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) end if @isdefined(Enzyme) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 3135501e5..e2a69d62e 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -6,7 +6,7 @@ AD_locationscale_bijectors = Dict( ) if @isdefined(Mooncake) - AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.config()) + AD_locationscale_bijectors[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) end if @isdefined(Enzyme) diff --git a/test/inference/scoregradelbo_distributionsad.jl b/test/inference/scoregradelbo_distributionsad.jl index 7903852a8..9a621b402 100644 --- a/test/inference/scoregradelbo_distributionsad.jl +++ b/test/inference/scoregradelbo_distributionsad.jl @@ -6,7 +6,7 @@ AD_scoregradelbo_distributionsad = Dict( ) if @isdefined(Mooncake) - AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.config()) + AD_scoregradelbo_distributionsad[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) end #if @isdefined(Enzyme) diff --git a/test/inference/scoregradelbo_locationscale.jl b/test/inference/scoregradelbo_locationscale.jl index 905542a68..753999dee 100644 --- a/test/inference/scoregradelbo_locationscale.jl +++ b/test/inference/scoregradelbo_locationscale.jl @@ -6,7 +6,7 @@ AD_scoregradelbo_locationscale = Dict( ) if @isdefined(Mooncake) - AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.config()) + AD_scoregradelbo_locationscale[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) end if @isdefined(Enzyme) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 0be749f17..e23aec580 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -8,7 +8,7 @@ const interface_ad_backends = Dict( ) if @isdefined(Mooncake) - interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.config()) + interface_ad_backends[:Mooncake] = AutoMooncake(; config=Mooncake.Config()) end if @isdefined(Enzyme) diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 3614dedfa..be835e203 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -38,7 +38,7 @@ end ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() ] if @isdefined(Mooncake) - push!(ad_backends, AutoMooncake(; config=Mooncake.config())) + push!(ad_backends, AutoMooncake(; config=Mooncake.Config())) end if @isdefined(Enzyme) push!( From 68cd211a10e6bfac4346cd285701386748ceb51c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Oct 2024 12:37:42 -0700 Subject: [PATCH 11/13] fix formatting --- bench/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/bench/README.md b/bench/README.md index ee7d5076a..685217bbc 100644 --- a/bench/README.md +++ b/bench/README.md @@ -5,6 +5,7 @@ The initial version was heavily inspired by the setup of [Lux.jl](https://github The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl). To run the benchmarks locally, follow the following steps: + ```julia using Pkg Pkg.activate(".") From e1b73d1ded8eb6f176c45301c4273a3041061174 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 4 Oct 2024 20:02:29 -0700 Subject: [PATCH 12/13] fix error and failing test (rounding error) --- bench/benchmarks.jl | 2 +- test/families/location_scale.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 22daaaffe..9e18bd91f 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -50,7 +50,7 @@ begin ("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), - ("Mooncake", AutoMooncake(; config=Mooncake.Config())), + #("Mooncake", AutoMooncake(; config=Mooncake.Config())), #("Enzyme", AutoEnzyme()), ], (familyname, family) in [ diff --git a/test/families/location_scale.jl b/test/families/location_scale.jl index bd45458dc..c112352e9 100644 --- a/test/families/location_scale.jl +++ b/test/families/location_scale.jl @@ -93,7 +93,7 @@ @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) - @test samples_ref == rand(StableRNG(1), q, n_montecarlo) + @test samples_ref ≈ rand(StableRNG(1), q, n_montecarlo) end @testset "rand! AbstractVector" begin From bec84cc4f0b5f5700084de38fe4496ef8fa6015e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 6 Oct 2024 21:18:51 -0700 Subject: [PATCH 13/13] remove empty extensions --- ext/AdvancedVIForwardDiffExt.jl | 0 ext/AdvancedVIMooncakeExt.jl | 0 ext/AdvancedVIReverseDiffExt.jl | 0 ext/AdvancedVIZygoteExt.jl | 0 src/AdvancedVI.jl | 9 --------- 5 files changed, 9 deletions(-) delete mode 100644 ext/AdvancedVIForwardDiffExt.jl delete mode 100644 ext/AdvancedVIMooncakeExt.jl delete mode 100644 ext/AdvancedVIReverseDiffExt.jl delete mode 100644 ext/AdvancedVIZygoteExt.jl diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index aebe765e9..1d0c4f502 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -262,15 +262,6 @@ end @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/AdvancedVIEnzymeExt.jl") end - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - include("../ext/AdvancedVIForwardDiffExt.jl") - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("../ext/AdvancedVIReverseDiffExt.jl") - end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/AdvancedVIZygoteExt.jl") - end end end