From b81f85edee0cbb92b4bae6dc177dbaf2a0c99ec3 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Mon, 7 Aug 2023 20:32:11 +0200 Subject: [PATCH] Fix test cases. --- src/pydvl/value/shapley/truncated.py | 6 +----- tests/value/shapley/test_classwise.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/pydvl/value/shapley/truncated.py b/src/pydvl/value/shapley/truncated.py index 006487d0a..006a02ae9 100644 --- a/src/pydvl/value/shapley/truncated.py +++ b/src/pydvl/value/shapley/truncated.py @@ -255,11 +255,7 @@ def truncated_montecarlo_shapley( # including the ones that are running n_submitted_jobs = 2 * n_jobs - accumulated_result = ValuationResult.zeros( - algorithm=algorithm, - indices=u.data.indices, - data_names=u.data.data_names, - ) + accumulated_result = ValuationResult.zeros(algorithm=algorithm) with init_executor(max_workers=n_jobs, config=config) as executor: futures = set() diff --git a/tests/value/shapley/test_classwise.py b/tests/value/shapley/test_classwise.py index a6ab5947e..6470f90b6 100644 --- a/tests/value/shapley/test_classwise.py +++ b/tests/value/shapley/test_classwise.py @@ -1,6 +1,8 @@ """ Test cases for the class wise shapley value. """ +import random +from random import seed from typing import Dict, Tuple import numpy as np @@ -239,7 +241,7 @@ def linear_classifier_cs_scorer_args_exact_solution_use_default_score() -> Tuple ] ) ), - {"rtol": 0.15}, + {"atol": 0.05}, ) @@ -263,7 +265,7 @@ def linear_classifier_cs_scorer_args_exact_solution_use_default_score_norm() -> ] ) ), - {"rtol": 0.15}, + {"atol": 0.05}, ) @@ -738,8 +740,13 @@ def linear_classifier_cs_scorer_args_exact_solution_use_add_idx_empty_set() -> T ) -@pytest.mark.parametrize("n_samples", [500]) -@pytest.mark.parametrize("n_resample_complement_sets", [3]) +@pytest.mark.parametrize("n_samples", [500], ids=lambda x: "n_samples={}".format(x)) +@pytest.mark.parametrize( + "n_resample_complement_sets", + [20], + ids=lambda x: "n_resample_complement_sets={}".format(x), +) +@pytest.mark.parametrize("seed", [13, 42, 666], ids=lambda x: "seed={}".format(x)) @pytest.mark.parametrize( "linear_classifier_cs_scorer_args_exact_solution", [ @@ -754,8 +761,12 @@ def test_classwise_shapley( linear_classifier_cs_scorer_args_exact_solution: Tuple[Dict, ValuationResult], n_samples: int, n_resample_complement_sets: int, + seed: int, request, ): + np.random.seed(seed) + random.seed(seed) + args, exact_solution, check_args = request.getfixturevalue( linear_classifier_cs_scorer_args_exact_solution )