From ff476b93388e5854b229c1129d9095dcdb4026d9 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Mon, 30 Oct 2023 16:32:59 +0100 Subject: [PATCH 01/29] Move existing caching module to new memcached module withing new caching package --- src/pydvl/utils/caching/__init__.py | 1 + src/pydvl/utils/{caching.py => caching/memcached.py} | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 src/pydvl/utils/caching/__init__.py rename src/pydvl/utils/{caching.py => caching/memcached.py} (99%) diff --git a/src/pydvl/utils/caching/__init__.py b/src/pydvl/utils/caching/__init__.py new file mode 100644 index 000000000..7b8e4d66c --- /dev/null +++ b/src/pydvl/utils/caching/__init__.py @@ -0,0 +1 @@ +from .memcached import * diff --git a/src/pydvl/utils/caching.py b/src/pydvl/utils/caching/memcached.py similarity index 99% rename from src/pydvl/utils/caching.py rename to src/pydvl/utils/caching/memcached.py index 37d087de4..87bf9aa01 100644 --- a/src/pydvl/utils/caching.py +++ b/src/pydvl/utils/caching/memcached.py @@ -96,8 +96,8 @@ from pymemcache import MemcacheUnexpectedCloseError from pymemcache.client import Client, RetryingClient -from .config import MemcachedClientConfig -from .numeric import running_moments +from ..config import MemcachedClientConfig +from ..numeric import running_moments PICKLE_VERSION = 5 # python >= 3.8 From 53374519fac3b2cf3ee89137646834a0976e630f Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 11:07:05 +0100 Subject: [PATCH 02/29] Refactor caching into separate classes and add 2 more implementations --- src/pydvl/utils/caching/__init__.py | 89 +++++++- src/pydvl/utils/caching/base.py | 303 ++++++++++++++++++++++++++ src/pydvl/utils/caching/config.py | 40 ++++ src/pydvl/utils/caching/disk.py | 76 +++++++ src/pydvl/utils/caching/memcached.py | 310 +++++++++------------------ src/pydvl/utils/caching/memory.py | 54 +++++ src/pydvl/utils/config.py | 72 +------ 7 files changed, 668 insertions(+), 276 deletions(-) create mode 100644 src/pydvl/utils/caching/base.py create mode 100644 src/pydvl/utils/caching/config.py create mode 100644 src/pydvl/utils/caching/disk.py create mode 100644 src/pydvl/utils/caching/memory.py diff --git a/src/pydvl/utils/caching/__init__.py b/src/pydvl/utils/caching/__init__.py index 7b8e4d66c..6a1a3e867 100644 --- a/src/pydvl/utils/caching/__init__.py +++ b/src/pydvl/utils/caching/__init__.py @@ -1 +1,88 @@ -from .memcached import * +"""Caching of functions. + +pyDVL caches utility values to allow reusing previously computed evaluations. + +!!! Warning + Function evaluations are cached with a key based on the function's signature + and code. This can lead to undesired cache hits, see [Cache reuse](#cache-reuse). + + Remember **not to reuse utility objects for different datasets**. + +# Configuration + +Memoization is disabled by default but can be enabled easily, +see [Setting up the cache](#setting-up-the-cache). +When enabled, it will be added to any callable used to construct a +[Utility][pydvl.utils.utility.Utility] (done with the decorator [@memcached][pydvl.utils.caching.memcached]). +Depending on the nature of the utility you might want to +enable the computation of a running average of function values, see +[Usage with stochastic functions](#usaage-with-stochastic-functions). +You can see all configuration options under [MemcachedConfig][pydvl.utils.config.MemcachedConfig]. + +## Default configuration + +```python +default_config = dict( + server=('localhost', 11211), + connect_timeout=1.0, + timeout=0.1, + # IMPORTANT! Disable small packet consolidation: + no_delay=True, + serde=serde.PickleSerde(pickle_version=PICKLE_VERSION) +) +``` + +# Usage with stochastic functions + +In addition to standard memoization, the decorator +[memcached()][pydvl.utils.caching.memcached] can compute running average and +standard error of repeated evaluations for the same input. This can be useful +for stochastic functions with high variance (e.g. model training for small +sample sizes), but drastically reduces the speed benefits of memoization. + +This behaviour can be activated with the argument `allow_repeated_evaluations` +to [memcached()][pydvl.utils.caching.memcached]. + +# Cache reuse + +When working directly with [memcached()][pydvl.utils.caching.memcached], it is +essential to only cache pure functions. If they have any kind of state, either +internal or external (e.g. a closure over some data that may change), then the +cache will fail to notice this and the same value will be returned. + +When a function is wrapped with [memcached()][pydvl.utils.caching.memcached] for +memoization, its signature (input and output names) and code are used as a key +for the cache. Alternatively you can pass a custom value to be used as key with + +```python +cached_fun = memcached(**asdict(cache_options))(fun, signature=custom_signature) +``` + +If you are running experiments with the same [Utility][pydvl.utils.utility.Utility] +but different datasets, this will lead to evaluations of the utility on new data +returning old values because utilities only use sample indices as arguments (so +there is no way to tell the difference between '1' for dataset A and '1' for +dataset 2 from the point of view of the cache). One solution is to empty the +cache between runs, but the preferred one is to **use a different Utility +object for each dataset**. + +# Unexpected cache misses + +Because all arguments to a function are used as part of the key for the cache, +sometimes one must exclude some of them. For example, If a function is going to +run across multiple processes and some reporting arguments are added (like a +`job_id` for logging purposes), these will be part of the signature and make the +functions distinct to the eyes of the cache. This can be avoided with the use of +[ignore_args][pydvl.utils.config.MemcachedConfig] in the configuration. + +""" + +from .base import * +from .config import * +from .disk import * +from .memory import * + +try: + from .memcached import * +except ImportError: + pass diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py new file mode 100644 index 000000000..339c85389 --- /dev/null +++ b/src/pydvl/utils/caching/base.py @@ -0,0 +1,303 @@ +import inspect +import logging +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Collection, Dict, Optional, Tuple, TypeVar, cast + +from joblib import hashing +from joblib.func_inspect import filter_args + +from ..numeric import running_moments +from .config import CachedFuncConfig + +__all__ = ["CacheStats", "CacheBackendBase", "CachedFunc"] + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheStats: + """Statistics gathered by cached functions. + + Attributes: + sets: number of times a value was set in the cache + misses: number of times a value was not found in the cache + hits: number of times a value was found in the cache + timeouts: number of times a timeout occurred + errors: number of times an error occurred + reconnects: number of times the client reconnected to the server + """ + + sets: int = 0 + misses: int = 0 + hits: int = 0 + timeouts: int = 0 + errors: int = 0 + reconnects: int = 0 + + +@dataclass +class CacheResult: + """A dataclass to store the cached result of a computation + as well as count and variance when using repeated evaluation. + + Attributes: + value: The cached value. + count: The number of times this value has been computed. + variance: The variance associated with the cached value. + """ + + value: float + count: int = 1 + variance: float = 0.0 + + +class CacheBackendBase(ABC): + """Abstract base class for cache backends. + + Implements cache management including wrapping functions to cache results, + retrieving results, setting results, clearing the cache, and combining cache keys. + + Attributes: + stats: Cache statistics tracker. + """ + + def __init__(self) -> None: + self.stats = CacheStats() + + def wrap( + self, + func: Callable, + *, + cached_func_config: CachedFuncConfig = CachedFuncConfig(), + ) -> "CachedFunc": + """Wraps a function to cache its results. + + Args: + func: The function to wrap. + cached_func_config: Optional caching options for the wrapped function. + + Returns: + The wrapped cached function. + """ + return CachedFunc( + func, + cache_backend=self, + cached_func_options=cached_func_config, + ) + + @abstractmethod + def get(self, key: str) -> Optional[CacheResult]: + """Abstract method to retrieve a cached result. + + Implemented by subclasses. + + Args: + key: The cache key. + + Returns: + The cached result or None if not found. + """ + pass + + @abstractmethod + def set(self, key: str, value: CacheResult) -> None: + """Abstract method to set a cached result. + + Implemented by subclasses. + + Args: + key: The cache key. + value: The result to cache. + """ + pass + + @abstractmethod + def clear(self) -> None: + """Abstract method to clear the entire cache.""" + pass + + @abstractmethod + def combine_hashes(self, *args: str) -> str: + """Abstract method to combine cache keys.""" + pass + + +class CachedFunc: + """Caches callable function results with a provided cache backend. + + Wraps a callable function to cache its results using a provided + an instance of CacheBackendBase subclass. + Can configure cache keys, repeated evaluations, and cache stats tracking. + + Args: + func: Callable to wrap. + cache_backend: An instance of CacheBackendBase that handles + setting and getting values. + cached_func_options: + ignore_args: Do not take these keyword arguments into account when + hashing the wrapped function for usage as key. This allows + sharing the cache among different jobs for the same experiment run if + the callable happens to have "nuisance" parameters like `job_id` which + do not affect the result of the computation. + time_threshold: computations taking less time than this many seconds are + not cached. + allow_repeated_evaluations: If `True`, repeated calls to a function + with the same arguments will be allowed and outputs averaged until the + running standard deviation of the mean stabilizes below + `rtol_stderr * mean`. + rtol_stderr: relative tolerance for repeated evaluations. More precisely, + [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function once the + standard deviation of the mean is smaller than `rtol_stderr * mean`. + min_repetitions: minimum number of times that a function evaluation + on the same arguments is repeated before returning cached values. Useful + for stochastic functions only. If the model training is very noisy, set + this number to higher values to reduce variance. + """ + + def __init__( + self, + func: Callable[..., T], + *, + cache_backend: CacheBackendBase, + cached_func_options: CachedFuncConfig = CachedFuncConfig(), + ) -> None: + self.func = func + self.cache_backend = cache_backend + self.cached_func_options = cached_func_options + + self.__doc__ = f"A wrapper around {func.__name__}() with caching enabled.\n" + ( + CachedFunc.__doc__ or "" + ) + self.__name__ = f"cached_{func.__name__}" + path = list(reversed(func.__qualname__.split("."))) + patched = [f"cached_{path[0]}"] + path[1:] + self.__qualname__ = ".".join(reversed(patched)) + + def __call__(self, *args, **kwargs) -> T: + """Call the wrapped cached function. + + Executes the wrapped function, caching and returning the result. + """ + return self._cached_call(args, kwargs) + + def _force_call(self, args, kwargs) -> Tuple[T, float]: + """Force re-evaluation of the wrapped function. + + Executes the wrapped function without caching. + + Returns: + Function result and execution duration. + """ + start = time.monotonic() + value = self.func(*args, **kwargs) + end = time.monotonic() + duration = end - start + return value, duration + + def _cached_call(self, args, kwargs) -> T: + """Cached wrapped function call. + + Executes the wrapped function with cache checking/setting. + + Returns: + Cached result of the wrapped function. + """ + key = self._get_cache_key(*args, **kwargs) + cached_result = self.cache_backend.get(key) + if cached_result is None: + value, duration = self._force_call(args, kwargs) + result = CacheResult(value) + if ( + duration >= self.cached_func_options.time_threshold + or self.cached_func_options.allow_repeated_evaluations + ): + self.cache_backend.set(key, result) + else: + result = cached_result + if self.cached_func_options.allow_repeated_evaluations: + error_on_average = (result.variance / result.count) ** (1 / 2) + if ( + error_on_average + > self.cached_func_options.rtol_stderr * result.value + or result.count <= self.cached_func_options.min_repetitions + ): + new_value, _ = self._force_call(args, kwargs) + new_avg, new_var = running_moments( + result.value, + result.variance, + result.count, + cast(float, new_value), + ) + result.value = new_avg + result.count += 1 + result.variance = new_var + self.cache_backend.set(key, result) + return result.value + + def _get_cache_key(self, *args, **kwargs) -> str: + """Returns a string key used to identify the function and input parameter hash.""" + func_hash = self._hash_function(self.func) + argument_hash = self._hash_arguments( + self.func, self.cached_func_options.ignore_args, args, kwargs + ) + key = self.cache_backend.combine_hashes(func_hash, argument_hash) + return key + + @staticmethod + def _hash_function(func: Callable) -> str: + """Create hash for wrapped function.""" + func_hash = hashing.hash((func.__code__.co_code, func.__code__.co_consts)) + return func_hash + + @staticmethod + def _hash_arguments( + func: Callable, + ignore_args: Collection[str], + args: Tuple[Any], + kwargs: Dict[str, Any], + ) -> str: + """Create hash for function arguments.""" + return hashing.hash( + CachedFunc._filter_args(func, ignore_args, args, kwargs), + ) + + @staticmethod + def _filter_args( + func: Callable, + ignore_args: Collection[str], + args: Tuple[Any], + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Filter arguments to exclude from cache keys.""" + # Remove kwargs before calling filter_args + # Because some of them might not be explicitly in the function's signature + # and that would raise an error when calling filter_args + kwargs = {k: v for k, v in kwargs.items() if k not in ignore_args} # type: ignore + # Update ignore_args + func_signature = inspect.signature(func) + arg_names = [] + for param in func_signature.parameters.values(): + if param.kind in [ + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + param.KEYWORD_ONLY, + ]: + arg_names.append(param.name) + ignore_args = [x for x in ignore_args if x in arg_names] + filtered_args: Dict[str, Any] = filter_args(func, ignore_args, args, kwargs) # type: ignore + # We ignore 'self' because for our use case we only care about the method. + # We don't want a cache if another attribute changes in the instance. + try: + filtered_args.pop("self") + except KeyError: + pass + return filtered_args # type: ignore + + @property + def stats(self) -> CacheStats: + """Cache backend statistics.""" + return self.cache_backend.stats diff --git a/src/pydvl/utils/caching/config.py b/src/pydvl/utils/caching/config.py new file mode 100644 index 000000000..14f7cb761 --- /dev/null +++ b/src/pydvl/utils/caching/config.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field +from typing import Collection + +__all__ = ["CachedFuncConfig"] + + +@dataclass +class CachedFuncConfig: + """Configuration for cached functions and methods, providing + memoization of function calls. + + Instances of this class are typically used as arguments for the construction + of a [Utility][pydvl.utils.utility.Utility]. + + Args: + ignore_args: Do not take these keyword arguments into account when + hashing the wrapped function for usage as key. This allows + sharing the cache among different jobs for the same experiment run if + the callable happens to have "nuisance" parameters like `job_id` which + do not affect the result of the computation. + time_threshold: Computations taking less time than this many seconds are + not cached. A value of 0 means that it will always cache results. + allow_repeated_evaluations: If `True`, repeated calls to a function + with the same arguments will be allowed and outputs averaged until the + running standard deviation of the mean stabilizes below + `rtol_stderr * mean`. + rtol_stderr: relative tolerance for repeated evaluations. More precisely, + [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function once the + standard deviation of the mean is smaller than `rtol_stderr * mean`. + min_repetitions: minimum number of times that a function evaluation + on the same arguments is repeated before returning cached values. Useful + for stochastic functions only. If the model training is very noisy, set + this number to higher values to reduce variance. + """ + + ignore_args: Collection[str] = field(default_factory=list) + time_threshold: float = 0 + allow_repeated_evaluations: bool = False + rtol_stderr: float = 0.1 + min_repetitions: int = 3 diff --git a/src/pydvl/utils/caching/disk.py b/src/pydvl/utils/caching/disk.py new file mode 100644 index 000000000..e65ff040b --- /dev/null +++ b/src/pydvl/utils/caching/disk.py @@ -0,0 +1,76 @@ +import os +import shutil +from pathlib import Path +from typing import Any, Optional, Union + +import cloudpickle + +from pydvl.utils.caching.base import CacheBackendBase + +__all__ = ["DiskCacheBackend"] + +PICKLE_VERSION = 5 # python >= 3.8 + +DEFAULT_CACHE_DIR = Path().home() / ".pydvl_cache/disk" + + +class DiskCacheBackend(CacheBackendBase): + """Disk cache backend that stores results in files. + + Implements the CacheBackendBase interface for a disk-based cache. + Stores cache entries as pickled files on disk, keyed by cache key. + + Attributes: + cache_dir: Base directory for cache storage. + """ + + def __init__( + self, + cache_dir: Union[os.PathLike, str] = DEFAULT_CACHE_DIR, + ) -> None: + """Initialize the disk cache backend. + + Args: + cache_dir: Base directory for cache storage. + """ + super().__init__() + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True, parents=True) + + def get(self, key: str) -> Optional[Any]: + """Get a value from the cache. + + Args: + key: Cache key. + + Returns: + Cached value or None if not found. + """ + cache_file = self.cache_dir / key + if not cache_file.exists(): + self.stats.misses += 1 + return None + self.stats.hits += 1 + with cache_file.open("rb") as f: + return cloudpickle.load(f) + + def set(self, key: str, value: Any) -> None: + """Set a value in the cache. + + Args: + key: Cache key. + value: Value to cache. + """ + cache_file = self.cache_dir / key + self.stats.sets += 1 + with cache_file.open("wb") as f: + cloudpickle.dump(value, f, protocol=PICKLE_VERSION) + + def clear(self) -> None: + """Deletes cache directory and recreates it.""" + shutil.rmtree(self.cache_dir) + self.cache_dir.mkdir(exist_ok=True, parents=True) + + def combine_hashes(self, *args: str) -> str: + """Join cache key components.""" + return os.pathsep.join(args) diff --git a/src/pydvl/utils/caching/memcached.py b/src/pydvl/utils/caching/memcached.py index 87bf9aa01..1f04ad2f5 100644 --- a/src/pydvl/utils/caching/memcached.py +++ b/src/pydvl/utils/caching/memcached.py @@ -86,127 +86,113 @@ import uuid import warnings from dataclasses import asdict, dataclass -from functools import wraps -from hashlib import blake2b -from io import BytesIO -from time import time -from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, cast +from typing import Any, Dict, Optional, Tuple -from cloudpickle import Pickler from pymemcache import MemcacheUnexpectedCloseError from pymemcache.client import Client, RetryingClient +from pymemcache.serde import PickleSerde -from ..config import MemcachedClientConfig -from ..numeric import running_moments +from .base import CacheBackendBase + +__all__ = ["MemcachedClientConfig", "MemcachedCacheBackend"] PICKLE_VERSION = 5 # python >= 3.8 logger = logging.getLogger(__name__) -T = TypeVar("T") - -@dataclass -class CacheStats: - """Statistics gathered by cached functions. +@dataclass(frozen=True) +class MemcachedClientConfig: + """Configuration of the memcached client. - Attributes: - sets: number of times a value was set in the cache - misses: number of times a value was not found in the cache - hits: number of times a value was found in the cache - timeouts: number of times a timeout occurred - errors: number of times an error occurred - reconnects: number of times the client reconnected to the server + Args: + server: A tuple of (IP|domain name, port). + connect_timeout: How many seconds to wait before raising + `ConnectionRefusedError` on failure to connect. + timeout: seconds to wait for send or recv calls on the socket + connected to memcached. + no_delay: set the `TCP_NODELAY` flag, which may help with performance + in some cases. + serde: a serializer / deserializer ("serde"). The default `PickleSerde` + should work in most cases. See [pymemcached's + documentation](https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html#pymemcache.client.base.Client) + for details. """ - sets: int = 0 - misses: int = 0 - hits: int = 0 - timeouts: int = 0 - errors: int = 0 - reconnects: int = 0 + server: Tuple[str, int] = ("localhost", 11211) + connect_timeout: float = 1.0 + timeout: float = 1.0 + no_delay: bool = True + serde: PickleSerde = PickleSerde(pickle_version=PICKLE_VERSION) -def serialize(x: Any) -> bytes: - """Serialize an object to bytes. - Args: - x: object to serialize. - - Returns: - serialized object. - """ - pickled_output = BytesIO() - pickler = Pickler(pickled_output, PICKLE_VERSION) - pickler.dump(x) - return pickled_output.getvalue() - - -def memcached( - client_config: Optional[MemcachedClientConfig] = None, - time_threshold: float = 0.3, - allow_repeated_evaluations: bool = False, - rtol_stderr: float = 0.1, - min_repetitions: int = 3, - ignore_args: Optional[Iterable[str]] = None, -) -> Callable[[Callable[..., T], bytes | None], Callable[..., T]]: - """ - Transparent, distributed memoization of function calls. +class MemcachedCacheBackend(CacheBackendBase): + """Memcached cache backend. - Given a function and its signature, memcached uses a distributed cache - that, for each set of inputs, keeps track of the average returned value, - with variance and number of times it was calculated. + Implements CacheBackendBase using a memcached client. - If the function is deterministic, i.e. same input corresponds to the same - exact output, set `allow_repeated_evaluations` to `False`. If instead the - function is stochastic (like the training of a model depending on random - initializations), memcached() allows to set a minimum number of evaluations - to compute a running average, and a tolerance after which the function will - not be called anymore. In other words, the function will be recomputed - until the value has stabilized with a standard error smaller than - `rtol_stderr * running average`. + Attributes: + config: Memcached client configuration. + client: Memcached client instance. + """ - !!! Warning - Do not cache functions with state! See [Cache reuse](cache-reuse) + def __init__(self, config: MemcachedClientConfig = MemcachedClientConfig()) -> None: + """Initialize memcached cache backend. - ??? Example - ```python - cached_fun = memcached(**asdict(cache_options))(heavy_computation) - ``` + Args: + config: Memcached client configuration. + """ + super().__init__() + self.config = config + self.client = self._connect(self.config) - Args: - client_config: configuration for pymemcache's - [Client][pymemcache.client.base.Client]. - Will be merged on top of the default configuration (see below). - time_threshold: computations taking less time than this many seconds are - not cached. - allow_repeated_evaluations: If `True`, repeated calls to a function - with the same arguments will be allowed and outputs averaged until the - running standard deviation of the mean stabilizes below - `rtol_stderr * mean`. - rtol_stderr: relative tolerance for repeated evaluations. More precisely, - [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function once the - standard deviation of the mean is smaller than `rtol_stderr * mean`. - min_repetitions: minimum number of times that a function evaluation - on the same arguments is repeated before returning cached values. Useful - for stochastic functions only. If the model training is very noisy, set - this number to higher values to reduce variance. - ignore_args: Do not take these keyword arguments into account when - hashing the wrapped function for usage as key in memcached. This allows - sharing the cache among different jobs for the same experiment run if - the callable happens to have "nuisance" parameters like `job_id` which - do not affect the result of the computation. - - Returns: - A wrapped function + def get(self, key: str) -> Optional[Any]: + """Get value from memcached. - """ - if ignore_args is None: - ignore_args = [] + Args: + key: Cache key. - # Do I really need this? - def connect(config: MemcachedClientConfig): - """First tries to establish a connection, then tries setting and - getting a value.""" + Returns: + Cached value or None if not found or client disconnected. + """ + result = None + try: + result = self.client.get(key) + except socket.timeout as e: + self.stats.timeouts += 1 + warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) + except OSError as e: + self.stats.errors += 1 + warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) + except AttributeError as e: + # FIXME: this depends on _recv() failing on invalid sockets + # See pymemcache.base.py, + self.stats.reconnects += 1 + warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) + self.client = self._connect(self.config) + if result is None: + self.stats.misses += 1 + else: + self.stats.hits += 1 + return result + + def set(self, key: str, value: Any) -> None: + """Set value in memcached. + + Args: + key: Cache key. + value: Value to cache. + """ + self.client.set(key, value, noreply=True) + self.stats.sets += 1 + + def clear(self) -> None: + """Flush all values from memcached.""" + self.client.flush_all(noreply=True) + + @staticmethod + def _connect(config: MemcachedClientConfig) -> RetryingClient: + """Connect to memcached server.""" try: client = RetryingClient( Client(**asdict(config)), @@ -226,114 +212,28 @@ def connect(config: MemcachedClientConfig): f"to {config.server} after " f"{config.connect_timeout} seconds: {str(e)}. Did you start memcached?" ) - raise e + raise except AssertionError as e: logger.error( # type: ignore f"@memcached: Failure saving dummy value " f"to {config.server}: {str(e)}" ) - - def wrapper(fun: Callable[..., T], signature: Optional[bytes] = None): - if signature is None: - signature = serialize((fun.__code__.co_code, fun.__code__.co_consts)) - - @wraps(fun, updated=[]) # don't try to use update() for a class - class Wrapped: - config: MemcachedClientConfig - stats: CacheStats - client: RetryingClient - - def __init__(self, config: MemcachedClientConfig): - self.config = config - self.stats = CacheStats() - self.client = connect(self.config) - self._signature = signature - - def __call__(self, *args, **kwargs) -> T: - key_kwargs = {k: v for k, v in kwargs.items() if k not in ignore_args} # type: ignore - arg_signature: bytes = serialize((args, list(key_kwargs.items()))) - - key = blake2b(self._signature + arg_signature).hexdigest().encode("ASCII") # type: ignore - - result_dict: Dict[str, float] = self.get_key_value(key) - if result_dict is None: - result_dict = {} - start = time() - value = fun(*args, **kwargs) - end = time() - result_dict["value"] = value - if end - start >= time_threshold or allow_repeated_evaluations: - result_dict["count"] = 1 - result_dict["variance"] = 0 - self.client.set(key, result_dict, noreply=True) - self.stats.sets += 1 - self.stats.misses += 1 - elif allow_repeated_evaluations: - self.stats.hits += 1 - value = result_dict["value"] - count = result_dict["count"] - variance = result_dict["variance"] - error_on_average = (variance / count) ** (1 / 2) - if ( - error_on_average > rtol_stderr * value - or count <= min_repetitions - ): - new_value = fun(*args, **kwargs) - new_avg, new_var = running_moments( - value, variance, int(count), cast(float, new_value) - ) - result_dict["value"] = new_avg - result_dict["count"] = count + 1 - result_dict["variance"] = new_var - self.client.set(key, result_dict, noreply=True) - self.stats.sets += 1 - else: - self.stats.hits += 1 - return result_dict["value"] # type: ignore - - def __getstate__(self): - """Enables pickling after a socket has been opened to the - memcached server, by removing the client from the stored - data.""" - odict = self.__dict__.copy() - del odict["client"] - return odict - - def __setstate__(self, d: dict): - """Restores a client connection after loading from a pickle.""" - self.config = d["config"] - self.stats = d["stats"] - self.client = Client(**asdict(self.config)) - self._signature = signature - - def get_key_value(self, key: bytes): - result = None - try: - result = self.client.get(key) - except socket.timeout as e: - self.stats.timeouts += 1 - warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) - except OSError as e: - self.stats.errors += 1 - warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) - except AttributeError as e: - # FIXME: this depends on _recv() failing on invalid sockets - # See pymemcache.base.py, - self.stats.reconnects += 1 - warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) - self.client = connect(self.config) - return result - - Wrapped.__doc__ = ( - f"A wrapper around {fun.__name__}() with remote caching enabled.\n" - + (Wrapped.__doc__ or "") - ) - Wrapped.__name__ = f"memcached_{fun.__name__}" - path = list(reversed(fun.__qualname__.split("."))) - patched = [f"memcached_{path[0]}"] + path[1:] - Wrapped.__qualname__ = ".".join(reversed(patched)) - - # TODO: pick from some config file or something - return Wrapped(client_config or MemcachedClientConfig()) - - return wrapper + raise + + def combine_hashes(self, *args: str) -> str: + """Join cache key components for Memcached.""" + return ":".join(args) + + def __getstate__(self) -> Dict: + """Enables pickling after a socket has been opened to the + memcached server, by removing the client from the stored + data.""" + odict = self.__dict__.copy() + del odict["client"] + return odict + + def __setstate__(self, d: Dict): + """Restores a client connection after loading from a pickle.""" + self.config = d["config"] + self.stats = d["stats"] + self.client = self._connect(self.config) diff --git a/src/pydvl/utils/caching/memory.py b/src/pydvl/utils/caching/memory.py new file mode 100644 index 000000000..9ffc48fa5 --- /dev/null +++ b/src/pydvl/utils/caching/memory.py @@ -0,0 +1,54 @@ +import os +from typing import Any, Optional + +from pydvl.utils.caching.base import CacheBackendBase + +__all__ = ["InMemoryCacheBackend"] + + +class InMemoryCacheBackend(CacheBackendBase): + """In-memory cache backend that stores results in a dictionary. + + Implements the CacheBackendBase interface for an in-memory-based cache. + Stores cache entries as values in a dictionary, keyed by cache key. + """ + + def __init__(self) -> None: + """Initialize the in-memory cache backend.""" + super().__init__() + self.cached_values = {} + + def get(self, key: str) -> Optional[Any]: + """Get a value from the cache. + + Args: + key: Cache key. + + Returns: + Cached value or None if not found. + """ + value = self.cached_values.get(key, None) + if value is not None: + self.stats.hits += 1 + else: + self.stats.misses += 1 + return value + + def set(self, key: str, value: Any) -> None: + """Set a value in the cache. + + Args: + key: Cache key. + value: Value to cache. + """ + self.cached_values[key] = value + self.stats.sets += 1 + + def clear(self) -> None: + """Deletes cache dictionary and recreates it.""" + del self.cached_values + self.cached_values = {} + + def combine_hashes(self, *args: str) -> str: + """Join cache key components.""" + return os.pathsep.join(args) diff --git a/src/pydvl/utils/config.py b/src/pydvl/utils/config.py index 6e240bffc..8c77c0263 100644 --- a/src/pydvl/utils/config.py +++ b/src/pydvl/utils/config.py @@ -1,72 +1,4 @@ -from dataclasses import dataclass, field -from typing import Iterable, Optional, Tuple - -from pymemcache.serde import PickleSerde - from pydvl.parallel.config import ParallelConfig +from pydvl.utils.caching.config import CachedFuncConfig -PICKLE_VERSION = 5 # python >= 3.8 - - -__all__ = ["MemcachedClientConfig", "MemcachedConfig", "ParallelConfig"] - - -@dataclass(frozen=True) -class MemcachedClientConfig: - """Configuration of the memcached client. - - Args: - server: A tuple of (IP|domain name, port). - connect_timeout: How many seconds to wait before raising - `ConnectionRefusedError` on failure to connect. - timeout: seconds to wait for send or recv calls on the socket - connected to memcached. - no_delay: set the `TCP_NODELAY` flag, which may help with performance - in some cases. - serde: a serializer / deserializer ("serde"). The default `PickleSerde` - should work in most cases. See [pymemcached's - documentation](https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html#pymemcache.client.base.Client) - for details. - """ - - server: Tuple[str, int] = ("localhost", 11211) - connect_timeout: float = 1.0 - timeout: float = 1.0 - no_delay: bool = True - serde: PickleSerde = PickleSerde(pickle_version=PICKLE_VERSION) - - -@dataclass -class MemcachedConfig: - """Configuration for [memcached()][pydvl.utils.caching.memcached], providing - memoization of function calls. - - Instances of this class are typically used as arguments for the construction - of a [Utility][pydvl.utils.utility.Utility]. - - Args: - client_config: Configuration for the connection to the memcached server. - time_threshold: computations taking less time than this many seconds are - not cached. - allow_repeated_evaluations: If `True`, repeated calls to a function - with the same arguments will be allowed and outputs averaged until the - running standard deviation of the mean stabilises below - `rtol_stderr * mean`. - rtol_stderr: relative tolerance for repeated evaluations. More precisely, - [memcached()][pydvl.utils.caching.memcached] will stop evaluating - the function once the standard deviation of the mean is smaller than - `rtol_stderr * mean`. - min_repetitions: minimum number of times that a function evaluation - on the same arguments is repeated before returning cached values. Useful - for stochastic functions only. If the model training is very noisy, set - this number to higher values to reduce variance. - ignore_args: Do not take these keyword arguments into account when hashing - the wrapped function for usage as key in memcached. - """ - - client_config: MemcachedClientConfig = field(default_factory=MemcachedClientConfig) - time_threshold: float = 0.3 - allow_repeated_evaluations: bool = False - rtol_stderr: float = 0.1 - min_repetitions: int = 3 - ignore_args: Optional[Iterable[str]] = None +__all__ = ["CachedFuncConfig", "ParallelConfig"] From f5c947eb9bfffdefc1c64211a20c25e51c7581d2 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 11:07:23 +0100 Subject: [PATCH 03/29] Adapt Utility to caching change --- src/pydvl/utils/utility.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 767e7f9e1..302819f9a 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -25,7 +25,6 @@ """ import logging import warnings -from dataclasses import asdict from typing import Dict, FrozenSet, Iterable, Optional, Tuple, Union, cast import numpy as np @@ -34,8 +33,7 @@ from sklearn.metrics import check_scoring from pydvl.utils import Dataset -from pydvl.utils.caching import CacheStats, memcached, serialize -from pydvl.utils.config import MemcachedConfig +from pydvl.utils.caching import CacheBackendBase, CachedFuncConfig, CacheStats from pydvl.utils.score import Scorer from pydvl.utils.types import SupervisedModel @@ -134,8 +132,8 @@ def __init__( score_range: Tuple[float, float] = (-np.inf, np.inf), catch_errors: bool = True, show_warnings: bool = False, - enable_cache: bool = False, - cache_options: Optional[MemcachedConfig] = None, + cache: Optional[CacheBackendBase] = None, + cached_func_options: CachedFuncConfig = CachedFuncConfig(), clone_before_fit: bool = True, ): self.model = self._clone_model(model) @@ -148,10 +146,9 @@ def __init__( self.score_range = scorer.range if scorer is not None else np.array(score_range) self.catch_errors = catch_errors self.show_warnings = show_warnings - self.enable_cache = enable_cache - self.cache_options: MemcachedConfig = cache_options or MemcachedConfig() + self.cache = cache + self.cached_func_options = cached_func_options self.clone_before_fit = clone_before_fit - self._signature = serialize((hash(self.model), hash(data), hash(scorer))) self._initialize_utility_wrapper() # FIXME: can't modify docstring of methods. Instead, I could use a @@ -159,12 +156,9 @@ def __init__( # self.__call__.__doc__ = self._utility_wrapper.__doc__ def _initialize_utility_wrapper(self): - if self.enable_cache: - # asdict() is recursive, but we want client_config to remain a dataclass - options = asdict(self.cache_options) - options["client_config"] = self.cache_options.client_config - self._utility_wrapper = memcached(**options)( # type: ignore - self._utility, signature=self._signature + if self.cache is not None: + self._utility_wrapper = self.cache.wrap( + self._utility, cached_func_config=self.cached_func_options ) else: self._utility_wrapper = self._utility @@ -244,18 +238,13 @@ def _clone_model(model: SupervisedModel) -> SupervisedModel: model = cast(SupervisedModel, model) return model - @property - def signature(self): - """Signature used for caching model results.""" - return self._signature - @property def cache_stats(self) -> Optional[CacheStats]: """Cache statistics are gathered when cache is enabled. - See [CacheStats][pydvl.utils.caching.CacheStats] for all fields returned. + See [CacheStats][pydvl.utils.caching.base.CacheStats] for all fields returned. """ - if self.enable_cache: - return self._utility_wrapper.stats # type: ignore + if self.cache is not None: + return self._utility_wrapper.stats return None def __getstate__(self): From ab96b818e087fcc18d0e50dfbe0cda2827c16887 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 11:07:32 +0100 Subject: [PATCH 04/29] Adapt tests --- tests/utils/test_caching.py | 285 ++++++++++++++++++++++++++---------- tests/utils/test_utility.py | 59 +++----- 2 files changed, 235 insertions(+), 109 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index c30e38fd8..713d7337c 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -1,4 +1,5 @@ import logging +import tempfile from time import sleep, time from typing import Optional @@ -7,164 +8,300 @@ from numpy.typing import NDArray from pydvl.parallel import MapReduceJob -from pydvl.utils import memcached +from pydvl.utils.caching import ( + CachedFunc, + CachedFuncConfig, + DiskCacheBackend, + InMemoryCacheBackend, + MemcachedCacheBackend, +) from pydvl.utils.types import Seed logger = logging.getLogger(__name__) -def test_failed_connection(): +def foo(indices: NDArray[np.int_], *args, **kwargs) -> float: + return float(np.sum(indices)) + + +def foo_duplicate(indices: NDArray[np.int_], *args, **kwargs) -> float: + return float(np.sum(indices)) + + +def foo_with_random(indices: NDArray[np.int_], *args, **kwargs) -> float: + rng = np.random.default_rng() + scale = kwargs.get("scale", 1.0) + return float(np.sum(indices)) + rng.normal(scale=scale) + + +def foo_with_random_and_sleep(indices: NDArray[np.int_], *args, **kwargs) -> float: + sleep(0.01) + rng = np.random.default_rng() + scale = kwargs.get("scale", 1.0) + return float(np.sum(indices)) + rng.normal(scale=scale) + + +# Used to test caching of methods +class Test: + def __init__(self): + self.value = 0 + + def foo(self): + return 1 + + +@pytest.fixture(params=["in-memory", "disk", "memcached"]) +def cache(request): + backend: str = request.param + if backend == "in-memory": + cache = InMemoryCacheBackend() + yield cache + cache.clear() + elif backend == "disk": + with tempfile.TemporaryDirectory() as tempdir: + cache = DiskCacheBackend(tempdir) + yield cache + cache.clear() + elif backend == "memcached": + cache = MemcachedCacheBackend() + yield cache + cache.clear() + else: + raise ValueError(f"Unknown cache backend {backend}") + + +@pytest.mark.parametrize( + "f1, f2, expected_equal", + [ + # Test that the same function gets the same hash + (lambda x: x, lambda x: x, True), + (foo, foo, True), + # Test that different functions get different hashes + (foo, lambda x: x, False), + (foo, foo_with_random, False), + (foo_with_random, foo_with_random_and_sleep, False), + # Test that functions with different names but the same body get different hashes + (foo, foo_duplicate, True), + ], +) +def test_cached_func_hash_function(f1, f2, expected_equal): + f1_hash = CachedFunc._hash_function(f1) + f2_hash = CachedFunc._hash_function(f2) + if expected_equal: + assert f1_hash == f2_hash, f"{f1_hash} != {f2_hash}" + else: + assert f1_hash != f2_hash, f"{f1_hash} == {f2_hash}" + + +@pytest.mark.parametrize( + "args1, args2, expected_equal", + [ + # Test that the same arguments get the same hash + ([[]], [[]], True), + ([[1]], [[1]], True), + ([frozenset([])], [frozenset([])], True), + ([frozenset([1])], [frozenset([1])], True), + ([np.ones(3)], [np.ones(3)], True), + ([np.ones(3), 16], [np.ones(3), 16], True), + ([frozenset(np.ones(3))], [frozenset(np.ones(3))], True), + # Test that different arguments get different hashes + ([[1, 2, 3]], [np.ones(3)], False), + ([np.ones(3)], [np.ones(5)], False), + ([np.ones(3)], [frozenset(np.ones(3))], False), + ], +) +def test_cached_func_hash_arguments(args1, args2, expected_equal): + args1_hash = CachedFunc._hash_arguments(foo, ignore_args=[], args=args1, kwargs={}) + args2_hash = CachedFunc._hash_arguments(foo, ignore_args=[], args=args2, kwargs={}) + if expected_equal: + assert args1_hash == args2_hash, f"{args1_hash} != {args2_hash}" + else: + assert args1_hash != args2_hash, f"{args1_hash} == {args2_hash}" + + +def test_cached_func_hash_arguments_of_method(): + obj = Test() + + hash1 = CachedFunc._hash_arguments(obj.foo, [], tuple(), {}) + obj.value += 1 + hash2 = CachedFunc._hash_arguments(obj.foo, [], tuple(), {}) + assert hash1 == hash2 + + +def test_single_job(cache): + wrapped_foo = cache.wrap(foo) + + n = 1000 + wrapped_foo(np.arange(n)) + hits_before = wrapped_foo.stats.hits + wrapped_foo(np.arange(n)) + hits_after = wrapped_foo.stats.hits + + assert hits_after > hits_before + + +def test_memcached_failed_connection(): from pydvl.utils import MemcachedClientConfig - client_config = MemcachedClientConfig(server=("localhost", 0), connect_timeout=0.1) + config = MemcachedClientConfig(server=("localhost", 0), connect_timeout=0.1) with pytest.raises((ConnectionRefusedError, OSError)): - memcached(client_config)(lambda x: x) + MemcachedCacheBackend(config) -def test_memcached_single_job(memcached_client): - client, config = memcached_client +def test_cache_time_threshold(cache): + cached_func_config = CachedFuncConfig(time_threshold=1.0) + wrapped_foo = cache.wrap(foo, cached_func_config=cached_func_config) - # TODO: maybe this should be a fixture too... - @memcached(client_config=config, time_threshold=0) # Always cache results - def foo(indices: NDArray[np.int_]) -> float: - return float(np.sum(indices)) + n = 1000 + indices = np.arange(n) + wrapped_foo(indices) + hits_before = wrapped_foo.stats.hits + misses_before = wrapped_foo.stats.misses + wrapped_foo(indices) + hits_after = wrapped_foo.stats.hits + misses_after = wrapped_foo.stats.misses + + assert hits_after == hits_before + assert misses_after > misses_before + + +def test_cache_ignore_args(cache): + # Note that we typically do NOT want to ignore run_id + cached_func_config = CachedFuncConfig( + ignore_args=["job_id"], + ) + wrapped_foo = cache.wrap(foo, cached_func_config=cached_func_config) n = 1000 - foo(np.arange(n)) - hits_before = client.stats()[b"get_hits"] - foo(np.arange(n)) - hits_after = client.stats()[b"get_hits"] + indices = np.arange(n) + wrapped_foo(indices, job_id=0) + hits_before = wrapped_foo.stats.hits + wrapped_foo(indices, job_id=16) + hits_after = wrapped_foo.stats.hits assert hits_after > hits_before -def test_memcached_parallel_jobs(memcached_client, parallel_config): - client, config = memcached_client +def test_parallel_jobs(cache, parallel_config): + if not isinstance(cache, MemcachedCacheBackend): + pytest.skip("Only running this test with MemcachedCacheBackend") + if parallel_config.backend != "joblib": + pytest.skip("We don't have to test this with all parallel backends") - @memcached( - client_config=config, - time_threshold=0, # Always cache results - # Note that we typically do NOT want to ignore run_id + # Note that we typically do NOT want to ignore run_id + cached_func_config = CachedFuncConfig( ignore_args=["job_id", "run_id"], ) - def foo(indices: NDArray[np.int_], *args, **kwargs) -> float: - # logger.info(f"run_id: {run_id}, running...") - return float(np.sum(indices)) + wrapped_foo = cache.wrap(foo, cached_func_config=cached_func_config) n = 1234 n_runs = 10 - hits_before = client.stats()[b"get_hits"] + hits_before = cache.client.stats()[b"get_hits"] map_reduce_job = MapReduceJob( - np.arange(n), foo, np.sum, n_jobs=4, config=parallel_config + np.arange(n), wrapped_foo, np.sum, n_jobs=4, config=parallel_config ) results = [] for _ in range(n_runs): result = map_reduce_job() results.append(result) - hits_after = client.stats()[b"get_hits"] + hits_after = cache.client.stats()[b"get_hits"] assert results[0] == n * (n - 1) / 2 # Sanity check # FIXME! This is non-deterministic: if packets are delayed for longer than # the timeout configured then we won't have num_runs hits. So we add this # good old hard-coded magic number here. - assert hits_after - hits_before >= n_runs - 2 + assert hits_after - hits_before >= n_runs - 2, wrapped_foo.stats -def test_memcached_repeated_training(memcached_client, worker_id: str): - _, config = memcached_client - - @memcached( - client_config=config, - time_threshold=0, # Always cache results +def test_repeated_training(cache, worker_id: str): + cached_func_config = CachedFuncConfig( allow_repeated_evaluations=True, rtol_stderr=0.01, ) - def foo(indices: NDArray[np.int_], uid: str) -> float: - return float(np.sum(indices)) + np.random.normal(scale=10) + wrapped_foo = cache.wrap( + foo_with_random, + cached_func_config=cached_func_config, + ) n = 7 - foo(np.arange(n), worker_id) - for _ in range(10_000): - result = foo(np.arange(n), worker_id) + indices = np.arange(n) - assert (result - np.sum(np.arange(n))) < 1 - assert foo.stats.sets < foo.stats.hits + for _ in range(1_000): + result = wrapped_foo(indices, worker_id) + assert np.isclose(result, np.sum(indices), atol=1) + assert wrapped_foo.stats.sets < wrapped_foo.stats.hits -def test_memcached_faster_with_repeated_training(memcached_client, worker_id: str): - _, config = memcached_client - @memcached( - client_config=config, - time_threshold=0, # Always cache results +def test_faster_with_repeated_training(cache, worker_id: str): + cached_func_config = CachedFuncConfig( allow_repeated_evaluations=True, rtol_stderr=0.1, ) - def foo_cache(indices: NDArray[np.int_], uid: str) -> float: - sleep(0.01) - return float(np.sum(indices)) + np.random.normal(scale=1) - - def foo_no_cache(indices: NDArray[np.int_], uid: str) -> float: - sleep(0.01) - return float(np.sum(indices)) + np.random.normal(scale=1) + wrapped_foo = cache.wrap( + foo_with_random_and_sleep, + cached_func_config=cached_func_config, + ) n = 3 - foo_cache(np.arange(n), worker_id) - foo_no_cache(np.arange(n), worker_id) + n_repetitions = 500 + indices = np.arange(n) start = time() - for _ in range(300): - result_fast = foo_cache(np.arange(n), worker_id) + for _ in range(n_repetitions): + result_fast = wrapped_foo(indices, worker_id) end = time() fast_time = end - start start = time() results_slow = [] - for _ in range(300): - result = foo_no_cache(np.arange(n), worker_id) + for _ in range(n_repetitions): + result = foo_with_random_and_sleep(indices, worker_id) results_slow.append(result) end = time() slow_time = end - start - assert (result_fast - np.mean(results_slow)) < 1 + assert np.isclose(np.mean(results_slow), np.sum(indices), atol=0.1) + assert np.isclose(result_fast, np.mean(results_slow), atol=1) assert fast_time < slow_time @pytest.mark.parametrize("n, atol", [(10, 5), (20, 10)]) @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("n_runs", [20]) -def test_memcached_parallel_repeated_training( - memcached_client, n, atol, n_jobs, n_runs, parallel_config, seed -): +def test_parallel_repeated_training(cache, n, atol, n_jobs, n_runs, parallel_config): if parallel_config.backend != "joblib": pytest.skip("We don't have to test this with all parallel backends") - _, config = memcached_client - @memcached( - client_config=config, - time_threshold=0, # Always cache results + def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: + return np.sum(indices).item() + np.random.normal(scale=1) + + # Note that we typically do NOT want to ignore run_id + cached_func_config = CachedFuncConfig( allow_repeated_evaluations=True, rtol_stderr=0.01, - # Note that we typically do NOT want to ignore run_id ignore_args=["job_id", "run_id"], ) - def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: - # from pydvl.utils.logging import logger - # logger.info(f"run_id: {run_id}, running...") - rng = np.random.default_rng(seed) - return np.sum(indices).item() + rng.normal(scale=5) + wrapped_map_func = cache.wrap( + map_func, + cached_func_config=cached_func_config, + ) def reduce_func(chunks: NDArray[np.float_]) -> float: return np.sum(chunks).item() map_reduce_job = MapReduceJob( - np.arange(n), map_func, reduce_func, n_jobs=n_jobs, config=parallel_config + np.arange(n), + wrapped_map_func, + reduce_func, + n_jobs=n_jobs, + config=parallel_config, ) results = [] for _ in range(n_runs): - result = map_reduce_job(seed=seed) + result = map_reduce_job() results.append(result) exact_value = np.sum(np.arange(n)).item() diff --git a/tests/utils/test_utility.py b/tests/utils/test_utility.py index dce0e3830..dddc172ec 100644 --- a/tests/utils/test_utility.py +++ b/tests/utils/test_utility.py @@ -1,11 +1,13 @@ # TODO add more tests! +import pickle import warnings import numpy as np import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import DataUtilityLearning, MemcachedConfig, Scorer, Utility, powerset +from pydvl.utils import DataUtilityLearning, Scorer, Utility, powerset +from pydvl.utils.caching import InMemoryCacheBackend @pytest.mark.parametrize("show_warnings", [False, True]) @@ -27,15 +29,14 @@ def score(self, x, y): utility = Utility( model=WarningModel(), data=housing_dataset, - enable_cache=False, show_warnings=show_warnings, ) utility([0]) if show_warnings: - assert len(recwarn) >= 1 + assert len(recwarn) >= 1, recwarn.list else: - assert len(recwarn) == 0 + assert len(recwarn) == 0, recwarn.list # noinspection PyUnresolvedReferences @@ -46,7 +47,6 @@ def test_data_utility_learning_wrapper(linear_dataset, training_budget): model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - enable_cache=False, ) wrapped_u = DataUtilityLearning(u, training_budget, LinearRegression()) subsets = list(powerset(wrapped_u.utility.data.indices)) @@ -59,51 +59,40 @@ def test_data_utility_learning_wrapper(linear_dataset, training_budget): # noinspection PyUnresolvedReferences @pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) -def test_cache(linear_dataset, memcache_client_config): +def test_utility_with_cache(linear_dataset): u = Utility( model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - enable_cache=True, - cache_options=MemcachedConfig( - client_config=memcache_client_config, time_threshold=0 - ), + cache=InMemoryCacheBackend(), ) subsets = list(powerset(u.data.indices)) for s in subsets: u(s) - assert u._utility_wrapper.stats.hits == 0 + assert u._utility_wrapper.stats.hits == 0, u._utility_wrapper.stats for s in subsets: u(s) - assert u._utility_wrapper.stats.hits == len(subsets) + + assert u._utility_wrapper.stats.hits == len(subsets), u._utility_wrapper.stats @pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) -@pytest.mark.parametrize("model_kwargs", [({}, {}), ({}, {"fit_intercept": False})]) -def test_different_cache_signature( - linear_dataset, memcache_client_config, model_kwargs -): - u1 = Utility( - model=LinearRegression(**model_kwargs[0]), - data=linear_dataset, - scorer=Scorer("r2"), - enable_cache=True, - cache_options=MemcachedConfig( - client_config=memcache_client_config, time_threshold=0 - ), - ) - u2 = Utility( - model=LinearRegression(**model_kwargs[1]), +@pytest.mark.parametrize("use_cache", [False, True]) +def test_utility_serialization(linear_dataset, use_cache): + if use_cache: + cache = InMemoryCacheBackend() + else: + cache = None + u = Utility( + model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - enable_cache=True, - cache_options=MemcachedConfig( - client_config=memcache_client_config, time_threshold=0 - ), + cache=cache, ) - - assert u1.signature != u2.signature - assert u1.signature == u1.signature - assert u2.signature == u2.signature + u_unpickled = pickle.loads(pickle.dumps(u)) + assert type(u.model) == type(u_unpickled.model) + assert type(u.scorer) == type(u_unpickled.scorer) + assert type(u.data) == type(u_unpickled.data) + assert (u.data.x_train == u_unpickled.data.x_train).all() From 0811edeb25468bbf72268af5a338e84c1344b315 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 14:17:39 +0100 Subject: [PATCH 05/29] Remove caching section from readme --- README.md | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/README.md b/README.md index 57bf56d33..6b1aff368 100644 --- a/README.md +++ b/README.md @@ -199,21 +199,6 @@ and for influence functions (e.g. [Influence Functions for Neural Networks](https://pydvl.org/stable/examples/influence_imagenet/)) with details on the algorithms and their applications. -## Caching - -pyDVL offers the possibility to cache certain results and -speed up computation. It uses [Memcached](https://memcached.org/) For that. - -You can run it either locally or, using -[Docker](https://www.docker.com/): - -```shell -docker container run --rm -p 11211:11211 --name pydvl-cache -d memcached:latest -``` - -You can read more in the -[documentation](https://pydvl.org/stable/getting-started/first-steps/#caching). - # Contributing Please open new issues for bugs, feature requests and extensions. You can read From 44c5f5755860804921d2feecfb35749495e5c1c8 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 17:42:22 +0100 Subject: [PATCH 06/29] Rename CacheBackendBase to CacheBackend, improve docstrings --- src/pydvl/utils/caching/__init__.py | 6 ++ src/pydvl/utils/caching/base.py | 59 ++++------- src/pydvl/utils/caching/disk.py | 47 ++++++++- src/pydvl/utils/caching/memcached.py | 142 ++++++++++----------------- src/pydvl/utils/caching/memory.py | 46 ++++++++- src/pydvl/utils/utility.py | 11 ++- 6 files changed, 171 insertions(+), 140 deletions(-) diff --git a/src/pydvl/utils/caching/__init__.py b/src/pydvl/utils/caching/__init__.py index 6a1a3e867..6c98bf9cb 100644 --- a/src/pydvl/utils/caching/__init__.py +++ b/src/pydvl/utils/caching/__init__.py @@ -32,6 +32,12 @@ ) ``` +# Supported Backends + +- [InMemoryCacheBackend][] +- [DiskCacheBackend][] +- [MemcachedCacheBackend][] + # Usage with stochastic functions In addition to standard memoization, the decorator diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py index 339c85389..00f26d67e 100644 --- a/src/pydvl/utils/caching/base.py +++ b/src/pydvl/utils/caching/base.py @@ -11,7 +11,7 @@ from ..numeric import running_moments from .config import CachedFuncConfig -__all__ = ["CacheStats", "CacheBackendBase", "CachedFunc"] +__all__ = ["CacheStats", "CacheBackend", "CachedFunc"] T = TypeVar("T") @@ -20,15 +20,15 @@ @dataclass class CacheStats: - """Statistics gathered by cached functions. + """Class used to store statistics gathered by cached functions. Attributes: - sets: number of times a value was set in the cache - misses: number of times a value was not found in the cache - hits: number of times a value was found in the cache - timeouts: number of times a timeout occurred - errors: number of times an error occurred - reconnects: number of times the client reconnected to the server + sets: Number of times a value was set in the cache. + misses: Number of times a value was not found in the cache. + hits: Number of times a value was found in the cache. + timeouts: Number of times a timeout occurred. + errors: Number of times an error occurred. + reconnects: Number of times the client reconnected to the server. """ sets: int = 0 @@ -41,13 +41,13 @@ class CacheStats: @dataclass class CacheResult: - """A dataclass to store the cached result of a computation + """A class used to store the cached result of a computation as well as count and variance when using repeated evaluation. Attributes: - value: The cached value. - count: The number of times this value has been computed. - variance: The variance associated with the cached value. + value: Cached value. + count: Number of times this value has been computed. + variance: Variance associated with the cached value. """ value: float @@ -55,11 +55,11 @@ class CacheResult: variance: float = 0.0 -class CacheBackendBase(ABC): +class CacheBackend(ABC): """Abstract base class for cache backends. - Implements cache management including wrapping functions to cache results, - retrieving results, setting results, clearing the cache, and combining cache keys. + Defines interface for cache access including wrapping callables, + getting/setting results, clearing cache, and combining cache keys. Attributes: stats: Cache statistics tracker. @@ -130,39 +130,22 @@ class CachedFunc: """Caches callable function results with a provided cache backend. Wraps a callable function to cache its results using a provided - an instance of CacheBackendBase subclass. - Can configure cache keys, repeated evaluations, and cache stats tracking. + an instance of a subclass of [CacheBackend][pydvl.utils.caching.base.CacheBackend]. + + This class is heavily inspired from that of [joblib.memory.MemorizedFunc][]. Args: func: Callable to wrap. - cache_backend: An instance of CacheBackendBase that handles + cache_backend: Instance of CacheBackendBase that handles setting and getting values. - cached_func_options: - ignore_args: Do not take these keyword arguments into account when - hashing the wrapped function for usage as key. This allows - sharing the cache among different jobs for the same experiment run if - the callable happens to have "nuisance" parameters like `job_id` which - do not affect the result of the computation. - time_threshold: computations taking less time than this many seconds are - not cached. - allow_repeated_evaluations: If `True`, repeated calls to a function - with the same arguments will be allowed and outputs averaged until the - running standard deviation of the mean stabilizes below - `rtol_stderr * mean`. - rtol_stderr: relative tolerance for repeated evaluations. More precisely, - [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function once the - standard deviation of the mean is smaller than `rtol_stderr * mean`. - min_repetitions: minimum number of times that a function evaluation - on the same arguments is repeated before returning cached values. Useful - for stochastic functions only. If the model training is very noisy, set - this number to higher values to reduce variance. + cached_func_options: Configuration for wrapped function. """ def __init__( self, func: Callable[..., T], *, - cache_backend: CacheBackendBase, + cache_backend: CacheBackend, cached_func_options: CachedFuncConfig = CachedFuncConfig(), ) -> None: self.func = func diff --git a/src/pydvl/utils/caching/disk.py b/src/pydvl/utils/caching/disk.py index e65ff040b..0966385d4 100644 --- a/src/pydvl/utils/caching/disk.py +++ b/src/pydvl/utils/caching/disk.py @@ -5,7 +5,7 @@ import cloudpickle -from pydvl.utils.caching.base import CacheBackendBase +from pydvl.utils.caching.base import CacheBackend __all__ = ["DiskCacheBackend"] @@ -14,14 +14,54 @@ DEFAULT_CACHE_DIR = Path().home() / ".pydvl_cache/disk" -class DiskCacheBackend(CacheBackendBase): +class DiskCacheBackend(CacheBackend): """Disk cache backend that stores results in files. - Implements the CacheBackendBase interface for a disk-based cache. + Implements the CacheBackend interface for a disk-based cache. Stores cache entries as pickled files on disk, keyed by cache key. + This allows sharing evaluations across processes in a single node/computer. + + Args: + cache_dir: Base directory for cache storage. Attributes: cache_dir: Base directory for cache storage. + + ??? Examples + ``` pycon + >>> from pydvl.utils.caching.disk import DiskCacheBackend + >>> cache = DiskCacheBackend(cache_dir="/tmp/pydvl_disk_cache") + >>> cache.clear() + >>> value = 42 + >>> cache.set("key", value) + >>> cache.get("key") + 42 + ``` + + ``` pycon + >>> from pydvl.utils.caching.disk import DiskCacheBackend + >>> cache = DiskCacheBackend(cache_dir="/tmp/pydvl_disk_cache") + >>> cache.clear() + >>> value = 42 + >>> def foo(x: int): + ... return x + 1 + ... + >>> wrapped_foo = cache.wrap(foo) + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 0 + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 1 + ``` + + """ def __init__( @@ -32,6 +72,7 @@ def __init__( Args: cache_dir: Base directory for cache storage. + By default, this is set to `~/.pydvl_cache/disk` """ super().__init__() self.cache_dir = Path(cache_dir) diff --git a/src/pydvl/utils/caching/memcached.py b/src/pydvl/utils/caching/memcached.py index 1f04ad2f5..ebef3b764 100644 --- a/src/pydvl/utils/caching/memcached.py +++ b/src/pydvl/utils/caching/memcached.py @@ -1,84 +1,3 @@ -""" Distributed caching of functions. - -pyDVL uses [memcached](https://memcached.org) to cache utility values, through -[pymemcache](https://pypi.org/project/pymemcache). This allows sharing -evaluations across processes and nodes in a cluster. You can run memcached as a -service, locally or remotely, see [Setting up the cache](#setting-up-the-cache) - -!!! Warning - Function evaluations are cached with a key based on the function's signature - and code. This can lead to undesired cache hits, see [Cache reuse](#cache-reuse). - - Remember **not to reuse utility objects for different datasets**. - -# Configuration - -Memoization is disabled by default but can be enabled easily, -see [Setting up the cache](#setting-up-the-cache). -When enabled, it will be added to any callable used to construct a -[Utility][pydvl.utils.utility.Utility] (done with the decorator [@memcached][pydvl.utils.caching.memcached]). -Depending on the nature of the utility you might want to -enable the computation of a running average of function values, see -[Usage with stochastic functions](#usaage-with-stochastic-functions). -You can see all configuration options under [MemcachedConfig][pydvl.utils.config.MemcachedConfig]. - -## Default configuration - -```python -default_config = dict( - server=('localhost', 11211), - connect_timeout=1.0, - timeout=0.1, - # IMPORTANT! Disable small packet consolidation: - no_delay=True, - serde=serde.PickleSerde(pickle_version=PICKLE_VERSION) -) -``` - -# Usage with stochastic functions - -In addition to standard memoization, the decorator -[memcached()][pydvl.utils.caching.memcached] can compute running average and -standard error of repeated evaluations for the same input. This can be useful -for stochastic functions with high variance (e.g. model training for small -sample sizes), but drastically reduces the speed benefits of memoization. - -This behaviour can be activated with the argument `allow_repeated_evaluations` -to [memcached()][pydvl.utils.caching.memcached]. - -# Cache reuse - -When working directly with [memcached()][pydvl.utils.caching.memcached], it is -essential to only cache pure functions. If they have any kind of state, either -internal or external (e.g. a closure over some data that may change), then the -cache will fail to notice this and the same value will be returned. - -When a function is wrapped with [memcached()][pydvl.utils.caching.memcached] for -memoization, its signature (input and output names) and code are used as a key -for the cache. Alternatively you can pass a custom value to be used as key with - -```python -cached_fun = memcached(**asdict(cache_options))(fun, signature=custom_signature) -``` - -If you are running experiments with the same [Utility][pydvl.utils.utility.Utility] -but different datasets, this will lead to evaluations of the utility on new data -returning old values because utilities only use sample indices as arguments (so -there is no way to tell the difference between '1' for dataset A and '1' for -dataset 2 from the point of view of the cache). One solution is to empty the -cache between runs, but the preferred one is to **use a different Utility -object for each dataset**. - -# Unexpected cache misses - -Because all arguments to a function are used as part of the key for the cache, -sometimes one must exclude some of them. For example, If a function is going to -run across multiple processes and some reporting arguments are added (like a -`job_id` for logging purposes), these will be part of the signature and make the -functions distinct to the eyes of the cache. This can be avoided with the use of -[ignore_args][pydvl.utils.config.MemcachedConfig] in the configuration. - -""" from __future__ import annotations import logging @@ -92,7 +11,7 @@ from pymemcache.client import Client, RetryingClient from pymemcache.serde import PickleSerde -from .base import CacheBackendBase +from .base import CacheBackend __all__ = ["MemcachedClientConfig", "MemcachedCacheBackend"] @@ -109,13 +28,12 @@ class MemcachedClientConfig: server: A tuple of (IP|domain name, port). connect_timeout: How many seconds to wait before raising `ConnectionRefusedError` on failure to connect. - timeout: seconds to wait for send or recv calls on the socket - connected to memcached. - no_delay: set the `TCP_NODELAY` flag, which may help with performance - in some cases. - serde: a serializer / deserializer ("serde"). The default `PickleSerde` - should work in most cases. See [pymemcached's - documentation](https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html#pymemcache.client.base.Client) + timeout: Duration in seconds to wait for send or recv calls + on the socket connected to memcached. + no_delay: If True, set the `TCP_NODELAY` flag, which may help + with performance in some cases. + serde: Serializer / Deserializer ("serde"). The default `PickleSerde` + should work in most cases. See [pymemcache.client.base.Client][] for details. """ @@ -126,14 +44,54 @@ class MemcachedClientConfig: serde: PickleSerde = PickleSerde(pickle_version=PICKLE_VERSION) -class MemcachedCacheBackend(CacheBackendBase): - """Memcached cache backend. +class MemcachedCacheBackend(CacheBackend): + """Memcached cache backend for the distributed caching of functions. - Implements CacheBackendBase using a memcached client. + Implements the CacheBackend interface for a memcached based cache. + This allows sharing evaluations across processes and nodes in a cluster. + You can run memcached as a service, locally or remotely, + see [Setting up the cache](#setting-up-the-cache) + + Args: + config: Memcached client configuration. Attributes: config: Memcached client configuration. client: Memcached client instance. + + ??? Examples + ``` pycon + >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend + >>> cache = MemcachedCacheBackend() + >>> cache.clear() + >>> value = 42 + >>> cache.set("key", value) + >>> cache.get("key") + 42 + ``` + + ``` pycon + >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend + >>> cache = MemcachedCacheBackend() + >>> cache.clear() + >>> value = 42 + >>> def foo(x: int): + ... return x + 1 + ... + >>> wrapped_foo = cache.wrap(foo) + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 0 + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 1 + ``` """ def __init__(self, config: MemcachedClientConfig = MemcachedClientConfig()) -> None: diff --git a/src/pydvl/utils/caching/memory.py b/src/pydvl/utils/caching/memory.py index 9ffc48fa5..064ce0d3d 100644 --- a/src/pydvl/utils/caching/memory.py +++ b/src/pydvl/utils/caching/memory.py @@ -1,16 +1,56 @@ import os from typing import Any, Optional -from pydvl.utils.caching.base import CacheBackendBase +from pydvl.utils.caching.base import CacheBackend __all__ = ["InMemoryCacheBackend"] -class InMemoryCacheBackend(CacheBackendBase): +class InMemoryCacheBackend(CacheBackend): """In-memory cache backend that stores results in a dictionary. - Implements the CacheBackendBase interface for an in-memory-based cache. + Implements the CacheBackend interface for an in-memory-based cache. Stores cache entries as values in a dictionary, keyed by cache key. + This allows sharing evaluations across threads in a single process. + + The implementation is not thread-safe. + + Attributes: + cached_values: Dictionary used to store cached values. + + ??? Examples + ``` pycon + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> cache = InMemoryCacheBackend() + >>> cache.clear() + >>> value = 42 + >>> cache.set("key", value) + >>> cache.get("key") + 42 + ``` + + ``` pycon + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> cache = InMemoryCacheBackend() + >>> cache.clear() + >>> value = 42 + >>> def foo(x: int): + ... return x + 1 + ... + >>> wrapped_foo = cache.wrap(foo) + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 0 + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 1 + ``` """ def __init__(self) -> None: diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 302819f9a..8d2ca7205 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -33,7 +33,7 @@ from sklearn.metrics import check_scoring from pydvl.utils import Dataset -from pydvl.utils.caching import CacheBackendBase, CachedFuncConfig, CacheStats +from pydvl.utils.caching import CacheBackend, CachedFuncConfig, CacheStats from pydvl.utils.score import Scorer from pydvl.utils.types import SupervisedModel @@ -100,8 +100,11 @@ class Utility: calculations. When this happens, the `default_score` is returned as a score and computation continues. show_warnings: Set to `False` to suppress warnings thrown by `fit()`. - enable_cache: If `True`, use memcached for memoization. - cache_options: Optional configuration object for memcached. + cache: Optional instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend] + used to wrap the _utility method of the Utility instance. + By default, this is set to None and that means that the utility evaluations + will not be cached. + cached_func_options: Optional configuration object for cached utility evaluation. clone_before_fit: If `True`, the model will be cloned before calling `fit()`. @@ -132,7 +135,7 @@ def __init__( score_range: Tuple[float, float] = (-np.inf, np.inf), catch_errors: bool = True, show_warnings: bool = False, - cache: Optional[CacheBackendBase] = None, + cache: Optional[CacheBackend] = None, cached_func_options: CachedFuncConfig = CachedFuncConfig(), clone_before_fit: bool = True, ): From 073745ffba51296db5b85582c63ec7a82ac39de2 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 17:43:14 +0100 Subject: [PATCH 07/29] Make pymemcached an optional dependency, define new memcached extra --- requirements.txt | 1 - setup.py | 1 + tox.ini | 1 + 3 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 471ba7475..aedacc12f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ scikit-learn scipy>=1.7.0 cvxpy>=1.3.0 joblib -pymemcache cloudpickle tqdm matplotlib diff --git a/setup.py b/setup.py index 1883f176e..d23fe0a29 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ tests_require=["pytest"], extras_require={ "cupy": ["cupy-cuda11x>=12.1.0"], + "memcached": ["pymemcache"], "influence": ["torch>=2.0.0"], "ray": ["ray>=0.8"], }, diff --git a/tox.ini b/tox.ini index f48a3d9fd..666a5760c 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,7 @@ passenv = extras = ray influence + memcached commands = pytest -n auto --dist worksteal --cov "{envsitepackagesdir}/pydvl" {posargs} From 6b2e3a7bbe456c31a9fe244b9af9d2fd01f88e13 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 17:43:46 +0100 Subject: [PATCH 08/29] Add joblib documentation inventory --- mkdocs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/mkdocs.yml b/mkdocs.yml index c4a80316a..f33c43574 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,6 +69,7 @@ plugins: - https://scikit-learn.org/stable/objects.inv - https://pytorch.org/docs/stable/objects.inv - https://pymemcache.readthedocs.io/en/latest/objects.inv + - https://joblib.readthedocs.io/en/stable/objects.inv paths: [ src ] # search packages in the src folder options: docstring_style: google From 6b5e6ee46662a5421bcdb72a45074b8c362e0c23 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 17:44:04 +0100 Subject: [PATCH 09/29] Update and improve installation and first-steps docs --- docs/getting-started/first-steps.md | 89 +++++++++++++++---- docs/getting-started/installation.md | 123 ++++++++++++++------------- 2 files changed, 135 insertions(+), 77 deletions(-) diff --git a/docs/getting-started/first-steps.md b/docs/getting-started/first-steps.md index a86cf6307..dcbe54d24 100644 --- a/docs/getting-started/first-steps.md +++ b/docs/getting-started/first-steps.md @@ -9,8 +9,7 @@ alias: !!! Warning Make sure you have read [[installation]] before using the library. - In particular read about how caching and parallelization work, - since they might require additional setup. + In particular read about which extra dependencies you may need. ## Main concepts @@ -23,7 +22,6 @@ should be enough to get you started. computation and related methods. * [[influence-values]] for instructions on how to compute influence functions. - ## Running the examples If you are somewhat familiar with the concepts of data valuation, you can start @@ -38,21 +36,20 @@ by browsing our worked-out examples illustrating pyDVL's capabilities either: # Advanced usage -Besides the do's and don'ts of data valuation itself, which are the subject of +Besides the dos and don'ts of data valuation itself, which are the subject of the examples and the documentation of each method, there are two main things to keep in mind when using pyDVL. ## Caching -pyDVL uses [memcached](https://memcached.org/) to cache the computation of the -utility function and speed up some computations (see the [installation -guide](installation.md/#setting-up-the-cache)). - -Caching of the utility function is disabled by default. When it is enabled it -takes into account the data indices passed as argument and the utility function -wrapped into the [Utility][pydvl.utils.utility.Utility] object. This means that +PyDVL can cache the computation of the utility function +and speed up some computations for data valuation. +It is however disabled by default. +When it is enabled it takes into account the data indices passed as argument +and the utility function wrapped into the +[Utility][pydvl.utils.utility.Utility] object. This means that care must be taken when reusing the same utility function with different data, -see the documentation for the [caching module][pydvl.utils.caching] for more +see the documentation for the [caching package][pydvl.utils.caching] for more information. In general, caching won't play a major role in the computation of Shapley values @@ -61,20 +58,65 @@ the same utility function computation, is very low. However, it can be very useful when comparing methods that use the same utility function, or when running multiple experiments with the same data. +pyDVL supports different caching backends: + +- [InMemoryCacheBackend][pydvl.utils.caching.memory.InMemoryCacheBackend]: + an in-memory cache backend that uses a dictionary to store and retrieve + cached values. This is used to share cached values between threads + in a single process. +- [DiskCacheBackend][pydvl.utils.caching.disk.DiskCacheBackend]: + a disk-based cache backend that uses pickled values written to and read from disk. + This is used to share cached values between processes in a single machine. +- [MemcachedCacheBackend][pydvl.utils.caching.memcached.MemcachedCacheBackend]: + a [Memcached](https://memcached.org/)-based cache backend that uses pickled values written to + and read from a Memcached server. This is used to share cached values + between processes across multiple machines. + + **Note** This specific backend requires optional dependencies. + See [[installation#extras]] for more information) + !!! tip "When is the cache really necessary?" Crucially, semi-value computations with the [PermutationSampler][pydvl.value.sampler.PermutationSampler] require caching to be enabled, or they will take twice as long as the direct implementation in [compute_shapley_values][pydvl.value.shapley.compute_shapley_values]. +!!! tip "Using the cache" + Continue reading about the cache in the documentation + for the [caching package][pydvl.utils.caching]. + +### Setting up the Memcached cache + +[Memcached](https://memcached.org/) is an in-memory key-value store accessible +over the network. pyDVL can use it to cache the computation of the utility function +and speed up some computations (in particular, semi-value computations with the +[PermutationSampler][pydvl.value.sampler.PermutationSampler] but other methods +may benefit as well). + +You can either install it as a package or run it inside a docker container (the +simplest). For installation instructions, refer to the [Getting +started](https://github.com/memcached/memcached/wiki#getting-started) section in +memcached's wiki. Then you can run it with: + +```shell +memcached -u user +``` + +To run memcached inside a container in daemon mode instead, use: + +```shell +docker container run -d --rm -p 11211:11211 memcached:latest +``` + ## Parallelization -pyDVL supports [joblib](https://joblib.readthedocs.io/en/latest/) for local -parallelization (within one machine) and [ray](https://ray.io) for distributed -parallelization (across multiple machines). +pyDVL uses [joblib](https://joblib.readthedocs.io/en/latest/) for local +parallelization (within one machine) and supports using +[Ray](https://ray.io) for distributed parallelization (across multiple machines). -The former works out of the box but for the latter you will need to provide a -running cluster (or run ray in local mode). +The former works out of the box but for the latter you will need to install +additional dependencies (see [[installation#extras]] ) +and to provide a running cluster (or run ray in local mode). As of v0.7.0 pyDVL does not allow requesting resources per task sent to the cluster, so you will need to make sure that each worker has enough resources to @@ -82,3 +124,16 @@ handle the tasks it receives. A data valuation task using game-theoretic methods will typically make a copy of the whole model and dataset to each worker, even if the re-training only happens on a subset of the data. This means that you should make sure that each worker has enough memory to handle the whole dataset. + +### Ray + +Please follow the instructions in Ray's documentation to set up a cluster. +Once you have a running cluster, you can use it by passing the address +of the head node to parallel methods via [ParallelConfig][pydvl.parallel.config.ParallelConfig]. + +For a local ray cluster you would use: + +```python +from pydvl.parallel.config import ParallelConfig +config = ParallelConfig(backend="ray") +``` diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md index 2d2164ada..125f81d13 100644 --- a/docs/getting-started/installation.md +++ b/docs/getting-started/installation.md @@ -13,32 +13,6 @@ To install the latest release use: pip install pyDVL ``` -To use all features of influence functions use instead: - -```shell -pip install pyDVL[influence] -``` - -This includes a dependency on [PyTorch](https://pytorch.org/) (Version 2.0 and -above) and thus is left out by default. - -In case that you have a supported version of CUDA installed (v11.2 to 11.8 as of -this writing), you can enable eigenvalue computations for low-rank approximations -with [CuPy](https://docs.cupy.dev/en/stable/index.html) on the GPU by using: - -```shell -pip install pyDVL[cupy] -``` - -If you use a different version of CUDA, please install CuPy -[manually](https://docs.cupy.dev/en/stable/install.html). - -In order to check the installation you can use: - -```shell -python -c "import pydvl; print(pydvl.__version__)" -``` - You can also install the latest development version from [TestPyPI](https://test.pypi.org/project/pyDVL/): @@ -46,42 +20,71 @@ You can also install the latest development version from pip install pyDVL --index-url https://test.pypi.org/simple/ ``` -## Dependencies - -pyDVL requires Python >= 3.8, [Memcached](https://memcached.org/) for caching -and [Ray](https://ray.io) for parallelization in a cluster (locally it uses joblib). -Additionally, the [Influence functions][pydvl.influence] module requires PyTorch -(see [[installation]]). - -ray is used to distribute workloads across nodes in a cluster (it can be used -locally as well, but for this we recommend joblib instead). Please follow the -instructions in their documentation to set up the cluster. Once you have a -running cluster, you can use it by passing the address of the head node to -parallel methods via [ParallelConfig][pydvl.utils.parallel]. - -## Setting up the cache - -[memcached](https://memcached.org/) is an in-memory key-value store accessible -over the network. pyDVL uses it to cache the computation of the utility function -and speed up some computations (in particular, semi-value computations with the -[PermutationSampler][pydvl.value.sampler.PermutationSampler] but other methods -may benefit as well). - -You can either install it as a package or run it inside a docker container (the -simplest). For installation instructions, refer to the [Getting -started](https://github.com/memcached/memcached/wiki#getting-started) section in -memcached's wiki. Then you can run it with: +In order to check the installation you can use: ```shell -memcached -u user +python -c "import pydvl; print(pydvl.__version__)" ``` -To run memcached inside a container in daemon mode instead, do: - -```shell -docker container run -d --rm -p 11211:11211 memcached:latest -``` +## Dependencies -!!! tip "Using the cache" - Continue reading about the cache in the [First Steps](first-steps.md#caching) - and the documentation for the [caching module][pydvl.utils.caching]. +pyDVL requires Python >= 3.8, [numpy](https://numpy.org/), +[scikit-learn](https://scikit-learn.org/stable/), [scipy](https://scipy.org/), +[cvxpy](https://www.cvxpy.org/) for the Core methods, +and [joblib](https://joblib.readthedocs.io/en/stable/) +for parallelization locally. Additionally,the [Influence functions][pydvl.influence] +module requires PyTorch (see [[installation#extras]]). + +### Extras + +pyDVL has a few [extra](https://peps.python.org/pep-0508/#extras) dependencies +that can be optionally installed: + +- `influence`: + + To use all features of influence functions use instead: + + ```shell + pip install pyDVL[influence] + ``` + + This includes a dependency on [PyTorch](https://pytorch.org/) (Version 2.0 and + above) and thus is left out by default. + +- `cupy`: + + In case that you have a supported version of CUDA installed (v11.2 to 11.8 as of + this writing), you can enable eigenvalue computations for low-rank approximations + with [CuPy](https://docs.cupy.dev/en/stable/index.html) on the GPU by using: + + ```shell + pip install pyDVL[cupy] + ``` + + This installs [cupy-cuda11x](https://pypi.org/project/cupy-cuda11x/). + + If you use a different version of CUDA, please install CuPy + [manually](https://docs.cupy.dev/en/stable/install.html). + +- `ray`: + + If you want to use [Ray](https://www.ray.io/) to distribute data valuation + workloads across nodes in a cluster (it can be used locally as well, + but for this we recommend joblib instead) install pyDVL using: + + ```shell + pip install pyDVL[ray] + ``` + + see [[getting-started#ray]] for more details on how to use it. + +- `memcached`: + + If you want to use [Memcached](https://memcached.org/) for caching + utility evaluations, use: + + ```shell + pip install pyDVL[memcached] + ``` + + This installs [pymemcache](https://github.com/pinterest/pymemcache) additionally. From 82fa95257004b439349e9e1218c810a9b83937cd Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 17:44:21 +0100 Subject: [PATCH 10/29] Add link to extras section of docs to readme --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6b1aff368..74aa31fa6 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,9 @@ pip install pyDVL --index-url https://test.pypi.org/simple/ For more instructions and information refer to [Installing pyDVL ](https://pydvl.org/stable/getting-started/installation/) in the -documentation. +documentation. Refer more specifically to the [Dependencies +](https://pydvl.org/stable/getting-started/installation/#extras) +section for a list of extra requirements. # Usage From 16ae64011578d1cafdf0fa59acf7f46428a158cc Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 23 Nov 2023 17:55:30 +0100 Subject: [PATCH 11/29] Update changelog --- CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2a06c7ec..0c7a61669 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,18 @@ # Changelog - ## Unreleased +### Added + +- New cache backends: InMemoryCacheBackend and DiskCacheBackend + [PR #458](https://github.com/aai-institute/pyDVL/pull/458) + ### Changed - Simplify and improve tests, add CodeCov code coverage [PR #429](https://github.com/aai-institute/pyDVL/pull/429) +- Refactor and simplify caching implementation + [PR #458](https://github.com/aai-institute/pyDVL/pull/458) ## 0.7.1 - 🆕 New methods, bug fixes and improvements for local tests 🐞🧪 From fbc96cf13b8651d8e0b12f72260c05ed9518f3bb Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Fri, 24 Nov 2023 10:19:01 +0100 Subject: [PATCH 12/29] Fix type hints --- src/pydvl/utils/caching/base.py | 19 +++++++++---------- src/pydvl/utils/caching/memory.py | 4 ++-- src/pydvl/utils/utility.py | 5 +++-- src/pydvl/value/semivalues.py | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py index 00f26d67e..f040f48bb 100644 --- a/src/pydvl/utils/caching/base.py +++ b/src/pydvl/utils/caching/base.py @@ -13,8 +13,6 @@ __all__ = ["CacheStats", "CacheBackend", "CachedFunc"] -T = TypeVar("T") - logger = logging.getLogger(__name__) @@ -143,7 +141,7 @@ class CachedFunc: def __init__( self, - func: Callable[..., T], + func: Callable[..., float], *, cache_backend: CacheBackend, cached_func_options: CachedFuncConfig = CachedFuncConfig(), @@ -160,14 +158,14 @@ def __init__( patched = [f"cached_{path[0]}"] + path[1:] self.__qualname__ = ".".join(reversed(patched)) - def __call__(self, *args, **kwargs) -> T: + def __call__(self, *args, **kwargs) -> float: """Call the wrapped cached function. Executes the wrapped function, caching and returning the result. """ return self._cached_call(args, kwargs) - def _force_call(self, args, kwargs) -> Tuple[T, float]: + def _force_call(self, args, kwargs) -> Tuple[float, float]: """Force re-evaluation of the wrapped function. Executes the wrapped function without caching. @@ -181,7 +179,7 @@ def _force_call(self, args, kwargs) -> Tuple[T, float]: duration = end - start return value, duration - def _cached_call(self, args, kwargs) -> T: + def _cached_call(self, args, kwargs) -> float: """Cached wrapped function call. Executes the wrapped function with cache checking/setting. @@ -233,26 +231,27 @@ def _get_cache_key(self, *args, **kwargs) -> str: @staticmethod def _hash_function(func: Callable) -> str: """Create hash for wrapped function.""" - func_hash = hashing.hash((func.__code__.co_code, func.__code__.co_consts)) + func_hash: str = hashing.hash((func.__code__.co_code, func.__code__.co_consts)) return func_hash @staticmethod def _hash_arguments( func: Callable, ignore_args: Collection[str], - args: Tuple[Any], + args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> str: """Create hash for function arguments.""" - return hashing.hash( + args_hash: str = hashing.hash( CachedFunc._filter_args(func, ignore_args, args, kwargs), ) + return args_hash @staticmethod def _filter_args( func: Callable, ignore_args: Collection[str], - args: Tuple[Any], + args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> Dict[str, Any]: """Filter arguments to exclude from cache keys.""" diff --git a/src/pydvl/utils/caching/memory.py b/src/pydvl/utils/caching/memory.py index 064ce0d3d..6ca27b7b5 100644 --- a/src/pydvl/utils/caching/memory.py +++ b/src/pydvl/utils/caching/memory.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional +from typing import Any, Dict, Optional from pydvl.utils.caching.base import CacheBackend @@ -56,7 +56,7 @@ class InMemoryCacheBackend(CacheBackend): def __init__(self) -> None: """Initialize the in-memory cache backend.""" super().__init__() - self.cached_values = {} + self.cached_values: Dict[str, Any] = {} def get(self, key: str) -> Optional[Any]: """Get a value from the cache. diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 8d2ca7205..10bc7ebf4 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -246,9 +246,10 @@ def cache_stats(self) -> Optional[CacheStats]: """Cache statistics are gathered when cache is enabled. See [CacheStats][pydvl.utils.caching.base.CacheStats] for all fields returned. """ + cache_stats: Optional[CacheStats] = None if self.cache is not None: - return self._utility_wrapper.stats - return None + cache_stats = self._utility_wrapper.stats + return cache_stats def __getstate__(self): state = self.__dict__.copy() diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index eceba171e..9eee1c83d 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -216,7 +216,7 @@ def compute_generic_semivalues( from pydvl.parallel import effective_n_jobs, init_executor, init_parallel_backend - if isinstance(sampler, PermutationSampler) and not u.enable_cache: + if isinstance(sampler, PermutationSampler) and u.cache is None: log.warning( "PermutationSampler requires caching to be enabled or computation " "will be doubled wrt. a 'direct' implementation of permutation MC" From ab4578a6c399177543167f950c70b226e35dad50 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Fri, 24 Nov 2023 15:41:08 +0100 Subject: [PATCH 13/29] Remove leftover uses of enable_cache argument --- src/pydvl/utils/utility.py | 3 ++- tests/value/conftest.py | 5 +++-- tests/value/shapley/test_knn.py | 5 ++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 10bc7ebf4..2773b96a1 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -179,7 +179,8 @@ def _utility(self, indices: FrozenSet) -> float: """Clones the model, fits it on a subset of the training data and scores it on the test data. - If the object is constructed with `enable_cache = True`, results are + If an instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend] + is passed during construction, results are memoized to avoid duplicate computation. This is useful in particular when computing utilities of permutations of indices or when randomly sampling from the powerset of indices. diff --git a/tests/value/conftest.py b/tests/value/conftest.py index 3eaa3d672..33e58bf64 100644 --- a/tests/value/conftest.py +++ b/tests/value/conftest.py @@ -72,7 +72,6 @@ def score(self, x: NDArray, y: NDArray) -> float: score_range=(0, x.sum() / x.max()), catch_errors=False, show_warnings=True, - enable_cache=False, ) @@ -122,7 +121,9 @@ def linear_shapley(cache, linear_dataset, scorer, n_jobs): if u is None: u = Utility( - LinearRegression(), data=linear_dataset, scorer=scorer, enable_cache=False + LinearRegression(), + data=linear_dataset, + scorer=scorer, ) exact_values = combinatorial_exact_shapley(u, progress=False, n_jobs=n_jobs) cache.set(u_cache_key, u) diff --git a/tests/value/shapley/test_knn.py b/tests/value/shapley/test_knn.py index 1ca7a1fbc..cf935f347 100644 --- a/tests/value/shapley/test_knn.py +++ b/tests/value/shapley/test_knn.py @@ -40,7 +40,10 @@ def knn_loss_function(labels, predictions, n_classes=3): ) utility = Utility( - model, data=data, scorer=scorer, show_warnings=False, enable_cache=False + model, + data=data, + scorer=scorer, + show_warnings=False, ) exact_values = combinatorial_exact_shapley( utility, progress=False, n_jobs=min(len(data), available_cpus()) From 6ce679acb9272b515641fdc590794529f997df53 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Fri, 24 Nov 2023 16:22:23 +0100 Subject: [PATCH 14/29] Use name cache_backend instead of cache --- src/pydvl/utils/caching/disk.py | 17 ++++++++--------- src/pydvl/utils/caching/memcached.py | 14 +++++++------- src/pydvl/utils/caching/memory.py | 16 ++++++++-------- src/pydvl/utils/utility.py | 6 +++--- 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/pydvl/utils/caching/disk.py b/src/pydvl/utils/caching/disk.py index 0966385d4..71ac56c2c 100644 --- a/src/pydvl/utils/caching/disk.py +++ b/src/pydvl/utils/caching/disk.py @@ -30,23 +30,23 @@ class DiskCacheBackend(CacheBackend): ??? Examples ``` pycon >>> from pydvl.utils.caching.disk import DiskCacheBackend - >>> cache = DiskCacheBackend(cache_dir="/tmp/pydvl_disk_cache") - >>> cache.clear() + >>> cache_backend = DiskCacheBackend(cache_dir="/tmp/pydvl_disk_cache") + >>> cache_backend.clear() >>> value = 42 - >>> cache.set("key", value) - >>> cache.get("key") + >>> cache_backend.set("key", value) + >>> cache_backend.get("key") 42 ``` ``` pycon - >>> from pydvl.utils.caching.disk import DiskCacheBackend - >>> cache = DiskCacheBackend(cache_dir="/tmp/pydvl_disk_cache") - >>> cache.clear() + >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend + >>> cache_backend = MemcachedCacheBackend() + >>> cache_backend.clear() >>> value = 42 >>> def foo(x: int): ... return x + 1 ... - >>> wrapped_foo = cache.wrap(foo) + >>> wrapped_foo = cache_backend.wrap(foo) >>> wrapped_foo(value) 43 >>> wrapped_foo.stats.misses @@ -61,7 +61,6 @@ class DiskCacheBackend(CacheBackend): 1 ``` - """ def __init__( diff --git a/src/pydvl/utils/caching/memcached.py b/src/pydvl/utils/caching/memcached.py index ebef3b764..e96a26b97 100644 --- a/src/pydvl/utils/caching/memcached.py +++ b/src/pydvl/utils/caching/memcached.py @@ -62,23 +62,23 @@ class MemcachedCacheBackend(CacheBackend): ??? Examples ``` pycon >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend - >>> cache = MemcachedCacheBackend() - >>> cache.clear() + >>> cache_backend = MemcachedCacheBackend() + >>> cache_backend.clear() >>> value = 42 - >>> cache.set("key", value) - >>> cache.get("key") + >>> cache_backend.set("key", value) + >>> cache_backend.get("key") 42 ``` ``` pycon >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend - >>> cache = MemcachedCacheBackend() - >>> cache.clear() + >>> cache_backend = MemcachedCacheBackend() + >>> cache_backend.clear() >>> value = 42 >>> def foo(x: int): ... return x + 1 ... - >>> wrapped_foo = cache.wrap(foo) + >>> wrapped_foo = cache_backend.wrap(foo) >>> wrapped_foo(value) 43 >>> wrapped_foo.stats.misses diff --git a/src/pydvl/utils/caching/memory.py b/src/pydvl/utils/caching/memory.py index 6ca27b7b5..8843e7083 100644 --- a/src/pydvl/utils/caching/memory.py +++ b/src/pydvl/utils/caching/memory.py @@ -21,23 +21,23 @@ class InMemoryCacheBackend(CacheBackend): ??? Examples ``` pycon >>> from pydvl.utils.caching.memory import InMemoryCacheBackend - >>> cache = InMemoryCacheBackend() - >>> cache.clear() + >>> cache_backend = InMemoryCacheBackend() + >>> cache_backend.clear() >>> value = 42 - >>> cache.set("key", value) - >>> cache.get("key") + >>> cache_backend.set("key", value) + >>> cache_backend.get("key") 42 ``` ``` pycon - >>> from pydvl.utils.caching.memory import InMemoryCacheBackend - >>> cache = InMemoryCacheBackend() - >>> cache.clear() + >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend + >>> cache_backend = MemcachedCacheBackend() + >>> cache_backend.clear() >>> value = 42 >>> def foo(x: int): ... return x + 1 ... - >>> wrapped_foo = cache.wrap(foo) + >>> wrapped_foo = cache_backend.wrap(foo) >>> wrapped_foo(value) 43 >>> wrapped_foo.stats.misses diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 2773b96a1..d8cd7ae08 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -100,7 +100,7 @@ class Utility: calculations. When this happens, the `default_score` is returned as a score and computation continues. show_warnings: Set to `False` to suppress warnings thrown by `fit()`. - cache: Optional instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend] + cache_backend: Optional instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend] used to wrap the _utility method of the Utility instance. By default, this is set to None and that means that the utility evaluations will not be cached. @@ -135,7 +135,7 @@ def __init__( score_range: Tuple[float, float] = (-np.inf, np.inf), catch_errors: bool = True, show_warnings: bool = False, - cache: Optional[CacheBackend] = None, + cache_backend: Optional[CacheBackend] = None, cached_func_options: CachedFuncConfig = CachedFuncConfig(), clone_before_fit: bool = True, ): @@ -149,7 +149,7 @@ def __init__( self.score_range = scorer.range if scorer is not None else np.array(score_range) self.catch_errors = catch_errors self.show_warnings = show_warnings - self.cache = cache + self.cache = cache_backend self.cached_func_options = cached_func_options self.clone_before_fit = clone_before_fit self._initialize_utility_wrapper() From ab9a20aba5571dc6f2831ec65a703d7d4425e0d2 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Fri, 24 Nov 2023 16:22:35 +0100 Subject: [PATCH 15/29] Fix tests --- tests/utils/test_caching.py | 36 ++++++++++++++------------ tests/utils/test_utility.py | 4 +-- tests/value/conftest.py | 8 ++++++ tests/value/shapley/test_montecarlo.py | 9 ++++--- tests/value/shapley/test_naive.py | 21 ++++++++++----- tests/value/shapley/test_truncated.py | 5 ++-- 6 files changed, 51 insertions(+), 32 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 713d7337c..14a82d46a 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -51,7 +51,7 @@ def foo(self): @pytest.fixture(params=["in-memory", "disk", "memcached"]) -def cache(request): +def cache_backend(request): backend: str = request.param if backend == "in-memory": cache = InMemoryCacheBackend() @@ -128,8 +128,8 @@ def test_cached_func_hash_arguments_of_method(): assert hash1 == hash2 -def test_single_job(cache): - wrapped_foo = cache.wrap(foo) +def test_single_job(cache_backend): + wrapped_foo = cache_backend.wrap(foo) n = 1000 wrapped_foo(np.arange(n)) @@ -148,9 +148,9 @@ def test_memcached_failed_connection(): MemcachedCacheBackend(config) -def test_cache_time_threshold(cache): +def test_cache_time_threshold(cache_backend): cached_func_config = CachedFuncConfig(time_threshold=1.0) - wrapped_foo = cache.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) n = 1000 indices = np.arange(n) @@ -165,12 +165,12 @@ def test_cache_time_threshold(cache): assert misses_after > misses_before -def test_cache_ignore_args(cache): +def test_cache_ignore_args(cache_backend): # Note that we typically do NOT want to ignore run_id cached_func_config = CachedFuncConfig( ignore_args=["job_id"], ) - wrapped_foo = cache.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) n = 1000 indices = np.arange(n) @@ -182,8 +182,8 @@ def test_cache_ignore_args(cache): assert hits_after > hits_before -def test_parallel_jobs(cache, parallel_config): - if not isinstance(cache, MemcachedCacheBackend): +def test_parallel_jobs(cache_backend, parallel_config): + if not isinstance(cache_backend, MemcachedCacheBackend): pytest.skip("Only running this test with MemcachedCacheBackend") if parallel_config.backend != "joblib": pytest.skip("We don't have to test this with all parallel backends") @@ -192,11 +192,11 @@ def test_parallel_jobs(cache, parallel_config): cached_func_config = CachedFuncConfig( ignore_args=["job_id", "run_id"], ) - wrapped_foo = cache.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) n = 1234 n_runs = 10 - hits_before = cache.client.stats()[b"get_hits"] + hits_before = cache_backend.client.stats()[b"get_hits"] map_reduce_job = MapReduceJob( np.arange(n), wrapped_foo, np.sum, n_jobs=4, config=parallel_config @@ -215,12 +215,12 @@ def test_parallel_jobs(cache, parallel_config): assert hits_after - hits_before >= n_runs - 2, wrapped_foo.stats -def test_repeated_training(cache, worker_id: str): +def test_repeated_training(cache_backend, worker_id: str): cached_func_config = CachedFuncConfig( allow_repeated_evaluations=True, rtol_stderr=0.01, ) - wrapped_foo = cache.wrap( + wrapped_foo = cache_backend.wrap( foo_with_random, cached_func_config=cached_func_config, ) @@ -235,12 +235,12 @@ def test_repeated_training(cache, worker_id: str): assert wrapped_foo.stats.sets < wrapped_foo.stats.hits -def test_faster_with_repeated_training(cache, worker_id: str): +def test_faster_with_repeated_training(cache_backend, worker_id: str): cached_func_config = CachedFuncConfig( allow_repeated_evaluations=True, rtol_stderr=0.1, ) - wrapped_foo = cache.wrap( + wrapped_foo = cache_backend.wrap( foo_with_random_and_sleep, cached_func_config=cached_func_config, ) @@ -271,7 +271,9 @@ def test_faster_with_repeated_training(cache, worker_id: str): @pytest.mark.parametrize("n, atol", [(10, 5), (20, 10)]) @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("n_runs", [20]) -def test_parallel_repeated_training(cache, n, atol, n_jobs, n_runs, parallel_config): +def test_parallel_repeated_training( + cache_backend, n, atol, n_jobs, n_runs, parallel_config +): if parallel_config.backend != "joblib": pytest.skip("We don't have to test this with all parallel backends") @@ -284,7 +286,7 @@ def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: rtol_stderr=0.01, ignore_args=["job_id", "run_id"], ) - wrapped_map_func = cache.wrap( + wrapped_map_func = cache_backend.wrap( map_func, cached_func_config=cached_func_config, ) diff --git a/tests/utils/test_utility.py b/tests/utils/test_utility.py index dddc172ec..c11be1c5f 100644 --- a/tests/utils/test_utility.py +++ b/tests/utils/test_utility.py @@ -64,7 +64,7 @@ def test_utility_with_cache(linear_dataset): model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - cache=InMemoryCacheBackend(), + cache_backend=InMemoryCacheBackend(), ) subsets = list(powerset(u.data.indices)) @@ -89,7 +89,7 @@ def test_utility_serialization(linear_dataset, use_cache): model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - cache=cache, + cache_backend=cache, ) u_unpickled = pickle.loads(pickle.dumps(u)) assert type(u.model) == type(u_unpickled.model) diff --git a/tests/value/conftest.py b/tests/value/conftest.py index 33e58bf64..19e84a878 100644 --- a/tests/value/conftest.py +++ b/tests/value/conftest.py @@ -8,6 +8,7 @@ from pydvl.parallel.config import ParallelConfig from pydvl.utils import Dataset, SupervisedModel, Utility +from pydvl.utils.caching import InMemoryCacheBackend from pydvl.utils.status import Status from pydvl.value import ValuationResult from pydvl.value.shapley.naive import combinatorial_exact_shapley @@ -134,3 +135,10 @@ def linear_shapley(cache, linear_dataset, scorer, n_jobs): @pytest.fixture(scope="module") def parallel_config(): yield ParallelConfig(backend="joblib", n_cpus_local=num_workers(), wait_timeout=0.1) + + +@pytest.fixture() +def cache_backend(): + cache = InMemoryCacheBackend() + yield cache + cache.clear() diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index b2b558461..0c5d781ea 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -6,7 +6,7 @@ from sklearn.linear_model import LinearRegression from pydvl.parallel.config import ParallelConfig -from pydvl.utils import Dataset, GroupedDataset, MemcachedConfig, Status, Utility +from pydvl.utils import Dataset, GroupedDataset, Status, Utility from pydvl.utils.numeric import num_samples_permutation_hoeffding from pydvl.utils.score import Scorer, squashed_r2 from pydvl.utils.types import Seed @@ -224,6 +224,7 @@ def test_linear_montecarlo_with_outlier( total_atol: float, fun, kwargs: dict, + cache_backend, ): """Tests whether valuation methods are able to detect an obvious outlier. @@ -241,7 +242,7 @@ def test_linear_montecarlo_with_outlier( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values = compute_shapley_values( linear_utility, mode=fun, progress=False, n_jobs=n_jobs, **kwargs @@ -266,12 +267,12 @@ def test_linear_montecarlo_with_outlier( def test_grouped_linear_montecarlo_shapley( linear_dataset, n_jobs, - memcache_client_config: "MemcachedClientConfig", num_groups: int, fun: ShapleyMode, scorer: Scorer, rtol: float, kwargs: dict, + cache_backend, ): """ For permutation and truncated montecarlo, the rtol for each scorer is chosen @@ -285,7 +286,7 @@ def test_grouped_linear_montecarlo_shapley( LinearRegression(), data=grouped_linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) exact_values = combinatorial_exact_shapley(grouped_linear_utility, progress=False) diff --git a/tests/value/shapley/test_naive.py b/tests/value/shapley/test_naive.py index 9b1151d99..45c32b1a9 100644 --- a/tests/value/shapley/test_naive.py +++ b/tests/value/shapley/test_naive.py @@ -4,7 +4,7 @@ import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import GroupedDataset, MemcachedConfig, Utility +from pydvl.utils import GroupedDataset, Utility from pydvl.value.shapley.naive import ( combinatorial_exact_shapley, permutation_exact_shapley, @@ -43,13 +43,18 @@ def test_analytic_exact_shapley(num_samples, analytic_shapley, fun, rtol, total_ ], ) def test_linear( - linear_dataset, memcache_client_config, scorer, rtol=0.01, total_atol=1e-5 + linear_dataset, + memcache_client_config, + scorer, + cache_backend, + rtol=0.01, + total_atol=1e-5, ): linear_utility = Utility( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values_combinatorial = combinatorial_exact_shapley(linear_utility, progress=False) @@ -70,6 +75,7 @@ def test_grouped_linear( num_groups, memcache_client_config, scorer, + cache_backend, rtol=0.01, total_atol=1e-5, ): @@ -81,7 +87,7 @@ def test_grouped_linear( LinearRegression(), data=grouped_linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values_combinatorial = combinatorial_exact_shapley( grouped_linear_utility, progress=False @@ -107,7 +113,7 @@ def test_grouped_linear( ], ) def test_linear_with_outlier( - linear_dataset, memcache_client_config, scorer, total_atol=1e-5 + linear_dataset, memcache_client_config, scorer, cache_backend, total_atol=1e-5 ): outlier_idx = np.random.randint(len(linear_dataset.y_train)) linear_dataset.y_train[outlier_idx] -= 100 @@ -115,7 +121,7 @@ def test_linear_with_outlier( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values = permutation_exact_shapley(linear_utility, progress=False) values.sort() @@ -169,6 +175,7 @@ def test_polynomial_with_outlier( polynomial_pipeline, memcache_client_config, scorer, + cache_backend, total_atol=1e-5, ): dataset, _ = polynomial_dataset @@ -178,7 +185,7 @@ def test_polynomial_with_outlier( polynomial_pipeline, dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) shapley_values = permutation_exact_shapley(poly_utility, progress=False) diff --git a/tests/value/shapley/test_truncated.py b/tests/value/shapley/test_truncated.py index 4727c087d..ac980ab96 100644 --- a/tests/value/shapley/test_truncated.py +++ b/tests/value/shapley/test_truncated.py @@ -4,7 +4,7 @@ import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import MemcachedConfig, Status, Utility +from pydvl.utils import Status, Utility from pydvl.utils.score import Scorer, squashed_r2 from pydvl.value import compute_shapley_values from pydvl.value.shapley import ShapleyMode @@ -125,6 +125,7 @@ def test_tmcs_linear_montecarlo_with_outlier( n_jobs, memcache_client_config, scorer: Scorer, + cache_backend, total_atol: float, fun, kwargs: dict, @@ -145,7 +146,7 @@ def test_tmcs_linear_montecarlo_with_outlier( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values = compute_shapley_values( linear_utility, mode=fun, progress=False, n_jobs=n_jobs, **kwargs From 8cb02ac0279fb528ea6a3bc0b7b077cfe83a4353 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Fri, 24 Nov 2023 16:40:29 +0100 Subject: [PATCH 16/29] More fixes --- tests/utils/test_caching.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 14a82d46a..8727357f5 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -42,7 +42,7 @@ def foo_with_random_and_sleep(indices: NDArray[np.int_], *args, **kwargs) -> flo # Used to test caching of methods -class Test: +class CacheTest: def __init__(self): self.value = 0 @@ -54,18 +54,18 @@ def foo(self): def cache_backend(request): backend: str = request.param if backend == "in-memory": - cache = InMemoryCacheBackend() - yield cache - cache.clear() + cache_backend = InMemoryCacheBackend() + yield cache_backend + cache_backend.clear() elif backend == "disk": with tempfile.TemporaryDirectory() as tempdir: - cache = DiskCacheBackend(tempdir) - yield cache - cache.clear() + cache_backend = DiskCacheBackend(tempdir) + yield cache_backend + cache_backend.clear() elif backend == "memcached": - cache = MemcachedCacheBackend() - yield cache - cache.clear() + cache_backend = MemcachedCacheBackend() + yield cache_backend + cache_backend.clear() else: raise ValueError(f"Unknown cache backend {backend}") @@ -120,7 +120,7 @@ def test_cached_func_hash_arguments(args1, args2, expected_equal): def test_cached_func_hash_arguments_of_method(): - obj = Test() + obj = CacheTest() hash1 = CachedFunc._hash_arguments(obj.foo, [], tuple(), {}) obj.value += 1 @@ -206,7 +206,7 @@ def test_parallel_jobs(cache_backend, parallel_config): for _ in range(n_runs): result = map_reduce_job() results.append(result) - hits_after = cache.client.stats()[b"get_hits"] + hits_after = cache_backend.client.stats()[b"get_hits"] assert results[0] == n * (n - 1) / 2 # Sanity check # FIXME! This is non-deterministic: if packets are delayed for longer than From 89472bc0a2fbda1bd1af41ca82e809932838ced9 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sun, 26 Nov 2023 10:53:42 +0100 Subject: [PATCH 17/29] Handle usage of MemcachedCacheBackend when pymemcache is not installed --- src/pydvl/utils/caching/__init__.py | 9 ++------- src/pydvl/utils/caching/memcached.py | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/pydvl/utils/caching/__init__.py b/src/pydvl/utils/caching/__init__.py index 6c98bf9cb..bc72741e0 100644 --- a/src/pydvl/utils/caching/__init__.py +++ b/src/pydvl/utils/caching/__init__.py @@ -1,6 +1,6 @@ """Caching of functions. -pyDVL caches utility values to allow reusing previously computed evaluations. +pyDVL caches (memoizes) utility values to allow reusing previously computed evaluations. !!! Warning Function evaluations are cached with a key based on the function's signature @@ -82,13 +82,8 @@ [ignore_args][pydvl.utils.config.MemcachedConfig] in the configuration. """ - from .base import * from .config import * from .disk import * +from .memcached import * from .memory import * - -try: - from .memcached import * -except ImportError: - pass diff --git a/src/pydvl/utils/caching/memcached.py b/src/pydvl/utils/caching/memcached.py index e96a26b97..63855682f 100644 --- a/src/pydvl/utils/caching/memcached.py +++ b/src/pydvl/utils/caching/memcached.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import logging import socket import uuid @@ -7,9 +5,14 @@ from dataclasses import asdict, dataclass from typing import Any, Dict, Optional, Tuple -from pymemcache import MemcacheUnexpectedCloseError -from pymemcache.client import Client, RetryingClient -from pymemcache.serde import PickleSerde +try: + from pymemcache import MemcacheUnexpectedCloseError + from pymemcache.client import Client, RetryingClient + from pymemcache.serde import PickleSerde + + PYMEMCACHE_INSTALLED = True +except ImportError: + PYMEMCACHE_INSTALLED = False from .base import CacheBackend @@ -100,6 +103,11 @@ def __init__(self, config: MemcachedClientConfig = MemcachedClientConfig()) -> N Args: config: Memcached client configuration. """ + if not PYMEMCACHE_INSTALLED: + raise ModuleNotFoundError( + "Cannot use MemcachedCacheBackend because pymemcache was not installed. " + "Make sure to install pyDVL using `pip install pyDVL[memcached]`" + ) super().__init__() self.config = config self.client = self._connect(self.config) From 030763d5e4fafd99bf55441669a47871e4acd040 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sun, 26 Nov 2023 11:10:33 +0100 Subject: [PATCH 18/29] Fix and improve caching package's docstring --- docs/getting-started/first-steps.md | 14 +++--- src/pydvl/utils/caching/__init__.py | 73 ++++++++++++++--------------- 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/docs/getting-started/first-steps.md b/docs/getting-started/first-steps.md index dcbe54d24..403724362 100644 --- a/docs/getting-started/first-steps.md +++ b/docs/getting-started/first-steps.md @@ -34,15 +34,15 @@ by browsing our worked-out examples illustrating pyDVL's capabilities either: have to install jupyter first manually since it's not a dependency of the library. -# Advanced usage +## Advanced usage Besides the dos and don'ts of data valuation itself, which are the subject of the examples and the documentation of each method, there are two main things to keep in mind when using pyDVL. -## Caching +### Caching -PyDVL can cache the computation of the utility function +PyDVL can cache (memoize) the computation of the utility function and speed up some computations for data valuation. It is however disabled by default. When it is enabled it takes into account the data indices passed as argument @@ -58,7 +58,7 @@ the same utility function computation, is very low. However, it can be very useful when comparing methods that use the same utility function, or when running multiple experiments with the same data. -pyDVL supports different caching backends: +pyDVL supports 3 different caching backends: - [InMemoryCacheBackend][pydvl.utils.caching.memory.InMemoryCacheBackend]: an in-memory cache backend that uses a dictionary to store and retrieve @@ -85,7 +85,7 @@ pyDVL supports different caching backends: Continue reading about the cache in the documentation for the [caching package][pydvl.utils.caching]. -### Setting up the Memcached cache +#### Setting up the Memcached cache [Memcached](https://memcached.org/) is an in-memory key-value store accessible over the network. pyDVL can use it to cache the computation of the utility function @@ -108,7 +108,7 @@ To run memcached inside a container in daemon mode instead, use: docker container run -d --rm -p 11211:11211 memcached:latest ``` -## Parallelization +### Parallelization pyDVL uses [joblib](https://joblib.readthedocs.io/en/latest/) for local parallelization (within one machine) and supports using @@ -125,7 +125,7 @@ will typically make a copy of the whole model and dataset to each worker, even if the re-training only happens on a subset of the data. This means that you should make sure that each worker has enough memory to handle the whole dataset. -### Ray +#### Ray Please follow the instructions in Ray's documentation to set up a cluster. Once you have a running cluster, you can use it by passing the address diff --git a/src/pydvl/utils/caching/__init__.py b/src/pydvl/utils/caching/__init__.py index bc72741e0..1089628bc 100644 --- a/src/pydvl/utils/caching/__init__.py +++ b/src/pydvl/utils/caching/__init__.py @@ -1,6 +1,7 @@ """Caching of functions. -pyDVL caches (memoizes) utility values to allow reusing previously computed evaluations. +PyDVL can cache (memoize) the computation of the utility function +and speed up some computations for data valuation. !!! Warning Function evaluations are cached with a key based on the function's signature @@ -10,67 +11,65 @@ # Configuration -Memoization is disabled by default but can be enabled easily, +Caching is disabled by default but can be enabled easily, see [Setting up the cache](#setting-up-the-cache). When enabled, it will be added to any callable used to construct a -[Utility][pydvl.utils.utility.Utility] (done with the decorator [@memcached][pydvl.utils.caching.memcached]). +[Utility][pydvl.utils.utility.Utility] (done with the wrap method of +[CacheBackend][pydvl.utils.caching.base.CacheBackend]). Depending on the nature of the utility you might want to enable the computation of a running average of function values, see [Usage with stochastic functions](#usaage-with-stochastic-functions). -You can see all configuration options under [MemcachedConfig][pydvl.utils.config.MemcachedConfig]. +You can see all configuration options under +[CachedFuncConfig][pydvl.utils.caching.config.CachedFuncConfig]. -## Default configuration +# Supported Backends -```python -default_config = dict( - server=('localhost', 11211), - connect_timeout=1.0, - timeout=0.1, - # IMPORTANT! Disable small packet consolidation: - no_delay=True, - serde=serde.PickleSerde(pickle_version=PICKLE_VERSION) -) -``` +pyDVL supports 3 different caching backends: -# Supported Backends +- [InMemoryCacheBackend][pydvl.utils.caching.memory.InMemoryCacheBackend]: + an in-memory cache backend that uses a dictionary to store and retrieve + cached values. This is used to share cached values between threads + in a single process. +- [DiskCacheBackend][pydvl.utils.caching.disk.DiskCacheBackend]: + a disk-based cache backend that uses pickled values written to and read from disk. + This is used to share cached values between processes in a single machine. +- [MemcachedCacheBackend][pydvl.utils.caching.memcached.MemcachedCacheBackend]: + a [Memcached](https://memcached.org/)-based cache backend that uses pickled values written to + and read from a Memcached server. This is used to share cached values + between processes across multiple machines. -- [InMemoryCacheBackend][] -- [DiskCacheBackend][] -- [MemcachedCacheBackend][] + **Note** This specific backend requires optional dependencies. + See [[installation#extras]] for more information) # Usage with stochastic functions -In addition to standard memoization, the decorator -[memcached()][pydvl.utils.caching.memcached] can compute running average and -standard error of repeated evaluations for the same input. This can be useful -for stochastic functions with high variance (e.g. model training for small -sample sizes), but drastically reduces the speed benefits of memoization. +In addition to standard memoization, the wrapped functions +can compute running average and standard error of repeated evaluations +for the same input. This can be useful for stochastic functions with high variance +(e.g. model training for small sample sizes), but drastically reduces +the speed benefits of memoization. -This behaviour can be activated with the argument `allow_repeated_evaluations` -to [memcached()][pydvl.utils.caching.memcached]. +This behaviour can be activated with the option +[allow_repeated_evaluations][pydvl.utils.caching.config.CachedFuncConfig].. # Cache reuse -When working directly with [memcached()][pydvl.utils.caching.memcached], it is +When working directly with [CachedFunc][pydvl.utils.caching.base.CachedFunc], it is essential to only cache pure functions. If they have any kind of state, either internal or external (e.g. a closure over some data that may change), then the cache will fail to notice this and the same value will be returned. -When a function is wrapped with [memcached()][pydvl.utils.caching.memcached] for -memoization, its signature (input and output names) and code are used as a key -for the cache. Alternatively you can pass a custom value to be used as key with - -```python -cached_fun = memcached(**asdict(cache_options))(fun, signature=custom_signature) -``` +When a function is wrapped with [CachedFunc][pydvl.utils.caching.base.CachedFunc] +for memoization, its signature (input and output names) and code are used as a key +for the cache. If you are running experiments with the same [Utility][pydvl.utils.utility.Utility] but different datasets, this will lead to evaluations of the utility on new data returning old values because utilities only use sample indices as arguments (so there is no way to tell the difference between '1' for dataset A and '1' for dataset 2 from the point of view of the cache). One solution is to empty the -cache between runs, but the preferred one is to **use a different Utility -object for each dataset**. +cache between runs by calling the `clear` method of the cache backend instance, +but the preferred one is to **use a different Utility object for each dataset**. # Unexpected cache misses @@ -79,7 +78,7 @@ run across multiple processes and some reporting arguments are added (like a `job_id` for logging purposes), these will be part of the signature and make the functions distinct to the eyes of the cache. This can be avoided with the use of -[ignore_args][pydvl.utils.config.MemcachedConfig] in the configuration. +[ignore_args][pydvl.utils.caching.config.CachedFuncConfig] option in the configuration. """ from .base import * From 763f6ec1fe1170a6c3cd0c3cd5a99a3c1704dd56 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sun, 26 Nov 2023 11:51:56 +0100 Subject: [PATCH 19/29] Fix tests --- tests/test_plugin.py | 4 +--- tests/value/conftest.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index c20e19db4..efbceeb2a 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -7,9 +7,7 @@ def test_marker_only(i): assert False -@pytest.fixture( - scope="function", params=[0, pytest.param(1, marks=pytest.mark.xfail), 2] -) +@pytest.fixture(scope="function", params=[0, pytest.param(1, marks=pytest.mark.xfail)]) def data(request): yield request.param diff --git a/tests/value/conftest.py b/tests/value/conftest.py index 19e84a878..0e3c48d29 100644 --- a/tests/value/conftest.py +++ b/tests/value/conftest.py @@ -117,8 +117,12 @@ def linear_shapley(cache, linear_dataset, scorer, n_jobs): args_hash = cache.hash_arguments(linear_dataset, scorer, n_jobs) u_cache_key = f"linear_shapley_u_{args_hash}" exact_values_cache_key = f"linear_shapley_exact_values_{args_hash}" - u = cache.get(u_cache_key, None) - exact_values = cache.get(exact_values_cache_key, None) + try: + u = cache.get(u_cache_key, None) + exact_values = cache.get(exact_values_cache_key, None) + except Exception: + cache.clear_cache(cache._cachedir) + raise if u is None: u = Utility( From bc92356809d92ad46c9fdbe0e686bb3cd86e7c1e Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sat, 9 Dec 2023 20:35:15 +0100 Subject: [PATCH 20/29] Add test for case when pymemcache is not installed --- tests/utils/test_caching.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 8727357f5..0f2d63709 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -140,6 +140,12 @@ def test_single_job(cache_backend): assert hits_after > hits_before +def test_without_pymemcache(mocker): + mocker.patch("pydvl.utils.caching.memcached.PYMEMCACHE_INSTALLED", False) + with pytest.raises(ModuleNotFoundError): + MemcachedCacheBackend() + + def test_memcached_failed_connection(): from pydvl.utils import MemcachedClientConfig From bb69b776aeac32a13b9d519b60aedd48fb60501d Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sat, 9 Dec 2023 20:40:14 +0100 Subject: [PATCH 21/29] Add test for cache backend serialization --- tests/utils/test_caching.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 0f2d63709..2f8783b1e 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -1,4 +1,5 @@ import logging +import pickle import tempfile from time import sleep, time from typing import Optional @@ -128,6 +129,17 @@ def test_cached_func_hash_arguments_of_method(): assert hash1 == hash2 +def test_cache_backend_serialization(cache_backend): + value = 16.8 + cache_backend.set("key", value) + deserialized_cache_backend = pickle.loads(pickle.dumps(cache_backend)) + assert deserialized_cache_backend.get("key") == value + if isinstance(cache_backend, InMemoryCacheBackend): + assert cache_backend.cached_values == deserialized_cache_backend.cached_values + elif isinstance(cache_backend, DiskCacheBackend): + assert cache_backend.cache_dir == deserialized_cache_backend.cache_dir + + def test_single_job(cache_backend): wrapped_foo = cache_backend.wrap(foo) From 486b43a60803f8a72a9f7e86aa8d0e58b3a31768 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sat, 9 Dec 2023 20:44:20 +0100 Subject: [PATCH 22/29] Use newly created temporary directory for DiskCacheBackend --- src/pydvl/utils/caching/disk.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pydvl/utils/caching/disk.py b/src/pydvl/utils/caching/disk.py index 71ac56c2c..c75513dec 100644 --- a/src/pydvl/utils/caching/disk.py +++ b/src/pydvl/utils/caching/disk.py @@ -1,5 +1,6 @@ import os import shutil +import tempfile from pathlib import Path from typing import Any, Optional, Union @@ -11,8 +12,6 @@ PICKLE_VERSION = 5 # python >= 3.8 -DEFAULT_CACHE_DIR = Path().home() / ".pydvl_cache/disk" - class DiskCacheBackend(CacheBackend): """Disk cache backend that stores results in files. @@ -65,15 +64,18 @@ class DiskCacheBackend(CacheBackend): def __init__( self, - cache_dir: Union[os.PathLike, str] = DEFAULT_CACHE_DIR, + cache_dir: Optional[Union[os.PathLike, str]] = None, ) -> None: """Initialize the disk cache backend. Args: cache_dir: Base directory for cache storage. - By default, this is set to `~/.pydvl_cache/disk` + If not provided, this defaults to a newly created + temporary directory. """ super().__init__() + if cache_dir is None: + cache_dir = tempfile.mkdtemp(prefix="pydvl") self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(exist_ok=True, parents=True) From ef8bd3300cdd2d9b155ec7a5d545e96c03d2df71 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Sat, 9 Dec 2023 20:50:48 +0100 Subject: [PATCH 23/29] Set default value of cached_func_options to None --- src/pydvl/utils/utility.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index d8cd7ae08..dbf5a1ad7 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -136,7 +136,7 @@ def __init__( catch_errors: bool = True, show_warnings: bool = False, cache_backend: Optional[CacheBackend] = None, - cached_func_options: CachedFuncConfig = CachedFuncConfig(), + cached_func_options: Optional[CachedFuncConfig] = None, clone_before_fit: bool = True, ): self.model = self._clone_model(model) @@ -150,6 +150,8 @@ def __init__( self.catch_errors = catch_errors self.show_warnings = show_warnings self.cache = cache_backend + if cached_func_options is None: + cached_func_options = CachedFuncConfig() self.cached_func_options = cached_func_options self.clone_before_fit = clone_before_fit self._initialize_utility_wrapper() From 73e8e54b4656b69330f011b245627035ad6390ba Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Wed, 13 Dec 2023 16:32:49 +0100 Subject: [PATCH 24/29] Set backend time_threshold to 0.3 --- src/pydvl/utils/caching/config.py | 2 +- tests/utils/test_caching.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pydvl/utils/caching/config.py b/src/pydvl/utils/caching/config.py index 14f7cb761..62e96c2aa 100644 --- a/src/pydvl/utils/caching/config.py +++ b/src/pydvl/utils/caching/config.py @@ -34,7 +34,7 @@ class CachedFuncConfig: """ ignore_args: Collection[str] = field(default_factory=list) - time_threshold: float = 0 + time_threshold: float = 0.3 allow_repeated_evaluations: bool = False rtol_stderr: float = 0.1 min_repetitions: int = 3 diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 2f8783b1e..4f488988f 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -141,7 +141,8 @@ def test_cache_backend_serialization(cache_backend): def test_single_job(cache_backend): - wrapped_foo = cache_backend.wrap(foo) + cached_func_config = CachedFuncConfig(time_threshold=0.0) + wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) n = 1000 wrapped_foo(np.arange(n)) @@ -186,6 +187,7 @@ def test_cache_time_threshold(cache_backend): def test_cache_ignore_args(cache_backend): # Note that we typically do NOT want to ignore run_id cached_func_config = CachedFuncConfig( + time_threshold=0.0, ignore_args=["job_id"], ) wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) @@ -235,6 +237,7 @@ def test_parallel_jobs(cache_backend, parallel_config): def test_repeated_training(cache_backend, worker_id: str): cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.01, ) @@ -255,6 +258,7 @@ def test_repeated_training(cache_backend, worker_id: str): def test_faster_with_repeated_training(cache_backend, worker_id: str): cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.1, ) @@ -300,6 +304,7 @@ def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: # Note that we typically do NOT want to ignore run_id cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.01, ignore_args=["job_id", "run_id"], From 287ca1d8b93fa230f562870421dc08ff3f4c6948 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Wed, 13 Dec 2023 16:56:12 +0100 Subject: [PATCH 25/29] Fix test of utility with cache --- tests/utils/test_utility.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_utility.py b/tests/utils/test_utility.py index c11be1c5f..5d2fd6504 100644 --- a/tests/utils/test_utility.py +++ b/tests/utils/test_utility.py @@ -7,7 +7,7 @@ from sklearn.linear_model import LinearRegression from pydvl.utils import DataUtilityLearning, Scorer, Utility, powerset -from pydvl.utils.caching import InMemoryCacheBackend +from pydvl.utils.caching import CachedFuncConfig, InMemoryCacheBackend @pytest.mark.parametrize("show_warnings", [False, True]) @@ -65,6 +65,7 @@ def test_utility_with_cache(linear_dataset): data=linear_dataset, scorer=Scorer("r2"), cache_backend=InMemoryCacheBackend(), + cached_func_options=CachedFuncConfig(time_threshold=0.0), ) subsets = list(powerset(u.data.indices)) From 196b3105cad48df2a1447fa3fb3371c870d7cf1c Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 14 Dec 2023 20:38:59 +0100 Subject: [PATCH 26/29] Add hash_prefix parameter to CachedFuncConfig, use it in utility --- src/pydvl/utils/caching/base.py | 32 +++++++++++++----------- src/pydvl/utils/caching/config.py | 4 ++- src/pydvl/utils/utility.py | 6 +++-- tests/utils/test_caching.py | 14 +++++------ tests/utils/test_utility.py | 41 +++++++++++++++++++++++++++++++ 5 files changed, 72 insertions(+), 25 deletions(-) diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py index f040f48bb..f301003e6 100644 --- a/src/pydvl/utils/caching/base.py +++ b/src/pydvl/utils/caching/base.py @@ -3,7 +3,7 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Collection, Dict, Optional, Tuple, TypeVar, cast +from typing import Any, Callable, Collection, Dict, Optional, Tuple, cast from joblib import hashing from joblib.func_inspect import filter_args @@ -70,13 +70,13 @@ def wrap( self, func: Callable, *, - cached_func_config: CachedFuncConfig = CachedFuncConfig(), + config: Optional[CachedFuncConfig] = None, ) -> "CachedFunc": """Wraps a function to cache its results. Args: func: The function to wrap. - cached_func_config: Optional caching options for the wrapped function. + config: Optional caching options for the wrapped function. Returns: The wrapped cached function. @@ -84,7 +84,7 @@ def wrap( return CachedFunc( func, cache_backend=self, - cached_func_options=cached_func_config, + config=config, ) @abstractmethod @@ -136,7 +136,7 @@ class CachedFunc: func: Callable to wrap. cache_backend: Instance of CacheBackendBase that handles setting and getting values. - cached_func_options: Configuration for wrapped function. + config: Configuration for wrapped function. """ def __init__( @@ -144,11 +144,13 @@ def __init__( func: Callable[..., float], *, cache_backend: CacheBackend, - cached_func_options: CachedFuncConfig = CachedFuncConfig(), + config: Optional[CachedFuncConfig] = None, ) -> None: self.func = func self.cache_backend = cache_backend - self.cached_func_options = cached_func_options + if config is None: + config = CachedFuncConfig() + self.config = config self.__doc__ = f"A wrapper around {func.__name__}() with caching enabled.\n" + ( CachedFunc.__doc__ or "" @@ -193,18 +195,17 @@ def _cached_call(self, args, kwargs) -> float: value, duration = self._force_call(args, kwargs) result = CacheResult(value) if ( - duration >= self.cached_func_options.time_threshold - or self.cached_func_options.allow_repeated_evaluations + duration >= self.config.time_threshold + or self.config.allow_repeated_evaluations ): self.cache_backend.set(key, result) else: result = cached_result - if self.cached_func_options.allow_repeated_evaluations: + if self.config.allow_repeated_evaluations: error_on_average = (result.variance / result.count) ** (1 / 2) if ( - error_on_average - > self.cached_func_options.rtol_stderr * result.value - or result.count <= self.cached_func_options.min_repetitions + error_on_average > self.config.rtol_stderr * result.value + or result.count <= self.config.min_repetitions ): new_value, _ = self._force_call(args, kwargs) new_avg, new_var = running_moments( @@ -223,9 +224,10 @@ def _get_cache_key(self, *args, **kwargs) -> str: """Returns a string key used to identify the function and input parameter hash.""" func_hash = self._hash_function(self.func) argument_hash = self._hash_arguments( - self.func, self.cached_func_options.ignore_args, args, kwargs + self.func, self.config.ignore_args, args, kwargs ) - key = self.cache_backend.combine_hashes(func_hash, argument_hash) + hashes = list(filter(bool, [self.config.hash_prefix, func_hash, argument_hash])) + key = self.cache_backend.combine_hashes(*hashes) return key @staticmethod diff --git a/src/pydvl/utils/caching/config.py b/src/pydvl/utils/caching/config.py index 62e96c2aa..d44850274 100644 --- a/src/pydvl/utils/caching/config.py +++ b/src/pydvl/utils/caching/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Collection +from typing import Collection, Optional __all__ = ["CachedFuncConfig"] @@ -13,6 +13,7 @@ class CachedFuncConfig: of a [Utility][pydvl.utils.utility.Utility]. Args: + hash_prefix: Optional string prefix that be prepended to the cache key. ignore_args: Do not take these keyword arguments into account when hashing the wrapped function for usage as key. This allows sharing the cache among different jobs for the same experiment run if @@ -33,6 +34,7 @@ class CachedFuncConfig: this number to higher values to reduce variance. """ + hash_prefix: Optional[str] = None ignore_args: Collection[str] = field(default_factory=list) time_threshold: float = 0.3 allow_repeated_evaluations: bool = False diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index dbf5a1ad7..7374a04cf 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -28,6 +28,7 @@ from typing import Dict, FrozenSet, Iterable, Optional, Tuple, Union, cast import numpy as np +from joblib import hashing from numpy.typing import NDArray from sklearn.base import clone from sklearn.metrics import check_scoring @@ -147,13 +148,14 @@ def __init__( self.default_score = scorer.default if scorer is not None else default_score # TODO: auto-fill from known scorers ? self.score_range = scorer.range if scorer is not None else np.array(score_range) + self.clone_before_fit = clone_before_fit self.catch_errors = catch_errors self.show_warnings = show_warnings self.cache = cache_backend if cached_func_options is None: cached_func_options = CachedFuncConfig() + cached_func_options.hash_prefix = hashing.hash((model, data, scorer)) self.cached_func_options = cached_func_options - self.clone_before_fit = clone_before_fit self._initialize_utility_wrapper() # FIXME: can't modify docstring of methods. Instead, I could use a @@ -163,7 +165,7 @@ def __init__( def _initialize_utility_wrapper(self): if self.cache is not None: self._utility_wrapper = self.cache.wrap( - self._utility, cached_func_config=self.cached_func_options + self._utility, config=self.cached_func_options ) else: self._utility_wrapper = self._utility diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index 4f488988f..b02949e63 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -142,7 +142,7 @@ def test_cache_backend_serialization(cache_backend): def test_single_job(cache_backend): cached_func_config = CachedFuncConfig(time_threshold=0.0) - wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) n = 1000 wrapped_foo(np.arange(n)) @@ -169,7 +169,7 @@ def test_memcached_failed_connection(): def test_cache_time_threshold(cache_backend): cached_func_config = CachedFuncConfig(time_threshold=1.0) - wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) n = 1000 indices = np.arange(n) @@ -190,7 +190,7 @@ def test_cache_ignore_args(cache_backend): time_threshold=0.0, ignore_args=["job_id"], ) - wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) n = 1000 indices = np.arange(n) @@ -212,7 +212,7 @@ def test_parallel_jobs(cache_backend, parallel_config): cached_func_config = CachedFuncConfig( ignore_args=["job_id", "run_id"], ) - wrapped_foo = cache_backend.wrap(foo, cached_func_config=cached_func_config) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) n = 1234 n_runs = 10 @@ -243,7 +243,7 @@ def test_repeated_training(cache_backend, worker_id: str): ) wrapped_foo = cache_backend.wrap( foo_with_random, - cached_func_config=cached_func_config, + config=cached_func_config, ) n = 7 @@ -264,7 +264,7 @@ def test_faster_with_repeated_training(cache_backend, worker_id: str): ) wrapped_foo = cache_backend.wrap( foo_with_random_and_sleep, - cached_func_config=cached_func_config, + config=cached_func_config, ) n = 3 @@ -311,7 +311,7 @@ def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: ) wrapped_map_func = cache_backend.wrap( map_func, - cached_func_config=cached_func_config, + config=cached_func_config, ) def reduce_func(chunks: NDArray[np.float_]) -> float: diff --git a/tests/utils/test_utility.py b/tests/utils/test_utility.py index 5d2fd6504..335b0c136 100644 --- a/tests/utils/test_utility.py +++ b/tests/utils/test_utility.py @@ -79,6 +79,47 @@ def test_utility_with_cache(linear_dataset): assert u._utility_wrapper.stats.hits == len(subsets), u._utility_wrapper.stats +@pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) +def test_different_utility_with_same_cache(linear_dataset): + cache_backend = InMemoryCacheBackend() + u1 = Utility( + model=LinearRegression(), + data=linear_dataset, + scorer=Scorer("r2"), + cache_backend=cache_backend, + cached_func_options=CachedFuncConfig(time_threshold=0.0), + ) + u2 = Utility( + model=LinearRegression(), + data=linear_dataset, + scorer=Scorer("max_error"), + cache_backend=cache_backend, + cached_func_options=CachedFuncConfig(time_threshold=0.0), + ) + + subset = u1.data.indices + # Call first utility with empty cache + # We expect a cache miss + u1(subset) + assert cache_backend.stats.hits == 0 + assert cache_backend.stats.misses == 1 + assert cache_backend.stats.sets == 1 + + # Call first utility again + # We expect a cache hit + u1(subset) + assert cache_backend.stats.hits == 1 + assert cache_backend.stats.misses == 1 + assert cache_backend.stats.sets == 1 + + # Call second utility + # We expect a cache miss + u2(subset) + assert cache_backend.stats.hits == 1 + assert cache_backend.stats.misses == 2 + assert cache_backend.stats.sets == 2 + + @pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) @pytest.mark.parametrize("use_cache", [False, True]) def test_utility_serialization(linear_dataset, use_cache): From 02da3429c220ff56e78f93288668f4944ec6da52 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 14 Dec 2023 21:11:57 +0100 Subject: [PATCH 27/29] Please mypy --- src/pydvl/utils/caching/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py index f301003e6..b3574644a 100644 --- a/src/pydvl/utils/caching/base.py +++ b/src/pydvl/utils/caching/base.py @@ -226,7 +226,9 @@ def _get_cache_key(self, *args, **kwargs) -> str: argument_hash = self._hash_arguments( self.func, self.config.ignore_args, args, kwargs ) - hashes = list(filter(bool, [self.config.hash_prefix, func_hash, argument_hash])) + hashes = [func_hash, argument_hash] + if self.config.hash_prefix is not None: + hashes.insert(0, self.config.hash_prefix) key = self.cache_backend.combine_hashes(*hashes) return key From a2662f23fa6e9aa7d6201a24de769b1b92bb9803 Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Thu, 14 Dec 2023 22:04:25 +0100 Subject: [PATCH 28/29] Use builtin hash to compute hash_prefix --- src/pydvl/utils/utility.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 7374a04cf..3b27d6e48 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -23,12 +23,12 @@ learning](https://arxiv.org/abs/2107.06336). arXiv preprint arXiv:2107.06336. """ +import hashlib import logging import warnings from typing import Dict, FrozenSet, Iterable, Optional, Tuple, Union, cast import numpy as np -from joblib import hashing from numpy.typing import NDArray from sklearn.base import clone from sklearn.metrics import check_scoring @@ -154,7 +154,8 @@ def __init__( self.cache = cache_backend if cached_func_options is None: cached_func_options = CachedFuncConfig() - cached_func_options.hash_prefix = hashing.hash((model, data, scorer)) + # TODO: Find a better way to do this + cached_func_options.hash_prefix = str(hash((model, data, scorer))) self.cached_func_options = cached_func_options self._initialize_utility_wrapper() From 471a64eab74ee41666fc13c60bd79d17ded1fca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristof=20Schr=C3=B6der?= Date: Mon, 18 Dec 2023 12:49:05 +0100 Subject: [PATCH 29/29] Add suggestions from review session --- src/pydvl/utils/caching/base.py | 7 +++++++ src/pydvl/utils/caching/config.py | 5 +++-- src/pydvl/utils/caching/disk.py | 6 +++--- src/pydvl/utils/caching/memory.py | 4 ++-- src/pydvl/utils/utility.py | 24 ++++++++++++++++++------ 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py index b3574644a..85c46326b 100644 --- a/src/pydvl/utils/caching/base.py +++ b/src/pydvl/utils/caching/base.py @@ -132,6 +132,13 @@ class CachedFunc: This class is heavily inspired from that of [joblib.memory.MemorizedFunc][]. + This class caches calls to the wrapped callable by generating a hash key + based on the wrapped callable's code, the arguments passed to it and the optional + hash_prefix. + + !!! Warning + This class only works with hashable arguments to the wrapped callable. + Args: func: Callable to wrap. cache_backend: Instance of CacheBackendBase that handles diff --git a/src/pydvl/utils/caching/config.py b/src/pydvl/utils/caching/config.py index d44850274..c110ffce7 100644 --- a/src/pydvl/utils/caching/config.py +++ b/src/pydvl/utils/caching/config.py @@ -14,6 +14,7 @@ class CachedFuncConfig: Args: hash_prefix: Optional string prefix that be prepended to the cache key. + This can be provided in order to guarantee cache reuse across runs. ignore_args: Do not take these keyword arguments into account when hashing the wrapped function for usage as key. This allows sharing the cache among different jobs for the same experiment run if @@ -26,8 +27,8 @@ class CachedFuncConfig: running standard deviation of the mean stabilizes below `rtol_stderr * mean`. rtol_stderr: relative tolerance for repeated evaluations. More precisely, - [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function once the - standard deviation of the mean is smaller than `rtol_stderr * mean`. + [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function + once the standard deviation of the mean is smaller than `rtol_stderr * mean`. min_repetitions: minimum number of times that a function evaluation on the same arguments is repeated before returning cached values. Useful for stochastic functions only. If the model training is very noisy, set diff --git a/src/pydvl/utils/caching/disk.py b/src/pydvl/utils/caching/disk.py index c75513dec..06250a450 100644 --- a/src/pydvl/utils/caching/disk.py +++ b/src/pydvl/utils/caching/disk.py @@ -29,7 +29,7 @@ class DiskCacheBackend(CacheBackend): ??? Examples ``` pycon >>> from pydvl.utils.caching.disk import DiskCacheBackend - >>> cache_backend = DiskCacheBackend(cache_dir="/tmp/pydvl_disk_cache") + >>> cache_backend = DiskCacheBackend() >>> cache_backend.clear() >>> value = 42 >>> cache_backend.set("key", value) @@ -38,8 +38,8 @@ class DiskCacheBackend(CacheBackend): ``` ``` pycon - >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend - >>> cache_backend = MemcachedCacheBackend() + >>> from pydvl.utils.caching.disk import DiskCacheBackend + >>> cache_backend = DiskCacheBackend() >>> cache_backend.clear() >>> value = 42 >>> def foo(x: int): diff --git a/src/pydvl/utils/caching/memory.py b/src/pydvl/utils/caching/memory.py index 8843e7083..270d3ce1a 100644 --- a/src/pydvl/utils/caching/memory.py +++ b/src/pydvl/utils/caching/memory.py @@ -30,8 +30,8 @@ class InMemoryCacheBackend(CacheBackend): ``` ``` pycon - >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend - >>> cache_backend = MemcachedCacheBackend() + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> cache_backend = InMemoryCacheBackend() >>> cache_backend.clear() >>> value = 42 >>> def foo(x: int): diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 3b27d6e48..b975c0ff2 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -120,6 +120,20 @@ class Utility: 0.9 ``` + With caching enabled: + + ```pycon + >>> from pydvl.utils import Utility, DataUtilityLearning, Dataset + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> from sklearn.linear_model import LinearRegression, LogisticRegression + >>> from sklearn.datasets import load_iris + >>> dataset = Dataset.from_sklearn(load_iris(), random_state=16) + >>> cache_backend = InMemoryCacheBackend() + >>> u = Utility(LogisticRegression(random_state=16), dataset, cache_backend=cache_backend) + >>> u(dataset.indices) + 0.9 + ``` + """ model: SupervisedModel @@ -154,15 +168,13 @@ def __init__( self.cache = cache_backend if cached_func_options is None: cached_func_options = CachedFuncConfig() - # TODO: Find a better way to do this - cached_func_options.hash_prefix = str(hash((model, data, scorer))) + # TODO: Find a better way to do this. + if cached_func_options.hash_prefix is None: + # FIX: This does not handle reusing the same across runs. + cached_func_options.hash_prefix = str(hash((model, data, scorer))) self.cached_func_options = cached_func_options self._initialize_utility_wrapper() - # FIXME: can't modify docstring of methods. Instead, I could use a - # factory which creates the class on the fly with the right doc. - # self.__call__.__doc__ = self._utility_wrapper.__doc__ - def _initialize_utility_wrapper(self): if self.cache is not None: self._utility_wrapper = self.cache.wrap(