diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e9eb196d..433182ab 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,6 +25,12 @@ jobs: with: python-version: ${{ matrix.python-version }} + # Clean python cache files + - name: Clean cache files + run: | + find . -type f -name "*.pyc" -delete + find . -type d -name "__pycache__" -delete + rm -f .coverage* coverage.xml # Cache Poetry dependencies # - name: Cache dependencies # uses: actions/cache@v2 @@ -56,10 +62,15 @@ jobs: run: | poetry run python -c "import torch; print(torch.__version__); print('CUDA available:', torch.cuda.is_available())" + - name: Debug Coverage Config + run: | + cat pyproject.toml + poetry run coverage debug config + - name: Run Tests with Coverage run: | - poetry run coverage run -m pytest - poetry run coverage xml -o coverage.xml + poetry run coverage run --source=. -m pytest + poetry run coverage xml -i -o coverage.xml env: COVERAGE_FILE: ".coverage.${{ matrix.python-version }}" diff --git a/.gitignore b/.gitignore index 87b71f6b..25755c20 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,8 @@ Thumbs.db *.DS_Store # VScode settings -.vscode/ \ No newline at end of file +.vscode/ + +# Quarto +README.html +README_files/ \ No newline at end of file diff --git a/README.md b/README.md index e9e6e192..37d10f89 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ There's currently a lot of development, so we recommend installing the most curr pip install git+https://github.com/alan-turing-institute/autoemulate.git ``` -There's also a release available on PyPI (will not contain the most recent features and models) +There's also a release available on PyPI (note: currently and older version and out of date with the documentation) ```bash pip install autoemulate ``` @@ -47,19 +47,27 @@ from autoemulate.simulations.projectile import simulate_projectile lhd = LatinHypercube([(-5., 1.), (0., 1000.)]) X = lhd.sample(100) y = np.array([simulate_projectile(x) for x in X]) + # compare emulator models ae = AutoEmulate() ae.setup(X, y) -best_model = ae.compare() +best_emulator = ae.compare() + # training set cross-validation results ae.summarise_cv() ae.plot_cv() + # test set results for the best model -ae.evaluate(best_model) -ae.plot_eval(best_model) +ae.evaluate(best_emulator) +ae.plot_eval(best_emulator) + # refit on full data and emulate! -best_model = ae.refit(best_model) -best_model.predict(X) +emulator = ae.refit(best_emulator) +emulator.predict(X) + +# global sensitivity analysis +si = ae.sensitivity_analysis(emulator) +ae.plot_sensitivity_analysis(si) ``` ## documentation diff --git a/autoemulate/compare.py b/autoemulate/compare.py index 7307ef03..f4b596e9 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -20,6 +20,9 @@ from autoemulate.plotting import _plot_model from autoemulate.printing import _print_setup from autoemulate.save import ModelSerialiser +from autoemulate.sensitivity_analysis import plot_sensitivity_analysis +from autoemulate.sensitivity_analysis import sensitivity_analysis +from autoemulate.utils import _ensure_2d from autoemulate.utils import _get_full_model_name from autoemulate.utils import _redirect_warnings from autoemulate.utils import get_model_name @@ -522,3 +525,64 @@ def plot_eval( ) return fig + + def sensitivity_analysis( + self, model=None, problem=None, N=1024, conf_level=0.95, as_df=True + ): + """Perform Sobol sensitivity analysis on a fitted emulator. + + Parameters + ---------- + model : object, optional + Fitted model. If None, uses the best model from cross-validation. + problem : dict, optional + The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'. + If None, the problem is generated from X using minimum and maximum values of the features as bounds. + + Example: + ```python + problem = { + "num_vars": 2, + "names": ["x1", "x2"], + "bounds": [[0, 1], [0, 1]], + } + ``` + N : int, optional + Number of samples to generate. Default is 1024. + conf_level : float, optional + Confidence level for the confidence intervals. Default is 0.95. + as_df : bool, optional + If True, return a long-format pandas DataFrame (default is True). + """ + if model is None: + if not hasattr(self, "best_model"): + raise RuntimeError("Must run compare() before sensitivity_analysis()") + model = self.refit(self.best_model) + self.logger.info( + f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data." + ) + + Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df) + return Si + + def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None): + """ + Plot the sensitivity analysis results. + + Parameters: + ----------- + results : pd.DataFrame + The results from sobol_results_to_df. + index : str, default "S1" + The type of sensitivity index to plot. + - "S1": first-order indices + - "S2": second-order/interaction indices + - "ST": total-order indices + n_cols : int, optional + The number of columns in the plot. Defaults to 3 if there are 3 or more outputs, + otherwise the number of outputs. + figsize : tuple, optional + Figure size as (width, height) in inches.If None, automatically calculated. + + """ + return plot_sensitivity_analysis(results, index, n_cols, figsize) diff --git a/autoemulate/sensitivity_analysis.py b/autoemulate/sensitivity_analysis.py new file mode 100644 index 00000000..846d0e31 --- /dev/null +++ b/autoemulate/sensitivity_analysis.py @@ -0,0 +1,304 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from SALib.analyze.sobol import analyze +from SALib.sample.sobol import sample + +from autoemulate.utils import _ensure_2d + + +def sensitivity_analysis( + model, problem=None, X=None, N=1024, conf_level=0.95, as_df=True +): + """Perform Sobol sensitivity analysis on a fitted emulator. + + Parameters: + ----------- + model : fitted emulator model + The emulator model to analyze. + problem : dict + The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'. + Example: + ```python + problem = { + "num_vars": 2, + "names": ["x1", "x2"], + "bounds": [[0, 1], [0, 1]], + } + ``` + N : int, optional + The number of samples to generate (default is 1024). + conf_level : float, optional + The confidence level for the confidence intervals (default is 0.95). + as_df : bool, optional + If True, return a pandas DataFrame (default is True). + + Returns: + -------- + pd.DataFrame or dict + If as_df is True, returns a long-format DataFrame with the sensitivity indices. + Otherwise, returns a dictionary where each key is the name of an output variable and each value is a dictionary + containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry + is a list of length corresponding to the number of parameters. + """ + Si = sobol_analysis(model, problem, X, N, conf_level) + + if as_df: + return sobol_results_to_df(Si) + else: + return Si + + +def _check_problem(problem): + """ + Check that the problem definition is valid. + """ + if not isinstance(problem, dict): + raise ValueError("problem must be a dictionary.") + + if "num_vars" not in problem: + raise ValueError("problem must contain 'num_vars'.") + if "names" not in problem: + raise ValueError("problem must contain 'names'.") + if "bounds" not in problem: + raise ValueError("problem must contain 'bounds'.") + + if len(problem["names"]) != problem["num_vars"]: + raise ValueError("Length of 'names' must match 'num_vars'.") + if len(problem["bounds"]) != problem["num_vars"]: + raise ValueError("Length of 'bounds' must match 'num_vars'.") + + return problem + + +def _get_output_names(problem, num_outputs): + """ + Get the output names from the problem definition or generate default names. + """ + # check if output_names is given + if "output_names" not in problem: + output_names = [f"y{i+1}" for i in range(num_outputs)] + else: + if isinstance(problem["output_names"], list): + output_names = problem["output_names"] + else: + raise ValueError("'output_names' must be a list of strings.") + + return output_names + + +def _generate_problem(X): + """ + Generate a problem definition from a design matrix. + """ + if X.ndim == 1: + raise ValueError("X must be a 2D array.") + + return { + "num_vars": X.shape[1], + "names": [f"x{i+1}" for i in range(X.shape[1])], + "bounds": [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])], + } + + +def sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95): + """ + Perform Sobol sensitivity analysis on a fitted emulator. + + Parameters: + ----------- + model : fitted emulator model + The emulator model to analyze. + problem : dict + The problem definition, including 'num_vars', 'names', and 'bounds'. + N : int, optional + The number of samples to generate (default is 1000). + + Returns: + -------- + dict + A dictionary where each key is the name of an output variable and each value is a dictionary + containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry + is a list of length corresponding to the number of parameters. + """ + # get problem + if problem is not None: + problem = _check_problem(problem) + elif X is not None: + problem = _generate_problem(X) + else: + raise ValueError("Either problem or X must be provided.") + + # saltelli sampling + param_values = sample(problem, N) + + # evaluate + Y = model.predict(param_values) + Y = _ensure_2d(Y) + + num_outputs = Y.shape[1] + output_names = _get_output_names(problem, num_outputs) + + # single or multiple output sobol analysis + results = {} + for i in range(num_outputs): + Si = analyze(problem, Y[:, i], conf_level=conf_level) + results[output_names[i]] = Si + + return results + + +def sobol_results_to_df(results): + """ + Convert Sobol results to a (long-format)pandas DataFrame. + + Parameters: + ----------- + results : dict + The Sobol indices returned by sobol_analysis. + + Returns: + -------- + pd.DataFrame + A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'. + """ + rows = [] + for output, indices in results.items(): + for index_type in ["S1", "ST", "S2"]: + values = indices.get(index_type) + conf_values = indices.get(f"{index_type}_conf") + if values is None or conf_values is None: + continue + + if index_type in ["S1", "ST"]: + rows.extend( + { + "output": output, + "parameter": f"X{i+1}", + "index": index_type, + "value": value, + "confidence": conf, + } + for i, (value, conf) in enumerate(zip(values, conf_values)) + ) + + elif index_type == "S2": + n = values.shape[0] + rows.extend( + { + "output": output, + "parameter": f"X{i+1}-X{j+1}", + "index": index_type, + "value": values[i, j], + "confidence": conf_values[i, j], + } + for i in range(n) + for j in range(i + 1, n) + if not np.isnan(values[i, j]) + ) + + return pd.DataFrame(rows) + + +# plotting -------------------------------------------------------------------- + + +def _validate_input(results, index): + if not isinstance(results, pd.DataFrame): + results = sobol_results_to_df(results) + # we only want to plot one index type at a time + valid_indices = ["S1", "S2", "ST"] + if index not in valid_indices: + raise ValueError( + f"Invalid index type: {index}. Must be one of {valid_indices}." + ) + return results[results["index"].isin([index])] + + +def _calculate_layout(n_outputs, n_cols=None): + if n_cols is None: + n_cols = 3 if n_outputs >= 3 else n_outputs + n_rows = int(np.ceil(n_outputs / n_cols)) + return n_rows, n_cols + + +def _create_bar_plot(ax, output_data, output_name): + """Create a bar plot for a single output.""" + bar_color = "#4C4B63" + x_pos = np.arange(len(output_data)) + + bars = ax.bar( + x_pos, + output_data["value"], + color=bar_color, + yerr=output_data["confidence"].values / 2, + capsize=3, + ) + + ax.set_xticks(x_pos) + ax.set_xticklabels(output_data["parameter"], rotation=45, ha="right") + ax.set_ylabel("Sobol Index") + ax.set_title(f"Output: {output_name}") + + +def plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None): + """ + Plot the sensitivity analysis results. + + Parameters: + ----------- + results : pd.DataFrame + The results from sobol_results_to_df. + index : str, default "S1" + The type of sensitivity index to plot. + - "S1": first-order indices + - "S2": second-order/interaction indices + - "ST": total-order indices + n_cols : int, optional + The number of columns in the plot. Defaults to 3 if there are 3 or more outputs, + otherwise the number of outputs. + figsize : tuple, optional + Figure size as (width, height) in inches.If None, automatically calculated. + + """ + with plt.style.context("fast"): + # prepare data + results = _validate_input(results, index) + unique_outputs = results["output"].unique() + n_outputs = len(unique_outputs) + + # layout + n_rows, n_cols = _calculate_layout(n_outputs, n_cols) + figsize = figsize or (4.5 * n_cols, 4 * n_rows) + + fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) + if isinstance(axes, np.ndarray): + axes = axes.flatten() + elif n_outputs == 1: + axes = [axes] + + for ax, output in zip(axes, unique_outputs): + output_data = results[results["output"] == output] + _create_bar_plot(ax, output_data, output) + + # remove any empty subplots + for idx in range(len(unique_outputs), len(axes)): + fig.delaxes(axes[idx]) + + index_names = { + "S1": "First-Order", + "S2": "Second-order/Interaction", + "ST": "Total-Order", + } + + # title + fig.suptitle( + f"{index_names[index]} indices and 95% CI", + fontsize=14, + ) + + plt.tight_layout() + # prevent double plotting in notebooks + plt.close(fig) + + return fig diff --git a/poetry.lock b/poetry.lock index 56e95794..34dd0522 100644 --- a/poetry.lock +++ b/poetry.lock @@ -446,63 +446,73 @@ test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] [[package]] name = "coverage" -version = "7.4.3" +version = "7.6.4" description = "Code coverage measurement for Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "coverage-7.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8580b827d4746d47294c0e0b92854c85a92c2227927433998f0d3320ae8a71b6"}, - {file = "coverage-7.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:718187eeb9849fc6cc23e0d9b092bc2348821c5e1a901c9f8975df0bc785bfd4"}, - {file = "coverage-7.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:767b35c3a246bcb55b8044fd3a43b8cd553dd1f9f2c1eeb87a302b1f8daa0524"}, - {file = "coverage-7.4.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae7f19afe0cce50039e2c782bff379c7e347cba335429678450b8fe81c4ef96d"}, - {file = "coverage-7.4.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba3a8aaed13770e970b3df46980cb068d1c24af1a1968b7818b69af8c4347efb"}, - {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ee866acc0861caebb4f2ab79f0b94dbfbdbfadc19f82e6e9c93930f74e11d7a0"}, - {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:506edb1dd49e13a2d4cac6a5173317b82a23c9d6e8df63efb4f0380de0fbccbc"}, - {file = "coverage-7.4.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd6545d97c98a192c5ac995d21c894b581f1fd14cf389be90724d21808b657e2"}, - {file = "coverage-7.4.3-cp310-cp310-win32.whl", hash = "sha256:f6a09b360d67e589236a44f0c39218a8efba2593b6abdccc300a8862cffc2f94"}, - {file = "coverage-7.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:18d90523ce7553dd0b7e23cbb28865db23cddfd683a38fb224115f7826de78d0"}, - {file = "coverage-7.4.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cbbe5e739d45a52f3200a771c6d2c7acf89eb2524890a4a3aa1a7fa0695d2a47"}, - {file = "coverage-7.4.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:489763b2d037b164846ebac0cbd368b8a4ca56385c4090807ff9fad817de4113"}, - {file = "coverage-7.4.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:451f433ad901b3bb00184d83fd83d135fb682d780b38af7944c9faeecb1e0bfe"}, - {file = "coverage-7.4.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fcc66e222cf4c719fe7722a403888b1f5e1682d1679bd780e2b26c18bb648cdc"}, - {file = "coverage-7.4.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ec74cfef2d985e145baae90d9b1b32f85e1741b04cd967aaf9cfa84c1334f3"}, - {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:abbbd8093c5229c72d4c2926afaee0e6e3140de69d5dcd918b2921f2f0c8baba"}, - {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:35eb581efdacf7b7422af677b92170da4ef34500467381e805944a3201df2079"}, - {file = "coverage-7.4.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8249b1c7334be8f8c3abcaaa996e1e4927b0e5a23b65f5bf6cfe3180d8ca7840"}, - {file = "coverage-7.4.3-cp311-cp311-win32.whl", hash = "sha256:cf30900aa1ba595312ae41978b95e256e419d8a823af79ce670835409fc02ad3"}, - {file = "coverage-7.4.3-cp311-cp311-win_amd64.whl", hash = "sha256:18c7320695c949de11a351742ee001849912fd57e62a706d83dfc1581897fa2e"}, - {file = "coverage-7.4.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b51bfc348925e92a9bd9b2e48dad13431b57011fd1038f08316e6bf1df107d10"}, - {file = "coverage-7.4.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d6cdecaedea1ea9e033d8adf6a0ab11107b49571bbb9737175444cea6eb72328"}, - {file = "coverage-7.4.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b2eccb883368f9e972e216c7b4c7c06cabda925b5f06dde0650281cb7666a30"}, - {file = "coverage-7.4.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c00cdc8fa4e50e1cc1f941a7f2e3e0f26cb2a1233c9696f26963ff58445bac7"}, - {file = "coverage-7.4.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9a4a8dd3dcf4cbd3165737358e4d7dfbd9d59902ad11e3b15eebb6393b0446e"}, - {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:062b0a75d9261e2f9c6d071753f7eef0fc9caf3a2c82d36d76667ba7b6470003"}, - {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:ebe7c9e67a2d15fa97b77ea6571ce5e1e1f6b0db71d1d5e96f8d2bf134303c1d"}, - {file = "coverage-7.4.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c0a120238dd71c68484f02562f6d446d736adcc6ca0993712289b102705a9a3a"}, - {file = "coverage-7.4.3-cp312-cp312-win32.whl", hash = "sha256:37389611ba54fd6d278fde86eb2c013c8e50232e38f5c68235d09d0a3f8aa352"}, - {file = "coverage-7.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:d25b937a5d9ffa857d41be042b4238dd61db888533b53bc76dc082cb5a15e914"}, - {file = "coverage-7.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:28ca2098939eabab044ad68850aac8f8db6bf0b29bc7f2887d05889b17346454"}, - {file = "coverage-7.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:280459f0a03cecbe8800786cdc23067a8fc64c0bd51dc614008d9c36e1659d7e"}, - {file = "coverage-7.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0cdedd3500e0511eac1517bf560149764b7d8e65cb800d8bf1c63ebf39edd2"}, - {file = "coverage-7.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a9babb9466fe1da12417a4aed923e90124a534736de6201794a3aea9d98484e"}, - {file = "coverage-7.4.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dec9de46a33cf2dd87a5254af095a409ea3bf952d85ad339751e7de6d962cde6"}, - {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:16bae383a9cc5abab9bb05c10a3e5a52e0a788325dc9ba8499e821885928968c"}, - {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2c854ce44e1ee31bda4e318af1dbcfc929026d12c5ed030095ad98197eeeaed0"}, - {file = "coverage-7.4.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ce8c50520f57ec57aa21a63ea4f325c7b657386b3f02ccaedeccf9ebe27686e1"}, - {file = "coverage-7.4.3-cp38-cp38-win32.whl", hash = "sha256:708a3369dcf055c00ddeeaa2b20f0dd1ce664eeabde6623e516c5228b753654f"}, - {file = "coverage-7.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:1bf25fbca0c8d121a3e92a2a0555c7e5bc981aee5c3fdaf4bb7809f410f696b9"}, - {file = "coverage-7.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b253094dbe1b431d3a4ac2f053b6d7ede2664ac559705a704f621742e034f1f"}, - {file = "coverage-7.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77fbfc5720cceac9c200054b9fab50cb2a7d79660609200ab83f5db96162d20c"}, - {file = "coverage-7.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6679060424faa9c11808598504c3ab472de4531c571ab2befa32f4971835788e"}, - {file = "coverage-7.4.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4af154d617c875b52651dd8dd17a31270c495082f3d55f6128e7629658d63765"}, - {file = "coverage-7.4.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8640f1fde5e1b8e3439fe482cdc2b0bb6c329f4bb161927c28d2e8879c6029ee"}, - {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:69b9f6f66c0af29642e73a520b6fed25ff9fd69a25975ebe6acb297234eda501"}, - {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0842571634f39016a6c03e9d4aba502be652a6e4455fadb73cd3a3a49173e38f"}, - {file = "coverage-7.4.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a78ed23b08e8ab524551f52953a8a05d61c3a760781762aac49f8de6eede8c45"}, - {file = "coverage-7.4.3-cp39-cp39-win32.whl", hash = "sha256:c0524de3ff096e15fcbfe8f056fdb4ea0bf497d584454f344d59fce069d3e6e9"}, - {file = "coverage-7.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:0209a6369ccce576b43bb227dc8322d8ef9e323d089c6f3f26a597b09cb4d2aa"}, - {file = "coverage-7.4.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:7cbde573904625509a3f37b6fecea974e363460b556a627c60dc2f47e2fffa51"}, - {file = "coverage-7.4.3.tar.gz", hash = "sha256:276f6077a5c61447a48d133ed13e759c09e62aff0dc84274a68dc18660104d52"}, + {file = "coverage-7.6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5f8ae553cba74085db385d489c7a792ad66f7f9ba2ee85bfa508aeb84cf0ba07"}, + {file = "coverage-7.6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8165b796df0bd42e10527a3f493c592ba494f16ef3c8b531288e3d0d72c1f6f0"}, + {file = "coverage-7.6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c8b95bf47db6d19096a5e052ffca0a05f335bc63cef281a6e8fe864d450a72"}, + {file = "coverage-7.6.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ed9281d1b52628e81393f5eaee24a45cbd64965f41857559c2b7ff19385df51"}, + {file = "coverage-7.6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0809082ee480bb8f7416507538243c8863ac74fd8a5d2485c46f0f7499f2b491"}, + {file = "coverage-7.6.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d541423cdd416b78626b55f123412fcf979d22a2c39fce251b350de38c15c15b"}, + {file = "coverage-7.6.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58809e238a8a12a625c70450b48e8767cff9eb67c62e6154a642b21ddf79baea"}, + {file = "coverage-7.6.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c9b8e184898ed014884ca84c70562b4a82cbc63b044d366fedc68bc2b2f3394a"}, + {file = "coverage-7.6.4-cp310-cp310-win32.whl", hash = "sha256:6bd818b7ea14bc6e1f06e241e8234508b21edf1b242d49831831a9450e2f35fa"}, + {file = "coverage-7.6.4-cp310-cp310-win_amd64.whl", hash = "sha256:06babbb8f4e74b063dbaeb74ad68dfce9186c595a15f11f5d5683f748fa1d172"}, + {file = "coverage-7.6.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:73d2b73584446e66ee633eaad1a56aad577c077f46c35ca3283cd687b7715b0b"}, + {file = "coverage-7.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:51b44306032045b383a7a8a2c13878de375117946d68dcb54308111f39775a25"}, + {file = "coverage-7.6.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3fb02fe73bed561fa12d279a417b432e5b50fe03e8d663d61b3d5990f29546"}, + {file = "coverage-7.6.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed8fe9189d2beb6edc14d3ad19800626e1d9f2d975e436f84e19efb7fa19469b"}, + {file = "coverage-7.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b369ead6527d025a0fe7bd3864e46dbee3aa8f652d48df6174f8d0bac9e26e0e"}, + {file = "coverage-7.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ade3ca1e5f0ff46b678b66201f7ff477e8fa11fb537f3b55c3f0568fbfe6e718"}, + {file = "coverage-7.6.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:27fb4a050aaf18772db513091c9c13f6cb94ed40eacdef8dad8411d92d9992db"}, + {file = "coverage-7.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4f704f0998911abf728a7783799444fcbbe8261c4a6c166f667937ae6a8aa522"}, + {file = "coverage-7.6.4-cp311-cp311-win32.whl", hash = "sha256:29155cd511ee058e260db648b6182c419422a0d2e9a4fa44501898cf918866cf"}, + {file = "coverage-7.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:8902dd6a30173d4ef09954bfcb24b5d7b5190cf14a43170e386979651e09ba19"}, + {file = "coverage-7.6.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:12394842a3a8affa3ba62b0d4ab7e9e210c5e366fbac3e8b2a68636fb19892c2"}, + {file = "coverage-7.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b6b4c83d8e8ea79f27ab80778c19bc037759aea298da4b56621f4474ffeb117"}, + {file = "coverage-7.6.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d5b8007f81b88696d06f7df0cb9af0d3b835fe0c8dbf489bad70b45f0e45613"}, + {file = "coverage-7.6.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b57b768feb866f44eeed9f46975f3d6406380275c5ddfe22f531a2bf187eda27"}, + {file = "coverage-7.6.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5915fcdec0e54ee229926868e9b08586376cae1f5faa9bbaf8faf3561b393d52"}, + {file = "coverage-7.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b58c672d14f16ed92a48db984612f5ce3836ae7d72cdd161001cc54512571f2"}, + {file = "coverage-7.6.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2fdef0d83a2d08d69b1f2210a93c416d54e14d9eb398f6ab2f0a209433db19e1"}, + {file = "coverage-7.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8cf717ee42012be8c0cb205dbbf18ffa9003c4cbf4ad078db47b95e10748eec5"}, + {file = "coverage-7.6.4-cp312-cp312-win32.whl", hash = "sha256:7bb92c539a624cf86296dd0c68cd5cc286c9eef2d0c3b8b192b604ce9de20a17"}, + {file = "coverage-7.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:1032e178b76a4e2b5b32e19d0fd0abbce4b58e77a1ca695820d10e491fa32b08"}, + {file = "coverage-7.6.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:023bf8ee3ec6d35af9c1c6ccc1d18fa69afa1cb29eaac57cb064dbb262a517f9"}, + {file = "coverage-7.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b0ac3d42cb51c4b12df9c5f0dd2f13a4f24f01943627120ec4d293c9181219ba"}, + {file = "coverage-7.6.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8fe4984b431f8621ca53d9380901f62bfb54ff759a1348cd140490ada7b693c"}, + {file = "coverage-7.6.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5fbd612f8a091954a0c8dd4c0b571b973487277d26476f8480bfa4b2a65b5d06"}, + {file = "coverage-7.6.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dacbc52de979f2823a819571f2e3a350a7e36b8cb7484cdb1e289bceaf35305f"}, + {file = "coverage-7.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:dab4d16dfef34b185032580e2f2f89253d302facba093d5fa9dbe04f569c4f4b"}, + {file = "coverage-7.6.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:862264b12ebb65ad8d863d51f17758b1684560b66ab02770d4f0baf2ff75da21"}, + {file = "coverage-7.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5beb1ee382ad32afe424097de57134175fea3faf847b9af002cc7895be4e2a5a"}, + {file = "coverage-7.6.4-cp313-cp313-win32.whl", hash = "sha256:bf20494da9653f6410213424f5f8ad0ed885e01f7e8e59811f572bdb20b8972e"}, + {file = "coverage-7.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:182e6cd5c040cec0a1c8d415a87b67ed01193ed9ad458ee427741c7d8513d963"}, + {file = "coverage-7.6.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a181e99301a0ae128493a24cfe5cfb5b488c4e0bf2f8702091473d033494d04f"}, + {file = "coverage-7.6.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:df57bdbeffe694e7842092c5e2e0bc80fff7f43379d465f932ef36f027179806"}, + {file = "coverage-7.6.4-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bcd1069e710600e8e4cf27f65c90c7843fa8edfb4520fb0ccb88894cad08b11"}, + {file = "coverage-7.6.4-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99b41d18e6b2a48ba949418db48159d7a2e81c5cc290fc934b7d2380515bd0e3"}, + {file = "coverage-7.6.4-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6b1e54712ba3474f34b7ef7a41e65bd9037ad47916ccb1cc78769bae324c01a"}, + {file = "coverage-7.6.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:53d202fd109416ce011578f321460795abfe10bb901b883cafd9b3ef851bacfc"}, + {file = "coverage-7.6.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:c48167910a8f644671de9f2083a23630fbf7a1cb70ce939440cd3328e0919f70"}, + {file = "coverage-7.6.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cc8ff50b50ce532de2fa7a7daae9dd12f0a699bfcd47f20945364e5c31799fef"}, + {file = "coverage-7.6.4-cp313-cp313t-win32.whl", hash = "sha256:b8d3a03d9bfcaf5b0141d07a88456bb6a4c3ce55c080712fec8418ef3610230e"}, + {file = "coverage-7.6.4-cp313-cp313t-win_amd64.whl", hash = "sha256:f3ddf056d3ebcf6ce47bdaf56142af51bb7fad09e4af310241e9db7a3a8022e1"}, + {file = "coverage-7.6.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cb7fa111d21a6b55cbf633039f7bc2749e74932e3aa7cb7333f675a58a58bf3"}, + {file = "coverage-7.6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11a223a14e91a4693d2d0755c7a043db43d96a7450b4f356d506c2562c48642c"}, + {file = "coverage-7.6.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a413a096c4cbac202433c850ee43fa326d2e871b24554da8327b01632673a076"}, + {file = "coverage-7.6.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00a1d69c112ff5149cabe60d2e2ee948752c975d95f1e1096742e6077affd376"}, + {file = "coverage-7.6.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f76846299ba5c54d12c91d776d9605ae33f8ae2b9d1d3c3703cf2db1a67f2c0"}, + {file = "coverage-7.6.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:fe439416eb6380de434886b00c859304338f8b19f6f54811984f3420a2e03858"}, + {file = "coverage-7.6.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0294ca37f1ba500667b1aef631e48d875ced93ad5e06fa665a3295bdd1d95111"}, + {file = "coverage-7.6.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6f01ba56b1c0e9d149f9ac85a2f999724895229eb36bd997b61e62999e9b0901"}, + {file = "coverage-7.6.4-cp39-cp39-win32.whl", hash = "sha256:bc66f0bf1d7730a17430a50163bb264ba9ded56739112368ba985ddaa9c3bd09"}, + {file = "coverage-7.6.4-cp39-cp39-win_amd64.whl", hash = "sha256:c481b47f6b5845064c65a7bc78bc0860e635a9b055af0df46fdf1c58cebf8e8f"}, + {file = "coverage-7.6.4-pp39.pp310-none-any.whl", hash = "sha256:3c65d37f3a9ebb703e710befdc489a38683a5b152242664b973a7b7b22348a4e"}, + {file = "coverage-7.6.4.tar.gz", hash = "sha256:29fc0f17b1d3fea332f8001d4558f8214af7f1d87a345f3a133c901d60347c73"}, ] [package.dependencies] @@ -568,6 +578,21 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "dill" +version = "0.3.9" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, + {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distlib" version = "0.3.8" @@ -1679,6 +1704,34 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] +[[package]] +name = "multiprocess" +version = "0.70.17" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "multiprocess-0.70.17-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7ddb24e5bcdb64e90ec5543a1f05a39463068b6d3b804aa3f2a4e16ec28562d6"}, + {file = "multiprocess-0.70.17-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d729f55198a3579f6879766a6d9b72b42d4b320c0dcb7844afb774d75b573c62"}, + {file = "multiprocess-0.70.17-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c2c82d0375baed8d8dd0d8c38eb87c5ae9c471f8e384ad203a36f095ee860f67"}, + {file = "multiprocess-0.70.17-pp38-pypy38_pp73-macosx_10_9_arm64.whl", hash = "sha256:a22a6b1a482b80eab53078418bb0f7025e4f7d93cc8e1f36481477a023884861"}, + {file = "multiprocess-0.70.17-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:349525099a0c9ac5936f0488b5ee73199098dac3ac899d81d326d238f9fd3ccd"}, + {file = "multiprocess-0.70.17-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:27b8409c02b5dd89d336107c101dfbd1530a2cd4fd425fc27dcb7adb6e0b47bf"}, + {file = "multiprocess-0.70.17-pp39-pypy39_pp73-macosx_10_13_arm64.whl", hash = "sha256:2ea0939b0f4760a16a548942c65c76ff5afd81fbf1083c56ae75e21faf92e426"}, + {file = "multiprocess-0.70.17-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:2b12e081df87ab755190e227341b2c3b17ee6587e9c82fecddcbe6aa812cd7f7"}, + {file = "multiprocess-0.70.17-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:a0f01cd9d079af7a8296f521dc03859d1a414d14c1e2b6e676ef789333421c95"}, + {file = "multiprocess-0.70.17-py310-none-any.whl", hash = "sha256:38357ca266b51a2e22841b755d9a91e4bb7b937979a54d411677111716c32744"}, + {file = "multiprocess-0.70.17-py311-none-any.whl", hash = "sha256:2884701445d0177aec5bd5f6ee0df296773e4fb65b11903b94c613fb46cfb7d1"}, + {file = "multiprocess-0.70.17-py312-none-any.whl", hash = "sha256:2818af14c52446b9617d1b0755fa70ca2f77c28b25ed97bdaa2c69a22c47b46c"}, + {file = "multiprocess-0.70.17-py313-none-any.whl", hash = "sha256:20c28ca19079a6c879258103a6d60b94d4ffe2d9da07dda93fb1c8bc6243f522"}, + {file = "multiprocess-0.70.17-py38-none-any.whl", hash = "sha256:1d52f068357acd1e5bbc670b273ef8f81d57863235d9fbf9314751886e141968"}, + {file = "multiprocess-0.70.17-py39-none-any.whl", hash = "sha256:c3feb874ba574fbccfb335980020c1ac631fbf2a3f7bee4e2042ede62558a021"}, + {file = "multiprocess-0.70.17.tar.gz", hash = "sha256:4ae2f11a3416809ebc9a48abfc8b14ecce0652a0944731a1493a3c1ba44ff57a"}, +] + +[package.dependencies] +dill = ">=0.3.9" + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -2901,6 +2954,30 @@ files = [ {file = "rpds_py-0.18.0.tar.gz", hash = "sha256:42821446ee7a76f5d9f71f9e33a4fb2ffd724bb3e7f93386150b61a43115788d"}, ] +[[package]] +name = "salib" +version = "1.5.1" +description = "Tools for global sensitivity analysis. Contains Sobol', Morris, FAST, DGSM, PAWN, HDMR, Moment Independent and fractional factorial methods" +optional = false +python-versions = ">=3.9" +files = [ + {file = "salib-1.5.1-py3-none-any.whl", hash = "sha256:a978b619c5a93eb14dd8c527f12e22d354b02f1f7143aba3cb84c1c7bc1382e5"}, + {file = "salib-1.5.1.tar.gz", hash = "sha256:e4a9c319b8dd02995a8dc983f57c452cb7e5b6dbd43e7b7856c90cb6a332bb5f"}, +] + +[package.dependencies] +matplotlib = ">=3.5" +multiprocess = "*" +numpy = ">=1.20.3" +pandas = ">=2.0" +scipy = ">=1.9.3" + +[package.extras] +dev = ["hatch", "myst-parser", "numpydoc", "pathos (>=0.3.2)", "pre-commit", "pydata-sphinx-theme (>=0.15.2)", "pytest", "pytest-cov", "sphinx"] +distributed = ["pathos (>=0.3.2)"] +doc = ["myst-parser", "numpydoc", "pydata-sphinx-theme (>=0.15.2)", "sphinx"] +test = ["pathos (>=0.3.2)", "pytest", "pytest-cov"] + [[package]] name = "scikit-learn" version = "1.4.1.post1" @@ -3920,4 +3997,4 @@ docs = [] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "4ca19f29dda8b99a25194547084293b75b74107e44e6b43663104cbf1ab18ba9" +content-hash = "7a6830c251b5aef183659d85f7eae56e380f37fed39a2bd6aa370a249233a3f5" diff --git a/pyproject.toml b/pyproject.toml index adc351c5..652012a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ iprogress = "^0.4" lightgbm = "^4.3.0" ipywidgets = "^8.1.2" gpytorch = "^1.12" +salib = "^1.5.1" [tool.poetry.group.dev.dependencies] @@ -41,7 +42,7 @@ black = "^23.10.1" pre-commit = "^3.5.0" jupyter-book = "^1.0.0" pytest-cov = "^4.1.0" -coverage = "^7.3.1" +coverage = "^7.6.4" plotnine = "^0.13.6" @@ -55,4 +56,8 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.coverage.run] -relative_files = true \ No newline at end of file +relative_files = true +source = [ + ".", + "/tmp" +] \ No newline at end of file diff --git a/tests/test_sensitivity_analysis.py b/tests/test_sensitivity_analysis.py new file mode 100644 index 00000000..db01adeb --- /dev/null +++ b/tests/test_sensitivity_analysis.py @@ -0,0 +1,206 @@ +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import make_regression + +from autoemulate.emulators import RandomForest +from autoemulate.experimental_design import LatinHypercube +from autoemulate.sensitivity_analysis import _calculate_layout +from autoemulate.sensitivity_analysis import _check_problem +from autoemulate.sensitivity_analysis import _generate_problem +from autoemulate.sensitivity_analysis import _get_output_names +from autoemulate.sensitivity_analysis import _validate_input +from autoemulate.sensitivity_analysis import sobol_analysis +from autoemulate.sensitivity_analysis import sobol_results_to_df +from autoemulate.simulations.projectile import simulate_projectile +from autoemulate.simulations.projectile import simulate_projectile_multioutput + + +@pytest.fixture +def Xy_1d(): + lhd = LatinHypercube([(-5.0, 1.0), (0.0, 1000.0)]) + X = lhd.sample(100) + y = np.array([simulate_projectile(x) for x in X]) + return X, y + + +@pytest.fixture +def Xy_2d(): + lhd = LatinHypercube([(-5.0, 1.0), (0.0, 1000.0)]) + X = lhd.sample(100) + y = np.array([simulate_projectile_multioutput(x) for x in X]) + return X, y + + +@pytest.fixture +def model_1d(Xy_1d): + X, y = Xy_1d + rf = RandomForest() + rf.fit(X, y) + return rf + + +@pytest.fixture +def model_2d(Xy_2d): + X, y = Xy_2d + rf = RandomForest() + rf.fit(X, y) + return rf + + +# test problem checking ---------------------------------------------------------------- +def test_check_problem(): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + } + problem = _check_problem(problem) + assert problem["num_vars"] == 2 + assert problem["names"] == ["c", "v0"] + assert problem["bounds"] == [(-5.0, 1.0), (0.0, 1000.0)] + + +def test_check_problem_invalid(): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + } + with pytest.raises(ValueError): + _check_problem(problem) + + +def test_check_problem_bounds(): + problem = {"num_vars": 2, "names": ["c", "v0"], "bounds": [(-5.0, 1.0)]} + with pytest.raises(ValueError): + _check_problem(problem) + + +# test output name retrieval -------------------------------------------------- +def test_get_output_names_default(): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + } + output_names = _get_output_names(problem, 1) + assert output_names == ["y1"] + + +def test_get_output_names_custom(): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + "output_names": ["lol1", "lol2"], + } + output_names = _get_output_names(problem, 2) + assert output_names == ["lol1", "lol2"] + + +def test_get_output_names_invalid(): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + "output_names": "lol", + } + with pytest.raises(ValueError): + _get_output_names(problem, 1) + + +# test Sobol analysis ------------------------------------------------------------ +def test_sobol_analysis(model_1d): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + } + + Si = sobol_analysis(model_1d, problem) + assert isinstance(Si, dict) + assert "y1" in Si + assert all( + key in Si["y1"] for key in ["S1", "S1_conf", "S2", "S2_conf", "ST", "ST_conf"] + ) + + +def test_sobol_analysis_2d(model_2d): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + } + Si = sobol_analysis(model_2d, problem) + assert isinstance(Si, dict) + assert ["y1", "y2"] == list(Si.keys()) + + +@pytest.fixture +def sobol_results_1d(model_1d): + problem = { + "num_vars": 2, + "names": ["c", "v0"], + "bounds": [(-5.0, 1.0), (0.0, 1000.0)], + } + return sobol_analysis(model_1d, problem) + + +# test conversion to DataFrame -------------------------------------------------- +def test_sobol_results_to_df(sobol_results_1d): + df = sobol_results_to_df(sobol_results_1d) + assert isinstance(df, pd.DataFrame) + assert df.columns.tolist() == [ + "output", + "parameter", + "index", + "value", + "confidence", + ] + assert ["X1", "X2", "X1-X2"] in df["parameter"].unique() + assert all(isinstance(x, float) for x in df["value"]) + assert all(isinstance(x, float) for x in df["confidence"]) + + +# test plotting ---------------------------------------------------------------- + + +# test _validate_input ---------------------------------------------------------- +def test_validate_input(sobol_results_1d): + with pytest.raises(ValueError): + _validate_input(sobol_results_1d, "S3") + + +def test_validate_input_valid(sobol_results_1d): + Si = _validate_input(sobol_results_1d, "S1") + assert isinstance(Si, pd.DataFrame) + + +# test _calculate_layout ------------------------------------------------------ +def test_calculate_layout(): + n_rows, n_cols = _calculate_layout(1) + assert n_rows == 1 + assert n_cols == 1 + + +def test_calculate_layout_3_outputs(): + n_rows, n_cols = _calculate_layout(3) + assert n_rows == 1 + assert n_cols == 3 + + +def test_calculate_layout_custom(): + n_rows, n_cols = _calculate_layout(3, 2) + assert n_rows == 2 + assert n_cols == 2 + + +# test _generate_problem ----------------------------------------------------- + + +def test_generate_problem(): + X = np.array([[0, 0], [1, 1], [2, 2]]) + problem = _generate_problem(X) + assert problem["num_vars"] == 2 + assert problem["names"] == ["x1", "x2"] + assert problem["bounds"] == [[0, 2], [0, 2]]