diff --git a/src/pydvl/utils/caching/config.py b/src/pydvl/utils/caching/config.py index 14f7cb761..62e96c2aa 100644 --- a/src/pydvl/utils/caching/config.py +++ b/src/pydvl/utils/caching/config.py @@ -34,7 +34,7 @@ class CachedFuncConfig: """ ignore_args: Collection[str] = field(default_factory=list) - time_threshold: float = 0 + time_threshold: float = 0.3 allow_repeated_evaluations: bool = False rtol_stderr: float = 0.1 min_repetitions: int = 3 diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 2f8783b1e..4f488988f 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -141,7 +141,8 @@ def test_cache_backend_serialization(cache_backend): def test_single_job(cache_backend): - wrapped_foo = cache_backend.wrap(foo) + cached_func_config = CachedFuncConfig(time_threshold=0.0) + wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) n = 1000 wrapped_foo(np.arange(n)) @@ -186,6 +187,7 @@ def test_cache_time_threshold(cache_backend): def test_cache_ignore_args(cache_backend): # Note that we typically do NOT want to ignore run_id cached_func_config = CachedFuncConfig( + time_threshold=0.0, ignore_args=["job_id"], ) wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) @@ -235,6 +237,7 @@ def test_parallel_jobs(cache_backend, parallel_config): def test_repeated_training(cache_backend, worker_id: str): cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.01, ) @@ -255,6 +258,7 @@ def test_repeated_training(cache_backend, worker_id: str): def test_faster_with_repeated_training(cache_backend, worker_id: str): cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.1, ) @@ -300,6 +304,7 @@ def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: # Note that we typically do NOT want to ignore run_id cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.01, ignore_args=["job_id", "run_id"],