Skip to content

Commit

Permalink
Add unit test for reinitialize_distributions
Browse files Browse the repository at this point in the history
Added a unit test to the `Parameters` class unit tests for the
`reinitialize_distributions` method that demonstrates how this method
affects the seeding behavior across workers.
  • Loading branch information
TimothyWillard committed Dec 17, 2024
1 parent e39f961 commit c828f1e
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion flepimop/gempyor_pkg/tests/parameters/test_parameters_class.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from datetime import date
from functools import partial
from itertools import repeat
import pathlib
from tempfile import NamedTemporaryFile
from typing import Any, Callable
from uuid import uuid4

import confuse
import numpy as np
import pandas as pd
import pytest
from tempfile import NamedTemporaryFile
import tqdm.contrib.concurrent

from gempyor.parameters import Parameters
from gempyor.testing import (
Expand Down Expand Up @@ -192,6 +194,22 @@ def insufficient_dates_parameter_factory(tmp_path: pathlib.Path) -> MockParamete
)


def sample_params(params: Parameters, reinit: bool) -> np.ndarray:
"""
Helper method for unit testing.
Args:
params: The instance of the Parameters class to sample from.
reinit: Whether to reinitialize the parameters.
Returns:
The sampled parameters as a flattened numpy array.
"""
if reinit:
params.reinitialize_distributions()
return params.parameters_quick_draw(1, 1).flatten()


class TestParameters:
@pytest.mark.parametrize("factory", [(nonunique_invalid_parameter_factory)])
def test_nonunique_parameter_names_value_error(
Expand Down Expand Up @@ -697,3 +715,34 @@ def test_parameters_reduce(self) -> None:
# TODO: Come back and unit test this method after getting a better handle on
# these NPI objects.
pass

def test_reinitialize_parameters(self, tmp_path: pathlib.Path) -> None:
mock_inputs = distribution_three_valid_parameter_factory(tmp_path)

np.random.seed(123)

params = mock_inputs.create_parameters_instance()

results = tqdm.contrib.concurrent.process_map(
sample_params,
repeat(params, times=6),
repeat(False, times=6),
max_workers=2,
disable=True,
)

for i in range(1, len(results)):
assert np.allclose(results[i - 1], results[i])

np.random.seed(123)

results_with_reinit = tqdm.contrib.concurrent.process_map(
sample_params,
repeat(params, times=6),
repeat(True, times=6),
max_workers=2,
disable=True,
)

for i in range(1, len(results_with_reinit)):
assert not np.allclose(results_with_reinit[i - 1], results_with_reinit[i])

0 comments on commit c828f1e

Please sign in to comment.