-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
238: add ishigami example & compatible with new EKP r=odunbar a=odunbar <!--- THESE LINES ARE COMMENTED --> ## Purpose <!--- One sentence to describe the purpose of this PR, refer to any linked issues: #14 -- this will link to issue 14 Closes #2 -- this will automatically close issue 2 on PR merge --> - Adds a new example that explores off the training points. Also apparently this is a challenging emulation problem as the GP emulator training fails with a default kernel and internal training methods. - Example also has an analytic form of the "sensitivity" to it's parameters, only by learning the right function on the domain can one recover the right sensitivities. ## Content <!--- specific tasks that are currently complete - Solution implemented --> - The example in `examples/Emulator/Ishigami` - Also added optimizer options "accelerator" and "n_feature_opt" removing the complex defaulting from before - edits to buildkite to stop it taking an old EKP version <!--- Review checklist I have: - followed the codebase contribution guide: https://clima.github.io/ClimateMachine.jl/latest/Contributing/ - followed the style guide: https://clima.github.io/ClimateMachine.jl/latest/DevDocs/CodeStyle/ - followed the documentation policy: https://github.com/CliMA/policies/wiki/Documentation-Policy - checked that this PR does not duplicate an open PR. In the Content, I have included - relevant unit tests, and integration tests, - appropriate docstrings on all functions, structs, and modules, and included relevant documentation. --> ---- - [ ] I have read and checked the items on the review checklist. Co-authored-by: odunbar <[email protected]>
- Loading branch information
Showing
7 changed files
with
275 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[deps] | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
CalibrateEmulateSample = "95e48a1f-0bec-4818-9538-3db4340308e3" | ||
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" | ||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
GlobalSensitivityAnalysis = "1b10255b-6da3-57ce-9089-d24e8517b87e" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
|
||
using GlobalSensitivityAnalysis | ||
const GSA = GlobalSensitivityAnalysis | ||
using Distributions | ||
using DataStructures | ||
using Random | ||
using LinearAlgebra | ||
|
||
using CalibrateEmulateSample.EnsembleKalmanProcesses | ||
using CalibrateEmulateSample.Emulators | ||
using CalibrateEmulateSample.DataContainers | ||
|
||
using CairoMakie, ColorSchemes #for plots | ||
seed = 2589456 | ||
#= | ||
#take in parameters x as [3 x pts] matrix | ||
# Classical values (a,b) = (7, 0.05) from Sobol, Levitan 1999 | ||
# also (a,b) = (7, 0.1) from Marrel et al 2009 | ||
function ishigami(x::MM; a = 7.0, b = 0.05) where {MM <: AbstractMatrix} | ||
@assert size(x,1) == 3 | ||
return (1 .+ b * x[3,:].^4) * sin.(x[1,:]) + a * sin.(x[2,:]).^2 | ||
end | ||
=# | ||
function main() | ||
|
||
rng = MersenneTwister(seed) | ||
|
||
n_repeats = 20 # repeat exp with same data. | ||
|
||
# To create the sampling | ||
n_data_gen = 2000 | ||
|
||
data = SobolData( | ||
params = OrderedDict(:x1 => Uniform(-π, π), :x2 => Uniform(-π, π), :x3 => Uniform(-π, π)), | ||
N = n_data_gen, | ||
) | ||
|
||
# To perform global analysis, | ||
# one must generate samples using Sobol sequence (i.e. creates more than N points) | ||
samples = GSA.sample(data) | ||
n_data = size(samples, 1) # [n_samples x 3] | ||
# run model (example) | ||
y = GSA.ishigami(samples) | ||
# perform Sobol Analysis | ||
result = analyze(data, y) | ||
|
||
f1 = Figure(resolution = (1.618 * 900, 300), markersize = 4) | ||
axx = Axis(f1[1, 1], xlabel = "x1", ylabel = "f") | ||
axy = Axis(f1[1, 2], xlabel = "x2", ylabel = "f") | ||
axz = Axis(f1[1, 3], xlabel = "x3", ylabel = "f") | ||
|
||
scatter!(axx, samples[:, 1], y[:], color = :orange) | ||
scatter!(axy, samples[:, 2], y[:], color = :orange) | ||
scatter!(axz, samples[:, 3], y[:], color = :orange) | ||
|
||
save("ishigami_slices_truth.png", f1, px_per_unit = 3) | ||
save("ishigami_slices_truth.pdf", f1, px_per_unit = 3) | ||
|
||
n_train_pts = 300 | ||
ind = shuffle!(rng, Vector(1:n_data))[1:n_train_pts] | ||
# now subsample the samples data | ||
n_tp = length(ind) | ||
input = zeros(3, n_tp) | ||
output = zeros(1, n_tp) | ||
Γ = 1e-2 | ||
noise = rand(rng, Normal(0, Γ), n_tp) | ||
for i in 1:n_tp | ||
input[:, i] = samples[ind[i], :] | ||
output[i] = y[ind[i]] + noise[i] | ||
end | ||
iopairs = PairedDataContainer(input, output) | ||
|
||
cases = ["Prior", "GP", "RF-scalar"] | ||
case = cases[3] | ||
decorrelate = true | ||
nugget = Float64(1e-12) | ||
|
||
overrides = Dict( | ||
"scheduler" => DataMisfitController(terminate_at = 1e4), | ||
"cov_sample_multiplier" => 1.0, | ||
"n_features_opt" => 100, | ||
"n_iteration" => 10, | ||
) | ||
if case == "Prior" | ||
# don't do anything | ||
overrides["n_iteration"] = 0 | ||
overrides["cov_sample_multiplier"] = 0.1 | ||
end | ||
|
||
y_preds = [] | ||
result_preds = [] | ||
|
||
for rep_idx in 1:n_repeats | ||
|
||
# Build ML tools | ||
if case == "GP" | ||
gppackage = Emulators.GPJL() | ||
pred_type = Emulators.YType() | ||
mlt = GaussianProcess( | ||
gppackage; | ||
kernel = nothing, # use default squared exponential kernel | ||
prediction_type = pred_type, | ||
noise_learn = false, | ||
) | ||
|
||
elseif case ∈ ["RF-scalar", "Prior"] | ||
|
||
kernel_structure = SeparableKernel(LowRankFactor(3, nugget), OneDimFactor()) | ||
n_features = 500 | ||
mlt = ScalarRandomFeatureInterface( | ||
n_features, | ||
3, | ||
rng = rng, | ||
kernel_structure = kernel_structure, | ||
optimizer_options = overrides, | ||
) | ||
end | ||
|
||
# Emulate | ||
emulator = Emulator(mlt, iopairs; obs_noise_cov = Γ * I, decorrelate = decorrelate) | ||
optimize_hyperparameters!(emulator) | ||
|
||
# predict on all Sobol points with emulator (example) | ||
y_pred, y_var = predict(emulator, samples', transform_to_real = true) | ||
|
||
# obtain emulated Sobol indices | ||
result_pred = analyze(data, y_pred') | ||
push!(y_preds, y_pred) | ||
push!(result_preds, result_pred) | ||
|
||
end | ||
|
||
# analytic sobol indices | ||
a = 7 | ||
b = 0.1 | ||
V = a^2 / 8 + b * π^4 / 5 + b^2 * π^8 / 18 + 1 / 2 | ||
V1 = 0.5 * (1 + b * π^4 / 5)^2 | ||
V2 = a^2 / 8 | ||
V3 = 0 | ||
VT1 = 0.5 * (1 + b * π^4 / 5)^2 + 8 * b^2 * π^8 / 225 | ||
VT2 = a^2 / 8 | ||
VT3 = 8 * b^2 * π^8 / 225 | ||
|
||
|
||
println(" ") | ||
println("True Sobol Indices") | ||
println("******************") | ||
println(" firstorder: ", [V1 / V, V2 / V, V3 / V]) | ||
println(" totalorder: ", [VT1 / V, VT2 / V, VT3 / V]) | ||
println(" ") | ||
println("Sampled truth Sobol Indices (# points $n_data)") | ||
println("***************************") | ||
println(" firstorder: ", result[:firstorder]) | ||
println(" totalorder: ", result[:totalorder]) | ||
println(" ") | ||
|
||
println("Sampled Emulated Sobol Indices (# obs $n_train_pts, noise var $Γ)") | ||
println("***************************************************************") | ||
if n_repeats == 1 | ||
println(" firstorder: ", result_preds[1][:firstorder]) | ||
println(" totalorder: ", result_preds[1][:totalorder]) | ||
else | ||
firstorder_mean = mean([rp[:firstorder] for rp in result_preds]) | ||
firstorder_std = std([rp[:firstorder] for rp in result_preds]) | ||
totalorder_mean = mean([rp[:totalorder] for rp in result_preds]) | ||
totalorder_std = std([rp[:totalorder] for rp in result_preds]) | ||
|
||
println("(mean) firstorder: ", firstorder_mean) | ||
println("(std) firstorder: ", firstorder_std) | ||
println("(mean) totalorder: ", totalorder_mean) | ||
println("(std) totalorder: ", totalorder_std) | ||
end | ||
|
||
|
||
# plots | ||
|
||
f2 = Figure(resolution = (1.618 * 900, 300), markersize = 4) | ||
axx_em = Axis(f2[1, 1], xlabel = "x1", ylabel = "f") | ||
axy_em = Axis(f2[1, 2], xlabel = "x2", ylabel = "f") | ||
axz_em = Axis(f2[1, 3], xlabel = "x3", ylabel = "f") | ||
scatter!(axx_em, samples[:, 1], y_preds[1][:], color = :blue) | ||
scatter!(axy_em, samples[:, 2], y_preds[1][:], color = :blue) | ||
scatter!(axz_em, samples[:, 3], y_preds[1][:], color = :blue) | ||
scatter!(axx_em, samples[ind, 1], y[ind] + noise, color = :red, markersize = 8) | ||
scatter!(axy_em, samples[ind, 2], y[ind] + noise, color = :red, markersize = 8) | ||
scatter!(axz_em, samples[ind, 3], y[ind] + noise, color = :red, markersize = 8) | ||
|
||
save("ishigami_slices_$(case).png", f2, px_per_unit = 3) | ||
save("ishigami_slices_$(case).pdf", f2, px_per_unit = 3) | ||
|
||
|
||
end | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.