diff --git a/src/pydvl/value/shapley/truncated.py b/src/pydvl/value/shapley/truncated.py index 006487d0a..006a02ae9 100644 --- a/src/pydvl/value/shapley/truncated.py +++ b/src/pydvl/value/shapley/truncated.py @@ -255,11 +255,7 @@ def truncated_montecarlo_shapley( # including the ones that are running n_submitted_jobs = 2 * n_jobs - accumulated_result = ValuationResult.zeros( - algorithm=algorithm, - indices=u.data.indices, - data_names=u.data.data_names, - ) + accumulated_result = ValuationResult.zeros(algorithm=algorithm) with init_executor(max_workers=n_jobs, config=config) as executor: futures = set() diff --git a/tests/value/shapley/test_classwise.py b/tests/value/shapley/test_classwise.py index a6ab5947e..fcc54f23d 100644 --- a/tests/value/shapley/test_classwise.py +++ b/tests/value/shapley/test_classwise.py @@ -1,6 +1,8 @@ """ Test cases for the class wise shapley value. """ +import random +from random import seed from typing import Dict, Tuple import numpy as np @@ -239,7 +241,7 @@ def linear_classifier_cs_scorer_args_exact_solution_use_default_score() -> Tuple ] ) ), - {"rtol": 0.15}, + {"atol": 0.05}, ) @@ -263,7 +265,7 @@ def linear_classifier_cs_scorer_args_exact_solution_use_default_score_norm() -> ] ) ), - {"rtol": 0.15}, + {"atol": 0.05}, ) @@ -738,8 +740,12 @@ def linear_classifier_cs_scorer_args_exact_solution_use_add_idx_empty_set() -> T ) -@pytest.mark.parametrize("n_samples", [500]) -@pytest.mark.parametrize("n_resample_complement_sets", [3]) +@pytest.mark.parametrize("n_samples", [500], ids=lambda x: "n_samples={}".format(x)) +@pytest.mark.parametrize( + "n_resample_complement_sets", + [20], + ids=lambda x: "n_resample_complement_sets={}".format(x), +) @pytest.mark.parametrize( "linear_classifier_cs_scorer_args_exact_solution", [