Skip to content

Commit

Permalink
Fix comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Sep 11, 2023
1 parent 307af25 commit 9083bc2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 37 deletions.
13 changes: 10 additions & 3 deletions src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@

import logging
import math
import warnings
from enum import Enum
from itertools import islice
from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast
Expand Down Expand Up @@ -172,7 +173,7 @@ def compute_generic_semivalues(
u: Utility object with model, data, and scoring function.
coefficient: The semi-value coefficient
done: Stopping criterion.
batch_size: Number of marginal evaluations per (parallelized) task.
batch_size: Number of marginal evaluations per single parallel job.
n_jobs: Number of parallel jobs to use.
config: Object configuring parallel computation, with cluster
address, number of cpus, etc.
Expand All @@ -191,6 +192,12 @@ def compute_generic_semivalues(
"will be doubled wrt. a 'direct' implementation of permutation MC"
)

if batch_size != 1:
warnings.warn(
"batch_size is deprecated and will be removed in future versions.",
DeprecationWarning,
)

result = ValuationResult.zeros(
algorithm=f"semivalue-{str(sampler)}-{coefficient.__name__}", # type: ignore
indices=u.data.indices,
Expand Down Expand Up @@ -298,7 +305,7 @@ def compute_shapley_semivalues(
done: Stopping criterion.
sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler`
for a list.
batch_size: Number of marginal evaluations per (parallelized) task.
batch_size: Number of marginal evaluations per single parallel job.
n_jobs: Number of parallel jobs to use.
config: Object configuring parallel computation, with cluster
address, number of cpus, etc.
Expand Down Expand Up @@ -342,7 +349,7 @@ def compute_banzhaf_semivalues(
done: Stopping criterion.
sampler_t: The sampler type to use. See :mod:`pydvl.value.sampler` for a
list.
batch_size: Number of marginal evaluations per (parallelized) task.
batch_size: Number of marginal evaluations per single parallel job.
n_jobs: Number of parallel jobs to use.
seed: Either an instance of a numpy random number generator or a seed for it.
config: Object configuring parallel computation, with cluster address,
Expand Down
6 changes: 3 additions & 3 deletions tests/value/shapley/test_montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .. import check_rank_correlation, check_total_value, check_values
from ..conftest import polynomial_dataset
from ..utils import call_fn_multiple_seeds
from ..utils import call_with_seeds

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,7 +94,7 @@ def test_montecarlo_shapley_housing_dataset_reproducible(
kwargs: dict,
seed: Seed,
):
values_1, values_2 = call_fn_multiple_seeds(
values_1, values_2 = call_with_seeds(
compute_shapley_values,
Utility(LinearRegression(), data=housing_dataset, scorer="r2"),
mode=fun,
Expand All @@ -121,7 +121,7 @@ def test_montecarlo_shapley_housing_dataset_stochastic(
seed: Seed,
seed_alt: Seed,
):
values_1, values_2 = call_fn_multiple_seeds(
values_1, values_2 = call_with_seeds(
compute_shapley_values,
Utility(LinearRegression(), data=housing_dataset, scorer="r2"),
mode=fun,
Expand Down
15 changes: 7 additions & 8 deletions tests/value/test_semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates

from . import check_values
from .utils import measure_execution_time
from .utils import timed


@pytest.mark.parametrize("num_samples", [5])
Expand Down Expand Up @@ -76,9 +76,8 @@ def test_shapley_batch_size(
):
u, exact_values = analytic_shapley
criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2))
result_single_batch, total_seconds_single_batch = measure_execution_time(
compute_generic_semivalues
)(
timed_fn = timed(compute_generic_semivalues)
result_single_batch = timed_fn(
sampler(u.data.indices, seed=seed),
u,
coefficient,
Expand All @@ -87,9 +86,8 @@ def test_shapley_batch_size(
batch_size=1,
config=parallel_config,
)
result_multi_batch, total_seconds_multi_batch = measure_execution_time(
compute_generic_semivalues
)(
total_seconds_single_batch = timed_fn.execution_time
result_multi_batch = timed_fn(
sampler(u.data.indices, seed=seed),
u,
coefficient,
Expand All @@ -98,8 +96,9 @@ def test_shapley_batch_size(
batch_size=batch_size,
config=parallel_config,
)
total_seconds_multi_batch = timed_fn.execution_time
assert total_seconds_multi_batch < total_seconds_single_batch
check_values(result_single_batch, result_multi_batch, rtol=0.0, atol=0.0)
check_values(result_single_batch, result_multi_batch, rtol=1e-4)


@pytest.mark.parametrize("num_samples", [5])
Expand Down
49 changes: 26 additions & 23 deletions tests/value/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,48 @@
from copy import deepcopy
from functools import wraps
from logging import getLogger
from typing import Callable, Optional, Tuple, TypeVar
from typing import Callable, Optional, Protocol, Tuple, TypeVar

from pydvl.utils.types import Seed

logger = getLogger(__name__)

ReturnType = TypeVar("ReturnType")
ReturnT = TypeVar("ReturnT")


def call_fn_multiple_seeds(
fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs
) -> Tuple:
def call_with_seeds(fun: Callable, *args, seeds: Tuple[Seed, ...], **kwargs) -> Tuple:
"""
Execute a function multiple times with different seeds. It copies the arguments
and keyword arguments before passing them to the function.
Args:
fn: The function to execute.
fun: The function to execute.
args: The arguments to pass to the function.
seeds: The seeds to use.
kwargs: The keyword arguments to pass to the function.
Returns:
A tuple of the results of the function.
"""
return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds)
return tuple(fun(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds)


def measure_execution_time(
func: Callable[..., ReturnType]
) -> Callable[..., Tuple[Optional[ReturnType], float]]:
class TimedCallable(Protocol):
"""A callable that has an attribute to keep track of execution time."""

execution_time: float

def __call__(self, *args, **kwargs) -> ReturnT:
...


def timed(fun: Callable[..., ReturnT]) -> TimedCallable:
"""
Takes a function `func` and returns a function with the same input arguments and
the original return value along with the execution time.
Args:
func: The function to be measured, accepting arbitrary arguments and returning
fun: The function to be measured, accepting arbitrary arguments and returning
any type.
Returns:
Expand All @@ -49,18 +54,16 @@ def measure_execution_time(
will have the same input arguments and return type as the original function.
"""

@wraps(func)
def wrapper(*args, **kwargs) -> Tuple[Optional[ReturnType], float]:
result = None
wrapper: TimedCallable

@wraps(fun)
def wrapper(*args, **kwargs) -> ReturnT:
start_time = time.time()
try:
result = func(*args, **kwargs)
except Exception as e:
logger.error(f"Error in {func.__name__}: {e}")
finally:
end_time = time.time()
execution_time = end_time - start_time
logger.info(f"{func.__name__} took {execution_time:.5f} seconds.")
return result, execution_time
result = fun(*args, **kwargs)
end_time = time.time()
wrapper.execution_time = end_time - start_time
logger.info(f"{fun.__name__} took {wrapper.execution_time:.5f} seconds.")
return result

wrapper.execution_time = 0.0
return wrapper

0 comments on commit 9083bc2

Please sign in to comment.