From 7dac3590fde0876acb65b76575425a40ca5509d7 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 26 Oct 2023 13:44:18 +0200 Subject: [PATCH] Pass seed to to map_reduce call inside test_memcached_parallel_repeated_training --- tests/utils/test_caching.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 008a65847..c30e38fd8 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -1,5 +1,6 @@ import logging from time import sleep, time +from typing import Optional import numpy as np import pytest @@ -7,6 +8,7 @@ from pydvl.parallel import MapReduceJob from pydvl.utils import memcached +from pydvl.utils.types import Seed logger = logging.getLogger(__name__) @@ -139,7 +141,6 @@ def test_memcached_parallel_repeated_training( if parallel_config.backend != "joblib": pytest.skip("We don't have to test this with all parallel backends") _, config = memcached_client - rng = np.random.default_rng(seed) @memcached( client_config=config, @@ -149,9 +150,10 @@ def test_memcached_parallel_repeated_training( # Note that we typically do NOT want to ignore run_id ignore_args=["job_id", "run_id"], ) - def map_func(indices: NDArray[np.int_]) -> float: + def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: # from pydvl.utils.logging import logger # logger.info(f"run_id: {run_id}, running...") + rng = np.random.default_rng(seed) return np.sum(indices).item() + rng.normal(scale=5) def reduce_func(chunks: NDArray[np.float_]) -> float: @@ -162,7 +164,7 @@ def reduce_func(chunks: NDArray[np.float_]) -> float: ) results = [] for _ in range(n_runs): - result = map_reduce_job() + result = map_reduce_job(seed=seed) results.append(result) exact_value = np.sum(np.arange(n)).item()