Skip to content

Commit

Permalink
Set backend time_threshold to 0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
AnesBenmerzoug committed Dec 13, 2023
1 parent 798c232 commit 73e8e54
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/pydvl/utils/caching/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CachedFuncConfig:
"""

ignore_args: Collection[str] = field(default_factory=list)
time_threshold: float = 0
time_threshold: float = 0.3
allow_repeated_evaluations: bool = False
rtol_stderr: float = 0.1
min_repetitions: int = 3
7 changes: 6 additions & 1 deletion tests/utils/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def test_cache_backend_serialization(cache_backend):


def test_single_job(cache_backend):
wrapped_foo = cache_backend.wrap(foo)
cached_func_config = CachedFuncConfig(time_threshold=0.0)
wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config)

n = 1000
wrapped_foo(np.arange(n))
Expand Down Expand Up @@ -186,6 +187,7 @@ def test_cache_time_threshold(cache_backend):
def test_cache_ignore_args(cache_backend):
# Note that we typically do NOT want to ignore run_id
cached_func_config = CachedFuncConfig(
time_threshold=0.0,
ignore_args=["job_id"],
)
wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config)
Expand Down Expand Up @@ -235,6 +237,7 @@ def test_parallel_jobs(cache_backend, parallel_config):

def test_repeated_training(cache_backend, worker_id: str):
cached_func_config = CachedFuncConfig(
time_threshold=0.0,
allow_repeated_evaluations=True,
rtol_stderr=0.01,
)
Expand All @@ -255,6 +258,7 @@ def test_repeated_training(cache_backend, worker_id: str):

def test_faster_with_repeated_training(cache_backend, worker_id: str):
cached_func_config = CachedFuncConfig(
time_threshold=0.0,
allow_repeated_evaluations=True,
rtol_stderr=0.1,
)
Expand Down Expand Up @@ -300,6 +304,7 @@ def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float:

# Note that we typically do NOT want to ignore run_id
cached_func_config = CachedFuncConfig(
time_threshold=0.0,
allow_repeated_evaluations=True,
rtol_stderr=0.01,
ignore_args=["job_id", "run_id"],
Expand Down

0 comments on commit 73e8e54

Please sign in to comment.