Skip to content

Commit

Permalink
Minor improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed May 2, 2023
1 parent dae8931 commit 75c4aab
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 181 deletions.
2 changes: 1 addition & 1 deletion docs/30-data-valuation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ than that of out-of-class points.
scoring = ClassWiseScorer("accuracy")
data = Dataset(...)
utility = Utility(model, data, scoring)
values = class_wise_shapley(
values = classwise_shapley(
utility,
done=MaxChecks(1000),
num_resample_complement_sets=10,
Expand Down
1 change: 0 additions & 1 deletion src/pydvl/utils/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable, Optional, Protocol, Tuple, Union

import numpy as np
from numpy._typing import NDArray
from numpy.typing import NDArray
from scipy.special import expit
from sklearn.metrics import accuracy_score, get_scorer, make_scorer
Expand Down
179 changes: 33 additions & 146 deletions src/pydvl/value/shapley/classwise.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
"""
Implementation of the algorithm footcite:t:`schoch_csshapley_2022`.
"""
import itertools
import logging
import numbers
import operator
from functools import reduce
from typing import Sequence, Tuple, cast
from typing import Sequence, cast

import numpy as np
from numpy._typing import NDArray
from numpy.typing import NDArray

from pydvl.utils import MapReduceJob, ParallelConfig, Utility

__all__ = [
"class_wise_shapley",
"classwise_shapley",
]

from tqdm import tqdm

from pydvl.utils.numeric import random_powerset_group_conditional
from pydvl.utils.score import ClassWiseScorer
from pydvl.value.result import ValuationResult
from pydvl.value.shapley.montecarlo import permutation_montecarlo_classwise_shapley
from pydvl.value.shapley.truncated import TruncationPolicy
from pydvl.value.stopping import MaxChecks, StoppingCriterion

logger = logging.getLogger(__name__)


def class_wise_shapley(
def classwise_shapley(
u: Utility,
*,
done: StoppingCriterion,
Expand Down Expand Up @@ -64,33 +65,35 @@ def _add(results: Sequence[ValuationResult]) -> ValuationResult:

map_reduce_job: MapReduceJob[NDArray, ValuationResult] = MapReduceJob(
u.data.indices,
map_func=_class_wise_shapley_worker,
map_func=_classwise_shapley_worker,
reduce_func=_add,
map_kwargs=dict(
u=u,
progress=progress,
truncation=truncation,
num_resample_complement_sets=n_resample_complement_sets,
done=done,
truncation=truncation,
normalize_score=False,
n_resample_complement_sets=n_resample_complement_sets,
progress=progress,
),
n_jobs=n_jobs,
config=config,
)
result = map_reduce_job()

if normalize_score:
_normalize_class_wise_shapley_values(result, u)
_normalize_classwise_shapley_values(result, u)

return result


def _class_wise_shapley_worker(
def _classwise_shapley_worker(
update_indices: NDArray[np.int_],
u: Utility,
*,
done: StoppingCriterion,
truncation: TruncationPolicy,
num_resample_complement_sets: int = 1,
normalize_score: bool = False,
n_resample_complement_sets: int = 1,
progress: bool = True,
) -> ValuationResult:
"""Computes class-wise shapley value by using a truncated monte-carlo permutation
Expand All @@ -102,41 +105,50 @@ def _class_wise_shapley_worker(
:param done: Function checking whether computation must stop.
:param truncation: A callable which decides whether to interrupt processing a
permutation and set all subsequent marginals to zero.
:param normalize_score: Whether to normalize the score by the number of classes.
This is handy to make the values comparable across different runs.
:param n_resample_complement_sets: How often the complement set shall be resampled
for each permutation.
:param progress: Whether to display progress bars for each job.
:return: ValuationResult object with the data values.
"""
_check_class_wise_shapley_utility(u)
_check_classwise_shapley_utility(u)

result = ValuationResult.zeros(
algorithm="class_wise_shapley",
algorithm="classwise_shapley",
)

x_train, y_train = u.data.get_training_data(update_indices)
unique_labels = np.unique(y_train)
pbar = tqdm(disable=not progress, position=0, total=100, unit="%")

while True:
for n_step in itertools.count():

pbar.n = 100 * done.completion()
pbar.refresh()

for label in unique_labels:
logger.debug(f"Processing label '{label}'.")
logger.debug(
f"Sampling subset #{n_step} complement index set of class '{label}'."
)
u.scorer.label = label
result += _class_complement_conditional_value_mc_estimator(
result += permutation_montecarlo_classwise_shapley(
u,
label,
update_indices=update_indices,
done=MaxChecks(num_resample_complement_sets),
done=MaxChecks(n_resample_complement_sets),
truncation=truncation,
)

if done(result):
break

if normalize_score:
_normalize_classwise_shapley_values(result, u)

return result


def _check_class_wise_shapley_utility(u: Utility):
def _check_classwise_shapley_utility(u: Utility):
"""
Checks whether the passed utility object is suitable for computing the class-wise
shapley values. Therefore, it is checked that the data passed is a pure
Expand Down Expand Up @@ -169,58 +181,7 @@ def _check_class_wise_shapley_utility(u: Utility):
)


def _class_complement_conditional_value_mc_estimator(
u: Utility,
label: int,
update_indices: NDArray[np.int_],
*,
done: StoppingCriterion,
truncation: TruncationPolicy,
) -> ValuationResult:
"""
Samples a random subset of the complement set and computes the truncated monte carlo
estimator.
:param u: Utility object with model, data, and scoring function. The scoring
function has to be of type :class:`~pydvl.utils.score.ClassWiseScorer`.
:param done: Function checking whether computation must stop.
:param label: The label for which the complement set shall be sampled.
:param update_indices: The indices of the active elements.
:param truncation: A callable which decides whether to interrupt processing a
permutation and set all subsequent marginals to zero.
:return: The updated result object.
"""

result = ValuationResult.zeros(
algorithm="class_wise_shapley",
)

_, y_train = u.data.get_training_data(u.data.indices)
class_indices_set, class_complement_indices_set = _split_into_index_set_by_label(
u.data.indices,
y_train,
label,
)
_, complement_y_train = u.data.get_training_data(class_complement_indices_set)

for subset_complement in random_powerset_group_conditional(
class_complement_indices_set, complement_y_train
):
indices_permutation = np.random.permutation(class_indices_set)
result += _truncated_permutation_mc(
u,
indices_permutation,
subset_complement,
truncation=truncation,
update_indices=update_indices,
)
if done(result):
break

return result


def _normalize_class_wise_shapley_values(
def _normalize_classwise_shapley_values(
result: ValuationResult,
u: Utility,
) -> ValuationResult:
Expand Down Expand Up @@ -251,77 +212,3 @@ def _normalize_class_wise_shapley_values(
result.values[indices_label_set] *= in_cls_acc / sigma

return result


def _split_into_index_set_by_label(
indices: NDArray[np.int_], labels: NDArray[np.int_], label: int
) -> Tuple[NDArray[np.int_], NDArray[np.int_]]:
"""
Splits the indices into two sets based on the passed `label` value.
:param indices: The indices to be used for referring to the data.
:param labels: Corresponding labels for the indices.
:param label: Label to be used for splitting.
:return: Tuple with two sets of indices.
"""
active_elements = labels == label
class_indices_set = np.where(active_elements)[0]
class_complement_indices_set = np.where(~active_elements)[0]
class_indices_set = indices[class_indices_set]
class_complement_indices_set = indices[class_complement_indices_set]
return class_indices_set, class_complement_indices_set


def _truncated_permutation_mc(
u: Utility,
class_indices_set: NDArray[np.int_],
class_complement_indices_set: NDArray[np.int_],
update_indices: NDArray[np.int_],
truncation: TruncationPolicy,
*,
algorithm: str = "class_wise_shapley",
) -> ValuationResult:
"""
A truncated version of a permutation-based MC estimator for class-wise shapley
values. It generates a permutation p[i] of the class label indices and iterates over
the subsets starting from the set containing only one element up to the full set of
indices.
:param u: Utility object with model, data, and scoring function. The scoring
function has to be of type :class:`~pydvl.utils.score.ClassWiseScorer`.
:param class_indices_set: Set of indices for data points with the label.
:param class_complement_indices_set: Set of indices for data points without the
label.
:param update_indices: The indices of the active elements.
:param truncation: A callable which decides whether to interrupt processing a
permutation and set all subsequent marginals to zero.
:return: ValuationResult object with the data values.
"""
if len(np.intersect1d(class_indices_set, class_complement_indices_set)) > 0:
raise ValueError(
"The class label set and the complement set have to be disjoint."
)

update_indices = [i for i in update_indices if i in class_indices_set]
result = ValuationResult.zeros(
algorithm=algorithm,
indices=update_indices,
data_names=u.data.data_names[update_indices],
)

train_set = np.concatenate((class_indices_set, class_complement_indices_set))
prev_score = None

for i in range(len(class_indices_set) + 1):
if prev_score is not None and truncation(i, prev_score):
score = prev_score
else:
score = u((*class_indices_set[:i], *class_complement_indices_set))

if prev_score is not None and class_indices_set[i - 1] in update_indices:
marginal = score - prev_score
result.update(class_indices_set[i - 1], marginal)

prev_score = score

return result
Loading

0 comments on commit 75c4aab

Please sign in to comment.