Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement class wise shapley #338

Merged
merged 53 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
3e69df0
Implement algorithm from paper `CS-Shapley: Class-wise Shapley Values…
Aug 12, 2023
06769be
Remove setter method from dataset to fix https://github.com/appliedAI…
Aug 28, 2023
d77fbc0
Moved files to `*classwise.py` file as in https://github.com/appliedA…
Aug 28, 2023
078e723
Optimize stopping criterion as described in https://github.com/applie…
Aug 28, 2023
684217c
Fix https://github.com/appliedAI-Initiative/pyDVL/pull/338#discussion…
Aug 28, 2023
4237e94
Move files from tests/misc.py file to test_classwise.py
Aug 28, 2023
cbeb569
Fix https://github.com/appliedAI-Initiative/pyDVL/pull/338#discussion…
kosmitive Aug 28, 2023
d8e218b
Remove processing details from comment.
Aug 28, 2023
375c200
Make _permutation_montecarlo_classwise_shapley private.
Aug 28, 2023
594e2a1
Merged develop.
Aug 28, 2023
0bd79d0
Merge branch 'develop' into 259-implement-class-wise-shapley
Sep 1, 2023
9df9a9b
Refactor executor pattern.
Sep 1, 2023
382552f
Remove writeable_array function.
Sep 1, 2023
5869e6e
Fix comments.
Sep 1, 2023
e64c637
Add short documentation.
Sep 1, 2023
c9a333d
Merge branch 'develop' into 259-implement-class-wise-shapley
Sep 1, 2023
33bdf6f
Refactor stopping criterion and solve type issue.
Sep 1, 2023
f4c9024
Fix test case regarding max checks.
Sep 1, 2023
adef9b4
Merge branch 'develop' into 259-implement-class-wise-shapley
mdbenito Sep 2, 2023
62c0f09
Move cs-shap specific function to its file
mdbenito Sep 2, 2023
bfce984
Fix docstring
mdbenito Sep 2, 2023
b4f50fc
Fix reference
mdbenito Sep 2, 2023
9444908
Merge branch 'develop' into 259-implement-class-wise-shapley
Sep 4, 2023
587a56b
Add seed parameter to class wise shapley.
Sep 4, 2023
e1f0e5d
Merge branch 'develop' into 259-implement-class-wise-shapley
kosmitive Sep 11, 2023
faf16df
Fix styling of comments, improve comments, rename methods and introdu…
Sep 12, 2023
ad31787
Merge branch 'develop' into 259-implement-class-wise-shapley
Sep 13, 2023
2554c9f
Change pi to sigma.
Sep 13, 2023
a95dae1
Rename variable from `subset_length` to `subset_size`.
Sep 13, 2023
9f6ec30
Adapt documentation.
Sep 13, 2023
9dfc3fd
Merge branch 'develop' into 259-implement-class-wise-shapley
Sep 23, 2023
037fa66
Move imports fix some stuff and add to CHANGELOG.md
Sep 23, 2023
94228d8
Fix that batch_size is overwritten.
Sep 25, 2023
5526031
Rework documentation.
Sep 25, 2023
6e2b0a4
Rework documentation.
Sep 25, 2023
cd01657
Finalize documentation.
Oct 6, 2023
4b4b4e2
Merge branch 'develop' into 259-implement-class-wise-shapley
mdbenito Oct 7, 2023
67416a3
Improvements to CWS docs
mdbenito Oct 7, 2023
74ba6b2
Actually allow for other inner metrics in ClasswiseScorer
mdbenito Oct 7, 2023
6d21579
[skip ci] More doc changes for CWS
mdbenito Oct 8, 2023
0963615
[skip ci] Minor
mdbenito Oct 8, 2023
c1195ec
Delete duplicate section on CWS
mdbenito Oct 8, 2023
79e4e5a
[skip ci] More doc tweaks
mdbenito Oct 8, 2023
6236caa
[skip ci] Add CWS paper to README.md
mdbenito Oct 8, 2023
f2a0bfc
[skip ci] Updated CHANGELOG.md
mdbenito Oct 8, 2023
94d95db
Finalize documentation, update plots.
Oct 10, 2023
a27a655
Merge branch 'develop' into 259-implement-class-wise-shapley
kosmitive Oct 10, 2023
37efbb8
Reformat `stopping.py`.
Oct 10, 2023
fdc4a06
Add noise removal to README.md.
Oct 10, 2023
3cee0f8
Refactor some text and add legend to density plot.
Oct 11, 2023
72d0105
Remove variance remark.
Oct 12, 2023
5ee24bd
[skip ci] Bunch of fixes to the doc
mdbenito Oct 14, 2023
bebfd9c
[skip ci] More fixes
mdbenito Oct 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ randomness.
`compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and
`compute_generic_semivalues`.
[PR #428](https://github.com/aai-institute/pyDVL/pull/428)
- Added classwise Shapley as proposed by (Schoch et al. 2021)
[https://arxiv.org/abs/2211.06800]
[PR #338](https://github.com/aai-institute/pyDVL/pull/338)

### Changed

Expand Down
68,001 changes: 68,001 additions & 0 deletions ...l/value/shapley/classwise/img/classwise-shapley-discounted-utility-function.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
145 changes: 145 additions & 0 deletions docs/value/classwise-shapley.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
---
title: Class-wise Shapley
---

# Class-wise Shapley

Class-wise Shapley (CWS) [@schoch_csshapley_2022] offers a Shapley framework
tailored for classification problems. Let $D$ be a dataset, $D_{y_i}$ be the
subset of $D$ with labels $y_i$, and $D_{-y_i}$ be the complement of $D_{y_i}$
in $D$. The key idea is that a sample $(x_i, y_i)$ might enhance the overall
performance on $D$, while being detrimental for the performance on $D_{y_i}$. To
address this issue, the authors introduce the estimator

$$
v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
\left [
\frac{1}{|D_{y_i}|}\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1}
\delta(S_{y_i} | S_{-y_i})
\right ],
$$

where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq
D_{-y_i}$, and the function $\delta$ is called **set-conditional marginal
Shapley value**. It is defined as

$$
\delta(S | C) = u( S \cup \{i\} | C ) − u(S | C),
$$

where $i \notin S, C$ and $S \bigcap C = \emptyset$.

In practical applications, estimating this quantity is done both with Monte
Carlo sampling of the powerset, and the set of index permutations
[@castro_polynomial_2009]. Typically, this requires fewer samples than the
original Shapley value, although the actual speed-up depends on the model and
the dataset.


??? Example "Computing classwise Shapley values"
```python
from pydvl.value import *

model = ...
data = Dataset(...)
scorer = ClasswiseScorer(...)
utility = Utility(model, data, scorer)
values = compute_classwise_shapley_values(
utility,
done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000),
truncation=RelativeTruncation(utility, rtol=0.01),
done_sample_complements=MaxChecks(1),
normalize_values=True
)
```


## Class-wise scorer

In order to use the classwise Shapley value, one needs to define a
[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. Given a sample
$x_i$ with label $y_i \in \mathbb{N}$, we define two disjoint sets $D_{y_i}$ and
$D_{-y_i}$ and define

$$
u(S) = f(a_S(D_{y_i}))) g(a_S(D_{-y_i}))),
$$

where $f$ and $g$ are monotonically increasing functions, $a_S(D_{y_i})$ is the
**in-class accuracy**, and $a_S(D_{-y_i})$ is the **out-of-class accuracy** (the
names originate from a choice by the authors to use accuracy, but in principle
any other score, like $F_1$ can be used).

The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are
therefore the defaults, but we leave the option to set different functions $f$
and $g$ for an exploration with different base scores.

??? Example "The default class-wise scorer"
```python
import numpy as np
from pydvl.value.shapley.classwise import ClasswiseScorer

# These are the defaults
identity = lambda x: x
scorer = ClasswiseScorer(
"accuracy",
in_class_discount_fn=identity,
out_of_class_discount_fn=np.exp
)
```

The level curves for $f(x)=x$ and $g(x)=e^x$ are depicted below. The white lines
illustrate the contour lines, annotated with their respective gradients.

![Level curves of the class-wise utility](img/classwise-shapley-discounted-utility-function.svg)

## Evaluation

We evaluate the method on the nine datasets used in [@schoch_csshapley_2022],
using the same pre-processing. For images, PCA is used to reduce down to 32 the
number of features found by a `Resnet18` model. For more details on the
pre-processing steps, please refer to the paper.

??? info "Datasets used for evaluation"
| Dataset | Data Type | Classes | Input Dims | OpenML ID |
|----------------|-----------|---------|------------|-----------|
| Diabetes | Tabular | 2 | 8 | 37 |
| Click | Tabular | 2 | 11 | 1216 |
| CPU | Tabular | 2 | 21 | 197 |
| Covertype | Tabular | 7 | 54 | 1596 |
| Phoneme | Tabular | 2 | 5 | 1489 |
| FMNIST | Image | 2 | 32 | 40996 |
| CIFAR10 | Image | 2 | 32 | 40927 |
| MNIST (binary) | Image | 2 | 32 | 554 |
| MNIST (multi) | Image | 10 | 32 | 554 |


### Value transfer
kosmitive marked this conversation as resolved.
Show resolved Hide resolved

We compare CWS to TMCS, Beta Shapley and LOO on a sample removal task with value
transfer to another model. Values are computed using a logistic regression
model, and used to prune the training set for a neural network. Random values
serve as a baseline. The following plot shows valuation-set accuracy of the
network on the y-axis, and the number of samples removed on the x-axis.

![Accuracy after sample removal using values transferred from logistic regression
to an MLP](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg)
kosmitive marked this conversation as resolved.
Show resolved Hide resolved

Samples are removed from high to low valuation order and hence we expect a steep
decrease in the curve. CWS is competitive with the compared methods. Especially
in very unbalanced datasets, like `Click`, the performance of CWS seems
superior. In other datasets, like `Covertype` and `Diabetes` and `MNIST (multi)`
the performance is on par with TMC. For `MNIST (binary)` and `Phoneme` the
performance is competitive. We remark that for all valuation methods the
same number of _evaluations of the marginal utility_ was used.

### Value distribution
kosmitive marked this conversation as resolved.
Show resolved Hide resolved

This section takes a look at the distribution of CWS values for the above
datasets. In the histogram plots one can see that it is slightly biased and
often skewed.

![Distribution of values computed using logistic
regression](img/classwise-shapley-example-densities.svg)

The curves are an approximation of the density using KDE.
Loading