diff --git a/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py b/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py index f82cf2a27..391434f7c 100644 --- a/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py +++ b/flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py @@ -1,6 +1,8 @@ from datetime import date from functools import partial +from itertools import repeat import pathlib +from tempfile import NamedTemporaryFile from typing import Any, Callable from uuid import uuid4 @@ -8,7 +10,7 @@ import numpy as np import pandas as pd import pytest -from tempfile import NamedTemporaryFile +import tqdm.contrib.concurrent from gempyor.parameters import Parameters from gempyor.testing import ( @@ -192,6 +194,22 @@ def insufficient_dates_parameter_factory(tmp_path: pathlib.Path) -> MockParamete ) +def sample_params(params: Parameters, reinit: bool) -> np.ndarray: + """ + Helper method for unit testing. + + Args: + params: The instance of the Parameters class to sample from. + reinit: Whether to reinitialize the parameters. + + Returns: + The sampled parameters as a flattened numpy array. + """ + if reinit: + params.reinitialize_distributions() + return params.parameters_quick_draw(1, 1).flatten() + + class TestParameters: @pytest.mark.parametrize("factory", [(nonunique_invalid_parameter_factory)]) def test_nonunique_parameter_names_value_error( @@ -697,3 +715,34 @@ def test_parameters_reduce(self) -> None: # TODO: Come back and unit test this method after getting a better handle on # these NPI objects. pass + + def test_reinitialize_parameters(self, tmp_path: pathlib.Path) -> None: + mock_inputs = distribution_three_valid_parameter_factory(tmp_path) + + np.random.seed(123) + + params = mock_inputs.create_parameters_instance() + + results = tqdm.contrib.concurrent.process_map( + sample_params, + repeat(params, times=6), + repeat(False, times=6), + max_workers=2, + disable=True, + ) + + for i in range(1, len(results)): + assert np.allclose(results[i - 1], results[i]) + + np.random.seed(123) + + results_with_reinit = tqdm.contrib.concurrent.process_map( + sample_params, + repeat(params, times=6), + repeat(True, times=6), + max_workers=2, + disable=True, + ) + + for i in range(1, len(results_with_reinit)): + assert not np.allclose(results_with_reinit[i - 1], results_with_reinit[i])