From 0914b6666cdfb295afc0128d7ba38887c9e99332 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Wed, 6 Sep 2023 01:04:11 +0200 Subject: [PATCH 01/20] Added `batch_size` parameter to `compute_banzhaf_semivalues`, `compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and `compute_generic_semivalues`. --- CHANGELOG.md | 5 ++++ src/pydvl/value/semivalues.py | 47 ++++++++++++++++++++++++---------- tests/value/test_semivalues.py | 27 +++++++++++++++++++ 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc82e515b..67d5f8ff7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,10 @@ randomness. `pydvl.value.semivalues`. Introduced new type `Seed` and conversion function `ensure_seed_sequence`. [PR #396](https://github.com/aai-institute/pyDVL/pull/396) +- Added `batch_size` parameter to `compute_banzhaf_semivalues`, + `compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and + `compute_generic_semivalues`. + [PR #428](https://github.com/aai-institute/pyDVL/pull/428) ### Changed @@ -240,3 +244,4 @@ It contains: - Parallelization of computations with Ray - Documentation - Notebooks containing examples of different use cases + diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 488a25037..65a84c0dc 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -73,7 +73,8 @@ import logging import math from enum import Enum -from typing import Optional, Protocol, Tuple, Type, TypeVar, cast +from itertools import islice +from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast import numpy as np import scipy as sp @@ -123,23 +124,28 @@ def __call__(self, n: int, k: int) -> float: MarginalT = Tuple[IndexT, float] -def _marginal(u: Utility, coefficient: SVCoefficient, sample: SampleT) -> MarginalT: +def _marginal( + u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT] +) -> Tuple[MarginalT, ...]: """Computation of marginal utility. This is a helper function for [compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues]. Args: u: Utility object with model, data, and scoring function. coefficient: The semivalue coefficient and sampler weight - sample: A tuple of index and subset of indices to compute a marginal - utility. + samples: A collection of samples. Each sample is a tuple of index and subset of + indices to compute a marginal utility. Returns: - Tuple with index and its marginal utility. + A collection of marginals. Each marginal is a tuple with index and its marginal + utility. """ n = len(u.data) - idx, s = sample - marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s)) - return idx, marginal + marginals: List[MarginalT] = [] + for idx, s in samples: + marginal = (u({idx}.union(s)) - u(s)) * coefficient(n, len(s)) + marginals.append((idx, marginal)) + return tuple(marginals) # @deprecated( @@ -153,6 +159,7 @@ def compute_generic_semivalues( coefficient: SVCoefficient, done: StoppingCriterion, *, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -164,6 +171,7 @@ def compute_generic_semivalues( u: Utility object with model, data, and scoring function. coefficient: The semi-value coefficient done: Stopping criterion. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -210,20 +218,24 @@ def compute_generic_semivalues( completed, pending = wait(pending, timeout=1, return_when=FIRST_COMPLETED) for future in completed: - idx, marginal = future.result() - result.update(idx, marginal) - if done(result): - return result + for idx, marginal in future.result(): + result.update(idx, marginal) + if done(result): + return result # Ensure that we always have n_submitted_jobs running try: for _ in range(n_submitted_jobs - len(pending)): + samples = tuple(islice(sampler_it, batch_size)) + if len(samples) == 0: + raise StopIteration + pending.add( executor.submit( _marginal, u=u, coefficient=correction, - sample=next(sampler_it), + samples=samples, ) ) except StopIteration: @@ -266,6 +278,7 @@ def compute_shapley_semivalues( *, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -284,6 +297,7 @@ def compute_shapley_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -298,6 +312,7 @@ def compute_shapley_semivalues( u, shapley_coefficient, done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, @@ -309,6 +324,7 @@ def compute_banzhaf_semivalues( *, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -325,6 +341,7 @@ def compute_banzhaf_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, @@ -339,6 +356,7 @@ def compute_banzhaf_semivalues( u, banzhaf_coefficient, done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, @@ -352,6 +370,7 @@ def compute_beta_shapley_semivalues( beta: float = 1, done: StoppingCriterion = MaxUpdates(100), sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, config: ParallelConfig = ParallelConfig(), progress: bool = False, @@ -369,6 +388,7 @@ def compute_beta_shapley_semivalues( beta: Beta parameter of the Beta distribution. done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, number of @@ -383,6 +403,7 @@ def compute_beta_shapley_semivalues( u, beta_coefficient(alpha, beta), done, + batch_size=batch_size, n_jobs=n_jobs, config=config, progress=progress, diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index ec937d028..474a2b7c1 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -58,6 +58,33 @@ def test_shapley( check_values(values, exact_values, rtol=0.2) +@pytest.mark.parametrize( + "num_samples,sampler,coefficient,batch_size", + [(5, PermutationSampler, beta_coefficient(1, 1), 2)], +) +def test_shapley_batch_size( + num_samples: int, + analytic_shapley, + sampler: Type[PowersetSampler], + coefficient: SVCoefficient, + batch_size: int, + n_jobs: int, + parallel_config: ParallelConfig, +): + u, exact_values = analytic_shapley + criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) + values = compute_generic_semivalues( + sampler(u.data.indices), + u, + coefficient, + criterion, + n_jobs=n_jobs, + batch_size=batch_size, + config=parallel_config, + ) + check_values(values, exact_values, rtol=0.2) + + @pytest.mark.parametrize("num_samples", [5]) @pytest.mark.parametrize( "sampler", From df5d67d98369e36029c7f1e0e16ba4a8b38c44c9 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Wed, 6 Sep 2023 02:21:34 +0200 Subject: [PATCH 02/20] Add function `measure_execution_time`. --- tests/value/utils.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/value/utils.py b/tests/value/utils.py index 7c38e344f..6681a5835 100644 --- a/tests/value/utils.py +++ b/tests/value/utils.py @@ -1,10 +1,17 @@ from __future__ import annotations +import time from copy import deepcopy -from typing import Callable, Tuple +from functools import wraps +from logging import getLogger +from typing import Callable, Optional, Tuple, TypeVar from pydvl.utils.types import Seed +logger = getLogger(__name__) + +ReturnType = TypeVar("ReturnType") + def call_fn_multiple_seeds( fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs @@ -23,3 +30,37 @@ def call_fn_multiple_seeds( A tuple of the results of the function. """ return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) + + +def measure_execution_time( + func: Callable[..., ReturnType] +) -> Callable[..., Tuple[Optional[ReturnType], float]]: + """ + Takes a function `func` and returns a function with the same input arguments and + the original return value along with the execution time. + + Args: + func: The function to be measured, accepting arbitrary arguments and returning + any type. + + Returns: + A wrapped function that, when called, returns a tuple containing the original + function's result and its execution time in seconds. The decorated function + will have the same input arguments and return type as the original function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Tuple[Optional[ReturnType], float]: + result = None + start_time = time.time() + try: + result = func(*args, **kwargs) + except Exception as e: + logger.error(f"Error in {func.__name__}: {e}") + finally: + end_time = time.time() + execution_time = end_time - start_time + logger.info(f"{func.__name__} took {execution_time:.5f} seconds.") + return result, execution_time + + return wrapper From c0817951e8eb3b9584e7cf8e4787141e237781b5 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Wed, 6 Sep 2023 02:22:20 +0200 Subject: [PATCH 03/20] Adapt test case to check if a small batch size has a performance improvement compared to a singular batch size. --- tests/value/test_semivalues.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 474a2b7c1..ceffeb7bb 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -1,10 +1,11 @@ import math +import time from typing import Type import numpy as np import pytest -from pydvl.utils import ParallelConfig +from pydvl.utils import ParallelConfig, Seed from pydvl.value.sampler import ( AntitheticSampler, DeterministicPermutationSampler, @@ -23,6 +24,7 @@ from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates from . import check_values +from .utils import measure_execution_time @pytest.mark.parametrize("num_samples", [5]) @@ -60,21 +62,35 @@ def test_shapley( @pytest.mark.parametrize( "num_samples,sampler,coefficient,batch_size", - [(5, PermutationSampler, beta_coefficient(1, 1), 2)], + [(5, PermutationSampler, beta_coefficient(1, 1), 5)], ) def test_shapley_batch_size( num_samples: int, analytic_shapley, - sampler: Type[PowersetSampler], + sampler: Type[PermutationSampler], coefficient: SVCoefficient, batch_size: int, n_jobs: int, parallel_config: ParallelConfig, + seed: Seed, ): u, exact_values = analytic_shapley criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) - values = compute_generic_semivalues( - sampler(u.data.indices), + result_single_batch, total_seconds_single_batch = measure_execution_time( + compute_generic_semivalues + )( + sampler(u.data.indices, seed=seed), + u, + coefficient, + criterion, + n_jobs=n_jobs, + batch_size=1, + config=parallel_config, + ) + result_multi_batch, total_seconds_multi_batch = measure_execution_time( + compute_generic_semivalues + )( + sampler(u.data.indices, seed=seed), u, coefficient, criterion, @@ -82,7 +98,8 @@ def test_shapley_batch_size( batch_size=batch_size, config=parallel_config, ) - check_values(values, exact_values, rtol=0.2) + assert total_seconds_multi_batch < total_seconds_single_batch + check_values(result_single_batch, result_multi_batch, rtol=0.0, atol=0.0) @pytest.mark.parametrize("num_samples", [5]) From 9083bc2fc1fc7ed457b6cb618073dec03807fe7f Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Mon, 11 Sep 2023 02:40:07 +0200 Subject: [PATCH 04/20] Fix comments. --- src/pydvl/value/semivalues.py | 13 +++++-- tests/value/shapley/test_montecarlo.py | 6 ++-- tests/value/test_semivalues.py | 15 ++++---- tests/value/utils.py | 49 ++++++++++++++------------ 4 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 68766cc3b..348686e8e 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -72,6 +72,7 @@ import logging import math +import warnings from enum import Enum from itertools import islice from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast @@ -172,7 +173,7 @@ def compute_generic_semivalues( u: Utility object with model, data, and scoring function. coefficient: The semi-value coefficient done: Stopping criterion. - batch_size: Number of marginal evaluations per (parallelized) task. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -191,6 +192,12 @@ def compute_generic_semivalues( "will be doubled wrt. a 'direct' implementation of permutation MC" ) + if batch_size != 1: + warnings.warn( + "batch_size is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + result = ValuationResult.zeros( algorithm=f"semivalue-{str(sampler)}-{coefficient.__name__}", # type: ignore indices=u.data.indices, @@ -298,7 +305,7 @@ def compute_shapley_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. - batch_size: Number of marginal evaluations per (parallelized) task. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -342,7 +349,7 @@ def compute_banzhaf_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. - batch_size: Number of marginal evaluations per (parallelized) task. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index 3024ed198..1c9fbf4a7 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -17,7 +17,7 @@ from .. import check_rank_correlation, check_total_value, check_values from ..conftest import polynomial_dataset -from ..utils import call_fn_multiple_seeds +from ..utils import call_with_seeds log = logging.getLogger(__name__) @@ -94,7 +94,7 @@ def test_montecarlo_shapley_housing_dataset_reproducible( kwargs: dict, seed: Seed, ): - values_1, values_2 = call_fn_multiple_seeds( + values_1, values_2 = call_with_seeds( compute_shapley_values, Utility(LinearRegression(), data=housing_dataset, scorer="r2"), mode=fun, @@ -121,7 +121,7 @@ def test_montecarlo_shapley_housing_dataset_stochastic( seed: Seed, seed_alt: Seed, ): - values_1, values_2 = call_fn_multiple_seeds( + values_1, values_2 = call_with_seeds( compute_shapley_values, Utility(LinearRegression(), data=housing_dataset, scorer="r2"), mode=fun, diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 825170802..54db72301 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -24,7 +24,7 @@ from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates from . import check_values -from .utils import measure_execution_time +from .utils import timed @pytest.mark.parametrize("num_samples", [5]) @@ -76,9 +76,8 @@ def test_shapley_batch_size( ): u, exact_values = analytic_shapley criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) - result_single_batch, total_seconds_single_batch = measure_execution_time( - compute_generic_semivalues - )( + timed_fn = timed(compute_generic_semivalues) + result_single_batch = timed_fn( sampler(u.data.indices, seed=seed), u, coefficient, @@ -87,9 +86,8 @@ def test_shapley_batch_size( batch_size=1, config=parallel_config, ) - result_multi_batch, total_seconds_multi_batch = measure_execution_time( - compute_generic_semivalues - )( + total_seconds_single_batch = timed_fn.execution_time + result_multi_batch = timed_fn( sampler(u.data.indices, seed=seed), u, coefficient, @@ -98,8 +96,9 @@ def test_shapley_batch_size( batch_size=batch_size, config=parallel_config, ) + total_seconds_multi_batch = timed_fn.execution_time assert total_seconds_multi_batch < total_seconds_single_batch - check_values(result_single_batch, result_multi_batch, rtol=0.0, atol=0.0) + check_values(result_single_batch, result_multi_batch, rtol=1e-4) @pytest.mark.parametrize("num_samples", [5]) diff --git a/tests/value/utils.py b/tests/value/utils.py index 6681a5835..c55ab13eb 100644 --- a/tests/value/utils.py +++ b/tests/value/utils.py @@ -4,24 +4,22 @@ from copy import deepcopy from functools import wraps from logging import getLogger -from typing import Callable, Optional, Tuple, TypeVar +from typing import Callable, Optional, Protocol, Tuple, TypeVar from pydvl.utils.types import Seed logger = getLogger(__name__) -ReturnType = TypeVar("ReturnType") +ReturnT = TypeVar("ReturnT") -def call_fn_multiple_seeds( - fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs -) -> Tuple: +def call_with_seeds(fun: Callable, *args, seeds: Tuple[Seed, ...], **kwargs) -> Tuple: """ Execute a function multiple times with different seeds. It copies the arguments and keyword arguments before passing them to the function. Args: - fn: The function to execute. + fun: The function to execute. args: The arguments to pass to the function. seeds: The seeds to use. kwargs: The keyword arguments to pass to the function. @@ -29,18 +27,25 @@ def call_fn_multiple_seeds( Returns: A tuple of the results of the function. """ - return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) + return tuple(fun(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) -def measure_execution_time( - func: Callable[..., ReturnType] -) -> Callable[..., Tuple[Optional[ReturnType], float]]: +class TimedCallable(Protocol): + """A callable that has an attribute to keep track of execution time.""" + + execution_time: float + + def __call__(self, *args, **kwargs) -> ReturnT: + ... + + +def timed(fun: Callable[..., ReturnT]) -> TimedCallable: """ Takes a function `func` and returns a function with the same input arguments and the original return value along with the execution time. Args: - func: The function to be measured, accepting arbitrary arguments and returning + fun: The function to be measured, accepting arbitrary arguments and returning any type. Returns: @@ -49,18 +54,16 @@ def measure_execution_time( will have the same input arguments and return type as the original function. """ - @wraps(func) - def wrapper(*args, **kwargs) -> Tuple[Optional[ReturnType], float]: - result = None + wrapper: TimedCallable + + @wraps(fun) + def wrapper(*args, **kwargs) -> ReturnT: start_time = time.time() - try: - result = func(*args, **kwargs) - except Exception as e: - logger.error(f"Error in {func.__name__}: {e}") - finally: - end_time = time.time() - execution_time = end_time - start_time - logger.info(f"{func.__name__} took {execution_time:.5f} seconds.") - return result, execution_time + result = fun(*args, **kwargs) + end_time = time.time() + wrapper.execution_time = end_time - start_time + logger.info(f"{fun.__name__} took {wrapper.execution_time:.5f} seconds.") + return result + wrapper.execution_time = 0.0 return wrapper From 365cf8cbcc6693e3ad4c4c92815a070f6897830c Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Mon, 11 Sep 2023 15:38:11 +0200 Subject: [PATCH 05/20] Add small comment, if test case fails. --- tests/value/test_semivalues.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 54db72301..9fcf6ea1b 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -98,6 +98,8 @@ def test_shapley_batch_size( ) total_seconds_multi_batch = timed_fn.execution_time assert total_seconds_multi_batch < total_seconds_single_batch + + # Occasionally, batch_2 arrives before batch_1, so rtol isn't always 0. check_values(result_single_batch, result_multi_batch, rtol=1e-4) From 9f47916d3006a1fe00f504942fcfa1286302736b Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Tue, 12 Sep 2023 00:56:03 +0200 Subject: [PATCH 06/20] Add deprecation notice to all functions in the documentation. --- src/pydvl/value/semivalues.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 348686e8e..62ca06e0c 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -181,6 +181,10 @@ def compute_generic_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in future + versions. """ from concurrent.futures import FIRST_COMPLETED, Future, wait @@ -194,7 +198,8 @@ def compute_generic_semivalues( if batch_size != 1: warnings.warn( - "batch_size is deprecated and will be removed in future versions.", + "Parameter `batch_size` is for experimental use and will be removed in " + "future versions", DeprecationWarning, ) @@ -314,6 +319,10 @@ def compute_shapley_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in future + versions. """ return compute_generic_semivalues( sampler_t(u.data.indices, seed=seed), @@ -358,6 +367,10 @@ def compute_banzhaf_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in future + versions. """ return compute_generic_semivalues( sampler_t(u.data.indices, seed=seed), @@ -405,6 +418,10 @@ def compute_beta_shapley_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in future + versions. """ return compute_generic_semivalues( sampler_t(u.data.indices, seed=seed), @@ -443,6 +460,7 @@ def compute_semivalues( done: StoppingCriterion = MaxUpdates(100), mode: SemiValueMode = SemiValueMode.Shapley, sampler_t: Type[StochasticSampler] = PermutationSampler, + batch_size: int = 1, n_jobs: int = 1, seed: Optional[Seed] = None, **kwargs, @@ -486,6 +504,7 @@ def compute_semivalues( [SemiValueMode][pydvl.value.semivalues.SemiValueMode] for a list. sampler_t: The sampler type to use. See [sampler][pydvl.value.sampler] for a list. + batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. kwargs: Additional keyword arguments passed to @@ -493,6 +512,10 @@ def compute_semivalues( Returns: Object with the results. + + !!! warning "Deprecation notice" + Parameter `batch_size` is for experimental use and will be removed in future + versions. """ if mode == SemiValueMode.Shapley: coefficient = shapley_coefficient From 313fcc006b22d9a61225ca157c99d1c36c57b0f3 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 12:13:37 +0200 Subject: [PATCH 07/20] Add missing forwarded arg --- src/pydvl/value/semivalues.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 62ca06e0c..ceadae9b4 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -534,5 +534,6 @@ def compute_semivalues( coefficient, done, n_jobs=n_jobs, + batch_size=batch_size, **kwargs, ) From 378feef176222c2f3932d51b048241e0ede71380 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 12:15:01 +0200 Subject: [PATCH 08/20] Rename for consistency --- src/pydvl/value/sampler.py | 46 +++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/pydvl/value/sampler.py b/src/pydvl/value/sampler.py index 0e3e479e9..b692604a8 100644 --- a/src/pydvl/value/sampler.py +++ b/src/pydvl/value/sampler.py @@ -75,12 +75,12 @@ ] -T = TypeVar("T", bound=np.generic) -SampleT = Tuple[T, NDArray[T]] +IndexT = TypeVar("IndexT", bound=np.int_) +SampleT = Tuple[IndexT, NDArray[IndexT]] Sequence.register(np.ndarray) -class PowersetSampler(abc.ABC, Iterable[SampleT], Generic[T]): +class PowersetSampler(abc.ABC, Iterable[SampleT], Generic[IndexT]): """Samplers are custom iterables over subsets of indices. Calling ``iter()`` on a sampler returns an iterator over tuples of the form @@ -121,9 +121,9 @@ class IndexIteration(Enum): def __init__( self, - indices: NDArray[T], + indices: NDArray[IndexT], index_iteration: IndexIteration = IndexIteration.Sequential, - outer_indices: NDArray[T] | None = None, + outer_indices: NDArray[IndexT] | None = None, ): """ Args: @@ -141,11 +141,11 @@ def __init__( self._n_samples = 0 @property - def indices(self) -> NDArray[T]: + def indices(self) -> NDArray[IndexT]: return self._indices @indices.setter - def indices(self, indices: NDArray[T]): + def indices(self, indices: NDArray[IndexT]): raise AttributeError("Cannot set indices of sampler") @property @@ -156,10 +156,10 @@ def n_samples(self) -> int: def n_samples(self, n: int): raise AttributeError("Cannot reset a sampler's number of samples") - def complement(self, exclude: Sequence[T]) -> NDArray[T]: - return np.setxor1d(self._indices, exclude) + def complement(self, exclude: Sequence[IndexT]) -> NDArray[IndexT]: + return np.setxor1d(self._indices, exclude) # type: ignore - def iterindices(self) -> Iterator[T]: + def iterindices(self) -> Iterator[IndexT]: """Iterates over indices in the order specified at construction. FIXME: this is probably not very useful, but I couldn't decide @@ -173,14 +173,14 @@ def iterindices(self) -> Iterator[T]: yield np.random.choice(self._outer_indices, size=1).item() @overload - def __getitem__(self, key: slice) -> PowersetSampler[T]: + def __getitem__(self, key: slice) -> PowersetSampler[IndexT]: ... @overload - def __getitem__(self, key: list[int]) -> PowersetSampler[T]: + def __getitem__(self, key: list[int]) -> PowersetSampler[IndexT]: ... - def __getitem__(self, key: slice | list[int]) -> PowersetSampler[T]: + def __getitem__(self, key: slice | list[int]) -> PowersetSampler[IndexT]: if isinstance(key, slice) or isinstance(key, Iterable): return self.__class__( self._indices, @@ -231,8 +231,8 @@ def __init__(self, *args, seed: Optional[Seed] = None, **kwargs): self._rng = np.random.default_rng(seed) -class DeterministicUniformSampler(PowersetSampler[T]): - def __init__(self, indices: NDArray[T], *args, **kwargs): +class DeterministicUniformSampler(PowersetSampler[IndexT]): + def __init__(self, indices: NDArray[IndexT], *args, **kwargs): """An iterator to perform uniform deterministic sampling of subsets. For every index $i$, each subset of the complement `indices - {i}` is @@ -268,7 +268,7 @@ def weight(cls, n: int, subset_len: int) -> float: return float(2 ** (n - 1)) if n > 0 else 1.0 -class UniformSampler(StochasticSamplerMixin, PowersetSampler[T]): +class UniformSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """An iterator to perform uniform random sampling of subsets. Iterating over every index $i$, either in sequence or at random depending on @@ -306,15 +306,15 @@ def weight(cls, n: int, subset_len: int) -> float: return float(2 ** (n - 1)) if n > 0 else 1.0 -class DeterministicCombinatorialSampler(DeterministicUniformSampler[T]): +class DeterministicCombinatorialSampler(DeterministicUniformSampler[IndexT]): @deprecated( target=DeterministicUniformSampler, deprecated_in="0.6.0", remove_in="0.8.0" ) - def __init__(self, indices: NDArray[T], *args, **kwargs): + def __init__(self, indices: NDArray[IndexT], *args, **kwargs): void(indices, args, kwargs) -class AntitheticSampler(StochasticSamplerMixin, PowersetSampler[T]): +class AntitheticSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """An iterator to perform uniform random sampling of subsets, and their complements. @@ -339,7 +339,7 @@ def weight(cls, n: int, subset_len: int) -> float: return float(2 ** (n - 1)) if n > 0 else 1.0 -class PermutationSampler(StochasticSamplerMixin, PowersetSampler[T]): +class PermutationSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """Sample permutations of indices and iterate through each returning increasing subsets, as required for the permutation definition of semi-values. @@ -365,7 +365,7 @@ def __iter__(self) -> Iterator[SampleT]: if self._n_samples == 0: # Empty index set break - def __getitem__(self, key: slice | list[int]) -> PowersetSampler[T]: + def __getitem__(self, key: slice | list[int]) -> PowersetSampler[IndexT]: """Permutation samplers cannot be split across indices, so we return a copy of the full sampler.""" return super().__getitem__(slice(None)) @@ -375,7 +375,7 @@ def weight(cls, n: int, subset_len: int) -> float: return n * math.comb(n - 1, subset_len) if n > 0 else 1.0 -class DeterministicPermutationSampler(PermutationSampler[T]): +class DeterministicPermutationSampler(PermutationSampler[IndexT]): """Samples all n! permutations of the indices deterministically, and iterates through them, returning sets as required for the permutation-based definition of semi-values. @@ -397,7 +397,7 @@ def __iter__(self) -> Iterator[SampleT]: self._n_samples += 1 -class RandomHierarchicalSampler(StochasticSamplerMixin, PowersetSampler[T]): +class RandomHierarchicalSampler(StochasticSamplerMixin, PowersetSampler[IndexT]): """For every index, sample a set size, then a set of that size. !!! Todo From f6f563942ac26d53103e7e7af201bbe7fa9f1804 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 12:15:19 +0200 Subject: [PATCH 09/20] Fix alias of generics --- src/pydvl/value/sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pydvl/value/sampler.py b/src/pydvl/value/sampler.py index b692604a8..7d57cae8e 100644 --- a/src/pydvl/value/sampler.py +++ b/src/pydvl/value/sampler.py @@ -424,5 +424,8 @@ def weight(cls, n: int, subset_len: int) -> float: # TODO Replace by Intersection[StochasticSamplerMixin, PowersetSampler[T]] # See https://github.com/python/typing/issues/213 StochasticSampler = Union[ - UniformSampler, PermutationSampler, AntitheticSampler, RandomHierarchicalSampler + UniformSampler[IndexT], + PermutationSampler[IndexT], + AntitheticSampler[IndexT], + RandomHierarchicalSampler[IndexT], ] From ab4a2035ed55e8c4215c9578d668d3f68a8694fa Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 12:18:11 +0200 Subject: [PATCH 10/20] epsilon more consistent typing for results (generics don't make sense here anyway) --- src/pydvl/value/result.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pydvl/value/result.py b/src/pydvl/value/result.py index 989d6d92e..5c875f884 100644 --- a/src/pydvl/value/result.py +++ b/src/pydvl/value/result.py @@ -484,7 +484,7 @@ def __repr__(self) -> str: repr_string += ")" return repr_string - def _check_compatible(self, other: "ValuationResult"): + def _check_compatible(self, other: ValuationResult): if not isinstance(other, ValuationResult): raise NotImplementedError( f"Cannot combine ValuationResult with {type(other)}" @@ -492,7 +492,9 @@ def _check_compatible(self, other: "ValuationResult"): if self.algorithm and self.algorithm != other.algorithm: raise ValueError("Cannot combine results from different algorithms") - def __add__(self, other: "ValuationResult") -> "ValuationResult": + def __add__( + self, other: ValuationResult[IndexT, NameT] + ) -> ValuationResult[IndexT, NameT]: """Adds two ValuationResults. The values must have been computed with the same algorithm. An exception @@ -601,7 +603,7 @@ def __add__(self, other: "ValuationResult") -> "ValuationResult": # extra_values=self._extra_values.update(other._extra_values), ) - def update(self, idx: int, new_value: float) -> "ValuationResult": + def update(self, idx: int, new_value: float) -> ValuationResult[IndexT, NameT]: """Updates the result in place with a new value, using running mean and variance. @@ -623,7 +625,7 @@ def update(self, idx: int, new_value: float) -> "ValuationResult": self._values[pos], self._variances[pos], self._counts[pos], new_value ) self[pos] = ValueItem( - index=cast(IndexT, idx), + index=cast(IndexT, idx), # FIXME name=self._names[pos], value=val, variance=var, @@ -766,7 +768,7 @@ def zeros( indices: Optional[Sequence[IndexT] | NDArray[IndexT]] = None, data_names: Optional[Sequence[NameT] | NDArray[NameT]] = None, n_samples: int = 0, - ) -> "ValuationResult": + ) -> ValuationResult: """Creates an empty [ValuationResult][pydvl.value.result.ValuationResult] object. Empty results are characterised by having an empty array of values. When From dfddb584bf06eab43966a321158e29291840ad71 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 12:19:06 +0200 Subject: [PATCH 11/20] Fix docstrings --- src/pydvl/value/semivalues.py | 61 +++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index ceadae9b4..6a78f9594 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -9,12 +9,14 @@ $$\sum_{k=1}^n w(k) = 1.$$ -!!! Note +??? Note For implementation consistency, we slightly depart from the common definition of semi-values, which includes a factor $1/n$ in the sum over subsets. Instead, we subsume this factor into the coefficient $w(k)$. -As such, the computation of a semi-value requires two components: +## Main components + +The computation of a semi-value requires two components: 1. A **subset sampler** that generates subsets of the set $D$ of interest. 2. A **coefficient** $w(k)$ that assigns a weight to each subset size $k$. @@ -44,6 +46,11 @@ require caching to be enabled or computation will be doubled wrt. a 'direct' implementation of permutation MC. +## Computing semi-values + +Samplers and coefficients can be arbitrarily mixed by means of the main entry +point of this module, +[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues]. There are several pre-defined coefficients, including the Shapley value of (Ghorbani and Zou, 2019)[^1], the Banzhaf index of (Wang and Jia)[^3], and the Beta coefficient of (Kwon and Zou, 2022)[^2]. For each of these methods, there is a @@ -53,6 +60,16 @@ and [compute_beta_shapley_semivalues][pydvl.value.semivalues.compute_beta_shapley_semivalues]. instead. +!!! tip "Parallelization and batching" + In order to ensure reproducibility and fine-grained control of + parallelization, samples are generated in the main process and then + distributed to worker processes for evaluation. For small sample sizes, this + can lead to a significant overhead. To avoid this, we temporarily provide an + additional argument `batch_size` to all methods which can improve + performance with small models up to an order of magnitude. Note that this + argument will be removed before version 1.0 in favour of a more general + solution. + ## References @@ -110,7 +127,8 @@ class SVCoefficient(Protocol): - """A coefficient for the computation of semi-values.""" + """The protocol that coefficients for the computation of semi-values must + fulfill.""" def __call__(self, n: int, k: int) -> float: """Computes the coefficient for a given subset size. @@ -183,8 +201,8 @@ def compute_generic_semivalues( Object with the results. !!! warning "Deprecation notice" - Parameter `batch_size` is for experimental use and will be removed in future - versions. + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ from concurrent.futures import FIRST_COMPLETED, Future, wait @@ -308,21 +326,22 @@ def compute_shapley_semivalues( Args: u: Utility object with model, data, and scoring function. done: Stopping criterion. - sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` - for a list. + sampler_t: The sampler type to use. See the + [sampler][pydvl.value.sampler] module for a list. batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. - seed: Either an instance of a numpy random number generator or a seed for it. + seed: Either an instance of a numpy random number generator or a seed + for it. progress: Whether to display a progress bar. Returns: Object with the results. !!! warning "Deprecation notice" - Parameter `batch_size` is for experimental use and will be removed in future - versions. + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ return compute_generic_semivalues( sampler_t(u.data.indices, seed=seed), @@ -356,11 +375,12 @@ def compute_banzhaf_semivalues( Args: u: Utility object with model, data, and scoring function. done: Stopping criterion. - sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a - list. + sampler_t: The sampler type to use. See the + [sampler][pydvl.value.sampler] module for a list. batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. - seed: Either an instance of a numpy random number generator or a seed for it. + seed: Either an instance of a numpy random number generator or a seed + for it. config: Object configuring parallel computation, with cluster address, number of cpus, etc. progress: Whether to display a progress bar. @@ -369,8 +389,8 @@ def compute_banzhaf_semivalues( Object with the results. !!! warning "Deprecation notice" - Parameter `batch_size` is for experimental use and will be removed in future - versions. + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ return compute_generic_semivalues( sampler_t(u.data.indices, seed=seed), @@ -408,7 +428,8 @@ def compute_beta_shapley_semivalues( alpha: Alpha parameter of the Beta distribution. beta: Beta parameter of the Beta distribution. done: Stopping criterion. - sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. + sampler_t: The sampler type to use. See the + [sampler][pydvl.value.sampler] module for a list. batch_size: Number of marginal evaluations per (parallelized) task. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. @@ -420,8 +441,8 @@ def compute_beta_shapley_semivalues( Object with the results. !!! warning "Deprecation notice" - Parameter `batch_size` is for experimental use and will be removed in future - versions. + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ return compute_generic_semivalues( sampler_t(u.data.indices, seed=seed), @@ -514,8 +535,8 @@ def compute_semivalues( Object with the results. !!! warning "Deprecation notice" - Parameter `batch_size` is for experimental use and will be removed in future - versions. + Parameter `batch_size` is for experimental use and will be removed in + future versions. """ if mode == SemiValueMode.Shapley: coefficient = shapley_coefficient From 4118505b24f668f09b2fd67a7aae1d2713e719b0 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 12:19:53 +0200 Subject: [PATCH 12/20] Improve deprecation warnings --- src/pydvl/value/semivalues.py | 42 +++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 6a78f9594..c300b8dab 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -173,6 +173,14 @@ def _marginal( # deprecated_in="0.8.0", # remove_in="0.9.0", # ) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_generic_semivalues( sampler: PowersetSampler, u: Utility, @@ -216,8 +224,8 @@ def compute_generic_semivalues( if batch_size != 1: warnings.warn( - "Parameter `batch_size` is for experimental use and will be removed in " - "future versions", + "Parameter `batch_size` is for experimental use and will be" + " removed in future versions", DeprecationWarning, ) @@ -304,6 +312,14 @@ def beta_coefficient_w(n: int, k: int) -> float: return cast(SVCoefficient, beta_coefficient_w) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_shapley_semivalues( u: Utility, *, @@ -355,6 +371,14 @@ def compute_shapley_semivalues( ) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_banzhaf_semivalues( u: Utility, *, @@ -404,6 +428,14 @@ def compute_banzhaf_semivalues( ) +@deprecated( + target=True, + deprecated_in="0.7.0", + remove_in="0.9.0", + args_mapping={"batch_size": None}, + template_mgs="batch_size is for experimental use and will be removed" + "in future versions.", +) def compute_beta_shapley_semivalues( u: Utility, *, @@ -456,11 +488,7 @@ def compute_beta_shapley_semivalues( ) -@deprecated( - target=True, - deprecated_in="0.7.0", - remove_in="0.8.0", -) +@deprecated(target=True, deprecated_in="0.7.0", remove_in="0.8.0") class SemiValueMode(str, Enum): """Enumeration of semi-value modes. From db416d3d58f0527a1ec5da15355a3687910c2f70 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 15:47:45 +0200 Subject: [PATCH 13/20] Use NDArray more consistently --- src/pydvl/utils/dataset.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/pydvl/utils/dataset.py b/src/pydvl/utils/dataset.py index 12a123806..883a1df36 100644 --- a/src/pydvl/utils/dataset.py +++ b/src/pydvl/utils/dataset.py @@ -40,10 +40,10 @@ class Dataset: def __init__( self, - x_train: Union[np.ndarray, pd.DataFrame], - y_train: Union[np.ndarray, pd.DataFrame], - x_test: Union[np.ndarray, pd.DataFrame], - y_test: Union[np.ndarray, pd.DataFrame], + x_train: Union[NDArray, pd.DataFrame], + y_train: Union[NDArray, pd.DataFrame], + x_test: Union[NDArray, pd.DataFrame], + y_test: Union[NDArray, pd.DataFrame], feature_names: Optional[Sequence[str]] = None, target_names: Optional[Sequence[str]] = None, data_names: Optional[Sequence[str]] = None, @@ -124,8 +124,12 @@ def make_names(s: str, a: np.ndarray) -> List[str]: raise ValueError("Mismatching number of targets and names") self.description = description or "No description" - self._indices = np.arange(len(self.x_train)) - self._data_names = data_names if data_names is not None else self._indices + self._indices = np.arange(len(self.x_train), dtype=np.int_) + self._data_names = ( + np.array(data_names, dtype=object) + if data_names is not None + else self._indices.astype(object) + ) def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple: return self.x_train[idx], self.y_train[idx] @@ -220,7 +224,7 @@ def target(self, name: str) -> Tuple[slice, int]: raise ValueError(f"Target {name} is not in {self.target_names}") @property - def indices(self): + def indices(self) -> NDArray[np.int_]: """Index of positions in data.x_train. Contiguous integers from 0 to len(Dataset). @@ -228,7 +232,7 @@ def indices(self): return self._indices @property - def data_names(self): + def data_names(self) -> NDArray[np.object_]: """Names of each individual datapoint. Used for reporting Shapley values. @@ -236,9 +240,9 @@ def data_names(self): return self._data_names @property - def dim(self): + def dim(self) -> int: """Returns the number of dimensions of a sample.""" - return self.x_train.shape[1] if len(self.x_train.shape) > 1 else 1 + return int(self.x_train.shape[1]) if len(self.x_train.shape) > 1 else 1 def __str__(self): return self.description @@ -256,7 +260,7 @@ def from_sklearn( **kwargs, ) -> "Dataset": """Constructs a [Dataset][pydvl.utils.Dataset] object from a - [sklearn.utils.Bunch][sklearn.utils.Bunch], as returned by the `load_*` + [sklearn.utils.Bunch][], as returned by the `load_*` functions in [scikit-learn toy datasets](https://scikit-learn.org/stable/datasets/toy_dataset.html). ??? Example @@ -360,10 +364,10 @@ def from_arrays( class GroupedDataset(Dataset): def __init__( self, - x_train: np.ndarray, - y_train: np.ndarray, - x_test: np.ndarray, - y_test: np.ndarray, + x_train: NDArray, + y_train: NDArray, + x_test: NDArray, + y_test: NDArray, data_groups: Sequence, feature_names: Optional[Sequence[str]] = None, target_names: Optional[Sequence[str]] = None, @@ -423,7 +427,9 @@ def __init__( self.group_items = list(self.groups.items()) self._indices = np.arange(len(self.groups.keys())) self._data_names = ( - group_names if group_names is not None else list(self.groups.keys()) + np.array(group_names, dtype=object) + if group_names is not None + else np.array(self.groups.keys(), dtype=object) ) def __len__(self): From eab59c5d9db07addfb7809b254ca288d1ccf1416 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 15:48:36 +0200 Subject: [PATCH 14/20] Move IndexT, NameT --- src/pydvl/utils/types.py | 9 ++++++--- src/pydvl/value/result.py | 10 ++-------- src/pydvl/value/sampler.py | 5 +---- src/pydvl/value/semivalues.py | 4 +--- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/pydvl/utils/types.py b/src/pydvl/utils/types.py index 5df91923d..e548dbfc4 100644 --- a/src/pydvl/utils/types.py +++ b/src/pydvl/utils/types.py @@ -6,11 +6,14 @@ from abc import ABCMeta from typing import Any, Optional, Protocol, TypeVar, Union, cast +import numpy as np from numpy.random import Generator, SeedSequence from numpy.typing import NDArray __all__ = [ "ensure_seed_sequence", + "IndexT", + "NameT", "MapFunction", "NoPublicConstructor", "ReduceFunction", @@ -18,7 +21,10 @@ "SupervisedModel", ] +IndexT = TypeVar("IndexT", bound=np.int_) +NameT = TypeVar("NameT", bound=np.object_) R = TypeVar("R", covariant=True) +Seed = Union[int, Generator] class MapFunction(Protocol[R]): @@ -74,9 +80,6 @@ def create(cls, *args: Any, **kwargs: Any): return super().__call__(*args, **kwargs) -Seed = Union[int, Generator] - - def ensure_seed_sequence( seed: Optional[Union[Seed, SeedSequence]] = None ) -> SeedSequence: diff --git a/src/pydvl/value/result.py b/src/pydvl/value/result.py index 5c875f884..95216f8a9 100644 --- a/src/pydvl/value/result.py +++ b/src/pydvl/value/result.py @@ -57,8 +57,6 @@ Literal, Optional, Sequence, - Tuple, - TypeVar, Union, cast, overload, @@ -71,21 +69,17 @@ from pydvl.utils.dataset import Dataset from pydvl.utils.numeric import running_moments from pydvl.utils.status import Status -from pydvl.utils.types import Seed +from pydvl.utils.types import IndexT, NameT, Seed try: import pandas # Try to import here for the benefit of mypy except ImportError: pass -__all__ = ["ValuationResult", "ValueItem", "IndexT", "NameT"] +__all__ = ["ValuationResult", "ValueItem"] logger = logging.getLogger(__name__) -# TODO: Move to value.types once it's there -IndexT = TypeVar("IndexT", bound=np.int_) -NameT = TypeVar("NameT", bound=Any) - @total_ordering @dataclass diff --git a/src/pydvl/value/sampler.py b/src/pydvl/value/sampler.py index 7d57cae8e..a51dbfb79 100644 --- a/src/pydvl/value/sampler.py +++ b/src/pydvl/value/sampler.py @@ -51,7 +51,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, overload, ) @@ -61,7 +60,7 @@ from numpy.typing import NDArray from pydvl.utils.numeric import powerset, random_subset, random_subset_of_size -from pydvl.utils.types import Seed +from pydvl.utils.types import IndexT, Seed __all__ = [ "AntitheticSampler", @@ -74,8 +73,6 @@ "StochasticSamplerMixin", ] - -IndexT = TypeVar("IndexT", bound=np.int_) SampleT = Tuple[IndexT, NDArray[IndexT]] Sequence.register(np.ndarray) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index c300b8dab..626379d7f 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -94,14 +94,13 @@ from itertools import islice from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast -import numpy as np import scipy as sp from deprecate import deprecated from tqdm import tqdm from pydvl.parallel.config import ParallelConfig from pydvl.utils import Utility -from pydvl.utils.types import Seed +from pydvl.utils.types import IndexT, Seed from pydvl.value import ValuationResult from pydvl.value.sampler import ( PermutationSampler, @@ -140,7 +139,6 @@ def __call__(self, n: int, k: int) -> float: ... -IndexT = TypeVar("IndexT", bound=np.generic) MarginalT = Tuple[IndexT, float] From cc4a26c0033cb08f2f8eac81f765a7cae77ee038 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 15:49:59 +0200 Subject: [PATCH 15/20] Some type fixes --- src/pydvl/value/oob/oob.py | 2 +- src/pydvl/value/result.py | 8 ++++---- src/pydvl/value/semivalues.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pydvl/value/oob/oob.py b/src/pydvl/value/oob/oob.py index c62a6255d..617046355 100644 --- a/src/pydvl/value/oob/oob.py +++ b/src/pydvl/value/oob/oob.py @@ -70,7 +70,7 @@ def compute_data_oob( Object with the data values. """ - result: ValuationResult[np.int_, np.float_] = ValuationResult.empty( + result: ValuationResult[np.int_, np.object_] = ValuationResult.empty( algorithm="data_oob", indices=u.data.indices, data_names=u.data.data_names ) diff --git a/src/pydvl/value/result.py b/src/pydvl/value/result.py index 95216f8a9..b7df8b53a 100644 --- a/src/pydvl/value/result.py +++ b/src/pydvl/value/result.py @@ -734,7 +734,7 @@ def empty( indices: Optional[Sequence[IndexT] | NDArray[IndexT]] = None, data_names: Optional[Sequence[NameT] | NDArray[NameT]] = None, n_samples: int = 0, - ) -> "ValuationResult": + ) -> ValuationResult: """Creates an empty [ValuationResult][pydvl.value.result.ValuationResult] object. Empty results are characterised by having an empty array of values. When @@ -762,7 +762,7 @@ def zeros( indices: Optional[Sequence[IndexT] | NDArray[IndexT]] = None, data_names: Optional[Sequence[NameT] | NDArray[NameT]] = None, n_samples: int = 0, - ) -> ValuationResult: + ) -> ValuationResult[IndexT, NameT]: """Creates an empty [ValuationResult][pydvl.value.result.ValuationResult] object. Empty results are characterised by having an empty array of values. When @@ -783,12 +783,12 @@ def zeros( if indices is None: indices = np.arange(n_samples, dtype=np.int_) else: - indices = np.array(indices) + indices = np.array(indices, dtype=np.int_) return cls( algorithm=algorithm, status=Status.Pending, indices=indices, - data_names=data_names + data_names=np.array(data_names, dtype=object) if data_names is not None else np.empty_like(indices, dtype=object), values=np.zeros(len(indices)), diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 626379d7f..a30fdaf8e 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -180,7 +180,7 @@ def _marginal( "in future versions.", ) def compute_generic_semivalues( - sampler: PowersetSampler, + sampler: PowersetSampler[IndexT], u: Utility, coefficient: SVCoefficient, done: StoppingCriterion, From 0a13195c490756268f6ef59772e0d31140c2d1e3 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 15:51:46 +0200 Subject: [PATCH 16/20] Pre-generate result indices and names as usual --- src/pydvl/value/shapley/montecarlo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pydvl/value/shapley/montecarlo.py b/src/pydvl/value/shapley/montecarlo.py index 1046326a5..af0f67ce2 100644 --- a/src/pydvl/value/shapley/montecarlo.py +++ b/src/pydvl/value/shapley/montecarlo.py @@ -201,7 +201,9 @@ def permutation_montecarlo_shapley( n_submitted_jobs = 2 * max_workers # number of jobs in the executor's queue seed_sequence = ensure_seed_sequence(seed) - result = ValuationResult.zeros(algorithm=algorithm) + result = ValuationResult.zeros( + algorithm=algorithm, indices=u.data.indices, data_names=u.data.data_names + ) pbar = tqdm(disable=not progress, total=100, unit="%") From 4f160a2a663978c334a1c17592d71d47d5d1ab94 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 15:52:27 +0200 Subject: [PATCH 17/20] Ignore mypy (generic in ValuationResult is mostly useless) --- src/pydvl/value/semivalues.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index a30fdaf8e..69bf8e809 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -269,10 +269,7 @@ def compute_generic_semivalues( pending.add( executor.submit( - _marginal, - u=u, - coefficient=correction, - samples=samples, + _marginal, u=u, coefficient=correction, samples=samples ) ) except StopIteration: @@ -357,7 +354,8 @@ def compute_shapley_semivalues( Parameter `batch_size` is for experimental use and will be removed in future versions. """ - return compute_generic_semivalues( + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, shapley_coefficient, @@ -414,7 +412,8 @@ def compute_banzhaf_semivalues( Parameter `batch_size` is for experimental use and will be removed in future versions. """ - return compute_generic_semivalues( + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, banzhaf_coefficient, @@ -474,7 +473,8 @@ def compute_beta_shapley_semivalues( Parameter `batch_size` is for experimental use and will be removed in future versions. """ - return compute_generic_semivalues( + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, beta_coefficient(alpha, beta), @@ -506,7 +506,7 @@ def compute_semivalues( *, done: StoppingCriterion = MaxUpdates(100), mode: SemiValueMode = SemiValueMode.Shapley, - sampler_t: Type[StochasticSampler] = PermutationSampler, + sampler_t: Type[StochasticSampler[IndexT]] = PermutationSampler[IndexT], batch_size: int = 1, n_jobs: int = 1, seed: Optional[Seed] = None, @@ -575,7 +575,9 @@ def compute_semivalues( else: raise ValueError(f"Unknown mode {mode}") coefficient = cast(SVCoefficient, coefficient) - return compute_generic_semivalues( + + # HACK: cannot infer return type because of useless IndexT, NameT + return compute_generic_semivalues( # type: ignore sampler_t(u.data.indices, seed=seed), u, coefficient, From 431ee065595e81cac6a14865e99049c5f960702f Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 15:57:51 +0200 Subject: [PATCH 18/20] Fix construction of data_names --- src/pydvl/utils/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pydvl/utils/dataset.py b/src/pydvl/utils/dataset.py index 883a1df36..c6331ce08 100644 --- a/src/pydvl/utils/dataset.py +++ b/src/pydvl/utils/dataset.py @@ -429,7 +429,7 @@ def __init__( self._data_names = ( np.array(group_names, dtype=object) if group_names is not None - else np.array(self.groups.keys(), dtype=object) + else np.array(list(self.groups.keys()), dtype=object) ) def __len__(self): From d398c271e8d222529574e85ab0617cf0beb251e8 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 17:05:18 +0200 Subject: [PATCH 19/20] Work around mypy's odd behaviour in CI --- src/pydvl/value/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pydvl/value/result.py b/src/pydvl/value/result.py index b7df8b53a..04b922a75 100644 --- a/src/pydvl/value/result.py +++ b/src/pydvl/value/result.py @@ -762,7 +762,7 @@ def zeros( indices: Optional[Sequence[IndexT] | NDArray[IndexT]] = None, data_names: Optional[Sequence[NameT] | NDArray[NameT]] = None, n_samples: int = 0, - ) -> ValuationResult[IndexT, NameT]: + ) -> ValuationResult: """Creates an empty [ValuationResult][pydvl.value.result.ValuationResult] object. Empty results are characterised by having an empty array of values. When From a3b1d4a9b8e25c98bbf46fca1b7cd701c0b4f138 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Sun, 17 Sep 2023 21:31:50 +0200 Subject: [PATCH 20/20] Give some slack to the multi batch test --- tests/value/test_semivalues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 9fcf6ea1b..0a71bcd54 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -97,7 +97,7 @@ def test_shapley_batch_size( config=parallel_config, ) total_seconds_multi_batch = timed_fn.execution_time - assert total_seconds_multi_batch < total_seconds_single_batch + assert total_seconds_multi_batch < total_seconds_single_batch * 1.1 # Occasionally, batch_2 arrives before batch_1, so rtol isn't always 0. check_values(result_single_batch, result_multi_batch, rtol=1e-4)