Skip to content

Commit

Permalink
Move shared rng to run models
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Oct 24, 2023
1 parent 222eb48 commit 0e82ac3
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 25 deletions.
8 changes: 0 additions & 8 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ def __init__(self, config: "ErtConfig", read_only: bool = False) -> None:
substitute=self.get_context().substitute_real_iter,
)

self._shared_rng = np.random.default_rng(
_seed_sequence(self.ert_config.random_seed)
)

@property
def update_configuration(self) -> UpdateConfiguration:
if not self._update_configuration:
Expand Down Expand Up @@ -284,10 +280,6 @@ def getObservations(self) -> EnkfObs:
def have_observations(self) -> bool:
return len(self._observations) > 0

def rng(self) -> np.random.Generator:
"""Will return the random number generator used for updates."""
return self._shared_rng

def createRunPath(self, run_context: RunContext) -> None:
t = time.perf_counter()
for iens, run_arg in enumerate(run_context):
Expand Down
6 changes: 4 additions & 2 deletions src/ert/gui/tools/run_analysis/run_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from contextlib import contextmanager
from typing import Optional

import numpy as np
from qtpy.QtCore import QObject, Qt, QThread, Signal, Slot
from qtpy.QtWidgets import QApplication, QMessageBox

from ert.analysis import ErtAnalysisError, Progress, smoother_update
from ert.enkf_main import EnKFMain
from ert.enkf_main import EnKFMain, _seed_sequence
from ert.gui.ertnotifier import ErtNotifier
from ert.gui.ertwidgets import resourceIcon
from ert.gui.ertwidgets.statusdialog import StatusDialog
Expand Down Expand Up @@ -42,14 +43,15 @@ def run(self):
"""Runs analysis using target and source cases. Returns whether
the analysis was successful."""
error: Optional[str] = None
rng = np.random.default_rng(_seed_sequence(self._ert.ert_config.random_seed))
try:
smoother_update(
self._source_fs,
self._target_fs,
str(uuid.uuid4()),
self._ert.getLocalConfig(),
self._ert.analysisConfig(),
self._ert.rng(),
rng,
self.progress_callback,
)
except ErtAnalysisError as e:
Expand Down
5 changes: 4 additions & 1 deletion src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ def smoother_update(
run_id: str,
progress_callback: Optional[ProgressCallback] = None,
global_std_scaling: float = 1.0,
rng: Optional[np.random.Generator] = None,
) -> None:
if rng is None:
rng = np.random.default_rng()
self.update_snapshots[run_id] = smoother_update(
prior_storage,
posterior_storage,
run_id,
self._enkf_main.getLocalConfig(),
self._enkf_main.analysisConfig(),
self._enkf_main.rng(),
rng,
progress_callback,
global_std_scaling,
)
Expand Down
7 changes: 6 additions & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Union

import numpy as np

