Skip to content

Commit

Permalink
Commitin
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Oct 30, 2024
1 parent 28763a6 commit 0723ff1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
49 changes: 33 additions & 16 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,24 @@ def correlation_callback(
param_ensemble_array = _load_param_ensemble_array(
source_ensemble, param_group, iens_active_index
)

# Calculate variance for each parameter
param_variance = np.var(param_ensemble_array, axis=1)
# Create mask for non-zero variance parameters
non_zero_variance_mask = ~np.isclose(param_variance, 0.0)

log_msg = f"Updating {np.sum(non_zero_variance_mask)} parameters {'with' if module.localization else 'without'} adaptive localization."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

log_msg = f"There are {num_obs} responses and {ensemble_size} realizations."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

log_msg = f"There are {(~non_zero_variance_mask).sum()} parameters with 0 variance that will not be updated."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

if module.localization:
config_node = source_ensemble.experiment.parameter_configuration[
param_group
Expand All @@ -618,32 +636,31 @@ def correlation_callback(
batch_size = _calculate_adaptive_batch_size(num_params, num_obs)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)

log_msg = f"Running localization on {num_params} parameters, {num_obs} responses, {ensemble_size} realizations and {len(batches)} batches"
log_msg = f"Adaptive localization has split parameters into {len(batches)} batch{'es' if len(batches) != 1 else ''}."
logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))

start = time.time()
cross_correlations: List[npt.NDArray[np.float64]] = []
for param_batch_idx in batches:
X_local = param_ensemble_array[param_batch_idx, :]
update_idx = param_batch_idx[non_zero_variance_mask[param_batch_idx]]
X_local = param_ensemble_array[update_idx, :]
if isinstance(config_node, GenKwConfig):
correlation_batch_callback = functools.partial(
correlation_callback,
cross_correlations_accumulator=cross_correlations,
)
else:
correlation_batch_callback = None
param_ensemble_array[param_batch_idx, :] = (
smoother_adaptive_es.assimilate(
X=X_local,
Y=S,
D=D,
alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage)
correlation_threshold=module.correlation_threshold,
cov_YY=cov_YY,
progress_callback=adaptive_localization_progress_callback,
correlation_callback=correlation_batch_callback,
)
param_ensemble_array[update_idx, :] = smoother_adaptive_es.assimilate(
X=X_local,
Y=S,
D=D,
alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage)
correlation_threshold=module.correlation_threshold,
cov_YY=cov_YY,
progress_callback=adaptive_localization_progress_callback,
correlation_callback=correlation_batch_callback,
)

if cross_correlations:
Expand All @@ -665,9 +682,9 @@ def correlation_callback(

else:
# In-place multiplication is not yet supported, therefore avoiding @=
param_ensemble_array = param_ensemble_array @ T.astype( # noqa: PLR6104
param_ensemble_array.dtype
)
param_ensemble_array[non_zero_variance_mask] = param_ensemble_array[ # noqa: PLR6104
non_zero_variance_mask
] @ T.astype(param_ensemble_array.dtype)

log_msg = f"Storing data for {param_group}.."
logger.info(log_msg)
Expand Down
2 changes: 2 additions & 0 deletions test-data/ert/heat_equation/heat_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def sample_prior_conductivity(ensemble_size, nx, rng):
rng = np.random.default_rng(iens)
cond = sample_prior_conductivity(ensemble_size=1, nx=nx, rng=rng).reshape(nx, nx)

cond[nx // 2 :, :] = 1.0

if iteration == 0:
resfo.write(
"cond.bgrdecl", [("COND ", cond.flatten(order="F").astype(np.float32))]
Expand Down
4 changes: 4 additions & 0 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,7 @@ def test_parameter_update_with_inactive_cells_xtgeo_grdecl(tmpdir):
assert "nan" not in Path(
"simulations/realization-0/iter-1/my_param.grdecl"
).read_text(encoding="utf-8")


# def test_field_parameter_update_using_heat_equation(heat_equation_storage):
# print("Hei")

0 comments on commit 0723ff1

Please sign in to comment.