Skip to content

Commit

Permalink
Orad/fix emulator bug repeat call (#281)
Browse files Browse the repository at this point in the history
* bugfix repeated calls, by skipping build

* typo in test

plots, repeats, more tweaks

format

varying dimension is easier

ens propto dim

adds recompute_cov and inflation/localization for scalar RF

different plots and configs with localization etc.

exp to fit paper results

add statsbase

add plotting when no repeats

add coeffs, hardcode some restart options in scalar RF

add coeffs, hardcode some restart options in scalar RF

kron add regularization matrix

save states and replotting from file

prior plots for l63

added save+plot for GFunc

add prior case

add save+plot to ishigami

updated plotting for EDMF

remove coeffs from loss again

JLD2 in projects

config update

format

tests passing
  • Loading branch information
odunbar committed Jul 31, 2024
1 parent f2e95cf commit f798b27
Show file tree
Hide file tree
Showing 14 changed files with 1,021 additions and 105 deletions.
120 changes: 111 additions & 9 deletions examples/EDMF_data/plot_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,112 @@ using Dates
# CES
using CalibrateEmulateSample.ParameterDistributions

#####
# Creates 1 plots: One for a specific case, One with 2 cases, and One with all cases (final case being the prior).


# date = Date(year,month,day)

# 2-parameter calibration exp
#exp_name = "ent-det-calibration"
#date_of_run = Date(2023, 10, 17)
#date_of_run = Date(2023, 10, 5)

# 5-parameter calibration exp
exp_name = "ent-det-tked-tkee-stab-calibration"
date_of_run = Date(2024, 2, 2)
date_of_run = Date(2024, 06, 14)

# Output figure read/write directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))
data_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))

#case:
cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-prior",
"RF-vector-svd-nonsep",
]
case_rf = cases[3]

# load
posterior_filepath = joinpath(data_save_directory, "$(case_rf)_posterior.jld2")
if !isfile(posterior_filepath)
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
println("Loading posterior distribution from: " * posterior_filepath)
posterior = load(posterior_filepath)["posterior"]
end
# get samples explicitly (may be easier to work with)
posterior_samples = vcat([get_distribution(posterior)[name] for name in get_name(posterior)]...) #samples are columns
transformed_posterior_samples =
mapslices(x -> transform_unconstrained_to_constrained(posterior, x), posterior_samples, dims = 1)

