From 6af9c78df9f80cc1af9a9fde1be016d69d863efe Mon Sep 17 00:00:00 2001 From: mastoffel Date: Wed, 16 Oct 2024 13:07:43 +0100 Subject: [PATCH] add plotting --- autoemulate/sensitivity_analysis.py | 69 +++++++++++++++++++---------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/autoemulate/sensitivity_analysis.py b/autoemulate/sensitivity_analysis.py index 02c0e66a..08db4de8 100644 --- a/autoemulate/sensitivity_analysis.py +++ b/autoemulate/sensitivity_analysis.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import plotnine as p9 from SALib.analyze.sobol import analyze from SALib.sample.sobol import sample @@ -41,7 +42,7 @@ def sobol_analysis(model, problem, N=1024): if Y.ndim == 1: Y = Y.reshape(-1, 1) num_outputs = Y.shape[1] - output_names = [f"y{i}" for i in range(num_outputs)] + output_names = [f"y{i+1}" for i in range(num_outputs)] results = {} for i in range(num_outputs): @@ -63,6 +64,7 @@ def sobol_results_to_df(results): Returns: -------- pd.DataFrame + A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'. """ rows = [] for output, indices in results.items(): @@ -102,27 +104,48 @@ def sobol_results_to_df(results): return pd.DataFrame(rows) -# def plot_sensitivity_indices(Si, problem): -# """ -# Plot the Sobol sensitivity indices. - -# Parameters: -# ----------- -# Si : dict -# The Sobol indices returned by perform_sobol_analysis. -# problem : dict -# The problem definition used in the analysis. -# """ -# import matplotlib.pyplot as plt - -# fig, ax = plt.subplots(figsize=(10, 6)) +def plot_sensitivity_analysis(results, type="bar"): + """ + Plot the sensitivity analysis results. -# indices = Si["S1"] -# names = problem["names"] + Parameters: + ----------- + results : pd.DataFrame + The results from sobol_results_to_df. + """ -# ax.bar(names, indices) -# ax.set_ylabel("First-order Sobol index") -# ax.set_title("Sensitivity Analysis Results") -# plt.xticks(rotation=45) -# plt.tight_layout() -# plt.show() + if not isinstance(results, pd.DataFrame): + results = sobol_results_to_df(results) + + if type == "bar": + # filter S1 and ST + results = results[results["index"].isin(["S1", "ST"])] + + p = ( + p9.ggplot(results, p9.aes(x="parameter", y="value", fill="index")) + + p9.geom_bar(stat="identity", position="dodge") + + p9.facet_wrap("~output") + + p9.theme_538() + + p9.scale_fill_manual(values=["#5386E4", "#4C4B63"]) + + p9.labs(y="Sobol Index") + + p9.geom_errorbar( + p9.aes(ymin="value-confidence/2", ymax="value+confidence/2"), + position=p9.position_dodge(width=0.9), + width=0.25, + ) + + p9.ggtitle( + "Sensitivity Analysis: First-Order (S1) and \n Total-Order (ST) Indices and 95% CI" + ) + + p9.theme(plot_title=p9.element_text(hjust=0.5)) # Center the title + ) + elif type == "heatmap": + results = results[results["index"].isin(["S2"])] + results[["param1", "param2"]] = results["parameter"].str.split("-", expand=True) + p = ( + p9.ggplot(results, p9.aes("param1", "param2", fill="value")) + + p9.geom_tile() + + p9.scale_fill_gradient(low="#33658A", high="#9B1D20", limits=(0, 1)) + + p9.facet_wrap("~output") + + p9.theme_classic() + ) + return p