Skip to content

Commit

Permalink
Pass seed to to map_reduce call inside test_memcached_parallel_repeat…
Browse files Browse the repository at this point in the history
…ed_training
  • Loading branch information
AnesBenmerzoug committed Oct 26, 2023
1 parent ba3d13d commit 7dac359
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/utils/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from time import sleep, time
from typing import Optional

import numpy as np
import pytest
from numpy.typing import NDArray

from pydvl.parallel import MapReduceJob
from pydvl.utils import memcached
from pydvl.utils.types import Seed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,7 +141,6 @@ def test_memcached_parallel_repeated_training(
if parallel_config.backend != "joblib":
pytest.skip("We don't have to test this with all parallel backends")
_, config = memcached_client
rng = np.random.default_rng(seed)

@memcached(
client_config=config,
Expand All @@ -149,9 +150,10 @@ def test_memcached_parallel_repeated_training(
# Note that we typically do NOT want to ignore run_id
ignore_args=["job_id", "run_id"],
)
def map_func(indices: NDArray[np.int_]) -> float:
def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float:
# from pydvl.utils.logging import logger
# logger.info(f"run_id: {run_id}, running...")
rng = np.random.default_rng(seed)
return np.sum(indices).item() + rng.normal(scale=5)

def reduce_func(chunks: NDArray[np.float_]) -> float:
Expand All @@ -162,7 +164,7 @@ def reduce_func(chunks: NDArray[np.float_]) -> float:
)
results = []
for _ in range(n_runs):
result = map_reduce_job()
result = map_reduce_job(seed=seed)
results.append(result)

exact_value = np.sum(np.arange(n)).item()
Expand Down

0 comments on commit 7dac359

Please sign in to comment.