diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index 474a2b7c1..ceffeb7bb 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -1,10 +1,11 @@ import math +import time from typing import Type import numpy as np import pytest -from pydvl.utils import ParallelConfig +from pydvl.utils import ParallelConfig, Seed from pydvl.value.sampler import ( AntitheticSampler, DeterministicPermutationSampler, @@ -23,6 +24,7 @@ from pydvl.value.stopping import AbsoluteStandardError, MaxUpdates from . import check_values +from .utils import measure_execution_time @pytest.mark.parametrize("num_samples", [5]) @@ -60,21 +62,35 @@ def test_shapley( @pytest.mark.parametrize( "num_samples,sampler,coefficient,batch_size", - [(5, PermutationSampler, beta_coefficient(1, 1), 2)], + [(5, PermutationSampler, beta_coefficient(1, 1), 5)], ) def test_shapley_batch_size( num_samples: int, analytic_shapley, - sampler: Type[PowersetSampler], + sampler: Type[PermutationSampler], coefficient: SVCoefficient, batch_size: int, n_jobs: int, parallel_config: ParallelConfig, + seed: Seed, ): u, exact_values = analytic_shapley criterion = AbsoluteStandardError(0.02, 1.0) | MaxUpdates(2 ** (num_samples * 2)) - values = compute_generic_semivalues( - sampler(u.data.indices), + result_single_batch, total_seconds_single_batch = measure_execution_time( + compute_generic_semivalues + )( + sampler(u.data.indices, seed=seed), + u, + coefficient, + criterion, + n_jobs=n_jobs, + batch_size=1, + config=parallel_config, + ) + result_multi_batch, total_seconds_multi_batch = measure_execution_time( + compute_generic_semivalues + )( + sampler(u.data.indices, seed=seed), u, coefficient, criterion, @@ -82,7 +98,8 @@ def test_shapley_batch_size( batch_size=batch_size, config=parallel_config, ) - check_values(values, exact_values, rtol=0.2) + assert total_seconds_multi_batch < total_seconds_single_batch + check_values(result_single_batch, result_multi_batch, rtol=0.0, atol=0.0) @pytest.mark.parametrize("num_samples", [5])