from ert.cli import MODULE_MODE
from ert.config import HookRuntime, QueueSystem
from ert.enkf_main import EnKFMain
from ert.enkf_main import EnKFMain, _seed_sequence
from ert.ensemble_evaluator import (
Ensemble,
EnsembleBuilder,
Expand Down Expand Up @@ -103,6 +105,9 @@ def __init__(
self._iter_map: Dict[int, str] = {}
self.validate()
self._context_env_keys: List[str] = []
self.rng = np.random.default_rng(
_seed_sequence(simulation_arguments.random_seed)
)

@property
def queue_system(self) -> QueueSystem:
Expand Down
1 change: 1 addition & 0 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def run_experiment(
prior_context.sim_fs,
posterior_context.sim_fs,
prior_context.run_id, # type: ignore
rng=self.rng,
)
except ErtAnalysisError as e:
raise ErtRunError(
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def analyzeStep(
ensemble_id,
self.ert.getLocalConfig(),
self.ert.analysisConfig(),
self.ert.rng(),
self.rng,
)
except ErtAnalysisError as e:
raise ErtRunError(
Expand Down
1 change: 1 addition & 0 deletions src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def update(
posterior_context.sim_fs,
prior_context.run_id, # type: ignore
global_std_scaling=weight,
rng=self.rng,
)
except ErtAnalysisError as e:
raise ErtRunError(
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_update_snapshot(
"id",
ert.getLocalConfig(),
ert.analysisConfig(),
ert.rng(),
np.random.default_rng(3593114179000630026631423308983283277868),
)
else:
smoother_update(
Expand All @@ -141,7 +141,7 @@ def test_update_snapshot(
"id",
ert.getLocalConfig(),
ert.analysisConfig(),
ert.rng(),
np.random.default_rng(3593114179000630026631423308983283277868),
)

sim_gen_kw = list(prior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten())
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_localization(
prior.run_id,
ert.getLocalConfig(),
ert.analysisConfig(),
ert.rng(),
np.random.default_rng(3593114179000630026631423308983283277868),
)

sim_gen_kw = list(
Expand Down
17 changes: 11 additions & 6 deletions tests/unit_tests/gui/run_analysis/test_run_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ class MockedQIcon(QIcon):


@pytest.fixture
def mock_tool(mock_storage):
def ert_mock():
ert_mock = Mock()
ert_mock.ert_config.random_seed = None
return ert_mock


@pytest.fixture
def mock_tool(mock_storage, ert_mock):
with patch("ert.gui.tools.run_analysis.run_analysis_tool.resourceIcon") as rs:
rs.return_value = MockedQIcon()
(target, source) = mock_storage
Expand All @@ -39,8 +46,7 @@ def mock_tool(mock_storage):
notifier = Mock(spec_set=ErtNotifier)
notifier.storage.to_accessor.return_value = notifier.storage
notifier.storage.create_ensemble.return_value = target

tool = RunAnalysisTool(Mock(spec_set=EnKFMain), notifier)
tool = RunAnalysisTool(ert_mock, notifier)
tool._run_widget = run_widget
tool._dialog = Mock(spec_set=StatusDialog)

Expand All @@ -56,10 +62,9 @@ def mock_storage(storage):


@pytest.mark.requires_window_manager
def test_analyse_success(mock_storage, qtbot):
def test_analyse_success(mock_storage, qtbot, ert_mock):
(target, source) = mock_storage

analyse = Analyse(Mock(spec_set=EnKFMain), target, source)
analyse = Analyse(ert_mock, target, source)
thread = QThread()
with qtbot.waitSignals(
[analyse.finished, thread.finished], timeout=2000, raising=True
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/storage/test_parameter_sample_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def create_runpath(
*,
ensemble: Optional[EnsembleAccessor] = None,
iteration=0,
random_seed: Optional[int] = 1234,
) -> Tuple[EnKFMain, EnsembleAccessor]:
active_mask = [True] if active_mask is None else active_mask
ert_config = ErtConfig.from_file(config)
Expand All @@ -58,7 +59,7 @@ def create_runpath(
sample_prior(
ensemble,
[i for i, active in enumerate(active_mask) if active],
random_seed=1234,
random_seed=random_seed,
)
ert.createRunPath(prior)
return ert, ensemble
Expand Down Expand Up @@ -438,7 +439,7 @@ def test_initialize_random_seed(
fh.writelines("MY_KEYWORD <MY_KEYWORD>")
with open("prior.txt", mode="w", encoding="utf-8") as fh:
fh.writelines("MY_KEYWORD NORMAL 0 1")
create_runpath(storage, "config.ert")
create_runpath(storage, "config.ert", random_seed=None)
# We read the first parameter value as a reference value
expected = Path("simulations/realization-0/iter-0/kw.txt").read_text("utf-8")

Expand All @@ -458,7 +459,7 @@ def test_initialize_random_seed(
with open("prior.txt", mode="w", encoding="utf-8") as fh:
fh.writelines("MY_KEYWORD NORMAL 0 1")

create_runpath(storage, "config_2.ert")
create_runpath(storage, "config_2.ert", random_seed=int(random_seed))
with expectation:
assert (
Path("simulations/realization-0/iter-0/kw.txt").read_text("utf-8")
Expand Down

0 comments on commit 0e82ac3

Please sign in to comment.