Skip to content

Commit

Permalink
Commitin
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Nov 1, 2024
1 parent 40c8338 commit edef4be
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 40 deletions.
49 changes: 33 additions & 16 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,24 @@ def correlation_callback(
param_ensemble_array = _load_param_ensemble_array(
source_ensemble, param_group, iens_active_index
)

# Calculate variance for each parameter
param_variance = np.var(param_ensemble_array, axis=1)
# Create mask for non-zero variance parameters
non_zero_variance_mask = ~np.isclose(param_variance, 0.0)

log_msg = f"Updating {np.sum(non_zero_variance_mask)} parameters {'with' if module.localization else 'without'} adaptive localization."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

log_msg = f"There are {num_obs} responses and {ensemble_size} realizations."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

log_msg = f"There are {(~non_zero_variance_mask).sum()} parameters with 0 variance that will not be updated."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

if module.localization:
config_node = source_ensemble.experiment.parameter_configuration[
param_group
Expand All @@ -618,32 +636,31 @@ def correlation_callback(
batch_size = _calculate_adaptive_batch_size(num_params, num_obs)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)

log_msg = f"Running localization on {num_params} parameters, {num_obs} responses, {ensemble_size} realizations and {len(batches)} batches"
log_msg = f"Adaptive localization has split parameters into {len(batches)} batch{'es' if len(batches) != 1 else ''}."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

start = time.time()
cross_correlations: List[npt.NDArray[np.float64]] = []
for param_batch_idx in batches:
X_local = param_ensemble_array[param_batch_idx, :]
update_idx = param_batch_idx[non_zero_variance_mask[param_batch_idx]]
X_local = param_ensemble_array[update_idx, :]
if isinstance(config_node, GenKwConfig):
correlation_batch_callback = functools.partial(
correlation_callback,
cross_correlations_accumulator=cross_correlations,
)
else:
correlation_batch_callback = None
param_ensemble_array[param_batch_idx, :] = (
smoother_adaptive_es.assimilate(
X=X_local,
Y=S,
D=D,
alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage)
correlation_threshold=module.correlation_threshold,
cov_YY=cov_YY,
progress_callback=adaptive_localization_progress_callback,
correlation_callback=correlation_batch_callback,
)
param_ensemble_array[update_idx, :] = smoother_adaptive_es.assimilate(
X=X_local,
Y=S,
D=D,
alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage)
correlation_threshold=module.correlation_threshold,
cov_YY=cov_YY,
progress_callback=adaptive_localization_progress_callback,
correlation_callback=correlation_batch_callback,
)

if cross_correlations:
Expand All @@ -665,9 +682,9 @@ def correlation_callback(

else:
# In-place multiplication is not yet supported, therefore avoiding @=
param_ensemble_array = param_ensemble_array @ T.astype( # noqa: PLR6104
param_ensemble_array.dtype
)
param_ensemble_array[non_zero_variance_mask] = param_ensemble_array[ # noqa: PLR6104
non_zero_variance_mask
] @ T.astype(param_ensemble_array.dtype)

log_msg = f"Storing data for {param_group}.."
logger.info(log_msg)
Expand Down
46 changes: 22 additions & 24 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from uuid import UUID

import numpy as np
Expand Down Expand Up @@ -788,51 +788,49 @@ def load_all_gen_kw_data(
def save_parameters(
self,
group: str,
realization: int,
realization: int | Sequence[int],
dataset: xr.Dataset,
) -> None:
"""
Saves the provided dataset under a parameter group and realization index
Saves the provided dataset under a parameter group and realization index(es)
Parameters
----------
group : str
Parameter group name for saving dataset.
realization : int
Realization index for saving group.
realization : int or sequence of int
Realization index(es) for saving group.
dataset : Dataset
Dataset to save. It must contain a variable named 'values'
which will be used when flattening out the parameters into
a 1d-vector.
a 1d-vector. When saving multiple realizations, dataset must
have a 'realizations' dimension.
"""

if "values" not in dataset.variables:
raise ValueError(
f"Dataset for parameter group '{group}' "
f"must contain a 'values' variable"
f"Dataset for parameter group '{group}' must contain a 'values' variable"
)

if dataset["values"].size == 0:
raise ValueError(
f"Parameters {group} are empty. Cannot proceed with saving to storage."
)

if dataset["values"].ndim >= 2 and dataset["values"].values.dtype == "float64":
logger.warning(
"Dataset uses 'float64' for fields/surfaces. Use 'float32' to save memory."
)

if group not in self.experiment.parameter_configuration:
raise ValueError(f"{group} is not registered to the experiment.")

path = self._realization_dir(realization) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
realizations = [realization] if np.isscalar(realization) else realization

self._storage._to_netcdf_transaction(
path, dataset.expand_dims(realizations=[realization])
)
if len(realizations) > 1 and "realizations" not in dataset.dims:
raise ValueError(
"Dataset must have 'realizations' dimension when saving multiple realizations"
)

for real in realizations:
path = self._realization_dir(real) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
if "realizations" in dataset.dims:
data_to_save = dataset.sel(realizations=real)
else:
data_to_save = dataset.expand_dims(realizations=[real])
self._storage._to_netcdf_transaction(path, data_to_save)

@require_write
def save_response(
Expand Down
91 changes: 91 additions & 0 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

import numpy as np
import numpy.testing
import polars as pl
import resfo
import xtgeo

from ert.analysis import (
smoother_update,
)
from ert.config import ErtConfig
from ert.config.analysis_config import UpdateSettings
from ert.config.analysis_module import ESSettings
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
from ert.storage import open_storage

Expand Down Expand Up @@ -210,3 +216,88 @@ def test_parameter_update_with_inactive_cells_xtgeo_grdecl(tmpdir):
assert "nan" not in Path(
"simulations/realization-0/iter-1/my_param.grdecl"
).read_text(encoding="utf-8")


import warnings


def test_field_param_update_using_heat_equation_and_zero_var_params(
heat_equation_storage,
):
config = ErtConfig.from_file("config.ert")
with open_storage(config.ens_path, mode="w") as storage:
experiment = storage.get_experiment_by_name("es-mda")
prior = experiment.get_ensemble_by_name("default_0")
# posterior = experiment.get_ensemble_by_name("default_1")
cond = prior.load_parameters("COND")

new_experiment = storage.create_experiment(
parameters=config.ensemble_config.parameter_configuration,
responses=config.ensemble_config.response_configuration,
observations=config.observations,
name="exp-zero-var",
)
new_prior = storage.create_ensemble(
new_experiment,
ensemble_size=prior.ensemble_size,
iteration=0,
name="prior-zero-var",
)
cond["values"][:, :, :5, 0] = 1.0
new_prior.save_parameters("COND", range(prior.ensemble_size), cond)

responses = prior.load_responses("gen_data", range(prior.ensemble_size))
for realization in range(prior.ensemble_size):
df = responses.filter(pl.col("realization") == realization)
new_prior.save_response("gen_data", df, realization)

new_posterior = storage.create_ensemble(
new_experiment,
ensemble_size=config.model_config.num_realizations,
iteration=1,
name="new_ensemble",
prior_ensemble=new_prior,
)

with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always") # Ensure all warnings are always recorded
smoother_update(
new_prior,
new_posterior,
experiment.observation_keys,
config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(localization=True),
)

# Check that our specific warnings are not in the recorded warnings
for warning in record:
assert not (
isinstance(warning.message, RuntimeWarning)
and "divide by zero encountered in divide" in str(warning.message)
)
assert not (
isinstance(warning.message, UserWarning)
and "Cross-correlation matrix has entries not in [-1, 1]"
in str(warning.message)
)

# param_config = config.ensemble_config.parameter_configs["COND"]
# assert len(prior_result.x) == param_config.nx
# assert len(prior_result.y) == param_config.ny
# assert len(prior_result.z) == param_config.nz

# posterior_result = posterior.load_parameters("COND")["values"]
# prior_covariance = np.cov(
# prior_result.values.reshape(
# prior.ensemble_size, param_config.nx * param_config.ny * param_config.nz
# )
# )
# posterior_covariance = np.cov(
# posterior_result.values.reshape(
# posterior.ensemble_size,
# param_config.nx * param_config.ny * param_config.nz,
# )
# )
# # Check that generalized variance is reduced by update step.
# assert np.trace(prior_covariance) > np.trace(posterior_covariance)

0 comments on commit edef4be

Please sign in to comment.