diff --git a/CHANGELOG.md b/CHANGELOG.md index bc82e515b..67d5f8ff7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,10 @@ randomness. `pydvl.value.semivalues`. Introduced new type `Seed` and conversion function `ensure_seed_sequence`. [PR #396](https://github.com/aai-institute/pyDVL/pull/396) +- Added `batch_size` parameter to `compute_banzhaf_semivalues`, + `compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and + `compute_generic_semivalues`. + [PR #428](https://github.com/aai-institute/pyDVL/pull/428) ### Changed @@ -240,3 +244,4 @@ It contains: - Parallelization of computations with Ray - Documentation - Notebooks containing examples of different use cases + diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 488a25037..65a84c0dc 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -73,7 +73,8 @@ import logging import math from enum import Enum -from typing import Optional, Protocol, Tuple, Type, TypeVar, cast +from itertools import islice +from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast import numpy as np import scipy as sp @@ -123,23 +124,28 @@ def __call__(self, n: int, k: int) -> float: MarginalT = Tuple[IndexT, float] -def _marginal(u: Utility, coefficient: SVCoefficient, sample: SampleT) -> MarginalT: +def _marginal( + u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT] +) -> Tuple[MarginalT, ...]: """Computation of marginal utility. This is a helper function for [compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues]. Args: u: Utility object with model, data, and scoring function. coefficient: The semivalue coefficient and sampler weight - sample: A tuple of index and subset of indices to compute a marginal - utility. + samples: A collection of samples. Each sample is a tuple of index and subset of + indices to compute a marginal utility. Returns: - Tuple with index and its marginal utility. + A collection of marginals. Each marginal is a tuple with index and its marginal + utility. """ n = len(u.data) - idx, s = sample - marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s)) - return idx, marginal + marginals: List[MarginalT] = [] + for idx, s in samples: + marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s)) + marginals.append((idx, marginal)) + return tuple(marginals) # @deprecated( @@ -153,6 +159,7 @@ def compute_generic_semivalues( coefficient: SVCoefficient, done: StoppingCriterion, *, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -164,6 +171,7 @@ def compute_generic_semivalues( u: Utility object with model, data, and scoring function. coefficient: The semi-value coefficient done: Stopping criterion. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -210,20 +218,24 @@ def compute_generic_semivalues( completed, pending = wait(pending, timeout=1, return_when=FIRST_COMPLETED) for future in completed: - idx, marginal = future.result() - result.update(idx, marginal) - if done(result): - return result + for idx, marginal in future.result(): + result.update(idx, marginal) + if done(result): + return result # Ensure that we always have n_submitted_jobs running try: for _ in range(n_submitted_jobs - len(pending)): + samples = tuple(islice(sampler_it, batch_size)) + if len(samples) == 0: + raise StopIteration + pending.add( executor.submit( _marginal, u=u, coefficient=correction, - sample=next(sampler_it), + samples=samples, ) ) except StopIteration: @@ -266,6 +278,7 @@ def compute_shapley_semivalues( *, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -284,6 +297,7 @@ def compute_shapley_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -298,6 +312,7 @@ def compute_shapley_semivalues( u, shapley_coefficient, done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, @@ -309,6 +324,7 @@ def compute_banzhaf_semivalues( *, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -325,6 +341,7 @@ def compute_banzhaf_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, @@ -339,6 +356,7 @@ def compute_banzhaf_semivalues( u, banzhaf_coefficient, done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, @@ -352,6 +370,7 @@ def compute_beta_shapley_semivalues( beta: float = 1, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -369,6 +388,7 @@ def compute_beta_shapley_semivalues( beta: Beta parameter of the Beta distribution. done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, number of @@ -383,6 +403,7 @@ def compute_beta_shapley_semivalues( u, beta_coefficient(alpha, beta), done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index ec937d028..474a2b7c1 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -58,6 +58,33 @@ def test_shapley( check_values(values, exact_values, rtol=0.2) +@pytest.mark.parametrize( + "num_samples,sampler,coefficient,batch_size", + [(5, PermutationSampler, beta_coefficient(1, 1), 2)], +) +def test_shapley_batch_size( + num_samples: int, + analytic_shapley, + sampler: Type[PowersetSampler], + coefficient: SVCoefficient, + batch_size: int, + n_jobs: int, + parallel_config: ParallelConfig, +): + u, exact_values = analytic_shapley + criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) + values = compute_generic_semivalues( + sampler(u.data.indices), + u, + coefficient, + criterion, + n_jobs=n_jobs, + batch_size=batch_size, + config=parallel_config, + ) + check_values(values, exact_values, rtol=0.2) + + @pytest.mark.parametrize("num_samples", [5]) @pytest.mark.parametrize( "sampler",