diff --git a/CHANGELOG.md b/CHANGELOG.md index 967555de2..a0d0f51a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,11 @@ # Changelog - ## Unreleased ### Added +- New cache backends: InMemoryCacheBackend and DiskCacheBackend + [PR #458](https://github.com/aai-institute/pyDVL/pull/458) - New influence function interface `InfluenceFunctionModel` - Data parallel computation with `DaskInfluenceCalculator` [PR #26](https://github.com/aai-institute/pyDVL/issues/26) @@ -15,6 +16,8 @@ ### Changed +- Refactor and simplify caching implementation + [PR #458](https://github.com/aai-institute/pyDVL/pull/458) - Simplify display of computation progress [PR #466](https://github.com/aai-institute/pyDVL/pull/466) - Improve readme and explain better the examples diff --git a/docs/getting-started/first-steps.md b/docs/getting-started/first-steps.md index a86cf6307..403724362 100644 --- a/docs/getting-started/first-steps.md +++ b/docs/getting-started/first-steps.md @@ -9,8 +9,7 @@ alias: !!! Warning Make sure you have read [[installation]] before using the library. - In particular read about how caching and parallelization work, - since they might require additional setup. + In particular read about which extra dependencies you may need. ## Main concepts @@ -23,7 +22,6 @@ should be enough to get you started. computation and related methods. * [[influence-values]] for instructions on how to compute influence functions. - ## Running the examples If you are somewhat familiar with the concepts of data valuation, you can start @@ -36,23 +34,22 @@ by browsing our worked-out examples illustrating pyDVL's capabilities either: have to install jupyter first manually since it's not a dependency of the library. -# Advanced usage +## Advanced usage -Besides the do's and don'ts of data valuation itself, which are the subject of +Besides the dos and don'ts of data valuation itself, which are the subject of the examples and the documentation of each method, there are two main things to keep in mind when using pyDVL. -## Caching - -pyDVL uses [memcached](https://memcached.org/) to cache the computation of the -utility function and speed up some computations (see the [installation -guide](installation.md/#setting-up-the-cache)). +### Caching -Caching of the utility function is disabled by default. When it is enabled it -takes into account the data indices passed as argument and the utility function -wrapped into the [Utility][pydvl.utils.utility.Utility] object. This means that +PyDVL can cache (memoize) the computation of the utility function +and speed up some computations for data valuation. +It is however disabled by default. +When it is enabled it takes into account the data indices passed as argument +and the utility function wrapped into the +[Utility][pydvl.utils.utility.Utility] object. This means that care must be taken when reusing the same utility function with different data, -see the documentation for the [caching module][pydvl.utils.caching] for more +see the documentation for the [caching package][pydvl.utils.caching] for more information. In general, caching won't play a major role in the computation of Shapley values @@ -61,20 +58,65 @@ the same utility function computation, is very low. However, it can be very useful when comparing methods that use the same utility function, or when running multiple experiments with the same data. +pyDVL supports 3 different caching backends: + +- [InMemoryCacheBackend][pydvl.utils.caching.memory.InMemoryCacheBackend]: + an in-memory cache backend that uses a dictionary to store and retrieve + cached values. This is used to share cached values between threads + in a single process. +- [DiskCacheBackend][pydvl.utils.caching.disk.DiskCacheBackend]: + a disk-based cache backend that uses pickled values written to and read from disk. + This is used to share cached values between processes in a single machine. +- [MemcachedCacheBackend][pydvl.utils.caching.memcached.MemcachedCacheBackend]: + a [Memcached](https://memcached.org/)-based cache backend that uses pickled values written to + and read from a Memcached server. This is used to share cached values + between processes across multiple machines. + + **Note** This specific backend requires optional dependencies. + See [[installation#extras]] for more information) + !!! tip "When is the cache really necessary?" Crucially, semi-value computations with the [PermutationSampler][pydvl.value.sampler.PermutationSampler] require caching to be enabled, or they will take twice as long as the direct implementation in [compute_shapley_values][pydvl.value.shapley.compute_shapley_values]. -## Parallelization +!!! tip "Using the cache" + Continue reading about the cache in the documentation + for the [caching package][pydvl.utils.caching]. + +#### Setting up the Memcached cache + +[Memcached](https://memcached.org/) is an in-memory key-value store accessible +over the network. pyDVL can use it to cache the computation of the utility function +and speed up some computations (in particular, semi-value computations with the +[PermutationSampler][pydvl.value.sampler.PermutationSampler] but other methods +may benefit as well). + +You can either install it as a package or run it inside a docker container (the +simplest). For installation instructions, refer to the [Getting +started](https://github.com/memcached/memcached/wiki#getting-started) section in +memcached's wiki. Then you can run it with: -pyDVL supports [joblib](https://joblib.readthedocs.io/en/latest/) for local -parallelization (within one machine) and [ray](https://ray.io) for distributed -parallelization (across multiple machines). +```shell +memcached -u user +``` -The former works out of the box but for the latter you will need to provide a -running cluster (or run ray in local mode). +To run memcached inside a container in daemon mode instead, use: + +```shell +docker container run -d --rm -p 11211:11211 memcached:latest +``` + +### Parallelization + +pyDVL uses [joblib](https://joblib.readthedocs.io/en/latest/) for local +parallelization (within one machine) and supports using +[Ray](https://ray.io) for distributed parallelization (across multiple machines). + +The former works out of the box but for the latter you will need to install +additional dependencies (see [[installation#extras]] ) +and to provide a running cluster (or run ray in local mode). As of v0.7.0 pyDVL does not allow requesting resources per task sent to the cluster, so you will need to make sure that each worker has enough resources to @@ -82,3 +124,16 @@ handle the tasks it receives. A data valuation task using game-theoretic methods will typically make a copy of the whole model and dataset to each worker, even if the re-training only happens on a subset of the data. This means that you should make sure that each worker has enough memory to handle the whole dataset. + +#### Ray + +Please follow the instructions in Ray's documentation to set up a cluster. +Once you have a running cluster, you can use it by passing the address +of the head node to parallel methods via [ParallelConfig][pydvl.parallel.config.ParallelConfig]. + +For a local ray cluster you would use: + +```python +from pydvl.parallel.config import ParallelConfig +config = ParallelConfig(backend="ray") +``` diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md index 2d2164ada..125f81d13 100644 --- a/docs/getting-started/installation.md +++ b/docs/getting-started/installation.md @@ -13,32 +13,6 @@ To install the latest release use: pip install pyDVL ``` -To use all features of influence functions use instead: - -```shell -pip install pyDVL[influence] -``` - -This includes a dependency on [PyTorch](https://pytorch.org/) (Version 2.0 and -above) and thus is left out by default. - -In case that you have a supported version of CUDA installed (v11.2 to 11.8 as of -this writing), you can enable eigenvalue computations for low-rank approximations -with [CuPy](https://docs.cupy.dev/en/stable/index.html) on the GPU by using: - -```shell -pip install pyDVL[cupy] -``` - -If you use a different version of CUDA, please install CuPy -[manually](https://docs.cupy.dev/en/stable/install.html). - -In order to check the installation you can use: - -```shell -python -c "import pydvl; print(pydvl.__version__)" -``` - You can also install the latest development version from [TestPyPI](https://test.pypi.org/project/pyDVL/): @@ -46,42 +20,71 @@ You can also install the latest development version from pip install pyDVL --index-url https://test.pypi.org/simple/ ``` -## Dependencies - -pyDVL requires Python >= 3.8, [Memcached](https://memcached.org/) for caching -and [Ray](https://ray.io) for parallelization in a cluster (locally it uses joblib). -Additionally, the [Influence functions][pydvl.influence] module requires PyTorch -(see [[installation]]). - -ray is used to distribute workloads across nodes in a cluster (it can be used -locally as well, but for this we recommend joblib instead). Please follow the -instructions in their documentation to set up the cluster. Once you have a -running cluster, you can use it by passing the address of the head node to -parallel methods via [ParallelConfig][pydvl.utils.parallel]. - -## Setting up the cache - -[memcached](https://memcached.org/) is an in-memory key-value store accessible -over the network. pyDVL uses it to cache the computation of the utility function -and speed up some computations (in particular, semi-value computations with the -[PermutationSampler][pydvl.value.sampler.PermutationSampler] but other methods -may benefit as well). - -You can either install it as a package or run it inside a docker container (the -simplest). For installation instructions, refer to the [Getting -started](https://github.com/memcached/memcached/wiki#getting-started) section in -memcached's wiki. Then you can run it with: +In order to check the installation you can use: ```shell -memcached -u user +python -c "import pydvl; print(pydvl.__version__)" ``` -To run memcached inside a container in daemon mode instead, do: - -```shell -docker container run -d --rm -p 11211:11211 memcached:latest -``` +## Dependencies -!!! tip "Using the cache" - Continue reading about the cache in the [First Steps](first-steps.md#caching) - and the documentation for the [caching module][pydvl.utils.caching]. +pyDVL requires Python >= 3.8, [numpy](https://numpy.org/), +[scikit-learn](https://scikit-learn.org/stable/), [scipy](https://scipy.org/), +[cvxpy](https://www.cvxpy.org/) for the Core methods, +and [joblib](https://joblib.readthedocs.io/en/stable/) +for parallelization locally. Additionally,the [Influence functions][pydvl.influence] +module requires PyTorch (see [[installation#extras]]). + +### Extras + +pyDVL has a few [extra](https://peps.python.org/pep-0508/#extras) dependencies +that can be optionally installed: + +- `influence`: + + To use all features of influence functions use instead: + + ```shell + pip install pyDVL[influence] + ``` + + This includes a dependency on [PyTorch](https://pytorch.org/) (Version 2.0 and + above) and thus is left out by default. + +- `cupy`: + + In case that you have a supported version of CUDA installed (v11.2 to 11.8 as of + this writing), you can enable eigenvalue computations for low-rank approximations + with [CuPy](https://docs.cupy.dev/en/stable/index.html) on the GPU by using: + + ```shell + pip install pyDVL[cupy] + ``` + + This installs [cupy-cuda11x](https://pypi.org/project/cupy-cuda11x/). + + If you use a different version of CUDA, please install CuPy + [manually](https://docs.cupy.dev/en/stable/install.html). + +- `ray`: + + If you want to use [Ray](https://www.ray.io/) to distribute data valuation + workloads across nodes in a cluster (it can be used locally as well, + but for this we recommend joblib instead) install pyDVL using: + + ```shell + pip install pyDVL[ray] + ``` + + see [[getting-started#ray]] for more details on how to use it. + +- `memcached`: + + If you want to use [Memcached](https://memcached.org/) for caching + utility evaluations, use: + + ```shell + pip install pyDVL[memcached] + ``` + + This installs [pymemcache](https://github.com/pinterest/pymemcache) additionally. diff --git a/mkdocs.yml b/mkdocs.yml index b00eefdbe..408b26b75 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,6 +69,7 @@ plugins: - https://scikit-learn.org/stable/objects.inv - https://pytorch.org/docs/stable/objects.inv - https://pymemcache.readthedocs.io/en/latest/objects.inv + - https://joblib.readthedocs.io/en/stable/objects.inv - https://docs.dask.org/en/latest/objects.inv - https://distributed.dask.org/en/latest/objects.inv paths: [ src ] # search packages in the src folder diff --git a/requirements.txt b/requirements.txt index 471ba7475..aedacc12f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ scikit-learn scipy>=1.7.0 cvxpy>=1.3.0 joblib -pymemcache cloudpickle tqdm matplotlib diff --git a/setup.py b/setup.py index 240152868..e8d77d53a 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ tests_require=["pytest"], extras_require={ "cupy": ["cupy-cuda11x>=12.1.0"], + "memcached": ["pymemcache"], "influence": [ "torch>=2.0.0", "dask>=2023.5.0", diff --git a/src/pydvl/utils/caching.py b/src/pydvl/utils/caching.py deleted file mode 100644 index 37d087de4..000000000 --- a/src/pydvl/utils/caching.py +++ /dev/null @@ -1,339 +0,0 @@ -""" Distributed caching of functions. - -pyDVL uses [memcached](https://memcached.org) to cache utility values, through -[pymemcache](https://pypi.org/project/pymemcache). This allows sharing -evaluations across processes and nodes in a cluster. You can run memcached as a -service, locally or remotely, see [Setting up the cache](#setting-up-the-cache) - -!!! Warning - Function evaluations are cached with a key based on the function's signature - and code. This can lead to undesired cache hits, see [Cache reuse](#cache-reuse). - - Remember **not to reuse utility objects for different datasets**. - -# Configuration - -Memoization is disabled by default but can be enabled easily, -see [Setting up the cache](#setting-up-the-cache). -When enabled, it will be added to any callable used to construct a -[Utility][pydvl.utils.utility.Utility] (done with the decorator [@memcached][pydvl.utils.caching.memcached]). -Depending on the nature of the utility you might want to -enable the computation of a running average of function values, see -[Usage with stochastic functions](#usaage-with-stochastic-functions). -You can see all configuration options under [MemcachedConfig][pydvl.utils.config.MemcachedConfig]. - -## Default configuration - -```python -default_config = dict( - server=('localhost', 11211), - connect_timeout=1.0, - timeout=0.1, - # IMPORTANT! Disable small packet consolidation: - no_delay=True, - serde=serde.PickleSerde(pickle_version=PICKLE_VERSION) -) -``` - -# Usage with stochastic functions - -In addition to standard memoization, the decorator -[memcached()][pydvl.utils.caching.memcached] can compute running average and -standard error of repeated evaluations for the same input. This can be useful -for stochastic functions with high variance (e.g. model training for small -sample sizes), but drastically reduces the speed benefits of memoization. - -This behaviour can be activated with the argument `allow_repeated_evaluations` -to [memcached()][pydvl.utils.caching.memcached]. - -# Cache reuse - -When working directly with [memcached()][pydvl.utils.caching.memcached], it is -essential to only cache pure functions. If they have any kind of state, either -internal or external (e.g. a closure over some data that may change), then the -cache will fail to notice this and the same value will be returned. - -When a function is wrapped with [memcached()][pydvl.utils.caching.memcached] for -memoization, its signature (input and output names) and code are used as a key -for the cache. Alternatively you can pass a custom value to be used as key with - -```python -cached_fun = memcached(**asdict(cache_options))(fun, signature=custom_signature) -``` - -If you are running experiments with the same [Utility][pydvl.utils.utility.Utility] -but different datasets, this will lead to evaluations of the utility on new data -returning old values because utilities only use sample indices as arguments (so -there is no way to tell the difference between '1' for dataset A and '1' for -dataset 2 from the point of view of the cache). One solution is to empty the -cache between runs, but the preferred one is to **use a different Utility -object for each dataset**. - -# Unexpected cache misses - -Because all arguments to a function are used as part of the key for the cache, -sometimes one must exclude some of them. For example, If a function is going to -run across multiple processes and some reporting arguments are added (like a -`job_id` for logging purposes), these will be part of the signature and make the -functions distinct to the eyes of the cache. This can be avoided with the use of -[ignore_args][pydvl.utils.config.MemcachedConfig] in the configuration. - -""" -from __future__ import annotations - -import logging -import socket -import uuid -import warnings -from dataclasses import asdict, dataclass -from functools import wraps -from hashlib import blake2b -from io import BytesIO -from time import time -from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, cast - -from cloudpickle import Pickler -from pymemcache import MemcacheUnexpectedCloseError -from pymemcache.client import Client, RetryingClient - -from .config import MemcachedClientConfig -from .numeric import running_moments - -PICKLE_VERSION = 5 # python >= 3.8 - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -@dataclass -class CacheStats: - """Statistics gathered by cached functions. - - Attributes: - sets: number of times a value was set in the cache - misses: number of times a value was not found in the cache - hits: number of times a value was found in the cache - timeouts: number of times a timeout occurred - errors: number of times an error occurred - reconnects: number of times the client reconnected to the server - """ - - sets: int = 0 - misses: int = 0 - hits: int = 0 - timeouts: int = 0 - errors: int = 0 - reconnects: int = 0 - - -def serialize(x: Any) -> bytes: - """Serialize an object to bytes. - Args: - x: object to serialize. - - Returns: - serialized object. - """ - pickled_output = BytesIO() - pickler = Pickler(pickled_output, PICKLE_VERSION) - pickler.dump(x) - return pickled_output.getvalue() - - -def memcached( - client_config: Optional[MemcachedClientConfig] = None, - time_threshold: float = 0.3, - allow_repeated_evaluations: bool = False, - rtol_stderr: float = 0.1, - min_repetitions: int = 3, - ignore_args: Optional[Iterable[str]] = None, -) -> Callable[[Callable[..., T], bytes | None], Callable[..., T]]: - """ - Transparent, distributed memoization of function calls. - - Given a function and its signature, memcached uses a distributed cache - that, for each set of inputs, keeps track of the average returned value, - with variance and number of times it was calculated. - - If the function is deterministic, i.e. same input corresponds to the same - exact output, set `allow_repeated_evaluations` to `False`. If instead the - function is stochastic (like the training of a model depending on random - initializations), memcached() allows to set a minimum number of evaluations - to compute a running average, and a tolerance after which the function will - not be called anymore. In other words, the function will be recomputed - until the value has stabilized with a standard error smaller than - `rtol_stderr * running average`. - - !!! Warning - Do not cache functions with state! See [Cache reuse](cache-reuse) - - ??? Example - ```python - cached_fun = memcached(**asdict(cache_options))(heavy_computation) - ``` - - Args: - client_config: configuration for pymemcache's - [Client][pymemcache.client.base.Client]. - Will be merged on top of the default configuration (see below). - time_threshold: computations taking less time than this many seconds are - not cached. - allow_repeated_evaluations: If `True`, repeated calls to a function - with the same arguments will be allowed and outputs averaged until the - running standard deviation of the mean stabilizes below - `rtol_stderr * mean`. - rtol_stderr: relative tolerance for repeated evaluations. More precisely, - [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function once the - standard deviation of the mean is smaller than `rtol_stderr * mean`. - min_repetitions: minimum number of times that a function evaluation - on the same arguments is repeated before returning cached values. Useful - for stochastic functions only. If the model training is very noisy, set - this number to higher values to reduce variance. - ignore_args: Do not take these keyword arguments into account when - hashing the wrapped function for usage as key in memcached. This allows - sharing the cache among different jobs for the same experiment run if - the callable happens to have "nuisance" parameters like `job_id` which - do not affect the result of the computation. - - Returns: - A wrapped function - - """ - if ignore_args is None: - ignore_args = [] - - # Do I really need this? - def connect(config: MemcachedClientConfig): - """First tries to establish a connection, then tries setting and - getting a value.""" - try: - client = RetryingClient( - Client(**asdict(config)), - attempts=3, - retry_delay=0.1, - retry_for=[MemcacheUnexpectedCloseError], - ) - - temp_key = str(uuid.uuid4()) - client.set(temp_key, 7) - assert client.get(temp_key) == 7 - client.delete(temp_key, 0) - return client - except ConnectionRefusedError as e: - logger.error( # type: ignore - f"@memcached: Timeout connecting " - f"to {config.server} after " - f"{config.connect_timeout} seconds: {str(e)}. Did you start memcached?" - ) - raise e - except AssertionError as e: - logger.error( # type: ignore - f"@memcached: Failure saving dummy value " - f"to {config.server}: {str(e)}" - ) - - def wrapper(fun: Callable[..., T], signature: Optional[bytes] = None): - if signature is None: - signature = serialize((fun.__code__.co_code, fun.__code__.co_consts)) - - @wraps(fun, updated=[]) # don't try to use update() for a class - class Wrapped: - config: MemcachedClientConfig - stats: CacheStats - client: RetryingClient - - def __init__(self, config: MemcachedClientConfig): - self.config = config - self.stats = CacheStats() - self.client = connect(self.config) - self._signature = signature - - def __call__(self, *args, **kwargs) -> T: - key_kwargs = {k: v for k, v in kwargs.items() if k not in ignore_args} # type: ignore - arg_signature: bytes = serialize((args, list(key_kwargs.items()))) - - key = blake2b(self._signature + arg_signature).hexdigest().encode("ASCII") # type: ignore - - result_dict: Dict[str, float] = self.get_key_value(key) - if result_dict is None: - result_dict = {} - start = time() - value = fun(*args, **kwargs) - end = time() - result_dict["value"] = value - if end - start >= time_threshold or allow_repeated_evaluations: - result_dict["count"] = 1 - result_dict["variance"] = 0 - self.client.set(key, result_dict, noreply=True) - self.stats.sets += 1 - self.stats.misses += 1 - elif allow_repeated_evaluations: - self.stats.hits += 1 - value = result_dict["value"] - count = result_dict["count"] - variance = result_dict["variance"] - error_on_average = (variance / count) ** (1 / 2) - if ( - error_on_average > rtol_stderr * value - or count <= min_repetitions - ): - new_value = fun(*args, **kwargs) - new_avg, new_var = running_moments( - value, variance, int(count), cast(float, new_value) - ) - result_dict["value"] = new_avg - result_dict["count"] = count + 1 - result_dict["variance"] = new_var - self.client.set(key, result_dict, noreply=True) - self.stats.sets += 1 - else: - self.stats.hits += 1 - return result_dict["value"] # type: ignore - - def __getstate__(self): - """Enables pickling after a socket has been opened to the - memcached server, by removing the client from the stored - data.""" - odict = self.__dict__.copy() - del odict["client"] - return odict - - def __setstate__(self, d: dict): - """Restores a client connection after loading from a pickle.""" - self.config = d["config"] - self.stats = d["stats"] - self.client = Client(**asdict(self.config)) - self._signature = signature - - def get_key_value(self, key: bytes): - result = None - try: - result = self.client.get(key) - except socket.timeout as e: - self.stats.timeouts += 1 - warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) - except OSError as e: - self.stats.errors += 1 - warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) - except AttributeError as e: - # FIXME: this depends on _recv() failing on invalid sockets - # See pymemcache.base.py, - self.stats.reconnects += 1 - warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) - self.client = connect(self.config) - return result - - Wrapped.__doc__ = ( - f"A wrapper around {fun.__name__}() with remote caching enabled.\n" - + (Wrapped.__doc__ or "") - ) - Wrapped.__name__ = f"memcached_{fun.__name__}" - path = list(reversed(fun.__qualname__.split("."))) - patched = [f"memcached_{path[0]}"] + path[1:] - Wrapped.__qualname__ = ".".join(reversed(patched)) - - # TODO: pick from some config file or something - return Wrapped(client_config or MemcachedClientConfig()) - - return wrapper diff --git a/src/pydvl/utils/caching/__init__.py b/src/pydvl/utils/caching/__init__.py new file mode 100644 index 000000000..1089628bc --- /dev/null +++ b/src/pydvl/utils/caching/__init__.py @@ -0,0 +1,88 @@ +"""Caching of functions. + +PyDVL can cache (memoize) the computation of the utility function +and speed up some computations for data valuation. + +!!! Warning + Function evaluations are cached with a key based on the function's signature + and code. This can lead to undesired cache hits, see [Cache reuse](#cache-reuse). + + Remember **not to reuse utility objects for different datasets**. + +# Configuration + +Caching is disabled by default but can be enabled easily, +see [Setting up the cache](#setting-up-the-cache). +When enabled, it will be added to any callable used to construct a +[Utility][pydvl.utils.utility.Utility] (done with the wrap method of +[CacheBackend][pydvl.utils.caching.base.CacheBackend]). +Depending on the nature of the utility you might want to +enable the computation of a running average of function values, see +[Usage with stochastic functions](#usaage-with-stochastic-functions). +You can see all configuration options under +[CachedFuncConfig][pydvl.utils.caching.config.CachedFuncConfig]. + +# Supported Backends + +pyDVL supports 3 different caching backends: + +- [InMemoryCacheBackend][pydvl.utils.caching.memory.InMemoryCacheBackend]: + an in-memory cache backend that uses a dictionary to store and retrieve + cached values. This is used to share cached values between threads + in a single process. +- [DiskCacheBackend][pydvl.utils.caching.disk.DiskCacheBackend]: + a disk-based cache backend that uses pickled values written to and read from disk. + This is used to share cached values between processes in a single machine. +- [MemcachedCacheBackend][pydvl.utils.caching.memcached.MemcachedCacheBackend]: + a [Memcached](https://memcached.org/)-based cache backend that uses pickled values written to + and read from a Memcached server. This is used to share cached values + between processes across multiple machines. + + **Note** This specific backend requires optional dependencies. + See [[installation#extras]] for more information) + +# Usage with stochastic functions + +In addition to standard memoization, the wrapped functions +can compute running average and standard error of repeated evaluations +for the same input. This can be useful for stochastic functions with high variance +(e.g. model training for small sample sizes), but drastically reduces +the speed benefits of memoization. + +This behaviour can be activated with the option +[allow_repeated_evaluations][pydvl.utils.caching.config.CachedFuncConfig].. + +# Cache reuse + +When working directly with [CachedFunc][pydvl.utils.caching.base.CachedFunc], it is +essential to only cache pure functions. If they have any kind of state, either +internal or external (e.g. a closure over some data that may change), then the +cache will fail to notice this and the same value will be returned. + +When a function is wrapped with [CachedFunc][pydvl.utils.caching.base.CachedFunc] +for memoization, its signature (input and output names) and code are used as a key +for the cache. + +If you are running experiments with the same [Utility][pydvl.utils.utility.Utility] +but different datasets, this will lead to evaluations of the utility on new data +returning old values because utilities only use sample indices as arguments (so +there is no way to tell the difference between '1' for dataset A and '1' for +dataset 2 from the point of view of the cache). One solution is to empty the +cache between runs by calling the `clear` method of the cache backend instance, +but the preferred one is to **use a different Utility object for each dataset**. + +# Unexpected cache misses + +Because all arguments to a function are used as part of the key for the cache, +sometimes one must exclude some of them. For example, If a function is going to +run across multiple processes and some reporting arguments are added (like a +`job_id` for logging purposes), these will be part of the signature and make the +functions distinct to the eyes of the cache. This can be avoided with the use of +[ignore_args][pydvl.utils.caching.config.CachedFuncConfig] option in the configuration. + +""" +from .base import * +from .config import * +from .disk import * +from .memcached import * +from .memory import * diff --git a/src/pydvl/utils/caching/base.py b/src/pydvl/utils/caching/base.py new file mode 100644 index 000000000..85c46326b --- /dev/null +++ b/src/pydvl/utils/caching/base.py @@ -0,0 +1,296 @@ +import inspect +import logging +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Collection, Dict, Optional, Tuple, cast + +from joblib import hashing +from joblib.func_inspect import filter_args + +from ..numeric import running_moments +from .config import CachedFuncConfig + +__all__ = ["CacheStats", "CacheBackend", "CachedFunc"] + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheStats: + """Class used to store statistics gathered by cached functions. + + Attributes: + sets: Number of times a value was set in the cache. + misses: Number of times a value was not found in the cache. + hits: Number of times a value was found in the cache. + timeouts: Number of times a timeout occurred. + errors: Number of times an error occurred. + reconnects: Number of times the client reconnected to the server. + """ + + sets: int = 0 + misses: int = 0 + hits: int = 0 + timeouts: int = 0 + errors: int = 0 + reconnects: int = 0 + + +@dataclass +class CacheResult: + """A class used to store the cached result of a computation + as well as count and variance when using repeated evaluation. + + Attributes: + value: Cached value. + count: Number of times this value has been computed. + variance: Variance associated with the cached value. + """ + + value: float + count: int = 1 + variance: float = 0.0 + + +class CacheBackend(ABC): + """Abstract base class for cache backends. + + Defines interface for cache access including wrapping callables, + getting/setting results, clearing cache, and combining cache keys. + + Attributes: + stats: Cache statistics tracker. + """ + + def __init__(self) -> None: + self.stats = CacheStats() + + def wrap( + self, + func: Callable, + *, + config: Optional[CachedFuncConfig] = None, + ) -> "CachedFunc": + """Wraps a function to cache its results. + + Args: + func: The function to wrap. + config: Optional caching options for the wrapped function. + + Returns: + The wrapped cached function. + """ + return CachedFunc( + func, + cache_backend=self, + config=config, + ) + + @abstractmethod + def get(self, key: str) -> Optional[CacheResult]: + """Abstract method to retrieve a cached result. + + Implemented by subclasses. + + Args: + key: The cache key. + + Returns: + The cached result or None if not found. + """ + pass + + @abstractmethod + def set(self, key: str, value: CacheResult) -> None: + """Abstract method to set a cached result. + + Implemented by subclasses. + + Args: + key: The cache key. + value: The result to cache. + """ + pass + + @abstractmethod + def clear(self) -> None: + """Abstract method to clear the entire cache.""" + pass + + @abstractmethod + def combine_hashes(self, *args: str) -> str: + """Abstract method to combine cache keys.""" + pass + + +class CachedFunc: + """Caches callable function results with a provided cache backend. + + Wraps a callable function to cache its results using a provided + an instance of a subclass of [CacheBackend][pydvl.utils.caching.base.CacheBackend]. + + This class is heavily inspired from that of [joblib.memory.MemorizedFunc][]. + + This class caches calls to the wrapped callable by generating a hash key + based on the wrapped callable's code, the arguments passed to it and the optional + hash_prefix. + + !!! Warning + This class only works with hashable arguments to the wrapped callable. + + Args: + func: Callable to wrap. + cache_backend: Instance of CacheBackendBase that handles + setting and getting values. + config: Configuration for wrapped function. + """ + + def __init__( + self, + func: Callable[..., float], + *, + cache_backend: CacheBackend, + config: Optional[CachedFuncConfig] = None, + ) -> None: + self.func = func + self.cache_backend = cache_backend + if config is None: + config = CachedFuncConfig() + self.config = config + + self.__doc__ = f"A wrapper around {func.__name__}() with caching enabled.\n" + ( + CachedFunc.__doc__ or "" + ) + self.__name__ = f"cached_{func.__name__}" + path = list(reversed(func.__qualname__.split("."))) + patched = [f"cached_{path[0]}"] + path[1:] + self.__qualname__ = ".".join(reversed(patched)) + + def __call__(self, *args, **kwargs) -> float: + """Call the wrapped cached function. + + Executes the wrapped function, caching and returning the result. + """ + return self._cached_call(args, kwargs) + + def _force_call(self, args, kwargs) -> Tuple[float, float]: + """Force re-evaluation of the wrapped function. + + Executes the wrapped function without caching. + + Returns: + Function result and execution duration. + """ + start = time.monotonic() + value = self.func(*args, **kwargs) + end = time.monotonic() + duration = end - start + return value, duration + + def _cached_call(self, args, kwargs) -> float: + """Cached wrapped function call. + + Executes the wrapped function with cache checking/setting. + + Returns: + Cached result of the wrapped function. + """ + key = self._get_cache_key(*args, **kwargs) + cached_result = self.cache_backend.get(key) + if cached_result is None: + value, duration = self._force_call(args, kwargs) + result = CacheResult(value) + if ( + duration >= self.config.time_threshold + or self.config.allow_repeated_evaluations + ): + self.cache_backend.set(key, result) + else: + result = cached_result + if self.config.allow_repeated_evaluations: + error_on_average = (result.variance / result.count) ** (1 / 2) + if ( + error_on_average > self.config.rtol_stderr * result.value + or result.count <= self.config.min_repetitions + ): + new_value, _ = self._force_call(args, kwargs) + new_avg, new_var = running_moments( + result.value, + result.variance, + result.count, + cast(float, new_value), + ) + result.value = new_avg + result.count += 1 + result.variance = new_var + self.cache_backend.set(key, result) + return result.value + + def _get_cache_key(self, *args, **kwargs) -> str: + """Returns a string key used to identify the function and input parameter hash.""" + func_hash = self._hash_function(self.func) + argument_hash = self._hash_arguments( + self.func, self.config.ignore_args, args, kwargs + ) + hashes = [func_hash, argument_hash] + if self.config.hash_prefix is not None: + hashes.insert(0, self.config.hash_prefix) + key = self.cache_backend.combine_hashes(*hashes) + return key + + @staticmethod + def _hash_function(func: Callable) -> str: + """Create hash for wrapped function.""" + func_hash: str = hashing.hash((func.__code__.co_code, func.__code__.co_consts)) + return func_hash + + @staticmethod + def _hash_arguments( + func: Callable, + ignore_args: Collection[str], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> str: + """Create hash for function arguments.""" + args_hash: str = hashing.hash( + CachedFunc._filter_args(func, ignore_args, args, kwargs), + ) + return args_hash + + @staticmethod + def _filter_args( + func: Callable, + ignore_args: Collection[str], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Filter arguments to exclude from cache keys.""" + # Remove kwargs before calling filter_args + # Because some of them might not be explicitly in the function's signature + # and that would raise an error when calling filter_args + kwargs = {k: v for k, v in kwargs.items() if k not in ignore_args} # type: ignore + # Update ignore_args + func_signature = inspect.signature(func) + arg_names = [] + for param in func_signature.parameters.values(): + if param.kind in [ + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + param.KEYWORD_ONLY, + ]: + arg_names.append(param.name) + ignore_args = [x for x in ignore_args if x in arg_names] + filtered_args: Dict[str, Any] = filter_args(func, ignore_args, args, kwargs) # type: ignore + # We ignore 'self' because for our use case we only care about the method. + # We don't want a cache if another attribute changes in the instance. + try: + filtered_args.pop("self") + except KeyError: + pass + return filtered_args # type: ignore + + @property + def stats(self) -> CacheStats: + """Cache backend statistics.""" + return self.cache_backend.stats diff --git a/src/pydvl/utils/caching/config.py b/src/pydvl/utils/caching/config.py new file mode 100644 index 000000000..c110ffce7 --- /dev/null +++ b/src/pydvl/utils/caching/config.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass, field +from typing import Collection, Optional + +__all__ = ["CachedFuncConfig"] + + +@dataclass +class CachedFuncConfig: + """Configuration for cached functions and methods, providing + memoization of function calls. + + Instances of this class are typically used as arguments for the construction + of a [Utility][pydvl.utils.utility.Utility]. + + Args: + hash_prefix: Optional string prefix that be prepended to the cache key. + This can be provided in order to guarantee cache reuse across runs. + ignore_args: Do not take these keyword arguments into account when + hashing the wrapped function for usage as key. This allows + sharing the cache among different jobs for the same experiment run if + the callable happens to have "nuisance" parameters like `job_id` which + do not affect the result of the computation. + time_threshold: Computations taking less time than this many seconds are + not cached. A value of 0 means that it will always cache results. + allow_repeated_evaluations: If `True`, repeated calls to a function + with the same arguments will be allowed and outputs averaged until the + running standard deviation of the mean stabilizes below + `rtol_stderr * mean`. + rtol_stderr: relative tolerance for repeated evaluations. More precisely, + [memcached()][pydvl.utils.caching.memcached] will stop evaluating the function + once the standard deviation of the mean is smaller than `rtol_stderr * mean`. + min_repetitions: minimum number of times that a function evaluation + on the same arguments is repeated before returning cached values. Useful + for stochastic functions only. If the model training is very noisy, set + this number to higher values to reduce variance. + """ + + hash_prefix: Optional[str] = None + ignore_args: Collection[str] = field(default_factory=list) + time_threshold: float = 0.3 + allow_repeated_evaluations: bool = False + rtol_stderr: float = 0.1 + min_repetitions: int = 3 diff --git a/src/pydvl/utils/caching/disk.py b/src/pydvl/utils/caching/disk.py new file mode 100644 index 000000000..06250a450 --- /dev/null +++ b/src/pydvl/utils/caching/disk.py @@ -0,0 +1,118 @@ +import os +import shutil +import tempfile +from pathlib import Path +from typing import Any, Optional, Union + +import cloudpickle + +from pydvl.utils.caching.base import CacheBackend + +__all__ = ["DiskCacheBackend"] + +PICKLE_VERSION = 5 # python >= 3.8 + + +class DiskCacheBackend(CacheBackend): + """Disk cache backend that stores results in files. + + Implements the CacheBackend interface for a disk-based cache. + Stores cache entries as pickled files on disk, keyed by cache key. + This allows sharing evaluations across processes in a single node/computer. + + Args: + cache_dir: Base directory for cache storage. + + Attributes: + cache_dir: Base directory for cache storage. + + ??? Examples + ``` pycon + >>> from pydvl.utils.caching.disk import DiskCacheBackend + >>> cache_backend = DiskCacheBackend() + >>> cache_backend.clear() + >>> value = 42 + >>> cache_backend.set("key", value) + >>> cache_backend.get("key") + 42 + ``` + + ``` pycon + >>> from pydvl.utils.caching.disk import DiskCacheBackend + >>> cache_backend = DiskCacheBackend() + >>> cache_backend.clear() + >>> value = 42 + >>> def foo(x: int): + ... return x + 1 + ... + >>> wrapped_foo = cache_backend.wrap(foo) + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 0 + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 1 + ``` + + """ + + def __init__( + self, + cache_dir: Optional[Union[os.PathLike, str]] = None, + ) -> None: + """Initialize the disk cache backend. + + Args: + cache_dir: Base directory for cache storage. + If not provided, this defaults to a newly created + temporary directory. + """ + super().__init__() + if cache_dir is None: + cache_dir = tempfile.mkdtemp(prefix="pydvl") + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True, parents=True) + + def get(self, key: str) -> Optional[Any]: + """Get a value from the cache. + + Args: + key: Cache key. + + Returns: + Cached value or None if not found. + """ + cache_file = self.cache_dir / key + if not cache_file.exists(): + self.stats.misses += 1 + return None + self.stats.hits += 1 + with cache_file.open("rb") as f: + return cloudpickle.load(f) + + def set(self, key: str, value: Any) -> None: + """Set a value in the cache. + + Args: + key: Cache key. + value: Value to cache. + """ + cache_file = self.cache_dir / key + self.stats.sets += 1 + with cache_file.open("wb") as f: + cloudpickle.dump(value, f, protocol=PICKLE_VERSION) + + def clear(self) -> None: + """Deletes cache directory and recreates it.""" + shutil.rmtree(self.cache_dir) + self.cache_dir.mkdir(exist_ok=True, parents=True) + + def combine_hashes(self, *args: str) -> str: + """Join cache key components.""" + return os.pathsep.join(args) diff --git a/src/pydvl/utils/caching/memcached.py b/src/pydvl/utils/caching/memcached.py new file mode 100644 index 000000000..63855682f --- /dev/null +++ b/src/pydvl/utils/caching/memcached.py @@ -0,0 +1,205 @@ +import logging +import socket +import uuid +import warnings +from dataclasses import asdict, dataclass +from typing import Any, Dict, Optional, Tuple + +try: + from pymemcache import MemcacheUnexpectedCloseError + from pymemcache.client import Client, RetryingClient + from pymemcache.serde import PickleSerde + + PYMEMCACHE_INSTALLED = True +except ImportError: + PYMEMCACHE_INSTALLED = False + +from .base import CacheBackend + +__all__ = ["MemcachedClientConfig", "MemcachedCacheBackend"] + +PICKLE_VERSION = 5 # python >= 3.8 + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class MemcachedClientConfig: + """Configuration of the memcached client. + + Args: + server: A tuple of (IP|domain name, port). + connect_timeout: How many seconds to wait before raising + `ConnectionRefusedError` on failure to connect. + timeout: Duration in seconds to wait for send or recv calls + on the socket connected to memcached. + no_delay: If True, set the `TCP_NODELAY` flag, which may help + with performance in some cases. + serde: Serializer / Deserializer ("serde"). The default `PickleSerde` + should work in most cases. See [pymemcache.client.base.Client][] + for details. + """ + + server: Tuple[str, int] = ("localhost", 11211) + connect_timeout: float = 1.0 + timeout: float = 1.0 + no_delay: bool = True + serde: PickleSerde = PickleSerde(pickle_version=PICKLE_VERSION) + + +class MemcachedCacheBackend(CacheBackend): + """Memcached cache backend for the distributed caching of functions. + + Implements the CacheBackend interface for a memcached based cache. + This allows sharing evaluations across processes and nodes in a cluster. + You can run memcached as a service, locally or remotely, + see [Setting up the cache](#setting-up-the-cache) + + Args: + config: Memcached client configuration. + + Attributes: + config: Memcached client configuration. + client: Memcached client instance. + + ??? Examples + ``` pycon + >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend + >>> cache_backend = MemcachedCacheBackend() + >>> cache_backend.clear() + >>> value = 42 + >>> cache_backend.set("key", value) + >>> cache_backend.get("key") + 42 + ``` + + ``` pycon + >>> from pydvl.utils.caching.memcached import MemcachedCacheBackend + >>> cache_backend = MemcachedCacheBackend() + >>> cache_backend.clear() + >>> value = 42 + >>> def foo(x: int): + ... return x + 1 + ... + >>> wrapped_foo = cache_backend.wrap(foo) + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 0 + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 1 + ``` + """ + + def __init__(self, config: MemcachedClientConfig = MemcachedClientConfig()) -> None: + """Initialize memcached cache backend. + + Args: + config: Memcached client configuration. + """ + if not PYMEMCACHE_INSTALLED: + raise ModuleNotFoundError( + "Cannot use MemcachedCacheBackend because pymemcache was not installed. " + "Make sure to install pyDVL using `pip install pyDVL[memcached]`" + ) + super().__init__() + self.config = config + self.client = self._connect(self.config) + + def get(self, key: str) -> Optional[Any]: + """Get value from memcached. + + Args: + key: Cache key. + + Returns: + Cached value or None if not found or client disconnected. + """ + result = None + try: + result = self.client.get(key) + except socket.timeout as e: + self.stats.timeouts += 1 + warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) + except OSError as e: + self.stats.errors += 1 + warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) + except AttributeError as e: + # FIXME: this depends on _recv() failing on invalid sockets + # See pymemcache.base.py, + self.stats.reconnects += 1 + warnings.warn(f"{type(self).__name__}: {str(e)}", RuntimeWarning) + self.client = self._connect(self.config) + if result is None: + self.stats.misses += 1 + else: + self.stats.hits += 1 + return result + + def set(self, key: str, value: Any) -> None: + """Set value in memcached. + + Args: + key: Cache key. + value: Value to cache. + """ + self.client.set(key, value, noreply=True) + self.stats.sets += 1 + + def clear(self) -> None: + """Flush all values from memcached.""" + self.client.flush_all(noreply=True) + + @staticmethod + def _connect(config: MemcachedClientConfig) -> RetryingClient: + """Connect to memcached server.""" + try: + client = RetryingClient( + Client(**asdict(config)), + attempts=3, + retry_delay=0.1, + retry_for=[MemcacheUnexpectedCloseError], + ) + + temp_key = str(uuid.uuid4()) + client.set(temp_key, 7) + assert client.get(temp_key) == 7 + client.delete(temp_key, 0) + return client + except ConnectionRefusedError as e: + logger.error( # type: ignore + f"@memcached: Timeout connecting " + f"to {config.server} after " + f"{config.connect_timeout} seconds: {str(e)}. Did you start memcached?" + ) + raise + except AssertionError as e: + logger.error( # type: ignore + f"@memcached: Failure saving dummy value " + f"to {config.server}: {str(e)}" + ) + raise + + def combine_hashes(self, *args: str) -> str: + """Join cache key components for Memcached.""" + return ":".join(args) + + def __getstate__(self) -> Dict: + """Enables pickling after a socket has been opened to the + memcached server, by removing the client from the stored + data.""" + odict = self.__dict__.copy() + del odict["client"] + return odict + + def __setstate__(self, d: Dict): + """Restores a client connection after loading from a pickle.""" + self.config = d["config"] + self.stats = d["stats"] + self.client = self._connect(self.config) diff --git a/src/pydvl/utils/caching/memory.py b/src/pydvl/utils/caching/memory.py new file mode 100644 index 000000000..270d3ce1a --- /dev/null +++ b/src/pydvl/utils/caching/memory.py @@ -0,0 +1,94 @@ +import os +from typing import Any, Dict, Optional + +from pydvl.utils.caching.base import CacheBackend + +__all__ = ["InMemoryCacheBackend"] + + +class InMemoryCacheBackend(CacheBackend): + """In-memory cache backend that stores results in a dictionary. + + Implements the CacheBackend interface for an in-memory-based cache. + Stores cache entries as values in a dictionary, keyed by cache key. + This allows sharing evaluations across threads in a single process. + + The implementation is not thread-safe. + + Attributes: + cached_values: Dictionary used to store cached values. + + ??? Examples + ``` pycon + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> cache_backend = InMemoryCacheBackend() + >>> cache_backend.clear() + >>> value = 42 + >>> cache_backend.set("key", value) + >>> cache_backend.get("key") + 42 + ``` + + ``` pycon + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> cache_backend = InMemoryCacheBackend() + >>> cache_backend.clear() + >>> value = 42 + >>> def foo(x: int): + ... return x + 1 + ... + >>> wrapped_foo = cache_backend.wrap(foo) + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 0 + >>> wrapped_foo(value) + 43 + >>> wrapped_foo.stats.misses + 1 + >>> wrapped_foo.stats.hits + 1 + ``` + """ + + def __init__(self) -> None: + """Initialize the in-memory cache backend.""" + super().__init__() + self.cached_values: Dict[str, Any] = {} + + def get(self, key: str) -> Optional[Any]: + """Get a value from the cache. + + Args: + key: Cache key. + + Returns: + Cached value or None if not found. + """ + value = self.cached_values.get(key, None) + if value is not None: + self.stats.hits += 1 + else: + self.stats.misses += 1 + return value + + def set(self, key: str, value: Any) -> None: + """Set a value in the cache. + + Args: + key: Cache key. + value: Value to cache. + """ + self.cached_values[key] = value + self.stats.sets += 1 + + def clear(self) -> None: + """Deletes cache dictionary and recreates it.""" + del self.cached_values + self.cached_values = {} + + def combine_hashes(self, *args: str) -> str: + """Join cache key components.""" + return os.pathsep.join(args) diff --git a/src/pydvl/utils/config.py b/src/pydvl/utils/config.py index 6e240bffc..8c77c0263 100644 --- a/src/pydvl/utils/config.py +++ b/src/pydvl/utils/config.py @@ -1,72 +1,4 @@ -from dataclasses import dataclass, field -from typing import Iterable, Optional, Tuple - -from pymemcache.serde import PickleSerde - from pydvl.parallel.config import ParallelConfig +from pydvl.utils.caching.config import CachedFuncConfig -PICKLE_VERSION = 5 # python >= 3.8 - - -__all__ = ["MemcachedClientConfig", "MemcachedConfig", "ParallelConfig"] - - -@dataclass(frozen=True) -class MemcachedClientConfig: - """Configuration of the memcached client. - - Args: - server: A tuple of (IP|domain name, port). - connect_timeout: How many seconds to wait before raising - `ConnectionRefusedError` on failure to connect. - timeout: seconds to wait for send or recv calls on the socket - connected to memcached. - no_delay: set the `TCP_NODELAY` flag, which may help with performance - in some cases. - serde: a serializer / deserializer ("serde"). The default `PickleSerde` - should work in most cases. See [pymemcached's - documentation](https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html#pymemcache.client.base.Client) - for details. - """ - - server: Tuple[str, int] = ("localhost", 11211) - connect_timeout: float = 1.0 - timeout: float = 1.0 - no_delay: bool = True - serde: PickleSerde = PickleSerde(pickle_version=PICKLE_VERSION) - - -@dataclass -class MemcachedConfig: - """Configuration for [memcached()][pydvl.utils.caching.memcached], providing - memoization of function calls. - - Instances of this class are typically used as arguments for the construction - of a [Utility][pydvl.utils.utility.Utility]. - - Args: - client_config: Configuration for the connection to the memcached server. - time_threshold: computations taking less time than this many seconds are - not cached. - allow_repeated_evaluations: If `True`, repeated calls to a function - with the same arguments will be allowed and outputs averaged until the - running standard deviation of the mean stabilises below - `rtol_stderr * mean`. - rtol_stderr: relative tolerance for repeated evaluations. More precisely, - [memcached()][pydvl.utils.caching.memcached] will stop evaluating - the function once the standard deviation of the mean is smaller than - `rtol_stderr * mean`. - min_repetitions: minimum number of times that a function evaluation - on the same arguments is repeated before returning cached values. Useful - for stochastic functions only. If the model training is very noisy, set - this number to higher values to reduce variance. - ignore_args: Do not take these keyword arguments into account when hashing - the wrapped function for usage as key in memcached. - """ - - client_config: MemcachedClientConfig = field(default_factory=MemcachedClientConfig) - time_threshold: float = 0.3 - allow_repeated_evaluations: bool = False - rtol_stderr: float = 0.1 - min_repetitions: int = 3 - ignore_args: Optional[Iterable[str]] = None +__all__ = ["CachedFuncConfig", "ParallelConfig"] diff --git a/src/pydvl/utils/utility.py b/src/pydvl/utils/utility.py index 767e7f9e1..b975c0ff2 100644 --- a/src/pydvl/utils/utility.py +++ b/src/pydvl/utils/utility.py @@ -23,9 +23,9 @@ learning](https://arxiv.org/abs/2107.06336). arXiv preprint arXiv:2107.06336. """ +import hashlib import logging import warnings -from dataclasses import asdict from typing import Dict, FrozenSet, Iterable, Optional, Tuple, Union, cast import numpy as np @@ -34,8 +34,7 @@ from sklearn.metrics import check_scoring from pydvl.utils import Dataset -from pydvl.utils.caching import CacheStats, memcached, serialize -from pydvl.utils.config import MemcachedConfig +from pydvl.utils.caching import CacheBackend, CachedFuncConfig, CacheStats from pydvl.utils.score import Scorer from pydvl.utils.types import SupervisedModel @@ -102,8 +101,11 @@ class Utility: calculations. When this happens, the `default_score` is returned as a score and computation continues. show_warnings: Set to `False` to suppress warnings thrown by `fit()`. - enable_cache: If `True`, use memcached for memoization. - cache_options: Optional configuration object for memcached. + cache_backend: Optional instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend] + used to wrap the _utility method of the Utility instance. + By default, this is set to None and that means that the utility evaluations + will not be cached. + cached_func_options: Optional configuration object for cached utility evaluation. clone_before_fit: If `True`, the model will be cloned before calling `fit()`. @@ -118,6 +120,20 @@ class Utility: 0.9 ``` + With caching enabled: + + ```pycon + >>> from pydvl.utils import Utility, DataUtilityLearning, Dataset + >>> from pydvl.utils.caching.memory import InMemoryCacheBackend + >>> from sklearn.linear_model import LinearRegression, LogisticRegression + >>> from sklearn.datasets import load_iris + >>> dataset = Dataset.from_sklearn(load_iris(), random_state=16) + >>> cache_backend = InMemoryCacheBackend() + >>> u = Utility(LogisticRegression(random_state=16), dataset, cache_backend=cache_backend) + >>> u(dataset.indices) + 0.9 + ``` + """ model: SupervisedModel @@ -134,8 +150,8 @@ def __init__( score_range: Tuple[float, float] = (-np.inf, np.inf), catch_errors: bool = True, show_warnings: bool = False, - enable_cache: bool = False, - cache_options: Optional[MemcachedConfig] = None, + cache_backend: Optional[CacheBackend] = None, + cached_func_options: Optional[CachedFuncConfig] = None, clone_before_fit: bool = True, ): self.model = self._clone_model(model) @@ -146,25 +162,23 @@ def __init__( self.default_score = scorer.default if scorer is not None else default_score # TODO: auto-fill from known scorers ? self.score_range = scorer.range if scorer is not None else np.array(score_range) + self.clone_before_fit = clone_before_fit self.catch_errors = catch_errors self.show_warnings = show_warnings - self.enable_cache = enable_cache - self.cache_options: MemcachedConfig = cache_options or MemcachedConfig() - self.clone_before_fit = clone_before_fit - self._signature = serialize((hash(self.model), hash(data), hash(scorer))) + self.cache = cache_backend + if cached_func_options is None: + cached_func_options = CachedFuncConfig() + # TODO: Find a better way to do this. + if cached_func_options.hash_prefix is None: + # FIX: This does not handle reusing the same across runs. + cached_func_options.hash_prefix = str(hash((model, data, scorer))) + self.cached_func_options = cached_func_options self._initialize_utility_wrapper() - # FIXME: can't modify docstring of methods. Instead, I could use a - # factory which creates the class on the fly with the right doc. - # self.__call__.__doc__ = self._utility_wrapper.__doc__ - def _initialize_utility_wrapper(self): - if self.enable_cache: - # asdict() is recursive, but we want client_config to remain a dataclass - options = asdict(self.cache_options) - options["client_config"] = self.cache_options.client_config - self._utility_wrapper = memcached(**options)( # type: ignore - self._utility, signature=self._signature + if self.cache is not None: + self._utility_wrapper = self.cache.wrap( + self._utility, config=self.cached_func_options ) else: self._utility_wrapper = self._utility @@ -182,7 +196,8 @@ def _utility(self, indices: FrozenSet) -> float: """Clones the model, fits it on a subset of the training data and scores it on the test data. - If the object is constructed with `enable_cache = True`, results are + If an instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend] + is passed during construction, results are memoized to avoid duplicate computation. This is useful in particular when computing utilities of permutations of indices or when randomly sampling from the powerset of indices. @@ -244,19 +259,15 @@ def _clone_model(model: SupervisedModel) -> SupervisedModel: model = cast(SupervisedModel, model) return model - @property - def signature(self): - """Signature used for caching model results.""" - return self._signature - @property def cache_stats(self) -> Optional[CacheStats]: """Cache statistics are gathered when cache is enabled. - See [CacheStats][pydvl.utils.caching.CacheStats] for all fields returned. + See [CacheStats][pydvl.utils.caching.base.CacheStats] for all fields returned. """ - if self.enable_cache: - return self._utility_wrapper.stats # type: ignore - return None + cache_stats: Optional[CacheStats] = None + if self.cache is not None: + cache_stats = self._utility_wrapper.stats + return cache_stats def __getstate__(self): state = self.__dict__.copy() diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index eceba171e..9eee1c83d 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -216,7 +216,7 @@ def compute_generic_semivalues( from pydvl.parallel import effective_n_jobs, init_executor, init_parallel_backend - if isinstance(sampler, PermutationSampler) and not u.enable_cache: + if isinstance(sampler, PermutationSampler) and u.cache is None: log.warning( "PermutationSampler requires caching to be enabled or computation " "will be doubled wrt. a 'direct' implementation of permutation MC" diff --git a/tests/test_plugin.py b/tests/test_plugin.py index c20e19db4..efbceeb2a 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -7,9 +7,7 @@ def test_marker_only(i): assert False -@pytest.fixture( - scope="function", params=[0, pytest.param(1, marks=pytest.mark.xfail), 2] -) +@pytest.fixture(scope="function", params=[0, pytest.param(1, marks=pytest.mark.xfail)]) def data(request): yield request.param diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index c30e38fd8..b02949e63 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -1,4 +1,6 @@ import logging +import pickle +import tempfile from time import sleep, time from typing import Optional @@ -7,164 +9,324 @@ from numpy.typing import NDArray from pydvl.parallel import MapReduceJob -from pydvl.utils import memcached +from pydvl.utils.caching import ( + CachedFunc, + CachedFuncConfig, + DiskCacheBackend, + InMemoryCacheBackend, + MemcachedCacheBackend, +) from pydvl.utils.types import Seed logger = logging.getLogger(__name__) -def test_failed_connection(): +def foo(indices: NDArray[np.int_], *args, **kwargs) -> float: + return float(np.sum(indices)) + + +def foo_duplicate(indices: NDArray[np.int_], *args, **kwargs) -> float: + return float(np.sum(indices)) + + +def foo_with_random(indices: NDArray[np.int_], *args, **kwargs) -> float: + rng = np.random.default_rng() + scale = kwargs.get("scale", 1.0) + return float(np.sum(indices)) + rng.normal(scale=scale) + + +def foo_with_random_and_sleep(indices: NDArray[np.int_], *args, **kwargs) -> float: + sleep(0.01) + rng = np.random.default_rng() + scale = kwargs.get("scale", 1.0) + return float(np.sum(indices)) + rng.normal(scale=scale) + + +# Used to test caching of methods +class CacheTest: + def __init__(self): + self.value = 0 + + def foo(self): + return 1 + + +@pytest.fixture(params=["in-memory", "disk", "memcached"]) +def cache_backend(request): + backend: str = request.param + if backend == "in-memory": + cache_backend = InMemoryCacheBackend() + yield cache_backend + cache_backend.clear() + elif backend == "disk": + with tempfile.TemporaryDirectory() as tempdir: + cache_backend = DiskCacheBackend(tempdir) + yield cache_backend + cache_backend.clear() + elif backend == "memcached": + cache_backend = MemcachedCacheBackend() + yield cache_backend + cache_backend.clear() + else: + raise ValueError(f"Unknown cache backend {backend}") + + +@pytest.mark.parametrize( + "f1, f2, expected_equal", + [ + # Test that the same function gets the same hash + (lambda x: x, lambda x: x, True), + (foo, foo, True), + # Test that different functions get different hashes + (foo, lambda x: x, False), + (foo, foo_with_random, False), + (foo_with_random, foo_with_random_and_sleep, False), + # Test that functions with different names but the same body get different hashes + (foo, foo_duplicate, True), + ], +) +def test_cached_func_hash_function(f1, f2, expected_equal): + f1_hash = CachedFunc._hash_function(f1) + f2_hash = CachedFunc._hash_function(f2) + if expected_equal: + assert f1_hash == f2_hash, f"{f1_hash} != {f2_hash}" + else: + assert f1_hash != f2_hash, f"{f1_hash} == {f2_hash}" + + +@pytest.mark.parametrize( + "args1, args2, expected_equal", + [ + # Test that the same arguments get the same hash + ([[]], [[]], True), + ([[1]], [[1]], True), + ([frozenset([])], [frozenset([])], True), + ([frozenset([1])], [frozenset([1])], True), + ([np.ones(3)], [np.ones(3)], True), + ([np.ones(3), 16], [np.ones(3), 16], True), + ([frozenset(np.ones(3))], [frozenset(np.ones(3))], True), + # Test that different arguments get different hashes + ([[1, 2, 3]], [np.ones(3)], False), + ([np.ones(3)], [np.ones(5)], False), + ([np.ones(3)], [frozenset(np.ones(3))], False), + ], +) +def test_cached_func_hash_arguments(args1, args2, expected_equal): + args1_hash = CachedFunc._hash_arguments(foo, ignore_args=[], args=args1, kwargs={}) + args2_hash = CachedFunc._hash_arguments(foo, ignore_args=[], args=args2, kwargs={}) + if expected_equal: + assert args1_hash == args2_hash, f"{args1_hash} != {args2_hash}" + else: + assert args1_hash != args2_hash, f"{args1_hash} == {args2_hash}" + + +def test_cached_func_hash_arguments_of_method(): + obj = CacheTest() + + hash1 = CachedFunc._hash_arguments(obj.foo, [], tuple(), {}) + obj.value += 1 + hash2 = CachedFunc._hash_arguments(obj.foo, [], tuple(), {}) + assert hash1 == hash2 + + +def test_cache_backend_serialization(cache_backend): + value = 16.8 + cache_backend.set("key", value) + deserialized_cache_backend = pickle.loads(pickle.dumps(cache_backend)) + assert deserialized_cache_backend.get("key") == value + if isinstance(cache_backend, InMemoryCacheBackend): + assert cache_backend.cached_values == deserialized_cache_backend.cached_values + elif isinstance(cache_backend, DiskCacheBackend): + assert cache_backend.cache_dir == deserialized_cache_backend.cache_dir + + +def test_single_job(cache_backend): + cached_func_config = CachedFuncConfig(time_threshold=0.0) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) + + n = 1000 + wrapped_foo(np.arange(n)) + hits_before = wrapped_foo.stats.hits + wrapped_foo(np.arange(n)) + hits_after = wrapped_foo.stats.hits + + assert hits_after > hits_before + + +def test_without_pymemcache(mocker): + mocker.patch("pydvl.utils.caching.memcached.PYMEMCACHE_INSTALLED", False) + with pytest.raises(ModuleNotFoundError): + MemcachedCacheBackend() + + +def test_memcached_failed_connection(): from pydvl.utils import MemcachedClientConfig - client_config = MemcachedClientConfig(server=("localhost", 0), connect_timeout=0.1) + config = MemcachedClientConfig(server=("localhost", 0), connect_timeout=0.1) with pytest.raises((ConnectionRefusedError, OSError)): - memcached(client_config)(lambda x: x) + MemcachedCacheBackend(config) -def test_memcached_single_job(memcached_client): - client, config = memcached_client +def test_cache_time_threshold(cache_backend): + cached_func_config = CachedFuncConfig(time_threshold=1.0) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) - # TODO: maybe this should be a fixture too... - @memcached(client_config=config, time_threshold=0) # Always cache results - def foo(indices: NDArray[np.int_]) -> float: - return float(np.sum(indices)) + n = 1000 + indices = np.arange(n) + wrapped_foo(indices) + hits_before = wrapped_foo.stats.hits + misses_before = wrapped_foo.stats.misses + wrapped_foo(indices) + hits_after = wrapped_foo.stats.hits + misses_after = wrapped_foo.stats.misses + + assert hits_after == hits_before + assert misses_after > misses_before + + +def test_cache_ignore_args(cache_backend): + # Note that we typically do NOT want to ignore run_id + cached_func_config = CachedFuncConfig( + time_threshold=0.0, + ignore_args=["job_id"], + ) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) n = 1000 - foo(np.arange(n)) - hits_before = client.stats()[b"get_hits"] - foo(np.arange(n)) - hits_after = client.stats()[b"get_hits"] + indices = np.arange(n) + wrapped_foo(indices, job_id=0) + hits_before = wrapped_foo.stats.hits + wrapped_foo(indices, job_id=16) + hits_after = wrapped_foo.stats.hits assert hits_after > hits_before -def test_memcached_parallel_jobs(memcached_client, parallel_config): - client, config = memcached_client +def test_parallel_jobs(cache_backend, parallel_config): + if not isinstance(cache_backend, MemcachedCacheBackend): + pytest.skip("Only running this test with MemcachedCacheBackend") + if parallel_config.backend != "joblib": + pytest.skip("We don't have to test this with all parallel backends") - @memcached( - client_config=config, - time_threshold=0, # Always cache results - # Note that we typically do NOT want to ignore run_id + # Note that we typically do NOT want to ignore run_id + cached_func_config = CachedFuncConfig( ignore_args=["job_id", "run_id"], ) - def foo(indices: NDArray[np.int_], *args, **kwargs) -> float: - # logger.info(f"run_id: {run_id}, running...") - return float(np.sum(indices)) + wrapped_foo = cache_backend.wrap(foo, config=cached_func_config) n = 1234 n_runs = 10 - hits_before = client.stats()[b"get_hits"] + hits_before = cache_backend.client.stats()[b"get_hits"] map_reduce_job = MapReduceJob( - np.arange(n), foo, np.sum, n_jobs=4, config=parallel_config + np.arange(n), wrapped_foo, np.sum, n_jobs=4, config=parallel_config ) results = [] for _ in range(n_runs): result = map_reduce_job() results.append(result) - hits_after = client.stats()[b"get_hits"] + hits_after = cache_backend.client.stats()[b"get_hits"] assert results[0] == n * (n - 1) / 2 # Sanity check # FIXME! This is non-deterministic: if packets are delayed for longer than # the timeout configured then we won't have num_runs hits. So we add this # good old hard-coded magic number here. - assert hits_after - hits_before >= n_runs - 2 - + assert hits_after - hits_before >= n_runs - 2, wrapped_foo.stats -def test_memcached_repeated_training(memcached_client, worker_id: str): - _, config = memcached_client - @memcached( - client_config=config, - time_threshold=0, # Always cache results +def test_repeated_training(cache_backend, worker_id: str): + cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.01, ) - def foo(indices: NDArray[np.int_], uid: str) -> float: - return float(np.sum(indices)) + np.random.normal(scale=10) + wrapped_foo = cache_backend.wrap( + foo_with_random, + config=cached_func_config, + ) n = 7 - foo(np.arange(n), worker_id) - for _ in range(10_000): - result = foo(np.arange(n), worker_id) + indices = np.arange(n) - assert (result - np.sum(np.arange(n))) < 1 - assert foo.stats.sets < foo.stats.hits + for _ in range(1_000): + result = wrapped_foo(indices, worker_id) + assert np.isclose(result, np.sum(indices), atol=1) + assert wrapped_foo.stats.sets < wrapped_foo.stats.hits -def test_memcached_faster_with_repeated_training(memcached_client, worker_id: str): - _, config = memcached_client - @memcached( - client_config=config, - time_threshold=0, # Always cache results +def test_faster_with_repeated_training(cache_backend, worker_id: str): + cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.1, ) - def foo_cache(indices: NDArray[np.int_], uid: str) -> float: - sleep(0.01) - return float(np.sum(indices)) + np.random.normal(scale=1) - - def foo_no_cache(indices: NDArray[np.int_], uid: str) -> float: - sleep(0.01) - return float(np.sum(indices)) + np.random.normal(scale=1) + wrapped_foo = cache_backend.wrap( + foo_with_random_and_sleep, + config=cached_func_config, + ) n = 3 - foo_cache(np.arange(n), worker_id) - foo_no_cache(np.arange(n), worker_id) + n_repetitions = 500 + indices = np.arange(n) start = time() - for _ in range(300): - result_fast = foo_cache(np.arange(n), worker_id) + for _ in range(n_repetitions): + result_fast = wrapped_foo(indices, worker_id) end = time() fast_time = end - start start = time() results_slow = [] - for _ in range(300): - result = foo_no_cache(np.arange(n), worker_id) + for _ in range(n_repetitions): + result = foo_with_random_and_sleep(indices, worker_id) results_slow.append(result) end = time() slow_time = end - start - assert (result_fast - np.mean(results_slow)) < 1 + assert np.isclose(np.mean(results_slow), np.sum(indices), atol=0.1) + assert np.isclose(result_fast, np.mean(results_slow), atol=1) assert fast_time < slow_time @pytest.mark.parametrize("n, atol", [(10, 5), (20, 10)]) @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("n_runs", [20]) -def test_memcached_parallel_repeated_training( - memcached_client, n, atol, n_jobs, n_runs, parallel_config, seed +def test_parallel_repeated_training( + cache_backend, n, atol, n_jobs, n_runs, parallel_config ): if parallel_config.backend != "joblib": pytest.skip("We don't have to test this with all parallel backends") - _, config = memcached_client - @memcached( - client_config=config, - time_threshold=0, # Always cache results + def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: + return np.sum(indices).item() + np.random.normal(scale=1) + + # Note that we typically do NOT want to ignore run_id + cached_func_config = CachedFuncConfig( + time_threshold=0.0, allow_repeated_evaluations=True, rtol_stderr=0.01, - # Note that we typically do NOT want to ignore run_id ignore_args=["job_id", "run_id"], ) - def map_func(indices: NDArray[np.int_], seed: Optional[Seed] = None) -> float: - # from pydvl.utils.logging import logger - # logger.info(f"run_id: {run_id}, running...") - rng = np.random.default_rng(seed) - return np.sum(indices).item() + rng.normal(scale=5) + wrapped_map_func = cache_backend.wrap( + map_func, + config=cached_func_config, + ) def reduce_func(chunks: NDArray[np.float_]) -> float: return np.sum(chunks).item() map_reduce_job = MapReduceJob( - np.arange(n), map_func, reduce_func, n_jobs=n_jobs, config=parallel_config + np.arange(n), + wrapped_map_func, + reduce_func, + n_jobs=n_jobs, + config=parallel_config, ) results = [] for _ in range(n_runs): - result = map_reduce_job(seed=seed) + result = map_reduce_job() results.append(result) exact_value = np.sum(np.arange(n)).item() diff --git a/tests/utils/test_utility.py b/tests/utils/test_utility.py index dce0e3830..335b0c136 100644 --- a/tests/utils/test_utility.py +++ b/tests/utils/test_utility.py @@ -1,11 +1,13 @@ # TODO add more tests! +import pickle import warnings import numpy as np import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import DataUtilityLearning, MemcachedConfig, Scorer, Utility, powerset +from pydvl.utils import DataUtilityLearning, Scorer, Utility, powerset +from pydvl.utils.caching import CachedFuncConfig, InMemoryCacheBackend @pytest.mark.parametrize("show_warnings", [False, True]) @@ -27,15 +29,14 @@ def score(self, x, y): utility = Utility( model=WarningModel(), data=housing_dataset, - enable_cache=False, show_warnings=show_warnings, ) utility([0]) if show_warnings: - assert len(recwarn) >= 1 + assert len(recwarn) >= 1, recwarn.list else: - assert len(recwarn) == 0 + assert len(recwarn) == 0, recwarn.list # noinspection PyUnresolvedReferences @@ -46,7 +47,6 @@ def test_data_utility_learning_wrapper(linear_dataset, training_budget): model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - enable_cache=False, ) wrapped_u = DataUtilityLearning(u, training_budget, LinearRegression()) subsets = list(powerset(wrapped_u.utility.data.indices)) @@ -59,51 +59,82 @@ def test_data_utility_learning_wrapper(linear_dataset, training_budget): # noinspection PyUnresolvedReferences @pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) -def test_cache(linear_dataset, memcache_client_config): +def test_utility_with_cache(linear_dataset): u = Utility( model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - enable_cache=True, - cache_options=MemcachedConfig( - client_config=memcache_client_config, time_threshold=0 - ), + cache_backend=InMemoryCacheBackend(), + cached_func_options=CachedFuncConfig(time_threshold=0.0), ) subsets = list(powerset(u.data.indices)) for s in subsets: u(s) - assert u._utility_wrapper.stats.hits == 0 + assert u._utility_wrapper.stats.hits == 0, u._utility_wrapper.stats for s in subsets: u(s) - assert u._utility_wrapper.stats.hits == len(subsets) + + assert u._utility_wrapper.stats.hits == len(subsets), u._utility_wrapper.stats @pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) -@pytest.mark.parametrize("model_kwargs", [({}, {}), ({}, {"fit_intercept": False})]) -def test_different_cache_signature( - linear_dataset, memcache_client_config, model_kwargs -): +def test_different_utility_with_same_cache(linear_dataset): + cache_backend = InMemoryCacheBackend() u1 = Utility( - model=LinearRegression(**model_kwargs[0]), + model=LinearRegression(), data=linear_dataset, scorer=Scorer("r2"), - enable_cache=True, - cache_options=MemcachedConfig( - client_config=memcache_client_config, time_threshold=0 - ), + cache_backend=cache_backend, + cached_func_options=CachedFuncConfig(time_threshold=0.0), ) u2 = Utility( - model=LinearRegression(**model_kwargs[1]), + model=LinearRegression(), data=linear_dataset, - scorer=Scorer("r2"), - enable_cache=True, - cache_options=MemcachedConfig( - client_config=memcache_client_config, time_threshold=0 - ), + scorer=Scorer("max_error"), + cache_backend=cache_backend, + cached_func_options=CachedFuncConfig(time_threshold=0.0), ) - assert u1.signature != u2.signature - assert u1.signature == u1.signature - assert u2.signature == u2.signature + subset = u1.data.indices + # Call first utility with empty cache + # We expect a cache miss + u1(subset) + assert cache_backend.stats.hits == 0 + assert cache_backend.stats.misses == 1 + assert cache_backend.stats.sets == 1 + + # Call first utility again + # We expect a cache hit + u1(subset) + assert cache_backend.stats.hits == 1 + assert cache_backend.stats.misses == 1 + assert cache_backend.stats.sets == 1 + + # Call second utility + # We expect a cache miss + u2(subset) + assert cache_backend.stats.hits == 1 + assert cache_backend.stats.misses == 2 + assert cache_backend.stats.sets == 2 + + +@pytest.mark.parametrize("a, b, num_points", [(2, 0, 8)]) +@pytest.mark.parametrize("use_cache", [False, True]) +def test_utility_serialization(linear_dataset, use_cache): + if use_cache: + cache = InMemoryCacheBackend() + else: + cache = None + u = Utility( + model=LinearRegression(), + data=linear_dataset, + scorer=Scorer("r2"), + cache_backend=cache, + ) + u_unpickled = pickle.loads(pickle.dumps(u)) + assert type(u.model) == type(u_unpickled.model) + assert type(u.scorer) == type(u_unpickled.scorer) + assert type(u.data) == type(u_unpickled.data) + assert (u.data.x_train == u_unpickled.data.x_train).all() diff --git a/tests/value/conftest.py b/tests/value/conftest.py index 3eaa3d672..0e3c48d29 100644 --- a/tests/value/conftest.py +++ b/tests/value/conftest.py @@ -8,6 +8,7 @@ from pydvl.parallel.config import ParallelConfig from pydvl.utils import Dataset, SupervisedModel, Utility +from pydvl.utils.caching import InMemoryCacheBackend from pydvl.utils.status import Status from pydvl.value import ValuationResult from pydvl.value.shapley.naive import combinatorial_exact_shapley @@ -72,7 +73,6 @@ def score(self, x: NDArray, y: NDArray) -> float: score_range=(0, x.sum() / x.max()), catch_errors=False, show_warnings=True, - enable_cache=False, ) @@ -117,12 +117,18 @@ def linear_shapley(cache, linear_dataset, scorer, n_jobs): args_hash = cache.hash_arguments(linear_dataset, scorer, n_jobs) u_cache_key = f"linear_shapley_u_{args_hash}" exact_values_cache_key = f"linear_shapley_exact_values_{args_hash}" - u = cache.get(u_cache_key, None) - exact_values = cache.get(exact_values_cache_key, None) + try: + u = cache.get(u_cache_key, None) + exact_values = cache.get(exact_values_cache_key, None) + except Exception: + cache.clear_cache(cache._cachedir) + raise if u is None: u = Utility( - LinearRegression(), data=linear_dataset, scorer=scorer, enable_cache=False + LinearRegression(), + data=linear_dataset, + scorer=scorer, ) exact_values = combinatorial_exact_shapley(u, progress=False, n_jobs=n_jobs) cache.set(u_cache_key, u) @@ -133,3 +139,10 @@ def linear_shapley(cache, linear_dataset, scorer, n_jobs): @pytest.fixture(scope="module") def parallel_config(): yield ParallelConfig(backend="joblib", n_cpus_local=num_workers(), wait_timeout=0.1) + + +@pytest.fixture() +def cache_backend(): + cache = InMemoryCacheBackend() + yield cache + cache.clear() diff --git a/tests/value/shapley/test_knn.py b/tests/value/shapley/test_knn.py index 1ca7a1fbc..cf935f347 100644 --- a/tests/value/shapley/test_knn.py +++ b/tests/value/shapley/test_knn.py @@ -40,7 +40,10 @@ def knn_loss_function(labels, predictions, n_classes=3): ) utility = Utility( - model, data=data, scorer=scorer, show_warnings=False, enable_cache=False + model, + data=data, + scorer=scorer, + show_warnings=False, ) exact_values = combinatorial_exact_shapley( utility, progress=False, n_jobs=min(len(data), available_cpus()) diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index d95cdce9e..ef9deed1f 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -6,7 +6,7 @@ from sklearn.linear_model import LinearRegression from pydvl.parallel.config import ParallelConfig -from pydvl.utils import Dataset, GroupedDataset, MemcachedConfig, Status, Utility +from pydvl.utils import Dataset, GroupedDataset, Status, Utility from pydvl.utils.numeric import num_samples_permutation_hoeffding from pydvl.utils.score import Scorer, squashed_r2 from pydvl.utils.types import Seed @@ -224,6 +224,7 @@ def test_linear_montecarlo_with_outlier( total_atol: float, fun, kwargs: dict, + cache_backend, ): """Tests whether valuation methods are able to detect an obvious outlier. @@ -241,7 +242,7 @@ def test_linear_montecarlo_with_outlier( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values = compute_shapley_values( linear_utility, mode=fun, progress=False, n_jobs=n_jobs, **kwargs @@ -266,12 +267,12 @@ def test_linear_montecarlo_with_outlier( def test_grouped_linear_montecarlo_shapley( linear_dataset, n_jobs, - memcache_client_config: "MemcachedClientConfig", num_groups: int, fun: ShapleyMode, scorer: Scorer, rtol: float, kwargs: dict, + cache_backend, ): """ For permutation and truncated montecarlo, the rtol for each scorer is chosen @@ -285,7 +286,7 @@ def test_grouped_linear_montecarlo_shapley( LinearRegression(), data=grouped_linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) exact_values = combinatorial_exact_shapley(grouped_linear_utility, progress=False) diff --git a/tests/value/shapley/test_naive.py b/tests/value/shapley/test_naive.py index 9b1151d99..45c32b1a9 100644 --- a/tests/value/shapley/test_naive.py +++ b/tests/value/shapley/test_naive.py @@ -4,7 +4,7 @@ import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import GroupedDataset, MemcachedConfig, Utility +from pydvl.utils import GroupedDataset, Utility from pydvl.value.shapley.naive import ( combinatorial_exact_shapley, permutation_exact_shapley, @@ -43,13 +43,18 @@ def test_analytic_exact_shapley(num_samples, analytic_shapley, fun, rtol, total_ ], ) def test_linear( - linear_dataset, memcache_client_config, scorer, rtol=0.01, total_atol=1e-5 + linear_dataset, + memcache_client_config, + scorer, + cache_backend, + rtol=0.01, + total_atol=1e-5, ): linear_utility = Utility( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values_combinatorial = combinatorial_exact_shapley(linear_utility, progress=False) @@ -70,6 +75,7 @@ def test_grouped_linear( num_groups, memcache_client_config, scorer, + cache_backend, rtol=0.01, total_atol=1e-5, ): @@ -81,7 +87,7 @@ def test_grouped_linear( LinearRegression(), data=grouped_linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values_combinatorial = combinatorial_exact_shapley( grouped_linear_utility, progress=False @@ -107,7 +113,7 @@ def test_grouped_linear( ], ) def test_linear_with_outlier( - linear_dataset, memcache_client_config, scorer, total_atol=1e-5 + linear_dataset, memcache_client_config, scorer, cache_backend, total_atol=1e-5 ): outlier_idx = np.random.randint(len(linear_dataset.y_train)) linear_dataset.y_train[outlier_idx] -= 100 @@ -115,7 +121,7 @@ def test_linear_with_outlier( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values = permutation_exact_shapley(linear_utility, progress=False) values.sort() @@ -169,6 +175,7 @@ def test_polynomial_with_outlier( polynomial_pipeline, memcache_client_config, scorer, + cache_backend, total_atol=1e-5, ): dataset, _ = polynomial_dataset @@ -178,7 +185,7 @@ def test_polynomial_with_outlier( polynomial_pipeline, dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) shapley_values = permutation_exact_shapley(poly_utility, progress=False) diff --git a/tests/value/shapley/test_truncated.py b/tests/value/shapley/test_truncated.py index 4727c087d..ac980ab96 100644 --- a/tests/value/shapley/test_truncated.py +++ b/tests/value/shapley/test_truncated.py @@ -4,7 +4,7 @@ import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import MemcachedConfig, Status, Utility +from pydvl.utils import Status, Utility from pydvl.utils.score import Scorer, squashed_r2 from pydvl.value import compute_shapley_values from pydvl.value.shapley import ShapleyMode @@ -125,6 +125,7 @@ def test_tmcs_linear_montecarlo_with_outlier( n_jobs, memcache_client_config, scorer: Scorer, + cache_backend, total_atol: float, fun, kwargs: dict, @@ -145,7 +146,7 @@ def test_tmcs_linear_montecarlo_with_outlier( LinearRegression(), data=linear_dataset, scorer=scorer, - cache_options=MemcachedConfig(client_config=memcache_client_config), + cache_backend=cache_backend, ) values = compute_shapley_values( linear_utility, mode=fun, progress=False, n_jobs=n_jobs, **kwargs diff --git a/tox.ini b/tox.ini index f48a3d9fd..666a5760c 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,7 @@ passenv = extras = ray influence + memcached commands = pytest -n auto --dist worksteal --cov "{envsitepackagesdir}/pydvl" {posargs}