Skip to content

Commit

Permalink
Implement CS-Shapley from (Schoch et al.)[https://arxiv.org/abs/2211.…
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Sep 22, 2023
1 parent 43690b0 commit a023ecc
Show file tree
Hide file tree
Showing 11 changed files with 1,199 additions and 14 deletions.
47 changes: 47 additions & 0 deletions docs/value/shapley.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,53 @@ stop condition. This is an instance of a
[MaxTime][pydvl.value.stopping.MaxTime] and
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError].

### Class-wise Shapley

Class-wise Shapley [@schoch_csshapley_2022] offers a distinct Shapley framework tailored
for classification problems. Let $D$ be the dataset, $D_{y_i}$ be the subset of $D$ with
labels $y_i$, and $D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The key idea is that
a sample $(x_i, y_i)$, might enhance the overall performance on $D$, while being
detrimental for the performance on $D_{y_i}$. To address this nuanced behavior, the
authors introduced the estimator

$$
v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}} \frac{1}{|D_{y_i}|!}
\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1}
[u( S_{y_i} \cup \{i\} | S_{-y_i} ) − u( S_{y_i} | S_{-y_i})],
$$

where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq D_{-y_i}$. In
other words, the summations are over the powerset of $D_{y_i} \setminus \{i\}$ and
$D_{-y_i}$ respectively. The estimator employs a specialized utility function

$$
u(S_{y_i}|S_{-y_i}) = a_S(D_{y_i}) \exp(a_S(D_{-y_i})),
$$

where $S=S_{y_i} \cup S_{-y_i}$ and $a_S(D)$ is the accuracy of the model trained on $S$
and evaluated on $D$.In practical applications, the evaluation of this estimator
leverages both Monte Carlo sampling and permutation Monte Carlo sampling
[@castro_polynomial_2009].


```python
from pydvl.utils import Dataset, Utility
from pydvl.value import HistoryDeviation, MaxChecks, RelativeTruncation
from pydvl.value.shapley.classwise import compute_classwise_shapley_values, \
ClasswiseScorer

model = ...
data = Dataset(...)
scoring = ClasswiseScorer("accuracy")
utility = Utility(model, data, scoring)
values = compute_classwise_shapley_values(
utility,
done=HistoryDeviation(n_steps=500, rtol=5e-2),
truncation=RelativeTruncation(utility, rtol=0.01),
done_sample_complements=MaxChecks(1),
normalize_values=True
)
```

### Owen sampling

Expand Down
1 change: 1 addition & 0 deletions docs_includes/abbreviations.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*[CSP]: Constraint Satisfaction Problem
*[CS]: Class-wise Shapley
*[GT]: Group Testing
*[LC]: Least Core
*[LOO]: Leave-One-Out
Expand Down
60 changes: 60 additions & 0 deletions src/pydvl/utils/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"random_matrix_with_condition_number",
"random_subset",
"random_powerset",
"random_powerset_label_min",
"random_subset_of_size",
"top_k_value_accuracy",
]
Expand Down Expand Up @@ -133,6 +134,65 @@ def random_powerset(
total += 1


def random_powerset_label_min(
s: NDArray[T],
labels: NDArray[np.int_],
min_elements_per_label: int = 1,
seed: Optional[Seed] = None,
) -> Generator[NDArray[T], None, None]:
"""Draws random subsets from `s`, while ensuring that at least
`min_elements_per_label` elements per label are included in the draw. It can be used
for classification problems to ensure that a set contains information for all labels
(or not if `min_elements_per_label=0`).
Args:
s: Set to sample from
labels: Labels for the samples
min_elements_per_label: Minimum number of elements for each label.
seed: Either an instance of a numpy random number generator or a seed for it.
Returns:
Generated draw from the powerset of s with `min_elements_per_label` for each
label.
Raises:
ValueError: If `s` and `labels` are of different length or
`min_elements_per_label` is smaller than 0.
"""
if len(labels) != len(s):
raise ValueError("Set and labels have to be of same size.")

if min_elements_per_label < 0:
raise ValueError(
f"Parameter min_elements={min_elements_per_label} needs to be bigger or equal to 0."
)

rng = np.random.default_rng(seed)
unique_labels = np.unique(labels)

while True:
subsets: list[NDArray[T]] = []
for label in unique_labels:
label_indices = np.asarray(np.where(labels == label)[0])
subset_size = int(
rng.integers(
min(min_elements_per_label, len(label_indices)),
len(label_indices) + 1,
)
)
if subset_size > 0:
subsets.append(
random_subset_of_size(s[label_indices], subset_size, seed=rng)
)

if len(subsets) > 0:
subset = np.concatenate(tuple(subsets))
rng.shuffle(subset)
yield subset
else:
yield np.array([], dtype=s.dtype)


def random_subset_of_size(
s: NDArray[T], size: int, seed: Optional[Seed] = None
) -> NDArray[T]:
Expand Down
19 changes: 17 additions & 2 deletions src/pydvl/value/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,14 @@ def __add__(
xm[other_pos] = other._values
vm[other_pos] = other._variances

# np.maximum(1, n + m) covers case n = m = 0.
n_m_sum = np.maximum(1, n + m)

# Sample mean of n+m samples from two means of n and m samples
xnm = (n * xn + m * xm) / (n + m)
xnm = (n * xn + m * xm) / n_m_sum

# Sample variance of n+m samples from two sample variances of n and m samples
vnm = (n * (vn + xn**2) + m * (vm + xm**2)) / (n + m) - xnm**2
vnm = (n * (vn + xn**2) + m * (vm + xm**2)) / n_m_sum - xnm**2

if np.any(vnm < 0):
if np.any(vnm < -1e-6):
Expand Down Expand Up @@ -627,6 +631,17 @@ def update(self, idx: int, new_value: float) -> ValuationResult[IndexT, NameT]:
)
return self

def scale(self, factor: float, indices: Optional[NDArray[IndexT]] = None):
"""
Scales the values and variances of the result by a coefficient.
Args:
factor: Factor to scale by.
indices: Indices to scale. If None, all values are scaled.
"""
self._values[self._sort_positions[indices]] *= factor
self._variances[self._sort_positions[indices]] *= factor**2

def get(self, idx: Integral) -> ValueItem:
"""Retrieves a ValueItem by data index, as opposed to sort index, like
the indexing operator.
Expand Down
1 change: 1 addition & 0 deletions src/pydvl/value/shapley/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ..result import *
from ..stopping import *
from .classwise import *
from .common import *
from .gt import *
from .knn import *
Expand Down
Loading

0 comments on commit a023ecc

Please sign in to comment.