Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and extend RF interface for more flexible kernels. #220

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/API/RandomFeatures.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ get_n_features
get_input_dim
get_output_dim
get_rng
get_diagonalize_input
get_kernel_structure
get_feature_decomposition
get_optimizer_options
optimize_hyperparameters!(::ScalarRandomFeatureInterface)
Expand Down
2 changes: 2 additions & 0 deletions examples/EDMF_data/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
23 changes: 13 additions & 10 deletions examples/EDMF_data/plot_posterior.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Import modules
using StatsPlots
ENV["GKSwstype"] = "100"
using Plots

using CairoMakie, PairPlots
using JLD2
using Dates

Expand All @@ -11,11 +14,11 @@ using CalibrateEmulateSample.ParameterDistributions

# 2-parameter calibration exp
exp_name = "ent-det-calibration"
date_of_run = Date(2022, 7, 15)
date_of_run = Date(2023, 10, 5)

# 5-parameter calibration exp
#exp_name = "ent-det-tked-tkee-stab-calibration"
#date_of_run = Date(2022,7,14)
#date_of_run = Date(2023,10,4)

# Output figure read/write directory
figure_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run))
Expand All @@ -24,7 +27,7 @@ data_save_directory = joinpath(@__DIR__, "output", exp_name, string(date_of_run)
# load
posterior_filepath = joinpath(data_save_directory, "posterior.jld2")
if !isfile(posterior_filepath)
LoadError(posterior_filepath * " not found. Please check experiment name and date")
throw(ArgumentError(posterior_filepath * " not found. Please check experiment name and date"))
else
println("Loading posterior distribution from: " * posterior_filepath)
posterior = load(posterior_filepath)["posterior"]
Expand All @@ -40,10 +43,10 @@ density_filepath = joinpath(figure_save_directory, "posterior_dist_comp.png")
transformed_density_filepath = joinpath(figure_save_directory, "posterior_dist_phys.png")
labels = get_name(posterior)

gr(dpi = 300, size = (nparam_plots * 300, nparam_plots * 300))
# lower triangular marginal cornerplot
p = cornerplot(permutedims(posterior_samples, (2, 1)), label = labels, compact = true)
trans_p = cornerplot(permutedims(transformed_posterior_samples, (2, 1)), label = labels, compact = true)
data = (; [(Symbol(labels[i]), posterior_samples[i, :]) for i in 1:length(labels)]...)
transformed_data = (; [(Symbol(labels[i]), transformed_posterior_samples[i, :]) for i in 1:length(labels)]...)

savefig(p, density_filepath)
savefig(trans_p, transformed_density_filepath)
p = pairplot(data => (PairPlots.Scatter(),))
trans_p = pairplot(transformed_data => (PairPlots.Scatter(),))
save(density_filepath, p)
save(transformed_density_filepath, trans_p)
127 changes: 117 additions & 10 deletions examples/EDMF_data/uq_for_edmf.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#include(joinpath(@__DIR__, "..", "ci", "linkfig.jl"))
PLOT_FLAG = true
PLOT_FLAG = false

# Import modules
using Distributions # probability distributions and associated functions
using LinearAlgebra
ENV["GKSwstype"] = "100"
using Plots
using Random
using JLD2
Expand All @@ -17,6 +18,7 @@ using CalibrateEmulateSample.MarkovChainMonteCarlo
using CalibrateEmulateSample.ParameterDistributions
using CalibrateEmulateSample.DataContainers
using CalibrateEmulateSample.EnsembleKalmanProcesses
using CalibrateEmulateSample.EnsembleKalmanProcesses.Localizers
using CalibrateEmulateSample.Utilities

rng_seed = 42424242
Expand Down Expand Up @@ -126,13 +128,55 @@ function main()
end

# load and create prior distributions
#=
prior_filepath = joinpath(exp_dir, "prior.jld2")
if !isfile(prior_filepath)
LoadError("prior file \"prior.jld2\" not found in directory \"" * exp_dir * "/\"")
else
prior_dict_raw = load(prior_filepath) #using JLD2
prior = prior_dict_raw["prior"]
end
=#
# build prior if jld2 does not work
function get_prior_config(s::SS) where {SS <: AbstractString}
config = Dict()
if s == "ent-det-calibration"
config["constraints"] =
Dict("entrainment_factor" => [bounded(0.0, 1.0)], "detrainment_factor" => [bounded(0.0, 1.0)])
config["prior_mean"] = Dict("entrainment_factor" => 0.13, "detrainment_factor" => 0.51)
config["unconstrained_σ"] = 1.0
elseif s == "ent-det-tked-tkee-stab-calibration"
config["constraints"] = Dict(
"entrainment_factor" => [bounded(0.0, 1.0)],
"detrainment_factor" => [bounded(0.0, 1.0)],
"tke_ed_coeff" => [bounded(0.01, 1.0)],
"tke_diss_coeff" => [bounded(0.01, 1.0)],
"static_stab_coeff" => [bounded(0.01, 1.0)],
)
config["prior_mean"] = Dict(
"entrainment_factor" => 0.13,
"detrainment_factor" => 0.51,
"tke_ed_coeff" => 0.14,
"tke_diss_coeff" => 0.22,
"static_stab_coeff" => 0.4,
)
config["unconstrained_σ"] = 1.0
else
throw(ArgumentError("prior for experiment $s not found, please implement in uq_for_edmf.jl"))
end
return config
end
prior_config = get_prior_config(exp_name)
means = prior_config["prior_mean"]
std = prior_config["unconstrained_σ"]
constraints = prior_config["constraints"]


prior = combine_distributions([
ParameterDistribution(
Dict("name" => name, "distribution" => Parameterized(Normal(mean, std)), "constraint" => constraints[name]),
) for (name, mean) in means
])

# Option (ii) load EKP object
# max_ekp_it = 10 # use highest available iteration file
Expand All @@ -145,27 +189,90 @@ function main()
# end
# input_output_pairs = Utilities.get_training_points(ekpobj, max_ekp_it)

println("Completed calibration stage")
println("Completed calibration loading stage")
println(" ")
##############################################
# [3. ] Build Emulator from calibration data #
##############################################
println("Begin Emulation stage")
# Create GP object

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
gauss_proc = GaussianProcess(
gppackage;
kernel = nothing, # use default squared exponential kernel
prediction_type = pred_type,
noise_learn = false,
cases = [
"GP", # diagonalize, train scalar GP, assume diag inputs
"RF-scalar", # diagonalize, train scalar RF, don't asume diag inputs
"RF-vector-svd-diag",
"RF-vector-svd-nondiag",
"RF-vector-svd-nonsep",
]
case = cases[5]

overrides = Dict(
"verbose" => true,
"train_fraction" => 0.95,
"scheduler" => DataMisfitController(terminate_at = 100),
"cov_sample_multiplier" => 0.5,
"n_iteration" => 5,
# "n_ensemble" => 20,
# "localization" => SEC(0.1), # localization / sample error correction for small ensembles
)
nugget = 0.01
rng_seed = 99330
rng = Random.MersenneTwister(rng_seed)
input_dim = size(get_inputs(input_output_pairs), 1)
output_dim = size(get_outputs(input_output_pairs), 1)
if case == "GP"

gppackage = Emulators.SKLJL()
pred_type = Emulators.YType()
mlt = GaussianProcess(
gppackage;
kernel = nothing, # use default squared exponential kernel
prediction_type = pred_type,
noise_learn = false,
)
elseif case ["RF-scalar"]
n_features = 100
kernel_structure = SeparableKernel(CholeskyFactor(nugget), OneDimFactor())
mlt = ScalarRandomFeatureInterface(
n_features,
input_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
elseif case ["RF-vector-svd-diag", "RF-vector-svd-nondiag"]
# do we want to assume that the outputs are decorrelated in the machine-learning problem?
kernel_structure =
case ["RF-vector-svd-diag"] ? SeparableKernel(LowRankFactor(1, nugget), DiagonalFactor(nugget)) :
SeparableKernel(LowRankFactor(2, nugget), LowRankFactor(2, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
elseif case ["RF-vector-svd-nonsep"]
kernel_structure = NonseparableKernel(LowRankFactor(3, nugget))
n_features = 500

mlt = VectorRandomFeatureInterface(
n_features,
input_dim,
output_dim,
rng = rng,
kernel_structure = kernel_structure,
optimizer_options = overrides,
)
end

# Fit an emulator to the data
normalized = true

emulator = Emulator(gauss_proc, input_output_pairs; obs_noise_cov = truth_cov, normalize_inputs = normalized)
emulator = Emulator(mlt, input_output_pairs; obs_noise_cov = truth_cov, normalize_inputs = normalized)

# Optimize the GP hyperparameters for better fit
optimize_hyperparameters!(emulator)
Expand Down
3 changes: 2 additions & 1 deletion examples/Emulator/GaussianProcess/plot_GP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using CalibrateEmulateSample.DataContainers

plot_flag = true
if plot_flag
ENV["GKSwstype"] = "100"
using Plots
gr(size = (1500, 700))
Plots.scalefontsizes(1.3)
Expand Down Expand Up @@ -78,7 +79,7 @@ gaussian_process = GaussianProcess(gppackage, noise_learn = false)
# The observables y are related to the parameters x by:
# y = G(x1, x2) + η,
# where G(x1, x2) := [sin(x1) + cos(x2), sin(x1) - cos(x2)], and η ~ N(0, Σ)
n = 100 # number of training points
n = 150 # number of training points
p = 2 # input dim
d = 2 # output dim

Expand Down
9 changes: 9 additions & 0 deletions examples/Emulator/L63/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Loading
Loading