Skip to content

Commit

Permalink
Implement adaptive localization
Browse files Browse the repository at this point in the history
Add option of running adaptive localization that can simply
be turned on and does not need any user input.
Only parameters that are significantly correlated to responses
will be updated.
Default value of what constitutes significant correlation is calculated
based on theory, but can be set by the user.

Add tests for adaptive localization with threshold 0.0 and 1.0

Compute cross-correlation matrices without matrix inversion

Co-authored-by: Berent Å. S. Lunde <[email protected]>
Co-authored-by: Anna Kvashchuk <[email protected]>
  • Loading branch information
3 people committed Oct 18, 2023
1 parent 09aadf9 commit 5f4dec1
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 43 deletions.
132 changes: 114 additions & 18 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from iterative_ensemble_smoother.experimental import (
ensemble_smoother_update_step_row_scaling,
)
from tqdm import tqdm

from ert.config import Field, GenKwConfig, SurfaceConfig
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -368,6 +369,12 @@ def _load_observations_and_responses(
)


def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
return np.array_split(arr, int((arr.shape[0] / batch_size)) + 1)


def analysis_ES(
updatestep: UpdateConfiguration,
rng: np.random.Generator,
Expand Down Expand Up @@ -413,21 +420,17 @@ def analysis_ES(

# pylint: disable=unsupported-assignment-operation
smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot
if len(observation_values) == 0:

num_obs = len(observation_values)
if num_obs == 0:
raise ErtAnalysisError(
f"No active observations for update step: {update_step.name}."
)
noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))

smoother = ies.ES()
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=module.get_truncation(),
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
truncation = module.get_truncation()
noise = rng.standard_normal(size=(num_obs, ensemble_size))

for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
Expand All @@ -437,15 +440,96 @@ def analysis_ES(
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
progress_callback(Progress(Task("Updating data", 2, 3), None))
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] = smoother.update(
temp_storage[param_group.name][active_indices, :]
if module.localization():
Y_prime = S - S.mean(axis=1, keepdims=True)
C_YY = Y_prime @ Y_prime.T / (ensemble_size - 1)
Sigma_Y = np.std(S, axis=1, ddof=1)
batch_size: int = 1000
correlation_threshold = module.localization_correlation_threshold(
ensemble_size
)
# for parameter in update_step.parameters:
num_params = temp_storage[param_group.name].shape[0]

print(
(
f"Running localization on {num_params} parameters,",
f"{num_obs} responses and {ensemble_size} realizations...",
)
)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)
for param_batch_idx in tqdm(batches):
X_local = temp_storage[param_group.name][param_batch_idx, :]
# Parameter standard deviations
Sigma_A = np.std(X_local, axis=1, ddof=1)
# Cross-covariance between parameters and measurements
A = X_local - X_local.mean(axis=1, keepdims=True)
C_AY = A @ Y_prime.T / (ensemble_size - 1)
# Cross-correlation between parameters and measurements
c_AY = np.abs(
(C_AY / Sigma_Y.reshape(1, -1)) / Sigma_A.reshape(-1, 1)
)
# Absolute values of the correlation matrix
c_bool = c_AY > correlation_threshold
# Some parameters might be significantly correlated
# to the exact same responses.
# We want to call the update only once per such parameter group
# to speed up computation.
# Here we create a collection of unique sets of parameter-to-observation
# correlations.
param_correlation_sets: npt.NDArray[np.bool_] = np.unique(
c_bool, axis=0
)
# Drop the correlation set that does not correlate to any responses.
row_with_all_false = np.all(~param_correlation_sets, axis=1)
param_correlation_sets = param_correlation_sets[~row_with_all_false]

for param_correlation_set in param_correlation_sets:
# Find the rows matching the parameter group
matching_rows = np.all(c_bool == param_correlation_set, axis=1)
# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]
X_chunk = temp_storage[param_group.name][param_batch_idx, :][
row_indices, :
]
S_chunk = S[param_correlation_set, :]
observation_errors_loc = observation_errors[
param_correlation_set
]
observation_values_loc = observation_values[
param_correlation_set
]
smoother.fit(
S_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[param_correlation_set],
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
temp_storage[param_group.name][
param_batch_idx[row_indices], :
] = smoother.update(X_chunk)
else:
temp_storage[param_group.name] = smoother.update(
temp_storage[param_group.name]
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] = smoother.update(
temp_storage[param_group.name][active_indices, :]
)
else:
temp_storage[param_group.name] = smoother.update(
temp_storage[param_group.name]
)

if params_with_row_scaling := _get_params_with_row_scaling(
temp_storage, update_step.row_scaling_parameters
):
Expand All @@ -461,7 +545,19 @@ def analysis_ES(
for row_scaling_parameter, (A, _) in zip(
update_step.row_scaling_parameters, params_with_row_scaling
):
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)
params_with_row_scaling = ensemble_smoother_update_step_row_scaling(
S,
params_with_row_scaling,
observation_errors,
observation_values,
noise,
module.get_truncation(),
ies.InversionType(module.inversion),
)
for row_scaling_parameter, (A, _) in zip(
update_step.row_scaling_parameters, params_with_row_scaling
):
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)

progress_callback(Progress(Task("Storing data", 3, 3), None))
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
Expand Down
124 changes: 100 additions & 24 deletions src/ert/config/analysis_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import sys
from typing import TYPE_CHECKING, Dict, List, Type, TypedDict, Union

