Skip to content

Commit

Permalink
add plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Oct 16, 2024
1 parent 2783343 commit 6af9c78
Showing 1 changed file with 46 additions and 23 deletions.
69 changes: 46 additions & 23 deletions autoemulate/sensitivity_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import plotnine as p9
from SALib.analyze.sobol import analyze
from SALib.sample.sobol import sample

Expand Down Expand Up @@ -41,7 +42,7 @@ def sobol_analysis(model, problem, N=1024):
if Y.ndim == 1:
Y = Y.reshape(-1, 1)
num_outputs = Y.shape[1]
output_names = [f"y{i}" for i in range(num_outputs)]
output_names = [f"y{i+1}" for i in range(num_outputs)]

results = {}
for i in range(num_outputs):
Expand All @@ -63,6 +64,7 @@ def sobol_results_to_df(results):
Returns:
--------
pd.DataFrame
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
"""
rows = []
for output, indices in results.items():
Expand Down Expand Up @@ -102,27 +104,48 @@ def sobol_results_to_df(results):
return pd.DataFrame(rows)


# def plot_sensitivity_indices(Si, problem):
# """
# Plot the Sobol sensitivity indices.

# Parameters:
# -----------
# Si : dict
# The Sobol indices returned by perform_sobol_analysis.
# problem : dict
# The problem definition used in the analysis.
# """
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots(figsize=(10, 6))
def plot_sensitivity_analysis(results, type="bar"):
"""
Plot the sensitivity analysis results.
# indices = Si["S1"]
# names = problem["names"]
Parameters:
-----------
results : pd.DataFrame
The results from sobol_results_to_df.
"""

# ax.bar(names, indices)
# ax.set_ylabel("First-order Sobol index")
# ax.set_title("Sensitivity Analysis Results")
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.show()
if not isinstance(results, pd.DataFrame):
results = sobol_results_to_df(results)

if type == "bar":
# filter S1 and ST
results = results[results["index"].isin(["S1", "ST"])]

p = (
p9.ggplot(results, p9.aes(x="parameter", y="value", fill="index"))
+ p9.geom_bar(stat="identity", position="dodge")
+ p9.facet_wrap("~output")
+ p9.theme_538()
+ p9.scale_fill_manual(values=["#5386E4", "#4C4B63"])
+ p9.labs(y="Sobol Index")
+ p9.geom_errorbar(
p9.aes(ymin="value-confidence/2", ymax="value+confidence/2"),
position=p9.position_dodge(width=0.9),
width=0.25,
)
+ p9.ggtitle(
"Sensitivity Analysis: First-Order (S1) and \n Total-Order (ST) Indices and 95% CI"
)
+ p9.theme(plot_title=p9.element_text(hjust=0.5)) # Center the title
)
elif type == "heatmap":
results = results[results["index"].isin(["S2"])]
results[["param1", "param2"]] = results["parameter"].str.split("-", expand=True)
p = (
p9.ggplot(results, p9.aes("param1", "param2", fill="value"))
+ p9.geom_tile()
+ p9.scale_fill_gradient(low="#33658A", high="#9B1D20", limits=(0, 1))
+ p9.facet_wrap("~output")
+ p9.theme_classic()
)
return p

0 comments on commit 6af9c78

Please sign in to comment.