# histograms
nparam_plots = sum(get_dimensions(posterior)) - 1
density_filepath = joinpath(figure_save_directory, "$(case_rf)_posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "$(case_rf)_posterior_dist_phys.png")
labels = get_name(posterior)

burnin = 50_000

data_rf = (; [(Symbol(labels[i]), posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
transformed_data_rf =
(; [(Symbol(labels[i]), transformed_posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)

p = pairplot(data_rf => (PairPlots.Contourf(sigmas = 1:1:3),))
trans_p = pairplot(transformed_data_rf => (PairPlots.Contourf(sigmas = 1:1:3),))

save(density_filepath, p)
save(transformed_density_filepath, trans_p)

#
#
#

case_gp = cases[1]
# load
posterior_filepath = joinpath(data_save_directory, "$(case_gp)_posterior.jld2")
if !isfile(posterior_filepath)
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
println("Loading posterior distribution from: " * posterior_filepath)
posterior = load(posterior_filepath)["posterior"]
end
# get samples explicitly (may be easier to work with)
posterior_samples = vcat([get_distribution(posterior)[name] for name in get_name(posterior)]...) #samples are columns
transformed_posterior_samples =
mapslices(x -> transform_unconstrained_to_constrained(posterior, x), posterior_samples, dims = 1)

# histograms
nparam_plots = sum(get_dimensions(posterior)) - 1
density_filepath = joinpath(figure_save_directory, "$(case_rf)_$(case_gp)_posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "$(case_rf)_$(case_gp)_posterior_dist_phys.png")
labels = get_name(posterior)
data_gp = (; [(Symbol(labels[i]), posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
transformed_data_gp =
(; [(Symbol(labels[i]), transformed_posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
#
#
#
gp_smoothing = 1 # >1 = smoothing KDE in plotting

p = pairplot(
data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
)
trans_p = pairplot(
transformed_data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
transformed_data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
)

save(density_filepath, p)
save(transformed_density_filepath, trans_p)



# Finally include the prior too
case_prior = cases[2]
# load
posterior_filepath = joinpath(data_save_directory, "posterior.jld2")
posterior_filepath = joinpath(data_save_directory, "$(case_prior)_posterior.jld2")
if !isfile(posterior_filepath)
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
Expand All @@ -39,15 +128,28 @@ transformed_posterior_samples =

# histograms
nparam_plots = sum(get_dimensions(posterior)) - 1
density_filepath = joinpath(figure_save_directory, "posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "posterior_dist_phys.png")
density_filepath = joinpath(figure_save_directory, "all_posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "all_posterior_dist_phys.png")
labels = get_name(posterior)

data = (; [(Symbol(labels[i]), posterior_samples[i, :]) for i in 1:length(labels)]...)
transformed_data = (; [(Symbol(labels[i]), transformed_posterior_samples[i, :]) for i in 1:length(labels)]...)
data_prior = (; [(Symbol(labels[i]), posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
transformed_data_prior =
(; [(Symbol(labels[i]), transformed_posterior_samples[i, burnin:end]) for i in 1:length(labels)]...)
#
#
#

p = pairplot(
data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
data_prior => (PairPlots.Scatter(),),
)
trans_p = pairplot(
transformed_data_rf => (PairPlots.Contourf(sigmas = 1:1:3),),
transformed_data_gp => (PairPlots.Contourf(sigmas = 1:1:3, bandwidth = gp_smoothing),),
transformed_data_prior => (PairPlots.Scatter(),),
)

p = pairplot(data => (PairPlots.Scatter(),))
trans_p = pairplot(transformed_data => (PairPlots.Scatter(),))
save(density_filepath, p)
save(transformed_density_filepath, trans_p)

Expand Down
58 changes: 33 additions & 25 deletions examples/EDMF_data/uq_for_edmf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ function main()
# 5-parameter calibration exp
exp_name = "ent-det-tked-tkee-stab-calibration"

cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-prior",
"RF-vector-svd-nonsep",
]
case = cases[1]

# Output figure save directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(Dates.today()))
Expand Down Expand Up @@ -120,8 +126,7 @@ function main()
println("plotting ensembles...")
for plot_i in 1:size(outputs, 1)
p = scatter(inputs_constrained[1, :], inputs_constrained[2, :], zcolor = outputs[plot_i, :])
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".png"))
savefig(p, joinpath(figure_save_directory, "output_" * string(plot_i) * ".pdf"))
savefig(p, joinpath(figure_save_directory, "$(case)_output_" * string(plot_i) * ".png"))
end
println("finished plotting ensembles.")
end
Expand Down Expand Up @@ -200,17 +205,24 @@ function main()
println("Begin Emulation stage")
# Create GP object

cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-vector-svd-nonsep",
"RF-vector-nosvd-nonsep", # don't perform decorrelation
]
case = cases[3]
n_repeats = 2

opt_diagnostics = []
emulators = []
for rep_idx in 1:n_repeats
overrides = Dict(
"verbose" => true,
"train_fraction" => 0.85,
"scheduler" => DataMisfitController(terminate_at = 100),
"cov_sample_multiplier" => 1.0,
"n_iteration" => 15,
"n_features_opt" => 200,
"localization" => SEC(0.05),
)
if case == "RF-prior"
overrides = Dict("verbose" => true, "cov_sample_multiplier" => 0.01, "n_iteration" => 0)
end
nugget = 1e-6
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
if case == "GP"

overrides = Dict(
"verbose" => true,
Expand All @@ -222,13 +234,9 @@ function main()
# "n_ensemble" => 20,
# "localization" => SEC(1.0, 0.01), # localization / sample error correction for small ensembles
)
nugget = 1e-10#1e-12#0.01
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
decorrelate = true
if case == "GP"
elseif case ["RF-vector-svd-nonsep", "RF-prior"]
kernel_structure = NonseparableKernel(LowRankFactor(1, nugget))
n_features = 500

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
Expand Down Expand Up @@ -312,7 +320,7 @@ function main()

end

emulator_filepath = joinpath(data_save_directory, "emulator.jld2")
emulator_filepath = joinpath(data_save_directory, "$(case)_emulator.jld2")
save(emulator_filepath, "emulator", emulator)

println("Finished Emulation stage")
Expand All @@ -328,17 +336,17 @@ function main()
# determine a good step size
yt_sample = y_truth
mcmc = MCMCWrapper(RWMHSampling(), yt_sample, prior, emulator; init_params = u0)
new_step = optimize_stepsize(mcmc; init_stepsize = 0.1, N = 2000, discard_initial = 0)
new_step = optimize_stepsize(mcmc; init_stepsize = 0.1, N = 5000, discard_initial = 0)

# Now begin the actual MCMC
println("Begin MCMC - with step size ", new_step)
chain = MarkovChainMonteCarlo.sample(mcmc, 100_000; stepsize = new_step, discard_initial = 2_000)
chain = MarkovChainMonteCarlo.sample(mcmc, 300_000; stepsize = new_step, discard_initial = 2_000)
posterior = MarkovChainMonteCarlo.get_posterior(mcmc, chain)

mcmc_filepath = joinpath(data_save_directory, "mcmc_and_chain.jld2")
mcmc_filepath = joinpath(data_save_directory, "$(case)_mcmc_and_chain.jld2")
save(mcmc_filepath, "mcmc", mcmc, "chain", chain)

posterior_filepath = joinpath(data_save_directory, "posterior.jld2")
posterior_filepath = joinpath(data_save_directory, "$(case)_posterior.jld2")
save(posterior_filepath, "posterior", posterior)

println("Finished Sampling stage")
Expand Down
13 changes: 13 additions & 0 deletions examples/Emulator/G-function/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
GlobalSensitivityAnalysis = "1b10255b-6da3-57ce-9089-d24e8517b87e"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomFeatures = "36c3bae2-c0c3-419d-b3b4-eebadd35c5e5"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Loading

0 comments on commit f798b27

Please sign in to comment.