Skip to content

Commit

Permalink
Merge pull request #338 from aai-institute/259-implement-class-wise-s…
Browse files Browse the repository at this point in the history
…hapley

Implement class wise shapley
  • Loading branch information
mdbenito authored Oct 14, 2023
2 parents 4c10cc3 + bebfd9c commit 1949fcc
Show file tree
Hide file tree
Showing 29 changed files with 69,481 additions and 59 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- New method: Class-wise Shapley values
[PR #338](https://github.com/aai-institute/pyDVL/pull/338)
- No longer using docker within tests to start a memcached server
[PR #444](https://github.com/aai-institute/pyDVL/pull/444)
- Faster semi-value computation with per-index check of stopping criteria (optional)
Expand Down Expand Up @@ -43,6 +45,9 @@ randomness.
`compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and
`compute_generic_semivalues`.
[PR #428](https://github.com/aai-institute/pyDVL/pull/428)
- Added classwise Shapley as proposed by (Schoch et al. 2021)
[https://arxiv.org/abs/2211.06800]
[PR #338](https://github.com/aai-institute/pyDVL/pull/338)

### Changed

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ methods from the following papers:
Efficient Data Value](https://proceedings.mlr.press/v202/kwon23e.html). In
Proceedings of the 40th International Conference on Machine Learning, 18135–52.
PMLR, 2023.
- Schoch, Stephanie, Haifeng Xu, and Yangfeng Ji. [CS-Shapley: Class-Wise
Shapley Values for Data Valuation in
Classification](https://openreview.net/forum?id=KTOcrOR5mQ9). In Proc. of the
Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS).
New Orleans, Louisiana, USA, 2022.

Influence Functions compute the effect that single points have on an estimator /
model. We implement methods from the following papers:
Expand Down
68,001 changes: 68,001 additions & 0 deletions ...l/value/shapley/classwise/img/classwise-shapley-discounted-utility-function.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
269 changes: 269 additions & 0 deletions docs/value/classwise-shapley.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
---
title: Class-wise Shapley
---

# Class-wise Shapley

Class-wise Shapley (CWS) [@schoch_csshapley_2022] offers a Shapley framework
tailored for classification problems. Given a sample $x_i$ with label $y_i \in
\mathbb{N}$, let $D_{y_i}$ be the subset of $D$ with labels $y_i$, and
$D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The key idea is that the
sample $(x_i, y_i)$ might improve the overall model performance on $D$, while
being detrimental for the performance on $D_{y_i},$ e.g. because of a wrong
label. To address this issue, the authors introduced

$$
v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
\left [
\frac{1}{|D_{y_i}|}\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1}
\delta(S_{y_i} | S_{-y_i})
\right ],
$$

where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq
D_{-y_i}$ is _arbitrary_ (in particular, not the complement of $S_{y_i}$). The
function $\delta$ is called **set-conditional marginal Shapley value** and is
defined as

$$
\delta(S | C) = u( S_{+i} | C ) − u(S | C),
$$

for any set $S$ such that $i \notin S, C$ and $S \cap C = \emptyset$.

In practical applications, estimating this quantity is done both with Monte
Carlo sampling of the powerset, and the set of index permutations
[@castro_polynomial_2009]. Typically, this requires fewer samples than the
original Shapley value, although the actual speed-up depends on the model and
the dataset.


!!! Example "Computing classwise Shapley values"
Like all other game-theoretic valuation methods, CWS requires a
[Utility][pydvl.utils.utility.Utility] object constructed with model and
dataset, with the peculiarity of requiring a specific
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. The entry
point is the function
[compute_classwise_shapley_values][pydvl.value.shapley.classwise.compute_classwise_shapley_values]:

```python
from pydvl.value import *

model = ...
data = Dataset(...)
scorer = ClasswiseScorer(...)
utility = Utility(model, data, scorer)
values = compute_classwise_shapley_values(
utility,
done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000),
truncation=RelativeTruncation(utility, rtol=0.01),
done_sample_complements=MaxChecks(1),
normalize_values=True
)
```


### The class-wise scorer

In order to use the classwise Shapley value, one needs to define a
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. This scorer
is defined as

$$
u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})),
$$

where $f$ and $g$ are monotonically increasing functions, $a_S(D_{y_i})$ is the
**in-class accuracy**, and $a_S(D_{-y_i})$ is the **out-of-class accuracy** (the
names originate from a choice by the authors to use accuracy, but in principle
any other score, like $F_1$ can be used).

The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are
therefore the defaults, but we leave the option to set different functions $f$
and $g$ for an exploration with different base scores.

!!! Example "The default class-wise scorer"
Constructing the CWS scorer requires choosing a metric and the functions $f$
and $g$:

```python
import numpy as np
from pydvl.value.shapley.classwise import ClasswiseScorer

# These are the defaults
identity = lambda x: x
scorer = ClasswiseScorer(
"accuracy",
in_class_discount_fn=identity,
out_of_class_discount_fn=np.exp
)
```

??? "Surface of the discounted utility function"
The level curves for $f(x)=x$ and $g(x)=e^x$ are depicted below. The lines
illustrate the contour lines, annotated with their respective gradients.
![Level curves of the class-wise
utility](img/classwise-shapley-discounted-utility-function.svg){ align=left width=33% class=invertible }

## Evaluation

We illustrate the method with two experiments: point removal and noise removal,
as well as an analysis of the distribution of the values. For this we employ the
nine datasets used in [@schoch_csshapley_2022], using the same pre-processing.
For images, PCA is used to reduce down to 32 the features found by a pre-trained
`Resnet18` model. Standard loc-scale normalization is performed for all models
except gradient boosting, since the latter is not sensitive to the scale of the
features.

??? info "Datasets used for evaluation"
| Dataset | Data Type | Classes | Input Dims | OpenML ID |
|----------------|-----------|---------|------------|-----------|
| Diabetes | Tabular | 2 | 8 | 37 |
| Click | Tabular | 2 | 11 | 1216 |
| CPU | Tabular | 2 | 21 | 197 |
| Covertype | Tabular | 7 | 54 | 1596 |
| Phoneme | Tabular | 2 | 5 | 1489 |
| FMNIST | Image | 2 | 32 | 40996 |
| CIFAR10 | Image | 2 | 32 | 40927 |
| MNIST (binary) | Image | 2 | 32 | 554 |
| MNIST (multi) | Image | 10 | 32 | 554 |

We show mean and coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an "inner
metric". The former shows the performance of the method, whereas the latter
displays its stability: we normalize by the mean to see the relative effect of
the standard deviation. Ideally the mean value is maximal and CV minimal.

Finally, we note that for all sampling-based valuation methods the same number
of _evaluations of the marginal utility_ was used. This is important to make the
algorithms comparable, but in practice one should consider using a more
sophisticated stopping criterion.

### Dataset pruning for logistic regression (point removal)

In (best-)point removal, one first computes values for the training set and then
removes in sequence the points with the highest values. After each removal, the
remaining points are used to train the model from scratch and performance is
measured on a test set. This produces a curve of performance vs. number of
points removed which we show below.

As a scalar summary of this curve, [@schoch_csshapley_2022] define **Weighted
Accuracy Drop** (WAD) as:

$$
\text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j}
a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right)
= a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} ,
$$

where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$
and $T_{-\{1 \colon j \}}$ is the set $T$ without elements from $\{1, \dots , j
\}$.

We run the point removal experiment for a logistic regression model five times
and compute WAD for each run, then report the mean $\mu_\text{WAD}$ and standard
deviation $\sigma_\text{WAD}$.

![Mean WAD for best-point removal on logistic regression. Values
computed using LOO, CWS, Beta Shapley, and TMCS
](img/classwise-shapley-metric-wad-mean.svg){ class=invertible }

We see that CWS is competitive with all three other methods. In all problems
except `MNIST (multi)` it outperforms TMCS, while in that case TMCS has a slight
advantage.

In order to understand the variability of WAD we look at its coefficient of
variation (lower is better):

![Coefficient of Variation of WAD for best-point removal on logistic regression.
Values computed using LOO, CWS, Beta Shapley, and TMCS
](img/classwise-shapley-metric-wad-cv.svg){ class=invertible }

CWS is not the best method in terms of CV. For `CIFAR10`, `Click`, `CPU` and
`MNIST (binary)` Beta Shapley has the lowest CV. For `Diabetes`, `MNIST (multi)`
and `Phoneme` CWS is the winner and for `FMNIST` and `Covertype` TMCS takes the
lead. Besides LOO, TMCS has the highest relative standard deviation.

The following plot shows accuracy vs number of samples removed. Random values
serve as a baseline. The shaded area represents the 95% bootstrap confidence
interval of the mean across 5 runs.

![Accuracy after best-sample removal using values from logistic
regression](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg){ class=invertible }

Because samples are removed from high to low valuation order, we expect a steep
decrease in the curve.

Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with
CWS's CV on par with Beta Shapley's, making CWS a competitive method.


### Dataset pruning for a neural network by value transfer

Transfer of values from one model to another is probably of greater practical
relevance: values are computed using a cheap model and used to prune the dataset
before training a more expensive one.

The following plot shows accuracy vs number of samples removed for transfer from
logistic regression to a neural network. The shaded area represents the 95%
bootstrap confidence interval of the mean across 5 runs.

![Accuracy after sample removal using values transferred from logistic
regression to an MLP
](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg){ class=invertible }

As in the previous experiment samples are removed from high to low valuation
order and hence we expect a steep decrease in the curve. CWS is competitive with
the other methods, especially in very unbalanced datasets like `Click`. In other
datasets, like `Covertype`, `Diabetes` and `MNIST (multi)` the performance is on
par with TMCS.


### Detection of mis-labeled data points

The next experiment tries to detect mis-labeled data points in binary
classification tasks. 20% of the indices is flipped at random (we don't consider
multi-class datasets because there isn't a unique flipping strategy). The
following table shows the mean of the area under the curve (AUC) for five runs.

![Mean AUC for mis-labeled data point detection. Values computed using LOO, CWS,
Beta Shapley, and
TMCS](img/classwise-shapley-metric-auc-mean.svg){ class=invertible }

In the majority of cases TMCS has a slight advantage over CWS, except for
`Click`, where CWS has a slight edge, most probably due to the unbalanced nature
of the dataset. The following plot shows the CV for the AUC of the five runs.

![Coefficient of variation of AUC for mis-labeled data point detection. Values
computed using LOO, CWS, Beta Shapley, and TMCS
](img/classwise-shapley-metric-auc-cv.svg){ class=invertible }

In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.

Finally, we look at the ROC curves training the classifier on the $n$ first
samples in _increasing_ order of valuation (i.e. starting with the worst):

![Mean ROC across 5 runs with 95% bootstrap
CI](img/classwise-shapley-roc-auc-logistic-regression.svg){ class=invertible }

Although at first sight TMCS seems to be the winner, CWS stays competitive after
factoring in running time. For a perfectly balanced dataset, CWS needs on
average fewer samples than TCMS.

### Value distribution

For illustration, we compare the distribution of values computed by TMCS and
CWS.

![Histogram and estimated density of the values computed by TMCS and
CWS on all nine datasets](img/classwise-shapley-density.svg){ class=invertible }

For `Click` TMCS has a multi-modal distribution of values. We hypothesize that
this is due to the highly unbalanced nature of the dataset, and notice that CWS
has a single mode, leading to its greater performance on this dataset.

## Conclusion

CWS is an effective way to handle classification problems, in particular for
unbalanced datasets. It reduces the computing requirements by considering
in-class and out-of-class points separately.

1 change: 1 addition & 0 deletions docs/value/img/classwise-shapley-density.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/value/img/classwise-shapley-metric-auc-cv.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/value/img/classwise-shapley-metric-auc-mean.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/value/img/classwise-shapley-metric-wad-cv.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/value/img/classwise-shapley-metric-wad-mean.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 9 additions & 2 deletions docs/value/notation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@ title: Notation for valuation

# Notation for valuation

!!! todo
Organize this page better and use its content consistently throughout the
documentation.

The following notation is used throughout the documentation:

Let $D = \{x_1, \ldots, x_n\}$ be a training set of $n$ samples.

The utility function $u:\mathcal{D} \rightarrow \mathbb{R}$ maps subsets of $D$
to real numbers.
to real numbers. In pyDVL, we typically call this mappin a **score** for
consistency with sklearn, and reserve the term **utility** for the triple of
dataset $D$, model $f$ and score $u$, since they are used together to compute
the value.

The value $v$ of the $i$-th sample in dataset $D$ wrt. utility $u$ is
denoted as $v_u(x_i)$ or simply $v(i)$.

For any $S \subseteq D$, we donote by $S_{-i}$ the set of samples in $D$
For any $S \subseteq D$, we denote by $S_{-i}$ the set of samples in $D$
excluding $x_i$, and $S_{+i}$ denotes the set $S$ with $x_i$ added.

The marginal utility of adding sample $x_i$ to a subset $S$ is denoted as
Expand Down
5 changes: 3 additions & 2 deletions docs/value/shapley.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ values = compute_shapley_values(u=utility, mode="knn")

### Group testing

An alternative approach introduced in [@jia_efficient_2019a] first approximates
the differences of values with a Monte Carlo sum. With
An alternative method for the approximation of Shapley values introduced in
[@jia_efficient_2019a] first estimates the differences of values with a Monte
Carlo sum. With

$$\hat{\Delta}_{i j} \approx v_i - v_j,$$

Expand Down
14 changes: 10 additions & 4 deletions docs_includes/abbreviations.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
*[CSP]: Constraint Satisfaction Problem
*[CV]: Coefficient of Variation
*[CWS]: Class-wise Shapley
*[DUL]: Data Utility Learning
*[GT]: Group Testing
*[IF]: Influence Function
*[iHVP]: inverse Hessian-vector product
*[LC]: Least Core
*[LiSSA]: Linear-time Stochastic Second-order Algorithm
*[LOO]: Leave-One-Out
*[MCLC]: Monte Carlo Least Core
*[MCS]: Monte Carlo Shapley
*[ML]: Machine Learning
*[MLP]: Multi-Layer Perceptron
*[MLRC]: Machine Learning Reproducibility Challenge
*[MSE]: Mean Squared Error
*[PCA]: Principal Component Analysis
*[ROC]: Receiver Operating Characteristic
*[SV]: Shapley Value
*[TMCS]: Truncated Monte Carlo Shapley
*[IF]: Influence Function
*[iHVP]: inverse Hessian-vector product
*[LiSSA]: Linear-time Stochastic Second-order Algorithm
*[DUL]: Data Utility Learning
*[WAD]: Weighted Accuracy Drop
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ nav:
- Data Valuation:
- Introduction: value/index.md
- Notation: value/notation.md
- Shapley Values: value/shapley.md
- Shapley values: value/shapley.md
- Semi-values: value/semi-values.md
- The core: value/the-core.md
- Classwise Shapley: value/classwise-shapley.md
- Examples:
- Shapley values: examples/shapley_basic_spotify.ipynb
- KNN Shapley: examples/shapley_knn_flowers.ipynb
Expand Down
Loading

0 comments on commit 1949fcc

Please sign in to comment.