Skip to content

Commit

Permalink
Extend save parameters to handle multiple realizations
Browse files Browse the repository at this point in the history
Add test that uses the new functionality and also documents
some troublesome behavior of adaptive localization.
  • Loading branch information
dafeda committed Nov 1, 2024
1 parent 40c8338 commit f28d949
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 24 deletions.
46 changes: 22 additions & 24 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from uuid import UUID

import numpy as np
Expand Down Expand Up @@ -788,51 +788,49 @@ def load_all_gen_kw_data(
def save_parameters(
self,
group: str,
realization: int,
realization: int | Sequence[int],
dataset: xr.Dataset,
) -> None:
"""
Saves the provided dataset under a parameter group and realization index
Saves the provided dataset under a parameter group and realization index(es)
Parameters
----------
group : str
Parameter group name for saving dataset.
realization : int
Realization index for saving group.
realization : int or sequence of int
Realization index(es) for saving group.
dataset : Dataset
Dataset to save. It must contain a variable named 'values'
which will be used when flattening out the parameters into
a 1d-vector.
a 1d-vector. When saving multiple realizations, dataset must
have a 'realizations' dimension.
"""

if "values" not in dataset.variables:
raise ValueError(
f"Dataset for parameter group '{group}' "
f"must contain a 'values' variable"
f"Dataset for parameter group '{group}' must contain a 'values' variable"
)

if dataset["values"].size == 0:
raise ValueError(
f"Parameters {group} are empty. Cannot proceed with saving to storage."
)

if dataset["values"].ndim >= 2 and dataset["values"].values.dtype == "float64":
logger.warning(
"Dataset uses 'float64' for fields/surfaces. Use 'float32' to save memory."
)

if group not in self.experiment.parameter_configuration:
raise ValueError(f"{group} is not registered to the experiment.")

path = self._realization_dir(realization) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
realizations = [realization] if np.isscalar(realization) else realization

self._storage._to_netcdf_transaction(
path, dataset.expand_dims(realizations=[realization])
)
if len(realizations) > 1 and "realizations" not in dataset.dims:
raise ValueError(
"Dataset must have 'realizations' dimension when saving multiple realizations"
)

for real in realizations:
path = self._realization_dir(real) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
if "realizations" in dataset.dims:
data_to_save = dataset.sel(realizations=real)
else:
data_to_save = dataset.expand_dims(realizations=[real])
self._storage._to_netcdf_transaction(path, data_to_save)

@require_write
def save_response(
Expand Down
116 changes: 116 additions & 0 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import os
import stat
import warnings
from pathlib import Path
from textwrap import dedent

import numpy as np
import numpy.testing
import polars as pl
import resfo
import xtgeo

from ert.analysis import (
smoother_update,
)
from ert.config import ErtConfig
from ert.config.analysis_config import UpdateSettings
from ert.config.analysis_module import ESSettings
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
from ert.storage import open_storage

Expand Down Expand Up @@ -210,3 +217,112 @@ def test_parameter_update_with_inactive_cells_xtgeo_grdecl(tmpdir):
assert "nan" not in Path(
"simulations/realization-0/iter-1/my_param.grdecl"
).read_text(encoding="utf-8")


def test_field_param_update_using_heat_equation__zero_var_params_and_adaptive_loc(
heat_equation_storage,
):
"""Test field parameter updates with zero-variance regions and adaptive localization.
This test verifies the behavior of the ensemble smoother update when dealing with
field parameters that contain regions of zero variance (constant values across all
realizations). Such scenarios have been reported to cause performance issues and
numerical instabilities.
Specifically, this test:
1. Creates a field where the first 5 layers are set to constant values (1.0)
2. Performs a smoother update with adaptive localization
3. Verifies expected numerical warnings are raised due to zero variance
4. Confirms the update still reduces overall parameter uncertainty
The test documents known limitations with adaptive localization when handling
zero-variance regions, particularly:
- Runtime degradation
- Numerical warnings from division by zero
- Cross-correlation matrix instabilities
"""
config = ErtConfig.from_file("config.ert")
with open_storage(config.ens_path, mode="w") as storage:
experiment = storage.get_experiment_by_name("es-mda")
prior = experiment.get_ensemble_by_name("default_0")
cond = prior.load_parameters("COND")

new_experiment = storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
responses=config.ensemble_config.response_configuration,
observations=config.observations,
name="exp-zero-var",
)
new_prior = storage.create_ensemble(
new_experiment,
ensemble_size=prior.ensemble_size,
iteration=0,
name="prior-zero-var",
)
cond["values"][:, :, :5, 0] = 1.0
new_prior.save_parameters("COND", range(prior.ensemble_size), cond)

responses = prior.load_responses("gen_data", range(prior.ensemble_size))
for realization in range(prior.ensemble_size):
df = responses.filter(pl.col("realization") == realization)
new_prior.save_response("gen_data", df, realization)

new_posterior = storage.create_ensemble(
new_experiment,
ensemble_size=config.model_config.num_realizations,
iteration=1,
name="new_ensemble",
prior_ensemble=new_prior,
)

smoother_update(
new_prior,
new_posterior,
experiment.observation_keys,
config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(localization=True),
)

with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always") # Ensure all warnings are always recorded
smoother_update(
new_prior,
new_posterior,
experiment.observation_keys,
config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(localization=True),
)

warning_messages = [(w.category, str(w.message)) for w in record]

# Check that each required warning appears at least once
assert any(
issubclass(w[0], RuntimeWarning)
and "divide by zero encountered in divide" in w[1]
for w in warning_messages
)
assert any(
issubclass(w[0], UserWarning)
and "Cross-correlation matrix has entries not in [-1, 1]" in w[1]
for w in warning_messages
)

param_config = config.ensemble_config.parameter_configs["COND"]
prior_result = new_prior.load_parameters("COND")["values"]
posterior_result = new_posterior.load_parameters("COND")["values"]
prior_covariance = np.cov(
prior_result.values.reshape(
new_prior.ensemble_size,
param_config.nx * param_config.ny * param_config.nz,
).T
)
posterior_covariance = np.cov(
posterior_result.values.reshape(
new_posterior.ensemble_size,
param_config.nx * param_config.ny * param_config.nz,
).T
)
# Check that generalized variance is reduced by update step.
assert np.trace(prior_covariance) > np.trace(posterior_covariance)

0 comments on commit f28d949

Please sign in to comment.