From 5e56dd59d8a173464bbafddef2d470cf6768f0a7 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Mon, 11 Sep 2023 02:40:07 +0200 Subject: [PATCH] Fix comments. --- src/pydvl/value/semivalues.py | 13 +++++-- tests/value/shapley/test_montecarlo.py | 6 ++-- tests/value/test_semivalues.py | 15 ++++---- tests/value/utils.py | 49 ++++++++++++++------------ 4 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 68766cc3b..348686e8e 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -72,6 +72,7 @@ import logging import math +import warnings from enum import Enum from itertools import islice from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast @@ -172,7 +173,7 @@ def compute_generic_semivalues( u: Utility object with model, data, and scoring function. coefficient: The semi-value coefficient done: Stopping criterion. - batch_size: Number of marginal evaluations per (parallelized) task. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -191,6 +192,12 @@ def compute_generic_semivalues( "will be doubled wrt. a 'direct' implementation of permutation MC" ) + if batch_size != 1: + warnings.warn( + "batch_size is deprecated and will be removed in future versions.", + DeprecationWarning, + ) + result = ValuationResult.zeros( algorithm=f"semivalue-{str(sampler)}-{coefficient.__name__}", # type: ignore indices=u.data.indices, @@ -298,7 +305,7 @@ def compute_shapley_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. - batch_size: Number of marginal evaluations per (parallelized) task. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. config: Object configuring parallel computation, with cluster address, number of cpus, etc. @@ -342,7 +349,7 @@ def compute_banzhaf_semivalues( done: Stopping criterion. sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a list. - batch_size: Number of marginal evaluations per (parallelized) task. + batch_size: Number of marginal evaluations per single parallel job. n_jobs: Number of parallel jobs to use. seed: Either an instance of a numpy random number generator or a seed for it. config: Object configuring parallel computation, with cluster address, diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index 3024ed198..1c9fbf4a7 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -17,7 +17,7 @@ from .. import check_rank_correlation, check_total_value, check_values from ..conftest import polynomial_dataset -from ..utils import call_fn_multiple_seeds +from ..utils import call_with_seeds log = logging.getLogger(__name__) @@ -94,7 +94,7 @@ def test_montecarlo_shapley_housing_dataset_reproducible( kwargs: dict, seed: Seed, ): - values_1, values_2 = call_fn_multiple_seeds( + values_1, values_2 = call_with_seeds( compute_shapley_values, Utility(LinearRegression(), data=housing_dataset, scorer="r2"), mode=fun, @@ -121,7 +121,7 @@ def test_montecarlo_shapley_housing_dataset_stochastic( seed: Seed, seed_alt: Seed, ): - values_1, values_2 = call_fn_multiple_seeds( + values_1, values_2 = call_with_seeds( compute_shapley_values, Utility(LinearRegression(), data=housing_dataset, scorer="r2"), mode=fun, diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 825170802..997434694 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -24,7 +24,7 @@ from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates from . import check_values -from .utils import measure_execution_time +from .utils import timed @pytest.mark.parametrize("num_samples", [5]) @@ -76,9 +76,8 @@ def test_shapley_batch_size( ): u, exact_values = analytic_shapley criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) - result_single_batch, total_seconds_single_batch = measure_execution_time( - compute_generic_semivalues - )( + timed_fn = timed(compute_generic_semivalues) + result_single_batch = timed_fn( sampler(u.data.indices, seed=seed), u, coefficient, @@ -87,9 +86,8 @@ def test_shapley_batch_size( batch_size=1, config=parallel_config, ) - result_multi_batch, total_seconds_multi_batch = measure_execution_time( - compute_generic_semivalues - )( + total_seconds_single_batch = timed_fn.execution_time + result_multi_batch = timed_fn( sampler(u.data.indices, seed=seed), u, coefficient, @@ -98,8 +96,9 @@ def test_shapley_batch_size( batch_size=batch_size, config=parallel_config, ) + total_seconds_multi_batch = timed_fn.execution_time assert total_seconds_multi_batch < total_seconds_single_batch - check_values(result_single_batch, result_multi_batch, rtol=0.0, atol=0.0) + check_values(result_single_batch, result_multi_batch, rtol=1e-2) @pytest.mark.parametrize("num_samples", [5]) diff --git a/tests/value/utils.py b/tests/value/utils.py index 6681a5835..c55ab13eb 100644 --- a/tests/value/utils.py +++ b/tests/value/utils.py @@ -4,24 +4,22 @@ from copy import deepcopy from functools import wraps from logging import getLogger -from typing import Callable, Optional, Tuple, TypeVar +from typing import Callable, Optional, Protocol, Tuple, TypeVar from pydvl.utils.types import Seed logger = getLogger(__name__) -ReturnType = TypeVar("ReturnType") +ReturnT = TypeVar("ReturnT") -def call_fn_multiple_seeds( - fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs -) -> Tuple: +def call_with_seeds(fun: Callable, *args, seeds: Tuple[Seed, ...], **kwargs) -> Tuple: """ Execute a function multiple times with different seeds. It copies the arguments and keyword arguments before passing them to the function. Args: - fn: The function to execute. + fun: The function to execute. args: The arguments to pass to the function. seeds: The seeds to use. kwargs: The keyword arguments to pass to the function. @@ -29,18 +27,25 @@ def call_fn_multiple_seeds( Returns: A tuple of the results of the function. """ - return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) + return tuple(fun(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) -def measure_execution_time( - func: Callable[..., ReturnType] -) -> Callable[..., Tuple[Optional[ReturnType], float]]: +class TimedCallable(Protocol): + """A callable that has an attribute to keep track of execution time.""" + + execution_time: float + + def __call__(self, *args, **kwargs) -> ReturnT: + ... + + +def timed(fun: Callable[..., ReturnT]) -> TimedCallable: """ Takes a function `func` and returns a function with the same input arguments and the original return value along with the execution time. Args: - func: The function to be measured, accepting arbitrary arguments and returning + fun: The function to be measured, accepting arbitrary arguments and returning any type. Returns: @@ -49,18 +54,16 @@ def measure_execution_time( will have the same input arguments and return type as the original function. """ - @wraps(func) - def wrapper(*args, **kwargs) -> Tuple[Optional[ReturnType], float]: - result = None + wrapper: TimedCallable + + @wraps(fun) + def wrapper(*args, **kwargs) -> ReturnT: start_time = time.time() - try: - result = func(*args, **kwargs) - except Exception as e: - logger.error(f"Error in {func.__name__}: {e}") - finally: - end_time = time.time() - execution_time = end_time - start_time - logger.info(f"{func.__name__} took {execution_time:.5f} seconds.") - return result, execution_time + result = fun(*args, **kwargs) + end_time = time.time() + wrapper.execution_time = end_time - start_time + logger.info(f"{fun.__name__} took {wrapper.execution_time:.5f} seconds.") + return result + wrapper.execution_time = 0.0 return wrapper