diff --git a/CHANGELOG.md b/CHANGELOG.md index cb32f71e6..31a385f0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,10 @@ randomness. `pydvl.value.semivalues`. Introduced new type `Seed` and conversion function `ensure_seed_sequence`. [PR #396](https://github.com/aai-institute/pyDVL/pull/396) +- Added `batch_size` parameter to `compute_banzhaf_semivalues`, + `compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and + `compute_generic_semivalues`. + [PR #428](https://github.com/aai-institute/pyDVL/pull/428) ### Changed @@ -247,3 +251,4 @@ It contains: - Parallelization of computations with Ray - Documentation - Notebooks containing examples of different use cases + diff --git a/src/pydvl/utils/dataset.py b/src/pydvl/utils/dataset.py index 12a123806..c6331ce08 100644 --- a/src/pydvl/utils/dataset.py +++ b/src/pydvl/utils/dataset.py @@ -40,10 +40,10 @@ class Dataset: def __init__( self, - x_train: Union[np.ndarray, pd.DataFrame], - y_train: Union[np.ndarray, pd.DataFrame], - x_test: Union[np.ndarray, pd.DataFrame], - y_test: Union[np.ndarray, pd.DataFrame], + x_train: Union[NDArray, pd.DataFrame], + y_train: Union[NDArray, pd.DataFrame], + x_test: Union[NDArray, pd.DataFrame], + y_test: Union[NDArray, pd.DataFrame], feature_names: Optional[Sequence[str]] = None, target_names: Optional[Sequence[str]] = None, data_names: Optional[Sequence[str]] = None, @@ -124,8 +124,12 @@ def make_names(s: str, a: np.ndarray) -> List[str]: raise ValueError("Mismatching number of targets and names") self.description = description or "No description" - self._indices = np.arange(len(self.x_train)) - self._data_names = data_names if data_names is not None else self._indices + self._indices = np.arange(len(self.x_train), dtype=np.int_) + self._data_names = ( + np.array(data_names, dtype=object) + if data_names is not None + else self._indices.astype(object) + ) def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple: return self.x_train[idx], self.y_train[idx] @@ -220,7 +224,7 @@ def target(self, name: str) -> Tuple[slice, int]: raise ValueError(f"Target {name} is not in {self.target_names}") @property - def indices(self): + def indices(self) -> NDArray[np.int_]: """Index of positions in data.x_train. Contiguous integers from 0 to len(Dataset). @@ -228,7 +232,7 @@ def indices(self): return self._indices @property - def data_names(self): + def data_names(self) -> NDArray[np.object_]: """Names of each individual datapoint. Used for reporting Shapley values. @@ -236,9 +240,9 @@ def data_names(self): return self._data_names @property - def dim(self): + def dim(self) -> int: """Returns the number of dimensions of a sample.""" - return self.x_train.shape[1] if len(self.x_train.shape) > 1 else 1 + return int(self.x_train.shape[1]) if len(self.x_train.shape) > 1 else 1 def __str__(self): return self.description @@ -256,7 +260,7 @@ def from_sklearn( **kwargs, ) -> "Dataset": """Constructs a [Dataset][pydvl.utils.Dataset] object from a - [sklearn.utils.Bunch][sklearn.utils.Bunch], as returned by the `load_*` + [sklearn.utils.Bunch][], as returned by the `load_*` functions in [scikit-learn toy datasets](https://scikit-learn.org/stable/datasets/toy_dataset.html). ??? Example @@ -360,10 +364,10 @@ def from_arrays( class GroupedDataset(Dataset): def __init__( self, - x_train: np.ndarray, - y_train: np.ndarray, - x_test: np.ndarray, - y_test: np.ndarray, + x_train: NDArray, + y_train: NDArray, + x_test: NDArray, + y_test: NDArray, data_groups: Sequence, feature_names: Optional[Sequence[str]] = None, target_names: Optional[Sequence[str]] = None, @@ -423,7 +427,9 @@ def __init__( self.group_items = list(self.groups.items()) self._indices = np.arange(len(self.groups.keys())) self._data_names = ( - group_names if group_names is not None else list(self.groups.keys()) + np.array(group_names, dtype=object) + if group_names is not None + else np.array(list(self.groups.keys()), dtype=object) ) def __len__(self): diff --git a/src/pydvl/utils/types.py b/src/pydvl/utils/types.py index 5df91923d..e548dbfc4 100644 --- a/src/pydvl/utils/types.py +++ b/src/pydvl/utils/types.py @@ -6,11 +6,14 @@ from abc import ABCMeta from typing import Any, Optional, Protocol, TypeVar, Union, cast +import numpy as np from numpy.random import Generator, SeedSequence from numpy.typing import NDArray __all__ = [ "ensure_seed_sequence", + "IndexT", + "NameT", "MapFunction", "NoPublicConstructor", "ReduceFunction", @@ -18,7 +21,10 @@ "SupervisedModel", ] +IndexT = TypeVar("IndexT", bound=np.int_) +NameT = TypeVar("NameT", bound=np.object_) R = TypeVar("R", covariant=True) +Seed = Union[int, Generator] class MapFunction(Protocol[R]): @@ -74,9 +80,6 @@ def create(cls, *args: Any, **kwargs: Any): return super().__call__(*args, **kwargs) -Seed = Union[int, Generator] - - def ensure_seed_sequence( seed: Optional[Union[Seed, SeedSequence]] = None ) -> SeedSequence: diff --git a/src/pydvl/value/oob/oob.py b/src/pydvl/value/oob/oob.py index c62a6255d..617046355 100644 --- a/src/pydvl/value/oob/oob.py +++ b/src/pydvl/value/oob/oob.py @@ -70,7 +70,7 @@ def compute_data_oob( Object with the data values. """ - result: ValuationResult[np.int_, np.float_] = ValuationResult.empty( + result: ValuationResult[np.int_, np.object_] = ValuationResult.empty( algorithm="data_oob", indices=u.data.indices, data_names=u.data.data_names ) diff --git a/src/pydvl/value/result.py b/src/pydvl/value/result.py index 989d6d92e..04b922a75 100644 --- a/src/pydvl/value/result.py +++ b/src/pydvl/value/result.py @@ -57,8 +57,6 @@ Literal, Optional, Sequence, - Tuple, - TypeVar, Union, cast, overload, @@ -71,21 +69,17 @@ from pydvl.utils.dataset import Dataset from pydvl.utils.numeric import running_moments from pydvl.utils.status import Status -from pydvl.utils.types import Seed +from pydvl.utils.types import IndexT, NameT, Seed try: import pandas # Try to import here for the benefit of mypy except ImportError: pass -__all__ = ["ValuationResult", "ValueItem", "IndexT", "NameT"] +__all__ = ["ValuationResult", "ValueItem"] logger = logging.getLogger(__name__) -# TODO: Move to value.types once it's there -IndexT = TypeVar("IndexT", bound=np.int_) -NameT = TypeVar("NameT", bound=Any) - @total_ordering @dataclass @@ -484,7 +478,7 @@ def __repr__(self) -> str: repr_string += ")" return repr_string - def _check_compatible(self, other: "ValuationResult"): + def _check_compatible(self, other: ValuationResult): if not isinstance(other, ValuationResult): raise NotImplementedError( f"Cannot combine ValuationResult with {type(other)}" @@ -492,7 +486,9 @@ def _check_compatible(self, other: "ValuationResult"): if self.algorithm and self.algorithm != other.algorithm: raise ValueError("Cannot combine results from different algorithms") - def __add__(self, other: "ValuationResult") -> "ValuationResult": + def __add__( + self, other: ValuationResult[IndexT, NameT] + ) -> ValuationResult[IndexT, NameT]: """Adds two ValuationResults. The values must have been computed with the same algorithm. An exception @@ -601,7 +597,7 @@ def __add__(self, other: "ValuationResult") -> "ValuationResult": # extra_values=self._extra_values.update(other._extra_values), ) - def update(self, idx: int, new_value: float) -> "ValuationResult": + def update(self, idx: int, new_value: float) -> ValuationResult[IndexT, NameT]: """Updates the result in place with a new value, using running mean and variance. @@ -623,7 +619,7 @@ def update(self, idx: int, new_value: float) -> "ValuationResult": self._values[pos], self._variances[pos], self._counts[pos], new_value ) self[pos] = ValueItem( - index=cast(IndexT, idx), + index=cast(IndexT, idx), # FIXME name=self._names[pos], value=val, variance=var, @@ -738,7 +734,7 @@ def empty( indices: Optional[Sequence[IndexT] | NDArray[IndexT]] = None, data_names: Optional[Sequence[NameT] | NDArray[NameT]] = None, n_samples: int = 0, - ) -> "ValuationResult": + ) -> ValuationResult: """Creates an empty [ValuationResult][pydvl.value.result.ValuationResult] object. Empty results are characterised by having an empty array of values. When @@ -766,7 +762,7 @@ def zeros( indices: Optional[Sequence[IndexT] | NDArray[IndexT]] = None, data_names: Optional[Sequence[NameT] | NDArray[NameT]] = None, n_samples: int = 0, - ) -> "ValuationResult": + ) -> ValuationResult: """Creates an empty [ValuationResult][pydvl.value.result.ValuationResult] object. Empty results are characterised by having an empty array of values. When @@ -787,12 +783,12 @@ def zeros( if indices is None: indices = np.arange(n_samples, dtype=np.int_) else: - indices = np.array(indices) + indices = np.array(indices, dtype=np.int_) return cls( algorithm=algorithm, status=Status.Pending, indices=indices, - data_names=data_names + data_names=np.array(data_names, dtype=object) if data_names is not None else np.empty_like(indices, dtype=object), values=np.zeros(len(indices)), diff --git a/src/pydvl/value/sampler.py b/src/pydvl/value/sampler.py index 0e3e479e9..a51dbfb79 100644 --- a/src/pydvl/value/sampler.py +++ b/src/pydvl/value/sampler.py @@ -51,7 +51,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, overload, ) @@ -61,7 +60,7 @@ from numpy.typing import NDArray from pydvl.utils.numeric import powerset, random_subset, random_subset_of_size -from pydvl.utils.types import Seed +from pydvl.utils.types import IndexT, Seed __all__ = [ "AntitheticSampler", @@ -74,13 +73,11 @@ "StochasticSamplerMixin", ] - -T = TypeVar("T", bound=np.generic) -SampleT = Tuple[T, NDArray[T]] +SampleT = Tuple[IndexT, NDArray[IndexT]] Sequence.register(np.ndarray) -class PowersetSampler(abc.ABC, Iterable[SampleT], Generic[T]): +class PowersetSampler(abc.ABC, Iterable[SampleT], Generic[IndexT]): """Samplers are custom iterables over subsets of indices. Calling ``iter()`` on a sampler returns an iterator over tuples of the form @@ -121,9 +118,9 @@ class IndexIteration(Enum): def __init__( self, - indices: NDArray[T], + indices: NDArray[IndexT], index_iteration: IndexIteration = IndexIteration.Sequential, - outer_indices: NDArray[T] | None = None, + outer_indices: NDArray[IndexT] | None = None, ): """ Args: @@ -141,11 +138,11 @@ def __init__( self._n_samples = 0 @property - def indices(self) -> NDArray[T]: + def indices(self) -> NDArray[IndexT]: return self._indices @indices.setter - def indices(self, indices: NDArray[T]): + def indices(self, indices: NDArray[IndexT]): raise AttributeError("Cannot set indices of sampler") @property @@ -156,10 +153,10 @@ def n_samples(self) -> int: def n_samples(self, n: int): raise AttributeError("Cannot reset a sampler's number of samples") - def complement(self, exclude: Sequence[T]) -> NDArray[T]: - return np.setxor1d(self._indices, exclude) + def complement(self, exclude: Sequence[IndexT]) -> NDArray[IndexT]: + return np.setxor1d(self._indices, exclude) # type: ignore - def iterindices(self) -> Iterator[T]: + def iterindices(self) -> Iterator[IndexT]: """Iterates over indices in the order specified at construction. FIXME: this is probably not very useful, but I couldn't decide @@ -173,14 +170,14 @@ def iterindices(self) -> Iterator[T]: yield np.random.choice(self._outer_indices, size=1).item() @overload - def __getitem__(self, key: slice) -> PowersetSampler[T]: + def __getitem__(self, key: slice) -> PowersetSampler[IndexT]: ... @overload - def __getitem__(self, key: list[int]) -> PowersetSampler[T]: + def __getitem__(self, key: list[int]) -> PowersetSampler[IndexT]: ... - def __getitem__(self, key: slice | list[int]) -> PowersetSampler[T]: + def __getitem__(self, key: slice | list[int]) -> PowersetSampler[IndexT]: if isinstance(key, slice) or isinstance(key, Iterable): return self.__class__( self._indices, @@ -231,8 +228,8 @@ def __init__(self, *args, seed: Optional[Seed] = None, **kwargs): self._rng = np.random.default_rng(seed) -class DeterministicUniformSampler(PowersetSampler[T]): - def __init__(self, indices: NDArray[T], *args, **kwargs): +class DeterministicUniformSampler(PowersetSampler[IndexT]): + def __init__(self, indices: NDArray[IndexT], *args, **kwargs): """An iterator to perform uniform deterministic sampling of subsets. For every index $i$, each subset of the complement `indices - {i}` is @@ -268,7 +265,7 @@ def weight(cls, n: int, subset_len: int) -> float: return float(2 ** (n - 1)) if n > 0 else 1.0 -class UniformSampler(StochasticSamplerMixin, PowersetSampler[T]): +class UniformSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """An iterator to perform uniform random sampling of subsets. Iterating over every index $i$, either in sequence or at random depending on @@ -306,15 +303,15 @@ def weight(cls, n: int, subset_len: int) -> float: return float(2 ** (n - 1)) if n > 0 else 1.0 -class DeterministicCombinatorialSampler(DeterministicUniformSampler[T]): +class DeterministicCombinatorialSampler(DeterministicUniformSampler[IndexT]): @deprecated( target=DeterministicUniformSampler, deprecated_in="0.6.0", remove_in="0.8.0" ) - def __init__(self, indices: NDArray[T], *args, **kwargs): + def __init__(self, indices: NDArray[IndexT], *args, **kwargs): void(indices, args, kwargs) -class AntitheticSampler(StochasticSamplerMixin, PowersetSampler[T]): +class AntitheticSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """An iterator to perform uniform random sampling of subsets, and their complements. @@ -339,7 +336,7 @@ def weight(cls, n: int, subset_len: int) -> float: return float(2 ** (n - 1)) if n > 0 else 1.0 -class PermutationSampler(StochasticSamplerMixin, PowersetSampler[T]): +class PermutationSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """Sample permutations of indices and iterate through each returning increasing subsets, as required for the permutation definition of semi-values. @@ -365,7 +362,7 @@ def __iter__(self) -> Iterator[SampleT]: if self._n_samples == 0: # Empty index set break - def __getitem__(self, key: slice | list[int]) -> PowersetSampler[T]: + def __getitem__(self, key: slice | list[int]) -> PowersetSampler[IndexT]: """Permutation samplers cannot be split across indices, so we return a copy of the full sampler.""" return super().__getitem__(slice(None)) @@ -375,7 +372,7 @@ def weight(cls, n: int, subset_len: int) -> float: return n * math.comb(n - 1, subset_len) if n > 0 else 1.0 -class DeterministicPermutationSampler(PermutationSampler[T]): +class DeterministicPermutationSampler(PermutationSampler[IndexT]): """Samples all n! permutations of the indices deterministically, and iterates through them, returning sets as required for the permutation-based definition of semi-values. @@ -397,7 +394,7 @@ def __iter__(self) -> Iterator[SampleT]: self._n_samples += 1 -class RandomHierarchicalSampler(StochasticSamplerMixin, PowersetSampler[T]): +class RandomHierarchicalSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """For every index, sample a set size, then a set of that size. !!! Todo @@ -424,5 +421,8 @@ def weight(cls, n: int, subset_len: int) -> float: # TODO Replace by Intersection[StochasticSamplerMixin, PowersetSampler[T]] # See https://github.com/python/typing/issues/213 StochasticSampler = Union[ - UniformSampler, PermutationSampler, AntitheticSampler, RandomHierarchicalSampler + UniformSampler[IndexT], + PermutationSampler[IndexT], + AntitheticSampler[IndexT], + RandomHierarchicalSampler[IndexT], ] diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 97a778ba9..69bf8e809 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -9,12 +9,14 @@ $$\sum_{k=1}^n w(k) = 1.$$ -!!! Note +??? Note For implementation consistency, we slightly depart from the common definition of semi-values, which includes a factor $1/n$ in the sum over subsets. Instead, we subsume this factor into the coefficient $w(k)$. -As such, the computation of a semi-value requires two components: +## Main components + +The computation of a semi-value requires two components: 1. A **subset sampler** that generates subsets of the set $D$ of interest. 2. A **coefficient** $w(k)$ that assigns a weight to each subset size $k$. @@ -44,6 +46,11 @@ require caching to be enabled or computation will be doubled wrt. a 'direct' implementation of permutation MC. +## Computing semi-values + +Samplers and coefficients can be arbitrarily mixed by means of the main entry +point of this module, +[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues]. There are several pre-defined coefficients, including the Shapley value of (Ghorbani and Zou, 2019)[^1], the Banzhaf index of (Wang and Jia)[^3], and the Beta coefficient of (Kwon and Zou, 2022)[^2]. For each of these methods, there is a @@ -53,6 +60,16 @@ and [compute_beta_shapley_semivalues][pydvl.value.semivalues.compute_beta_shapley_semivalues]. instead. +!!! tip "Parallelization and batching" + In order to ensure reproducibility and fine-grained control of + parallelization, samples are generated in the main process and then + distributed to worker processes for evaluation. For small sample sizes, this + can lead to a significant overhead. To avoid this, we temporarily provide an + additional argument `batch_size` to all methods which can improve + performance with small models up to an order of magnitude. Note that this + argument will be removed before version 1.0 in favour of a more general + solution. + ## References @@ -72,17 +89,18 @@ import logging import math +import warnings from enum import Enum -from typing import Optional, Protocol, Tuple, Type, TypeVar, cast +from itertools import islice +from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast -import numpy as np import scipy as sp from deprecate import deprecated from tqdm import tqdm from pydvl.parallel.config import ParallelConfig from pydvl.utils import Utility -from pydvl.utils.types import Seed +from pydvl.utils.types import IndexT, Seed from pydvl.value import ValuationResult from pydvl.value.sampler import ( PermutationSampler, @@ -108,7 +126,8 @@ class SVCoefficient(Protocol): - """A coefficient for the computation of semi-values.""" + """The protocol that coefficients for the computation of semi-values must + fulfill.""" def __call__(self, n: int, k: int) -> float: """Computes the coefficient for a given subset size. @@ -120,27 +139,31 @@ def __call__(self, n: int, k: int) -> float: ... -IndexT = TypeVar("IndexT", bound=np.generic) MarginalT = Tuple[IndexT, float] -def _marginal(u: Utility, coefficient: SVCoefficient, sample: SampleT) -> MarginalT: +def _marginal( + u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT] +) -> Tuple[MarginalT, ...]: """Computation of marginal utility. This is a helper function for [compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues]. Args: u: Utility object with model, data, and scoring function. coefficient: The semivalue coefficient and sampler weight - sample: A tuple of index and subset of indices to compute a marginal - utility. + samples: A collection of samples. Each sample is a tuple of index and subset of + indices to compute a marginal utility. Returns: - Tuple with index and its marginal utility. + A collection of marginals. Each marginal is a tuple with index and its marginal + utility. """ n = len(u.data) - idx, s = sample - marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s)) - return idx, marginal + marginals: List[MarginalT] = [] + for idx, s in samples: + marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s)) + marginals.append((idx, marginal)) + return tuple(marginals) # @deprecated( @@ -148,12 +171,21 @@ def _marginal(u: Utility, coefficient: SVCoefficient, sample: SampleT) -> Margin # deprecated_in="0.8.0", # remove_in="0.9.0", # ) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_generic_semivalues( - sampler: PowersetSampler, + sampler: PowersetSampler[IndexT], u: Utility, coefficient: SVCoefficient, done: StoppingCriterion, *, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -165,6 +197,7 @@ def compute_generic_semivalues( u: Utility object with model, data, and scoring function. coefficient: The semi-value coefficient done: Stopping criterion. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -172,6 +205,10 @@ def compute_generic_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ from concurrent.futures import FIRST_COMPLETED, Future, wait @@ -183,6 +220,13 @@ def compute_generic_semivalues( "will be doubled wrt. a 'direct' implementation of permutation MC" ) + if batch_size != 1: + warnings.warn( + "Parameter `batch_size` is for experimental use and will be" + " removed in future versions", + DeprecationWarning, + ) + result = ValuationResult.zeros( algorithm=f"semivalue-{str(sampler)}-{coefficient.__name__}", # type: ignore indices=u.data.indices, @@ -211,20 +255,21 @@ def compute_generic_semivalues( completed, pending = wait(pending, timeout=1, return_when=FIRST_COMPLETED) for future in completed: - idx, marginal = future.result() - result.update(idx, marginal) - if done(result): - return result + for idx, marginal in future.result(): + result.update(idx, marginal) + if done(result): + return result # Ensure that we always have n_submitted_jobs running try: for _ in range(n_submitted_jobs - len(pending)): + samples = tuple(islice(sampler_it, batch_size)) + if len(samples) == 0: + raise StopIteration + pending.add( executor.submit( - _marginal, - u=u, - coefficient=correction, - sample=next(sampler_it), + _marginal, u=u, coefficient=correction, samples=samples ) ) except StopIteration: @@ -262,11 +307,20 @@ def beta_coefficient_w(n: int, k: int) -> float: return cast(SVCoefficient, beta_coefficient_w) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_shapley_semivalues( u: Utility, *, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -283,33 +337,50 @@ def compute_shapley_semivalues( Args: u: Utility object with model, data, and scoring function. done: Stopping criterion. - sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` - for a list. + sampler_t: The sampler type to use. See the + [sampler][pydvl.value.sampler] module for a list. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. - seed: Either an instance of a numpy random number generator or a seed for it. + seed: Either an instance of a numpy random number generator or a seed + for it. progress: Whether to display a progress bar. Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ - return compute_generic_semivalues( + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, shapley_coefficient, done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, ) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_banzhaf_semivalues( u: Utility, *, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -324,28 +395,44 @@ def compute_banzhaf_semivalues( Args: u: Utility object with model, data, and scoring function. done: Stopping criterion. - sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a - list. + sampler_t: The sampler type to use. See the + [sampler][pydvl.value.sampler] module for a list. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. - seed: Either an instance of a numpy random number generator or a seed for it. + seed: Either an instance of a numpy random number generator or a seed + for it. config: Object configuring parallel computation, with cluster address, number of cpus, etc. progress: Whether to display a progress bar. Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ - return compute_generic_semivalues( + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, banzhaf_coefficient, done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, ) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_beta_shapley_semivalues( u: Utility, *, @@ -353,6 +440,7 @@ def compute_beta_shapley_semivalues( beta: float = 1, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -369,7 +457,9 @@ def compute_beta_shapley_semivalues( alpha: Alpha parameter of the Beta distribution. beta: Beta parameter of the Beta distribution. done: Stopping criterion. - sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + sampler_t: The sampler type to use. See the + [sampler][pydvl.value.sampler] module for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, number of @@ -378,23 +468,25 @@ def compute_beta_shapley_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ - return compute_generic_semivalues( + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, beta_coefficient(alpha, beta), done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, ) -@deprecated( - target=True, - deprecated_in="0.7.0", - remove_in="0.8.0", -) +@deprecated(target=True, deprecated_in="0.7.0", remove_in="0.8.0") class SemiValueMode(str, Enum): """Enumeration of semi-value modes. @@ -414,7 +506,8 @@ def compute_semivalues( *, done: StoppingCriterion = MaxUpdates(100), mode: SemiValueMode = SemiValueMode.Shapley, - sampler_t: Type[StochasticSampler] = PermutationSampler, + sampler_t: Type[StochasticSampler[IndexT]] = PermutationSampler[IndexT], + batch_size: int = 1, n_jobs: int = 1, seed: Optional[Seed] = None, **kwargs, @@ -458,6 +551,7 @@ def compute_semivalues( [SemiValueMode][pydvl.value.semivalues.SemiValueMode] for a list. sampler_t: The sampler type to use. See [sampler][pydvl.value.sampler] for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. kwargs: Additional keyword arguments passed to @@ -465,6 +559,10 @@ def compute_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ if mode == SemiValueMode.Shapley: coefficient = shapley_coefficient @@ -477,11 +575,14 @@ def compute_semivalues( else: raise ValueError(f"Unknown mode {mode}") coefficient = cast(SVCoefficient, coefficient) - return compute_generic_semivalues( + + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, coefficient, done, n_jobs=n_jobs, + batch_size=batch_size, **kwargs, ) diff --git a/src/pydvl/value/shapley/montecarlo.py b/src/pydvl/value/shapley/montecarlo.py index 1046326a5..af0f67ce2 100644 --- a/src/pydvl/value/shapley/montecarlo.py +++ b/src/pydvl/value/shapley/montecarlo.py @@ -201,7 +201,9 @@ def permutation_montecarlo_shapley( n_submitted_jobs = 2 * max_workers # number of jobs in the executor's queue seed_sequence = ensure_seed_sequence(seed) - result = ValuationResult.zeros(algorithm=algorithm) + result = ValuationResult.zeros( + algorithm=algorithm, indices=u.data.indices, data_names=u.data.data_names + ) pbar = tqdm(disable=not progress, total=100, unit="%") diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index 3024ed198..1c9fbf4a7 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -17,7 +17,7 @@ from .. import check_rank_correlation, check_total_value, check_values from ..conftest import polynomial_dataset -from ..utils import call_fn_multiple_seeds +from ..utils import call_with_seeds log = logging.getLogger(__name__) @@ -94,7 +94,7 @@ def test_montecarlo_shapley_housing_dataset_reproducible( kwargs: dict, seed: Seed, ): - values_1, values_2 = call_fn_multiple_seeds( + values_1, values_2 = call_with_seeds( compute_shapley_values, Utility(LinearRegression(), data=housing_dataset, scorer="r2"), mode=fun, @@ -121,7 +121,7 @@ def test_montecarlo_shapley_housing_dataset_stochastic( seed: Seed, seed_alt: Seed, ): - values_1, values_2 = call_fn_multiple_seeds( + values_1, values_2 = call_with_seeds( compute_shapley_values, Utility(LinearRegression(), data=housing_dataset, scorer="r2"), mode=fun, diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index ea10dd339..0a71bcd54 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -5,6 +5,7 @@ import pytest from pydvl.parallel.config import ParallelConfig +from pydvl.utils.types import Seed from pydvl.value.sampler import ( AntitheticSampler, DeterministicPermutationSampler, @@ -23,6 +24,7 @@ from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates from . import check_values +from .utils import timed @pytest.mark.parametrize("num_samples", [5]) @@ -58,6 +60,49 @@ def test_shapley( check_values(values, exact_values, rtol=0.2) +@pytest.mark.parametrize( + "num_samples,sampler,coefficient,batch_size", + [(5, PermutationSampler, beta_coefficient(1, 1), 5)], +) +def test_shapley_batch_size( + num_samples: int, + analytic_shapley, + sampler: Type[PermutationSampler], + coefficient: SVCoefficient, + batch_size: int, + n_jobs: int, + parallel_config: ParallelConfig, + seed: Seed, +): + u, exact_values = analytic_shapley + criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) + timed_fn = timed(compute_generic_semivalues) + result_single_batch = timed_fn( + sampler(u.data.indices, seed=seed), + u, + coefficient, + criterion, + n_jobs=n_jobs, + batch_size=1, + config=parallel_config, + ) + total_seconds_single_batch = timed_fn.execution_time + result_multi_batch = timed_fn( + sampler(u.data.indices, seed=seed), + u, + coefficient, + criterion, + n_jobs=n_jobs, + batch_size=batch_size, + config=parallel_config, + ) + total_seconds_multi_batch = timed_fn.execution_time + assert total_seconds_multi_batch < total_seconds_single_batch * 1.1 + + # Occasionally, batch_2 arrives before batch_1, so rtol isn't always 0. + check_values(result_single_batch, result_multi_batch, rtol=1e-4) + + @pytest.mark.parametrize("num_samples", [5]) @pytest.mark.parametrize( "sampler", diff --git a/tests/value/utils.py b/tests/value/utils.py index 7c38e344f..c55ab13eb 100644 --- a/tests/value/utils.py +++ b/tests/value/utils.py @@ -1,20 +1,25 @@ from __future__ import annotations +import time from copy import deepcopy -from typing import Callable, Tuple +from functools import wraps +from logging import getLogger +from typing import Callable, Optional, Protocol, Tuple, TypeVar from pydvl.utils.types import Seed +logger = getLogger(__name__) -def call_fn_multiple_seeds( - fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs -) -> Tuple: +ReturnT = TypeVar("ReturnT") + + +def call_with_seeds(fun: Callable, *args, seeds: Tuple[Seed, ...], **kwargs) -> Tuple: """ Execute a function multiple times with different seeds. It copies the arguments and keyword arguments before passing them to the function. Args: - fn: The function to execute. + fun: The function to execute. args: The arguments to pass to the function. seeds: The seeds to use. kwargs: The keyword arguments to pass to the function. @@ -22,4 +27,43 @@ def call_fn_multiple_seeds( Returns: A tuple of the results of the function. """ - return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) + return tuple(fun(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) + + +class TimedCallable(Protocol): + """A callable that has an attribute to keep track of execution time.""" + + execution_time: float + + def __call__(self, *args, **kwargs) -> ReturnT: + ... + + +def timed(fun: Callable[..., ReturnT]) -> TimedCallable: + """ + Takes a function `func` and returns a function with the same input arguments and + the original return value along with the execution time. + + Args: + fun: The function to be measured, accepting arbitrary arguments and returning + any type. + + Returns: + A wrapped function that, when called, returns a tuple containing the original + function's result and its execution time in seconds. The decorated function + will have the same input arguments and return type as the original function. + """ + + wrapper: TimedCallable + + @wraps(fun) + def wrapper(*args, **kwargs) -> ReturnT: + start_time = time.time() + result = fun(*args, **kwargs) + end_time = time.time() + wrapper.execution_time = end_time - start_time + logger.info(f"{fun.__name__} took {wrapper.execution_time:.5f} seconds.") + return result + + wrapper.execution_time = 0.0 + return wrapper