Skip to content

Commit

Permalink
Adapt test case to check if a small batch size has a performance impr…
Browse files Browse the repository at this point in the history
…ovement compared to a singular batch size.
  • Loading branch information
Markus Semmler committed Sep 6, 2023
1 parent df5d67d commit c081795
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions tests/value/test_semivalues.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import time
from typing import Type

import numpy as np
import pytest

from pydvl.utils import ParallelConfig
from pydvl.utils import ParallelConfig, Seed
from pydvl.value.sampler import (
AntitheticSampler,
DeterministicPermutationSampler,
Expand All @@ -23,6 +24,7 @@
from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates

from . import check_values
from .utils import measure_execution_time


@pytest.mark.parametrize("num_samples", [5])
Expand Down Expand Up @@ -60,29 +62,44 @@ def test_shapley(

@pytest.mark.parametrize(
"num_samples,sampler,coefficient,batch_size",
[(5, PermutationSampler, beta_coefficient(1, 1), 2)],
[(5, PermutationSampler, beta_coefficient(1, 1), 5)],
)
def test_shapley_batch_size(
num_samples: int,
analytic_shapley,
sampler: Type[PowersetSampler],
sampler: Type[PermutationSampler],
coefficient: SVCoefficient,
batch_size: int,
n_jobs: int,
parallel_config: ParallelConfig,
seed: Seed,
):
u, exact_values = analytic_shapley
criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2))
values = compute_generic_semivalues(
sampler(u.data.indices),
result_single_batch, total_seconds_single_batch = measure_execution_time(
compute_generic_semivalues
)(
sampler(u.data.indices, seed=seed),
u,
coefficient,
criterion,
n_jobs=n_jobs,
batch_size=1,
config=parallel_config,
)
result_multi_batch, total_seconds_multi_batch = measure_execution_time(
compute_generic_semivalues
)(
sampler(u.data.indices, seed=seed),
u,
coefficient,
criterion,
n_jobs=n_jobs,
batch_size=batch_size,
config=parallel_config,
)
check_values(values, exact_values, rtol=0.2)
assert total_seconds_multi_batch < total_seconds_single_batch
check_values(result_single_batch, result_multi_batch, rtol=0.0, atol=0.0)


@pytest.mark.parametrize("num_samples", [5])
Expand Down

0 comments on commit c081795

Please sign in to comment.