Skip to content

Commit

Permalink
Fix test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Aug 7, 2023
1 parent 7687b26 commit b81f85e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
6 changes: 1 addition & 5 deletions src/pydvl/value/shapley/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,7 @@ def truncated_montecarlo_shapley(
# including the ones that are running
n_submitted_jobs = 2 * n_jobs

accumulated_result = ValuationResult.zeros(
algorithm=algorithm,
indices=u.data.indices,
data_names=u.data.data_names,
)
accumulated_result = ValuationResult.zeros(algorithm=algorithm)

with init_executor(max_workers=n_jobs, config=config) as executor:
futures = set()
Expand Down
19 changes: 15 additions & 4 deletions tests/value/shapley/test_classwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Test cases for the class wise shapley value.
"""
import random
from random import seed
from typing import Dict, Tuple

import numpy as np
Expand Down Expand Up @@ -239,7 +241,7 @@ def linear_classifier_cs_scorer_args_exact_solution_use_default_score() -> Tuple
]
)
),
{"rtol": 0.15},
{"atol": 0.05},
)


Expand All @@ -263,7 +265,7 @@ def linear_classifier_cs_scorer_args_exact_solution_use_default_score_norm() ->
]
)
),
{"rtol": 0.15},
{"atol": 0.05},
)


Expand Down Expand Up @@ -738,8 +740,13 @@ def linear_classifier_cs_scorer_args_exact_solution_use_add_idx_empty_set() -> T
)


@pytest.mark.parametrize("n_samples", [500])
@pytest.mark.parametrize("n_resample_complement_sets", [3])
@pytest.mark.parametrize("n_samples", [500], ids=lambda x: "n_samples={}".format(x))
@pytest.mark.parametrize(
"n_resample_complement_sets",
[20],
ids=lambda x: "n_resample_complement_sets={}".format(x),
)
@pytest.mark.parametrize("seed", [13, 42, 666], ids=lambda x: "seed={}".format(x))
@pytest.mark.parametrize(
"linear_classifier_cs_scorer_args_exact_solution",
[
Expand All @@ -754,8 +761,12 @@ def test_classwise_shapley(
linear_classifier_cs_scorer_args_exact_solution: Tuple[Dict, ValuationResult],
n_samples: int,
n_resample_complement_sets: int,
seed: int,
request,
):
np.random.seed(seed)
random.seed(seed)

args, exact_solution, check_args = request.getfixturevalue(
linear_classifier_cs_scorer_args_exact_solution
)
Expand Down

0 comments on commit b81f85e

Please sign in to comment.