Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sensitivity analysis #260

Merged
merged 12 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

# Clean python cache files
- name: Clean cache files
run: |
find . -type f -name "*.pyc" -delete
find . -type d -name "__pycache__" -delete
rm -f .coverage* coverage.xml
# Cache Poetry dependencies
# - name: Cache dependencies
# uses: actions/cache@v2
Expand Down Expand Up @@ -56,10 +62,15 @@ jobs:
run: |
poetry run python -c "import torch; print(torch.__version__); print('CUDA available:', torch.cuda.is_available())"

- name: Debug Coverage Config
run: |
cat pyproject.toml
poetry run coverage debug config

- name: Run Tests with Coverage
run: |
poetry run coverage run -m pytest
poetry run coverage xml -o coverage.xml
poetry run coverage run --source=. -m pytest
poetry run coverage xml -i -o coverage.xml
env:
COVERAGE_FILE: ".coverage.${{ matrix.python-version }}"

Expand Down
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,8 @@ Thumbs.db
*.DS_Store

# VScode settings
.vscode/
.vscode/

# Quarto
README.html
README_files/
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ There's currently a lot of development, so we recommend installing the most curr
pip install git+https://github.com/alan-turing-institute/autoemulate.git
```

There's also a release available on PyPI (will not contain the most recent features and models)
There's also a release available on PyPI (note: currently and older version and out of date with the documentation)
```bash
pip install autoemulate
```
Expand All @@ -47,19 +47,27 @@ from autoemulate.simulations.projectile import simulate_projectile
lhd = LatinHypercube([(-5., 1.), (0., 1000.)])
X = lhd.sample(100)
y = np.array([simulate_projectile(x) for x in X])

# compare emulator models
ae = AutoEmulate()
ae.setup(X, y)
best_model = ae.compare()
best_emulator = ae.compare()

# training set cross-validation results
ae.summarise_cv()
ae.plot_cv()

# test set results for the best model
ae.evaluate(best_model)
ae.plot_eval(best_model)
ae.evaluate(best_emulator)
ae.plot_eval(best_emulator)

# refit on full data and emulate!
best_model = ae.refit(best_model)
best_model.predict(X)
emulator = ae.refit(best_emulator)
emulator.predict(X)

# global sensitivity analysis
si = ae.sensitivity_analysis(emulator)
ae.plot_sensitivity_analysis(si)
```

## documentation
Expand Down
64 changes: 64 additions & 0 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from autoemulate.plotting import _plot_model
from autoemulate.printing import _print_setup
from autoemulate.save import ModelSerialiser
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
from autoemulate.sensitivity_analysis import sensitivity_analysis
from autoemulate.utils import _ensure_2d
from autoemulate.utils import _get_full_model_name
from autoemulate.utils import _redirect_warnings
from autoemulate.utils import get_model_name
Expand Down Expand Up @@ -522,3 +525,64 @@ def plot_eval(
)

return fig

def sensitivity_analysis(
self, model=None, problem=None, N=1024, conf_level=0.95, as_df=True
):
"""Perform Sobol sensitivity analysis on a fitted emulator.

Parameters
----------
model : object, optional
Fitted model. If None, uses the best model from cross-validation.
problem : dict, optional
The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
If None, the problem is generated from X using minimum and maximum values of the features as bounds.

Example:
```python
problem = {
"num_vars": 2,
"names": ["x1", "x2"],
"bounds": [[0, 1], [0, 1]],
}
```
N : int, optional
Number of samples to generate. Default is 1024.
conf_level : float, optional
Confidence level for the confidence intervals. Default is 0.95.
as_df : bool, optional
If True, return a long-format pandas DataFrame (default is True).
"""
if model is None:
if not hasattr(self, "best_model"):
raise RuntimeError("Must run compare() before sensitivity_analysis()")
model = self.refit(self.best_model)
self.logger.info(
f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data."
)

Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
return Si

def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None):
"""
Plot the sensitivity analysis results.

Parameters:
-----------
results : pd.DataFrame
The results from sobol_results_to_df.
index : str, default "S1"
The type of sensitivity index to plot.
- "S1": first-order indices
- "S2": second-order/interaction indices
- "ST": total-order indices
n_cols : int, optional
The number of columns in the plot. Defaults to 3 if there are 3 or more outputs,
otherwise the number of outputs.
figsize : tuple, optional
Figure size as (width, height) in inches.If None, automatically calculated.

"""
return plot_sensitivity_analysis(results, index, n_cols, figsize)
Loading
Loading