diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 8915fc02f0e..59569ebb67d 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -610,6 +610,24 @@ def correlation_callback( param_ensemble_array = _load_param_ensemble_array( source_ensemble, param_group, iens_active_index ) + + # Calculate variance for each parameter + param_variance = np.var(param_ensemble_array, axis=1) + # Create mask for non-zero variance parameters + non_zero_variance_mask = ~np.isclose(param_variance, 0.0) + + log_msg = f"Updating {np.sum(non_zero_variance_mask)} parameters {'with' if module.localization else 'without'} adaptive localization." + logger.info(log_msg) + progress_callback(AnalysisStatusEvent(msg=log_msg)) + + log_msg = f"There are {num_obs} responses and {ensemble_size} realizations." + logger.info(log_msg) + progress_callback(AnalysisStatusEvent(msg=log_msg)) + + log_msg = f"There are {(~non_zero_variance_mask).sum()} parameters with 0 variance that will not be updated." + logger.info(log_msg) + progress_callback(AnalysisStatusEvent(msg=log_msg)) + if module.localization: config_node = source_ensemble.experiment.parameter_configuration[ param_group @@ -618,14 +636,15 @@ def correlation_callback( batch_size = _calculate_adaptive_batch_size(num_params, num_obs) batches = _split_by_batchsize(np.arange(0, num_params), batch_size) - log_msg = f"Running localization on {num_params} parameters, {num_obs} responses, {ensemble_size} realizations and {len(batches)} batches" + log_msg = f"Adaptive localization has split parameters into {len(batches)} batch{'es' if len(batches) != 1 else ''}." logger.info(log_msg) progress_callback(AnalysisStatusEvent(msg=log_msg)) start = time.time() cross_correlations: List[npt.NDArray[np.float64]] = [] for param_batch_idx in batches: - X_local = param_ensemble_array[param_batch_idx, :] + update_idx = param_batch_idx[non_zero_variance_mask[param_batch_idx]] + X_local = param_ensemble_array[update_idx, :] if isinstance(config_node, GenKwConfig): correlation_batch_callback = functools.partial( correlation_callback, @@ -633,17 +652,15 @@ def correlation_callback( ) else: correlation_batch_callback = None - param_ensemble_array[param_batch_idx, :] = ( - smoother_adaptive_es.assimilate( - X=X_local, - Y=S, - D=D, - alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage) - correlation_threshold=module.correlation_threshold, - cov_YY=cov_YY, - progress_callback=adaptive_localization_progress_callback, - correlation_callback=correlation_batch_callback, - ) + param_ensemble_array[update_idx, :] = smoother_adaptive_es.assimilate( + X=X_local, + Y=S, + D=D, + alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage) + correlation_threshold=module.correlation_threshold, + cov_YY=cov_YY, + progress_callback=adaptive_localization_progress_callback, + correlation_callback=correlation_batch_callback, ) if cross_correlations: @@ -665,9 +682,9 @@ def correlation_callback( else: # In-place multiplication is not yet supported, therefore avoiding @= - param_ensemble_array = param_ensemble_array @ T.astype( # noqa: PLR6104 - param_ensemble_array.dtype - ) + param_ensemble_array[non_zero_variance_mask] = param_ensemble_array[ # noqa: PLR6104 + non_zero_variance_mask + ] @ T.astype(param_ensemble_array.dtype) log_msg = f"Storing data for {param_group}.." logger.info(log_msg) diff --git a/test-data/ert/heat_equation/heat_equation.py b/test-data/ert/heat_equation/heat_equation.py index e36862ea1aa..63b086a492b 100755 --- a/test-data/ert/heat_equation/heat_equation.py +++ b/test-data/ert/heat_equation/heat_equation.py @@ -63,6 +63,8 @@ def sample_prior_conductivity(ensemble_size, nx, rng): rng = np.random.default_rng(iens) cond = sample_prior_conductivity(ensemble_size=1, nx=nx, rng=rng).reshape(nx, nx) + cond[nx // 2 :, :] = 1.0 + if iteration == 0: resfo.write( "cond.bgrdecl", [("COND ", cond.flatten(order="F").astype(np.float32))] diff --git a/tests/ert/ui_tests/cli/test_field_parameter.py b/tests/ert/ui_tests/cli/test_field_parameter.py index b26c6495b5e..6ebc139a89e 100644 --- a/tests/ert/ui_tests/cli/test_field_parameter.py +++ b/tests/ert/ui_tests/cli/test_field_parameter.py @@ -296,3 +296,7 @@ def test_parameter_update_with_inactive_cells_xtgeo_grdecl(tmpdir): assert "nan" not in Path( "simulations/realization-0/iter-1/my_param.grdecl" ).read_text(encoding="utf-8") + + +# def test_field_parameter_update_using_heat_equation(heat_equation_storage): +# print("Hei")