Expand Down Expand Up @@ -33,6 +34,23 @@ class VariableInfo(TypedDict):
DEFAULT_IES_DEC_STEPLENGTH = 2.50
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False
# Default threshold is a function of ensemble size which is not available here.
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = -1


def correlation_threshold(ensemble_size: int, user_defined_threshold: float) -> float:
"""Decides whether or not to use user-defined or default threshold.
Default threshold taken from luo2022,
Continuous Hyper-parameter OPtimization (CHOP) in an ensemble Kalman filter
Section 2.3 - Localization in the CHOP problem
"""
default_threshold = 3 / math.sqrt(ensemble_size)
if user_defined_threshold == -1:
return default_threshold

return user_defined_threshold


class AnalysisMode(StrEnum):
Expand All @@ -58,6 +76,22 @@ def get_mode_variables(mode: AnalysisMode) -> Dict[str, "VariableInfo"]:
"step": 0.01,
"labelname": "Singular value truncation",
},
"LOCALIZATION": {
"type": bool,
"min": 0.0,
"value": DEFAULT_LOCALIZATION,
"max": 1.0,
"step": 1.0,
"labelname": "Adaptive localization",
},
"LOCALIZATION_CORRELATION_THRESHOLD": {
"type": float,
"min": 0.0,
"value": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD,
"max": 1.0,
"step": 0.1,
"labelname": "Adaptive localization correlation threshold",
},
}
ies_variables: Dict[str, "VariableInfo"] = {
"IES_MAX_STEPLENGTH": {
Expand Down Expand Up @@ -169,31 +203,47 @@ def set_var(self, var_name: str, value: Union[float, int, bool, str]) -> None:
self.handle_special_key_set(var_name, value)
elif var_name in self._variables:
var = self._variables[var_name]
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"

if var["type"] is not bool:
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
)
else:
var["value"] = new_value

except ValueError as e:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has "
f"incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
) from e
else:
if not isinstance(var["value"], bool):
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
else:
var["value"] = new_value

except ValueError as e:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
) from e
# When config is first read, `value` is a string
# that's either "False" or "True",
# but since bool("False") is True we need to convert it to bool.
if not isinstance(value, bool):
value = str(value).lower() != "false"

var["value"] = var["type"](value)
else:
raise ConfigValidationError(
f"Variable {var_name!r} not found in {self.name!r} analysis module"
Expand All @@ -210,6 +260,32 @@ def inversion(self, value: int) -> None:
def get_truncation(self) -> float:
return self.get_variable_value("ENKF_TRUNCATION")

def localization(self) -> bool:
return bool(self.get_variable_value("LOCALIZATION"))

def localization_correlation_threshold(self, ensemble_size: int) -> float:
return correlation_threshold(
ensemble_size, self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")
)

def get_steplength(self, iteration_nr: int) -> float:
"""
This is an implementation of Eq. (49), which calculates a suitable
step length for the update step, from the book:
Geir Evensen, Formulating the history matching problem with
consistent error statistics, Computational Geosciences (2021) 25:945 –970
Function not really used moved from C to keep the class interface consistent
should be investigated for possible removal.
"""
min_step_length = self.get_variable_value("IES_MIN_STEPLENGTH")
max_step_length = self.get_variable_value("IES_MAX_STEPLENGTH")
dec_step_length = self.get_variable_value("IES_DEC_STEPLENGTH")
step_length = min_step_length + (max_step_length - min_step_length) * pow(
2, -(iteration_nr - 1) / (dec_step_length - 1)
)
return step_length

def __repr__(self) -> str:
return f"AnalysisModule(name = {self.name})"

Expand Down
19 changes: 19 additions & 0 deletions src/ert/gui/ertwidgets/analysismodulevariablespanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
QWidget,
)

from ert.config.analysis_module import correlation_threshold
from ert.gui.ertwidgets.models.analysismodulevariablesmodel import (
AnalysisModuleVariablesModel,
)
Expand Down Expand Up @@ -41,10 +42,16 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
variable_type = analysis_module_variables_model.getVariableType(
variable_name
)

variable_value = analysis_module_variables_model.getVariableValue(
self.facade, self._analysis_module_name, variable_name
)

if variable_name == "LOCALIZATION_CORRELATION_THRESHOLD":
variable_value = correlation_threshold(
self.facade.get_ensemble_size(), variable_value
)

label_name = analysis_module_variables_model.getVariableLabelName(
variable_name
)
Expand Down Expand Up @@ -123,6 +130,17 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
lambda value: self.update_truncation_spinners(value, truncation_spinner)
)

localization_checkbox = self.widget_from_layout(layout, "LOCALIZATION")
localization_correlation_spinner = self.widget_from_layout(
layout, "LOCALIZATION_CORRELATION_THRESHOLD"
)
localization_correlation_spinner.setEnabled(localization_checkbox.isChecked())
localization_checkbox.stateChanged.connect(
lambda localization_is_on: localization_correlation_spinner.setEnabled(True)
if localization_is_on
else localization_correlation_spinner.setEnabled(False)
)

self.setLayout(layout)
self.blockSignals(False)

Expand Down Expand Up @@ -172,6 +190,7 @@ def createSpinBox(
def createCheckBox(self, variable_name, variable_value, variable_type):
spinner = QCheckBox()
spinner.setChecked(variable_value)
spinner.setObjectName(variable_name)
spinner.clicked.connect(
partial(self.valueChanged, variable_name, variable_type, spinner)
)
Expand Down
Loading

0 comments on commit 5f4dec1

Please sign in to comment.