diff --git a/CHANGELOG.md b/CHANGELOG.md
index b1dc3abae..11295ed3e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,8 @@
## Unreleased
+- New method: Class-wise Shapley values
+ [PR #338](https://github.com/aai-institute/pyDVL/pull/338)
- No longer using docker within tests to start a memcached server
[PR #444](https://github.com/aai-institute/pyDVL/pull/444)
- Faster semi-value computation with per-index check of stopping criteria (optional)
@@ -43,6 +45,9 @@ randomness.
`compute_beta_shapley_semivalues`, `compute_shapley_semivalues` and
`compute_generic_semivalues`.
[PR #428](https://github.com/aai-institute/pyDVL/pull/428)
+- Added classwise Shapley as proposed by (Schoch et al. 2021)
+ [https://arxiv.org/abs/2211.06800]
+ [PR #338](https://github.com/aai-institute/pyDVL/pull/338)
### Changed
diff --git a/README.md b/README.md
index ae72b280c..80e5c1f2f 100644
--- a/README.md
+++ b/README.md
@@ -71,6 +71,11 @@ methods from the following papers:
Efficient Data Value](https://proceedings.mlr.press/v202/kwon23e.html). In
Proceedings of the 40th International Conference on Machine Learning, 18135–52.
PMLR, 2023.
+- Schoch, Stephanie, Haifeng Xu, and Yangfeng Ji. [CS-Shapley: Class-Wise
+ Shapley Values for Data Valuation in
+ Classification](https://openreview.net/forum?id=KTOcrOR5mQ9). In Proc. of the
+ Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS).
+ New Orleans, Louisiana, USA, 2022.
Influence Functions compute the effect that single points have on an estimator /
model. We implement methods from the following papers:
diff --git a/docs/api/pydvl/value/shapley/classwise/img/classwise-shapley-discounted-utility-function.svg b/docs/api/pydvl/value/shapley/classwise/img/classwise-shapley-discounted-utility-function.svg
new file mode 100644
index 000000000..c925f1e4a
--- /dev/null
+++ b/docs/api/pydvl/value/shapley/classwise/img/classwise-shapley-discounted-utility-function.svg
@@ -0,0 +1,68001 @@
+
+
+
diff --git a/docs/value/classwise-shapley.md b/docs/value/classwise-shapley.md
new file mode 100644
index 000000000..a6911812a
--- /dev/null
+++ b/docs/value/classwise-shapley.md
@@ -0,0 +1,269 @@
+---
+title: Class-wise Shapley
+---
+
+# Class-wise Shapley
+
+Class-wise Shapley (CWS) [@schoch_csshapley_2022] offers a Shapley framework
+tailored for classification problems. Given a sample $x_i$ with label $y_i \in
+\mathbb{N}$, let $D_{y_i}$ be the subset of $D$ with labels $y_i$, and
+$D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The key idea is that the
+sample $(x_i, y_i)$ might improve the overall model performance on $D$, while
+being detrimental for the performance on $D_{y_i},$ e.g. because of a wrong
+label. To address this issue, the authors introduced
+
+$$
+v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}}
+\left [
+\frac{1}{|D_{y_i}|}\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1}
+\delta(S_{y_i} | S_{-y_i})
+\right ],
+$$
+
+where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq
+D_{-y_i}$ is _arbitrary_ (in particular, not the complement of $S_{y_i}$). The
+function $\delta$ is called **set-conditional marginal Shapley value** and is
+defined as
+
+$$
+\delta(S | C) = u( S_{+i} | C ) − u(S | C),
+$$
+
+for any set $S$ such that $i \notin S, C$ and $S \cap C = \emptyset$.
+
+In practical applications, estimating this quantity is done both with Monte
+Carlo sampling of the powerset, and the set of index permutations
+[@castro_polynomial_2009]. Typically, this requires fewer samples than the
+original Shapley value, although the actual speed-up depends on the model and
+the dataset.
+
+
+!!! Example "Computing classwise Shapley values"
+ Like all other game-theoretic valuation methods, CWS requires a
+ [Utility][pydvl.utils.utility.Utility] object constructed with model and
+ dataset, with the peculiarity of requiring a specific
+ [ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. The entry
+ point is the function
+ [compute_classwise_shapley_values][pydvl.value.shapley.classwise.compute_classwise_shapley_values]:
+
+ ```python
+ from pydvl.value import *
+
+ model = ...
+ data = Dataset(...)
+ scorer = ClasswiseScorer(...)
+ utility = Utility(model, data, scorer)
+ values = compute_classwise_shapley_values(
+ utility,
+ done=HistoryDeviation(n_steps=500, rtol=5e-2) | MaxUpdates(5000),
+ truncation=RelativeTruncation(utility, rtol=0.01),
+ done_sample_complements=MaxChecks(1),
+ normalize_values=True
+ )
+ ```
+
+
+### The class-wise scorer
+
+In order to use the classwise Shapley value, one needs to define a
+[ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer]. This scorer
+is defined as
+
+$$
+u(S) = f(a_S(D_{y_i})) g(a_S(D_{-y_i})),
+$$
+
+where $f$ and $g$ are monotonically increasing functions, $a_S(D_{y_i})$ is the
+**in-class accuracy**, and $a_S(D_{-y_i})$ is the **out-of-class accuracy** (the
+names originate from a choice by the authors to use accuracy, but in principle
+any other score, like $F_1$ can be used).
+
+The authors show that $f(x)=x$ and $g(x)=e^x$ have favorable properties and are
+therefore the defaults, but we leave the option to set different functions $f$
+and $g$ for an exploration with different base scores.
+
+!!! Example "The default class-wise scorer"
+ Constructing the CWS scorer requires choosing a metric and the functions $f$
+ and $g$:
+
+ ```python
+ import numpy as np
+ from pydvl.value.shapley.classwise import ClasswiseScorer
+
+ # These are the defaults
+ identity = lambda x: x
+ scorer = ClasswiseScorer(
+ "accuracy",
+ in_class_discount_fn=identity,
+ out_of_class_discount_fn=np.exp
+ )
+ ```
+
+??? "Surface of the discounted utility function"
+ The level curves for $f(x)=x$ and $g(x)=e^x$ are depicted below. The lines
+ illustrate the contour lines, annotated with their respective gradients.
+ ![Level curves of the class-wise
+ utility](img/classwise-shapley-discounted-utility-function.svg){ align=left width=33% class=invertible }
+
+## Evaluation
+
+We illustrate the method with two experiments: point removal and noise removal,
+as well as an analysis of the distribution of the values. For this we employ the
+nine datasets used in [@schoch_csshapley_2022], using the same pre-processing.
+For images, PCA is used to reduce down to 32 the features found by a pre-trained
+`Resnet18` model. Standard loc-scale normalization is performed for all models
+except gradient boosting, since the latter is not sensitive to the scale of the
+features.
+
+??? info "Datasets used for evaluation"
+ | Dataset | Data Type | Classes | Input Dims | OpenML ID |
+ |----------------|-----------|---------|------------|-----------|
+ | Diabetes | Tabular | 2 | 8 | 37 |
+ | Click | Tabular | 2 | 11 | 1216 |
+ | CPU | Tabular | 2 | 21 | 197 |
+ | Covertype | Tabular | 7 | 54 | 1596 |
+ | Phoneme | Tabular | 2 | 5 | 1489 |
+ | FMNIST | Image | 2 | 32 | 40996 |
+ | CIFAR10 | Image | 2 | 32 | 40927 |
+ | MNIST (binary) | Image | 2 | 32 | 554 |
+ | MNIST (multi) | Image | 10 | 32 | 554 |
+
+We show mean and coefficient of variation (CV) $\frac{\sigma}{\mu}$ of an "inner
+metric". The former shows the performance of the method, whereas the latter
+displays its stability: we normalize by the mean to see the relative effect of
+the standard deviation. Ideally the mean value is maximal and CV minimal.
+
+Finally, we note that for all sampling-based valuation methods the same number
+of _evaluations of the marginal utility_ was used. This is important to make the
+algorithms comparable, but in practice one should consider using a more
+sophisticated stopping criterion.
+
+### Dataset pruning for logistic regression (point removal)
+
+In (best-)point removal, one first computes values for the training set and then
+removes in sequence the points with the highest values. After each removal, the
+remaining points are used to train the model from scratch and performance is
+measured on a test set. This produces a curve of performance vs. number of
+points removed which we show below.
+
+As a scalar summary of this curve, [@schoch_csshapley_2022] define **Weighted
+Accuracy Drop** (WAD) as:
+
+$$
+\text{WAD} = \sum_{j=1}^{n} \left ( \frac{1}{j} \sum_{i=1}^{j}
+a_{T_{-\{1 \colon i-1 \}}}(D) - a_{T_{-\{1 \colon i \}}}(D) \right)
+= a_T(D) - \sum_{j=1}^{n} \frac{a_{T_{-\{1 \colon j \}}}(D)}{j} ,
+$$
+
+where $a_T(D)$ is the accuracy of the model (trained on $T$) evaluated on $D$
+and $T_{-\{1 \colon j \}}$ is the set $T$ without elements from $\{1, \dots , j
+\}$.
+
+We run the point removal experiment for a logistic regression model five times
+and compute WAD for each run, then report the mean $\mu_\text{WAD}$ and standard
+deviation $\sigma_\text{WAD}$.
+
+![Mean WAD for best-point removal on logistic regression. Values
+computed using LOO, CWS, Beta Shapley, and TMCS
+](img/classwise-shapley-metric-wad-mean.svg){ class=invertible }
+
+We see that CWS is competitive with all three other methods. In all problems
+except `MNIST (multi)` it outperforms TMCS, while in that case TMCS has a slight
+advantage.
+
+In order to understand the variability of WAD we look at its coefficient of
+variation (lower is better):
+
+![Coefficient of Variation of WAD for best-point removal on logistic regression.
+Values computed using LOO, CWS, Beta Shapley, and TMCS
+](img/classwise-shapley-metric-wad-cv.svg){ class=invertible }
+
+CWS is not the best method in terms of CV. For `CIFAR10`, `Click`, `CPU` and
+`MNIST (binary)` Beta Shapley has the lowest CV. For `Diabetes`, `MNIST (multi)`
+and `Phoneme` CWS is the winner and for `FMNIST` and `Covertype` TMCS takes the
+lead. Besides LOO, TMCS has the highest relative standard deviation.
+
+The following plot shows accuracy vs number of samples removed. Random values
+serve as a baseline. The shaded area represents the 95% bootstrap confidence
+interval of the mean across 5 runs.
+
+![Accuracy after best-sample removal using values from logistic
+regression](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg){ class=invertible }
+
+Because samples are removed from high to low valuation order, we expect a steep
+decrease in the curve.
+
+Overall we conclude that in terms of mean WAD, CWS and TMCS perform best, with
+CWS's CV on par with Beta Shapley's, making CWS a competitive method.
+
+
+### Dataset pruning for a neural network by value transfer
+
+Transfer of values from one model to another is probably of greater practical
+relevance: values are computed using a cheap model and used to prune the dataset
+before training a more expensive one.
+
+The following plot shows accuracy vs number of samples removed for transfer from
+logistic regression to a neural network. The shaded area represents the 95%
+bootstrap confidence interval of the mean across 5 runs.
+
+![Accuracy after sample removal using values transferred from logistic
+regression to an MLP
+](img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg){ class=invertible }
+
+As in the previous experiment samples are removed from high to low valuation
+order and hence we expect a steep decrease in the curve. CWS is competitive with
+the other methods, especially in very unbalanced datasets like `Click`. In other
+datasets, like `Covertype`, `Diabetes` and `MNIST (multi)` the performance is on
+par with TMCS.
+
+
+### Detection of mis-labeled data points
+
+The next experiment tries to detect mis-labeled data points in binary
+classification tasks. 20% of the indices is flipped at random (we don't consider
+multi-class datasets because there isn't a unique flipping strategy). The
+following table shows the mean of the area under the curve (AUC) for five runs.
+
+![Mean AUC for mis-labeled data point detection. Values computed using LOO, CWS,
+Beta Shapley, and
+TMCS](img/classwise-shapley-metric-auc-mean.svg){ class=invertible }
+
+In the majority of cases TMCS has a slight advantage over CWS, except for
+`Click`, where CWS has a slight edge, most probably due to the unbalanced nature
+of the dataset. The following plot shows the CV for the AUC of the five runs.
+
+![Coefficient of variation of AUC for mis-labeled data point detection. Values
+computed using LOO, CWS, Beta Shapley, and TMCS
+](img/classwise-shapley-metric-auc-cv.svg){ class=invertible }
+
+In terms of CV, CWS has a clear edge over TMCS and Beta Shapley.
+
+Finally, we look at the ROC curves training the classifier on the $n$ first
+samples in _increasing_ order of valuation (i.e. starting with the worst):
+
+![Mean ROC across 5 runs with 95% bootstrap
+CI](img/classwise-shapley-roc-auc-logistic-regression.svg){ class=invertible }
+
+Although at first sight TMCS seems to be the winner, CWS stays competitive after
+factoring in running time. For a perfectly balanced dataset, CWS needs on
+average fewer samples than TCMS.
+
+### Value distribution
+
+For illustration, we compare the distribution of values computed by TMCS and
+CWS.
+
+![Histogram and estimated density of the values computed by TMCS and
+CWS on all nine datasets](img/classwise-shapley-density.svg){ class=invertible }
+
+For `Click` TMCS has a multi-modal distribution of values. We hypothesize that
+this is due to the highly unbalanced nature of the dataset, and notice that CWS
+has a single mode, leading to its greater performance on this dataset.
+
+## Conclusion
+
+CWS is an effective way to handle classification problems, in particular for
+unbalanced datasets. It reduces the computing requirements by considering
+in-class and out-of-class points separately.
+
diff --git a/docs/value/img/classwise-shapley-density.svg b/docs/value/img/classwise-shapley-density.svg
new file mode 100644
index 000000000..44d954546
--- /dev/null
+++ b/docs/value/img/classwise-shapley-density.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-discounted-utility-function.svg b/docs/value/img/classwise-shapley-discounted-utility-function.svg
new file mode 100644
index 000000000..70ed7ab58
--- /dev/null
+++ b/docs/value/img/classwise-shapley-discounted-utility-function.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-metric-auc-cv.svg b/docs/value/img/classwise-shapley-metric-auc-cv.svg
new file mode 100644
index 000000000..3ddc5f5a4
--- /dev/null
+++ b/docs/value/img/classwise-shapley-metric-auc-cv.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-metric-auc-mean.svg b/docs/value/img/classwise-shapley-metric-auc-mean.svg
new file mode 100644
index 000000000..197ada82b
--- /dev/null
+++ b/docs/value/img/classwise-shapley-metric-auc-mean.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-metric-wad-cv.svg b/docs/value/img/classwise-shapley-metric-wad-cv.svg
new file mode 100644
index 000000000..696226e83
--- /dev/null
+++ b/docs/value/img/classwise-shapley-metric-wad-cv.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-metric-wad-mean.svg b/docs/value/img/classwise-shapley-metric-wad-mean.svg
new file mode 100644
index 000000000..7f74a384a
--- /dev/null
+++ b/docs/value/img/classwise-shapley-metric-wad-mean.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-roc-auc-logistic-regression.svg b/docs/value/img/classwise-shapley-roc-auc-logistic-regression.svg
new file mode 100644
index 000000000..0ec200f83
--- /dev/null
+++ b/docs/value/img/classwise-shapley-roc-auc-logistic-regression.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg b/docs/value/img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg
new file mode 100644
index 000000000..1071d5f0b
--- /dev/null
+++ b/docs/value/img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-logistic-regression.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg b/docs/value/img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg
new file mode 100644
index 000000000..85a3244d8
--- /dev/null
+++ b/docs/value/img/classwise-shapley-weighted-accuracy-drop-logistic-regression-to-mlp.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/value/notation.md b/docs/value/notation.md
index f14ce6466..83054d5e6 100644
--- a/docs/value/notation.md
+++ b/docs/value/notation.md
@@ -4,17 +4,24 @@ title: Notation for valuation
# Notation for valuation
+!!! todo
+ Organize this page better and use its content consistently throughout the
+ documentation.
+
The following notation is used throughout the documentation:
Let $D = \{x_1, \ldots, x_n\}$ be a training set of $n$ samples.
The utility function $u:\mathcal{D} \rightarrow \mathbb{R}$ maps subsets of $D$
-to real numbers.
+to real numbers. In pyDVL, we typically call this mappin a **score** for
+consistency with sklearn, and reserve the term **utility** for the triple of
+dataset $D$, model $f$ and score $u$, since they are used together to compute
+the value.
The value $v$ of the $i$-th sample in dataset $D$ wrt. utility $u$ is
denoted as $v_u(x_i)$ or simply $v(i)$.
-For any $S \subseteq D$, we donote by $S_{-i}$ the set of samples in $D$
+For any $S \subseteq D$, we denote by $S_{-i}$ the set of samples in $D$
excluding $x_i$, and $S_{+i}$ denotes the set $S$ with $x_i$ added.
The marginal utility of adding sample $x_i$ to a subset $S$ is denoted as
diff --git a/docs/value/shapley.md b/docs/value/shapley.md
index 77af2ae2b..6de30f0ab 100644
--- a/docs/value/shapley.md
+++ b/docs/value/shapley.md
@@ -175,8 +175,9 @@ values = compute_shapley_values(u=utility, mode="knn")
### Group testing
-An alternative approach introduced in [@jia_efficient_2019a] first approximates
-the differences of values with a Monte Carlo sum. With
+An alternative method for the approximation of Shapley values introduced in
+[@jia_efficient_2019a] first estimates the differences of values with a Monte
+Carlo sum. With
$$\hat{\Delta}_{i j} \approx v_i - v_j,$$
diff --git a/docs_includes/abbreviations.md b/docs_includes/abbreviations.md
index e0fa67a4c..a89425885 100644
--- a/docs_includes/abbreviations.md
+++ b/docs_includes/abbreviations.md
@@ -1,15 +1,21 @@
*[CSP]: Constraint Satisfaction Problem
+*[CV]: Coefficient of Variation
+*[CWS]: Class-wise Shapley
+*[DUL]: Data Utility Learning
*[GT]: Group Testing
+*[IF]: Influence Function
+*[iHVP]: inverse Hessian-vector product
*[LC]: Least Core
+*[LiSSA]: Linear-time Stochastic Second-order Algorithm
*[LOO]: Leave-One-Out
*[MCLC]: Monte Carlo Least Core
*[MCS]: Monte Carlo Shapley
*[ML]: Machine Learning
+*[MLP]: Multi-Layer Perceptron
*[MLRC]: Machine Learning Reproducibility Challenge
*[MSE]: Mean Squared Error
+*[PCA]: Principal Component Analysis
+*[ROC]: Receiver Operating Characteristic
*[SV]: Shapley Value
*[TMCS]: Truncated Monte Carlo Shapley
-*[IF]: Influence Function
-*[iHVP]: inverse Hessian-vector product
-*[LiSSA]: Linear-time Stochastic Second-order Algorithm
-*[DUL]: Data Utility Learning
+*[WAD]: Weighted Accuracy Drop
diff --git a/mkdocs.yml b/mkdocs.yml
index dde7b7e55..c4a80316a 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -193,9 +193,10 @@ nav:
- Data Valuation:
- Introduction: value/index.md
- Notation: value/notation.md
- - Shapley Values: value/shapley.md
+ - Shapley values: value/shapley.md
- Semi-values: value/semi-values.md
- The core: value/the-core.md
+ - Classwise Shapley: value/classwise-shapley.md
- Examples:
- Shapley values: examples/shapley_basic_spotify.ipynb
- KNN Shapley: examples/shapley_knn_flowers.ipynb
diff --git a/src/pydvl/utils/numeric.py b/src/pydvl/utils/numeric.py
index d223673ed..679573a82 100644
--- a/src/pydvl/utils/numeric.py
+++ b/src/pydvl/utils/numeric.py
@@ -5,7 +5,16 @@
from __future__ import annotations
from itertools import chain, combinations
-from typing import Collection, Generator, Iterator, Optional, Tuple, TypeVar, overload
+from typing import (
+ Collection,
+ Generator,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ overload,
+)
import numpy as np
from numpy.typing import NDArray
@@ -19,6 +28,7 @@
"random_matrix_with_condition_number",
"random_subset",
"random_powerset",
+ "random_powerset_label_min",
"random_subset_of_size",
"top_k_value_accuracy",
]
@@ -133,6 +143,66 @@ def random_powerset(
total += 1
+def random_powerset_label_min(
+ s: NDArray[T],
+ labels: NDArray[np.int_],
+ min_elements_per_label: int = 1,
+ seed: Optional[Seed] = None,
+) -> Generator[NDArray[T], None, None]:
+ """Draws random subsets from `s`, while ensuring that at least
+ `min_elements_per_label` elements per label are included in the draw. It can be used
+ for classification problems to ensure that a set contains information for all labels
+ (or not if `min_elements_per_label=0`).
+
+ Args:
+ s: Set to sample from
+ labels: Labels for the samples
+ min_elements_per_label: Minimum number of elements for each label.
+ seed: Either an instance of a numpy random number generator or a seed for it.
+
+ Returns:
+ Generated draw from the powerset of s with `min_elements_per_label` for each
+ label.
+
+ Raises:
+ ValueError: If `s` and `labels` are of different length or
+ `min_elements_per_label` is smaller than 0.
+ """
+ if len(labels) != len(s):
+ raise ValueError("Set and labels have to be of same size.")
+
+ if min_elements_per_label < 0:
+ raise ValueError(
+ f"Parameter min_elements={min_elements_per_label} needs to be bigger or "
+ f"equal to 0."
+ )
+
+ rng = np.random.default_rng(seed)
+ unique_labels = np.unique(labels)
+
+ while True:
+ subsets: List[NDArray[T]] = []
+ for label in unique_labels:
+ label_indices = np.asarray(np.where(labels == label)[0])
+ subset_size = int(
+ rng.integers(
+ min(min_elements_per_label, len(label_indices)),
+ len(label_indices) + 1,
+ )
+ )
+ if subset_size > 0:
+ subsets.append(
+ random_subset_of_size(s[label_indices], subset_size, seed=rng)
+ )
+
+ if len(subsets) > 0:
+ subset = np.concatenate(tuple(subsets))
+ rng.shuffle(subset)
+ yield subset
+ else:
+ yield np.array([], dtype=s.dtype)
+
+
def random_subset_of_size(
s: NDArray[T], size: int, seed: Optional[Seed] = None
) -> NDArray[T]:
diff --git a/src/pydvl/utils/score.py b/src/pydvl/utils/score.py
index a5d1aceef..f077c9a53 100644
--- a/src/pydvl/utils/score.py
+++ b/src/pydvl/utils/score.py
@@ -26,7 +26,13 @@
from pydvl.utils.types import SupervisedModel
-__all__ = ["Scorer", "compose_score", "squashed_r2", "squashed_variance"]
+__all__ = [
+ "Scorer",
+ "ScorerCallable",
+ "compose_score",
+ "squashed_r2",
+ "squashed_variance",
+]
class ScorerCallable(Protocol):
diff --git a/src/pydvl/value/result.py b/src/pydvl/value/result.py
index 3e4bc1b73..20def1390 100644
--- a/src/pydvl/value/result.py
+++ b/src/pydvl/value/result.py
@@ -538,10 +538,14 @@ def __add__(
xm[other_pos] = other._values
vm[other_pos] = other._variances
+ # np.maximum(1, n + m) covers case n = m = 0.
+ n_m_sum = np.maximum(1, n + m)
+
# Sample mean of n+m samples from two means of n and m samples
- xnm = (n * xn + m * xm) / (n + m)
+ xnm = (n * xn + m * xm) / n_m_sum
+
# Sample variance of n+m samples from two sample variances of n and m samples
- vnm = (n * (vn + xn**2) + m * (vm + xm**2)) / (n + m) - xnm**2
+ vnm = (n * (vn + xn**2) + m * (vm + xm**2)) / n_m_sum - xnm**2
if np.any(vnm < 0):
if np.any(vnm < -1e-6):
@@ -627,6 +631,17 @@ def update(self, idx: int, new_value: float) -> ValuationResult[IndexT, NameT]:
)
return self
+ def scale(self, factor: float, indices: Optional[NDArray[IndexT]] = None):
+ """
+ Scales the values and variances of the result by a coefficient.
+
+ Args:
+ factor: Factor to scale by.
+ indices: Indices to scale. If None, all values are scaled.
+ """
+ self._values[self._sort_positions[indices]] *= factor
+ self._variances[self._sort_positions[indices]] *= factor**2
+
def get(self, idx: Integral) -> ValueItem:
"""Retrieves a ValueItem by data index, as opposed to sort index, like
the indexing operator.
diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py
index 95011b1b9..cabe3c3ba 100644
--- a/src/pydvl/value/semivalues.py
+++ b/src/pydvl/value/semivalues.py
@@ -171,14 +171,6 @@ def _marginal(
# deprecated_in="0.8.0",
# remove_in="0.9.0",
# )
-@deprecated(
- target=True,
- deprecated_in="0.7.0",
- remove_in="0.9.0",
- args_mapping={"batch_size": None},
- template_mgs="batch_size is for experimental use and will be removed"
- "in future versions.",
-)
def compute_generic_semivalues(
sampler: PowersetSampler[IndexT],
u: Utility,
@@ -334,14 +326,6 @@ def beta_coefficient_w(n: int, k: int) -> float:
return cast(SVCoefficient, beta_coefficient_w)
-@deprecated(
- target=True,
- deprecated_in="0.7.0",
- remove_in="0.9.0",
- args_mapping={"batch_size": None},
- template_mgs="batch_size is for experimental use and will be removed"
- "in future versions.",
-)
def compute_shapley_semivalues(
u: Utility,
*,
@@ -394,14 +378,6 @@ def compute_shapley_semivalues(
)
-@deprecated(
- target=True,
- deprecated_in="0.7.0",
- remove_in="0.9.0",
- args_mapping={"batch_size": None},
- template_mgs="batch_size is for experimental use and will be removed"
- "in future versions.",
-)
def compute_banzhaf_semivalues(
u: Utility,
*,
@@ -452,14 +428,6 @@ def compute_banzhaf_semivalues(
)
-@deprecated(
- target=True,
- deprecated_in="0.7.0",
- remove_in="0.9.0",
- args_mapping={"batch_size": None},
- template_mgs="batch_size is for experimental use and will be removed"
- "in future versions.",
-)
def compute_beta_shapley_semivalues(
u: Utility,
*,
diff --git a/src/pydvl/value/shapley/__init__.py b/src/pydvl/value/shapley/__init__.py
index d4730237e..ec1ec44b5 100644
--- a/src/pydvl/value/shapley/__init__.py
+++ b/src/pydvl/value/shapley/__init__.py
@@ -11,6 +11,7 @@
from ..result import *
from ..stopping import *
+from .classwise import *
from .common import *
from .gt import *
from .knn import *
diff --git a/src/pydvl/value/shapley/classwise.py b/src/pydvl/value/shapley/classwise.py
new file mode 100644
index 000000000..438d953c8
--- /dev/null
+++ b/src/pydvl/value/shapley/classwise.py
@@ -0,0 +1,599 @@
+r"""
+Class-wise Shapley (Schoch et al., 2022)[^1] offers a Shapley framework tailored
+for classification problems. Let $D$ be a dataset, $D_{y_i}$ be the subset of
+$D$ with labels $y_i$, and $D_{-y_i}$ be the complement of $D_{y_i}$ in $D$. The
+key idea is that a sample $(x_i, y_i)$, might enhance the overall performance on
+$D$, while being detrimental for the performance on $D_{y_i}$. The Class-wise
+value is defined as:
+
+$$
+v_u(i) = \frac{1}{2^{|D_{-y_i}|}} \sum_{S_{-y_i}} \frac{1}{|D_{y_i}|!}
+\sum_{S_{y_i}} \binom{|D_{y_i}|-1}{|S_{y_i}|}^{-1}
+[u( S_{y_i} \cup \{i\} | S_{-y_i} ) − u( S_{y_i} | S_{-y_i})],
+$$
+
+where $S_{y_i} \subseteq D_{y_i} \setminus \{i\}$ and $S_{-y_i} \subseteq
+D_{-y_i}$.
+
+!!! tip "Analysis of Class-wise Shapley"
+ For a detailed analysis of the method, with comparison to other valuation
+ techniques, please refer to the [main
+ documentation](../../../../../value/classwise-shapley).
+
+In practice, the quantity above is estimated using Monte Carlo sampling of
+the powerset and the set of index permutations. This results in the estimator
+
+$$
+v_u(i) = \frac{1}{K} \sum_k \frac{1}{L} \sum_l
+[u(\sigma^{(l)}_{:i} \cup \{i\} | S^{(k)} ) − u( \sigma^{(l)}_{:i} | S^{(k)})],
+$$
+
+with $S^{(1)}, \dots, S^{(K)} \subseteq T_{-y_i},$ $\sigma^{(1)}, \dots,
+\sigma^{(L)} \in \Pi(T_{y_i}\setminus\{i\}),$ and $\sigma^{(l)}_{:i}$ denoting
+the set of indices in permutation $\sigma^{(l)}$ before the position where $i$
+appears. The sets $T_{y_i}$ and $T_{-y_i}$ are the training sets for the labels
+$y_i$ and $-y_i$, respectively.
+
+??? info "Notes for derivation of test cases"
+ The unit tests include the following manually constructed data:
+ Let $D=\{(1,0),(2,0),(3,0),(4,1)\}$ be the test set and $T=\{(1,0),(2,0),(3,1),(4,1)\}$
+ the train set. This specific dataset is chosen as it allows to solve the model
+
+ $$y = \max(0, \min(1, \text{round}(\beta^T x)))$$
+
+ in closed form $\beta = \frac{\text{dot}(x, y)}{\text{dot}(x, x)}$. From the closed-form
+ solution, the tables for in-class accuracy $a_S(D_{y_i})$ and out-of-class accuracy
+ $a_S(D_{-y_i})$ can be calculated. By using these tables and setting
+ $\{S^{(1)}, \dots, S^{(K)}\} = 2^{T_{-y_i}}$ and
+ $\{\sigma^{(1)}, \dots, \sigma^{(L)}\} = \Pi(T_{y_i}\setminus\{i\})$,
+ the Monte Carlo estimator can be evaluated ($2^M$ is the powerset of $M$).
+ The details of the derivation are left to the eager reader.
+
+# References
+
+[^1]: Schoch, Stephanie, Haifeng Xu, and
+ Yangfeng Ji. [CS-Shapley: Class-wise Shapley Values for Data Valuation in
+ Classification](https://openreview.net/forum?id=KTOcrOR5mQ9). In Proc. of
+ the Thirty-Sixth Conference on Neural Information Processing Systems
+ (NeurIPS). New Orleans, Louisiana, USA, 2022.
+
+"""
+import logging
+import numbers
+from concurrent.futures import FIRST_COMPLETED, Future, wait
+from copy import copy
+from typing import Callable, Optional, Set, Tuple, Union, cast
+
+import numpy as np
+from numpy.random import SeedSequence
+from numpy.typing import NDArray
+from tqdm import tqdm
+
+from pydvl.parallel import (
+ ParallelConfig,
+ effective_n_jobs,
+ init_executor,
+ init_parallel_backend,
+)
+from pydvl.utils import (
+ Dataset,
+ Scorer,
+ ScorerCallable,
+ Seed,
+ SupervisedModel,
+ Utility,
+ ensure_seed_sequence,
+ random_powerset_label_min,
+)
+from pydvl.value.result import ValuationResult
+from pydvl.value.shapley.truncated import TruncationPolicy
+from pydvl.value.stopping import MaxChecks, StoppingCriterion
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["ClasswiseScorer", "compute_classwise_shapley_values"]
+
+
+class ClasswiseScorer(Scorer):
+ r"""A Scorer designed for evaluation in classification problems. Its value
+ is computed from an in-class and an out-of-class "inner score" (Schoch et
+ al., 2022) 1. Let $S$ be the
+ training set and $D$ be the valuation set. For each label $c$, $D$ is
+ factorized into two disjoint sets: $D_c$ for in-class instances and $D_{-c}$
+ for out-of-class instances. The score combines an in-class metric of
+ performance, adjusted by a discounted out-of-class metric. These inner
+ scores must be provided upon construction or default to accuracy. They are
+ combined into:
+
+ $$
+ u(S_{y_i}) = f(a_S(D_{y_i}))\ g(a_S(D_{-y_i})),
+ $$
+
+ where $f$ and $g$ are continuous, monotonic functions. For a detailed
+ explanation, refer to section four of (Schoch et al., 2022) 1.
+
+ !!! warning Multi-class support
+ Metrics must support multiple class labels if you intend to apply them
+ to a multi-class problem. For instance, the metric 'accuracy' supports
+ multiple classes, but the metric `f1` does not. For a two-class
+ classification problem, using `f1_weighted` is essentially equivalent to
+ using `accuracy`.
+
+ Args:
+ scoring: Name of the scoring function or a callable that can be passed
+ to [Scorer][pydvl.utils.score.Scorer].
+ default: Score to use when a model fails to provide a number, e.g. when
+ too little was used to train it, or errors arise.
+ range: Numerical range of the score function. Some Monte Carlo methods
+ can use this to estimate the number of samples required for a
+ certain quality of approximation. If not provided, it can be read
+ from the `scoring` object if it provides it, for instance if it was
+ constructed with
+ [compose_score][pydvl.utils.score.compose_score].
+ in_class_discount_fn: Continuous, monotonic increasing function used to
+ discount the in-class score.
+ out_of_class_discount_fn: Continuous, monotonic increasing function used
+ to discount the out-of-class score.
+ initial_label: Set initial label (for the first iteration)
+ name: Name of the scorer. If not provided, the name of the inner scoring
+ function will be prefixed by `classwise `.
+
+ !!! tip "New in version 0.7.1"
+ """
+
+ def __init__(
+ self,
+ scoring: Union[str, ScorerCallable] = "accuracy",
+ default: float = 0.0,
+ range: Tuple[float, float] = (0, 1),
+ in_class_discount_fn: Callable[[float], float] = lambda x: x,
+ out_of_class_discount_fn: Callable[[float], float] = np.exp,
+ initial_label: Optional[int] = None,
+ name: Optional[str] = None,
+ ):
+ disc_score_in_class = in_class_discount_fn(range[1])
+ disc_score_out_of_class = out_of_class_discount_fn(range[1])
+ transformed_range = (0, disc_score_in_class * disc_score_out_of_class)
+ super().__init__(
+ scoring=scoring,
+ range=transformed_range,
+ default=default,
+ name=name or f"classwise {str(scoring)}",
+ )
+ self._in_class_discount_fn = in_class_discount_fn
+ self._out_of_class_discount_fn = out_of_class_discount_fn
+ self.label = initial_label
+
+ def __str__(self):
+ return self._name
+
+ def __call__(
+ self: "ClasswiseScorer",
+ model: SupervisedModel,
+ x_test: NDArray[np.float_],
+ y_test: NDArray[np.int_],
+ ) -> float:
+ (
+ in_class_score,
+ out_of_class_score,
+ ) = self.estimate_in_class_and_out_of_class_score(model, x_test, y_test)
+ disc_score_in_class = self._in_class_discount_fn(in_class_score)
+ disc_score_out_of_class = self._out_of_class_discount_fn(out_of_class_score)
+ return disc_score_in_class * disc_score_out_of_class
+
+ def estimate_in_class_and_out_of_class_score(
+ self,
+ model: SupervisedModel,
+ x_test: NDArray[np.float_],
+ y_test: NDArray[np.int_],
+ rescale_scores: bool = True,
+ ) -> Tuple[float, float]:
+ r"""
+ Computes in-class and out-of-class scores using the provided inner
+ scoring function. The result is
+
+ $$
+ a_S(D=\{(x_1, y_1), \dots, (x_K, y_K)\}) = \frac{1}{N} \sum_k s(y(x_k), y_k).
+ $$
+
+ In this context, for label $c$ calculations are executed twice: once for $D_c$
+ and once for $D_{-c}$ to determine the in-class and out-of-class scores,
+ respectively. By default, the raw scores are multiplied by $\frac{|D_c|}{|D|}$
+ and $\frac{|D_{-c}|}{|D|}$, respectively. This is done to ensure that both
+ scores are of the same order of magnitude. This normalization is particularly
+ useful when the inner score function $a_S$ is calculated by an estimator of the
+ form $\frac{1}{N} \sum_i x_i$, e.g. the accuracy.
+
+ Args:
+ model: Model used for computing the score on the validation set.
+ x_test: Array containing the features of the classification problem.
+ y_test: Array containing the labels of the classification problem.
+ rescale_scores: If set to True, the scores will be denormalized. This is
+ particularly useful when the inner score function $a_S$ is calculated by
+ an estimator of the form $\frac{1}{N} \sum_i x_i$.
+
+ Returns:
+ Tuple containing the in-class and out-of-class scores.
+ """
+ scorer = self._scorer
+ label_set_match = y_test == self.label
+ label_set = np.where(label_set_match)[0]
+ num_classes = len(np.unique(y_test))
+
+ if len(label_set) == 0:
+ return 0, 1 / (num_classes - 1)
+
+ complement_label_set = np.where(~label_set_match)[0]
+ in_class_score = scorer(model, x_test[label_set], y_test[label_set])
+ out_of_class_score = scorer(
+ model, x_test[complement_label_set], y_test[complement_label_set]
+ )
+
+ if rescale_scores:
+ n_in_class = np.count_nonzero(y_test == self.label)
+ n_out_of_class = len(y_test) - n_in_class
+ in_class_score *= n_in_class / (n_in_class + n_out_of_class)
+ out_of_class_score *= n_out_of_class / (n_in_class + n_out_of_class)
+
+ return in_class_score, out_of_class_score
+
+
+def compute_classwise_shapley_values(
+ u: Utility,
+ *,
+ done: StoppingCriterion,
+ truncation: TruncationPolicy,
+ done_sample_complements: Optional[StoppingCriterion] = None,
+ normalize_values: bool = True,
+ use_default_scorer_value: bool = True,
+ min_elements_per_label: int = 1,
+ n_jobs: int = 1,
+ config: ParallelConfig = ParallelConfig(),
+ progress: bool = False,
+ seed: Optional[Seed] = None,
+) -> ValuationResult:
+ r"""
+ Computes an approximate Class-wise Shapley value by sampling independent
+ permutations of the index set for each label and index sets sampled from the
+ powerset of the complement (with respect to the currently evaluated label),
+ approximating the sum:
+
+ $$
+ v_u(i) = \frac{1}{K} \sum_k \frac{1}{L} \sum_l
+ [u(\sigma^{(l)}_{:i} \cup \{i\} | S^{(k)} ) − u( \sigma^{(l)}_{:i} | S^{(k)})],
+ $$
+
+ where $\sigma_{:i}$ denotes the set of indices in permutation sigma before
+ the position where $i$ appears and $S$ is a subset of the index set of all other
+ labels(see [[data-valuation]] for details).
+
+ Args:
+ u: Utility object containing model, data, and scoring function. The
+ scorer must be of type
+ [ClasswiseScorer][pydvl.value.shapley.classwise.ClasswiseScorer].
+ done: Function that checks whether the computation needs to stop.
+ truncation: Callable function that decides whether to interrupt processing a
+ permutation and set subsequent marginals to zero.
+ done_sample_complements: Function checking whether computation needs to stop.
+ Otherwise, it will resample conditional sets until the stopping criterion is
+ met.
+ normalize_values: Indicates whether to normalize the values by the variation
+ in each class times their in-class accuracy.
+ done_sample_complements: Number of times to resample the complement set
+ for each permutation.
+ use_default_scorer_value: The first set of indices is the sampled complement
+ set. Unless not otherwise specified, the default scorer value is used for
+ this. If it is set to false, the base score is calculated from the utility.
+ min_elements_per_label: The minimum number of elements for each opposite
+ label.
+ n_jobs: Number of parallel jobs to run.
+ config: Parallel configuration.
+ progress: Whether to display a progress bar.
+ seed: Either an instance of a numpy random number generator or a seed for it.
+
+ Returns:
+ ValuationResult object containing computed data values.
+
+ !!! tip "New in version 0.7.1"
+ """
+ dim_correct = u.data.y_train.ndim == 1 and u.data.y_test.ndim == 1
+ is_integral = all(
+ map(
+ lambda v: isinstance(v, numbers.Integral), (*u.data.y_train, *u.data.y_test)
+ )
+ )
+ if not dim_correct or not is_integral:
+ raise ValueError(
+ "The supplied dataset has to be a 1-dimensional classification dataset."
+ )
+
+ if not isinstance(u.scorer, ClasswiseScorer):
+ raise ValueError(
+ "Please set a subclass of ClasswiseScorer object as scorer object of the"
+ " utility. See scoring argument of Utility."
+ )
+
+ parallel_backend = init_parallel_backend(config)
+ u_ref = parallel_backend.put(u)
+ n_jobs = effective_n_jobs(n_jobs, config)
+ n_submitted_jobs = 2 * n_jobs
+
+ pbar = tqdm(disable=not progress, position=0, total=100, unit="%")
+ algorithm = "classwise_shapley"
+ accumulated_result = ValuationResult.zeros(
+ algorithm=algorithm, indices=u.data.indices, data_names=u.data.data_names
+ )
+ terminate_exec = False
+ seed_sequence = ensure_seed_sequence(seed)
+
+ with init_executor(max_workers=n_jobs, config=config) as executor:
+ pending: Set[Future] = set()
+ while True:
+ completed_futures, pending = wait(
+ pending, timeout=60, return_when=FIRST_COMPLETED
+ )
+ for future in completed_futures:
+ accumulated_result += future.result()
+ if done(accumulated_result):
+ terminate_exec = True
+ break
+
+ pbar.n = 100 * done.completion()
+ pbar.refresh()
+ if terminate_exec:
+ break
+
+ n_remaining_slots = n_submitted_jobs - len(pending)
+ seeds = seed_sequence.spawn(n_remaining_slots)
+ for i in range(n_remaining_slots):
+ future = executor.submit(
+ _permutation_montecarlo_classwise_shapley_one_step,
+ u_ref,
+ truncation=truncation,
+ done_sample_complements=done_sample_complements,
+ use_default_scorer_value=use_default_scorer_value,
+ min_elements_per_label=min_elements_per_label,
+ algorithm_name=algorithm,
+ seed=seeds[i],
+ )
+ pending.add(future)
+
+ result = accumulated_result
+ if normalize_values:
+ result = _normalize_classwise_shapley_values(result, u)
+
+ return result
+
+
+def _permutation_montecarlo_classwise_shapley_one_step(
+ u: Utility,
+ *,
+ done_sample_complements: StoppingCriterion = None,
+ truncation: TruncationPolicy,
+ use_default_scorer_value: bool = True,
+ min_elements_per_label: int = 1,
+ algorithm_name: str = "classwise_shapley",
+ seed: Optional[SeedSequence] = None,
+) -> ValuationResult:
+ """Helper function for [compute_classwise_shapley_values()]
+ [pydvl.value.shapley.classwise.compute_classwise_shapley_values].
+
+
+ Args:
+ u: Utility object containing model, data, and scoring function. The
+ scorer must be of type [ClasswiseScorer]
+ [pydvl.value.shapley.classwise.ClasswiseScorer].
+ done_sample_complements: Function checking whether computation needs to stop.
+ Otherwise, it will resample conditional sets until the stopping criterion is
+ met.
+ truncation: Callable function that decides whether to interrupt processing a
+ permutation and set subsequent marginals to zero.
+ use_default_scorer_value: The first set of indices is the sampled complement
+ set. Unless not otherwise specified, the default scorer value is used for
+ this. If it is set to false, the base score is calculated from the utility.
+ min_elements_per_label: The minimum number of elements for each opposite
+ label.
+ algorithm_name: For the results object.
+ seed: Either an instance of a numpy random number generator or a seed for it.
+
+ Returns:
+ ValuationResult object containing computed data values.
+ """
+ if done_sample_complements is None:
+ done_sample_complements = MaxChecks(1)
+
+ result = ValuationResult.zeros(
+ algorithm=algorithm_name, indices=u.data.indices, data_names=u.data.data_names
+ )
+ rng = np.random.default_rng(seed)
+ x_train, y_train = u.data.get_training_data(u.data.indices)
+ unique_labels = np.unique(y_train)
+ scorer = cast(ClasswiseScorer, copy(u.scorer))
+ u.scorer = scorer
+
+ for label in unique_labels:
+ u.scorer.label = label
+ class_indices_set, class_complement_indices_set = _split_indices_by_label(
+ u.data.indices, y_train, label
+ )
+ _, complement_y_train = u.data.get_training_data(class_complement_indices_set)
+ indices_permutation = rng.permutation(class_indices_set)
+ done_sample_complements.reset()
+
+ for subset_idx, subset_complement in enumerate(
+ random_powerset_label_min(
+ class_complement_indices_set,
+ complement_y_train,
+ min_elements_per_label=min_elements_per_label,
+ seed=rng,
+ )
+ ):
+ result += _permutation_montecarlo_shapley_rollout(
+ u,
+ indices_permutation,
+ additional_indices=subset_complement,
+ truncation=truncation,
+ algorithm_name=algorithm_name,
+ use_default_scorer_value=use_default_scorer_value,
+ )
+ if done_sample_complements(result):
+ break
+
+ return result
+
+
+def _normalize_classwise_shapley_values(
+ result: ValuationResult, u: Utility
+) -> ValuationResult:
+ r"""
+ Normalize a valuation result specific to classwise Shapley.
+
+ Each value $v_i$ associated with the sample $(x_i, y_i)$ is normalized by
+ multiplying it with $a_S(D_{y_i})$ and dividing by $\sum_{j \in D_{y_i}} v_j$. For
+ more details, see (Schoch et al., 2022) 1
+ .
+
+ Args:
+ result: ValuationResult object to be normalized.
+ u: Utility object containing model, data, and scoring function. The
+ scorer must be of type [ClasswiseScorer]
+ [pydvl.value.shapley.classwise.ClasswiseScorer].
+
+ Returns:
+ Normalized ValuationResult object.
+ """
+ y_train = u.data.y_train
+ unique_labels = np.unique(np.concatenate((y_train, u.data.y_test)))
+ scorer = cast(ClasswiseScorer, u.scorer)
+
+ for idx_label, label in enumerate(unique_labels):
+ scorer.label = label
+ active_elements = y_train == label
+ indices_label_set = np.where(active_elements)[0]
+ indices_label_set = u.data.indices[indices_label_set]
+
+ u.model.fit(u.data.x_train, u.data.y_train)
+ scorer.label = label
+ in_class_acc, _ = scorer.estimate_in_class_and_out_of_class_score(
+ u.model, u.data.x_test, u.data.y_test
+ )
+
+ sigma = np.sum(result.values[indices_label_set])
+ if sigma != 0:
+ result.scale(in_class_acc / sigma, indices=indices_label_set)
+
+ return result
+
+
+def _permutation_montecarlo_shapley_rollout(
+ u: Utility,
+ permutation: NDArray[np.int_],
+ truncation: TruncationPolicy,
+ algorithm_name: str,
+ additional_indices: Optional[NDArray[np.int_]] = None,
+ use_default_scorer_value: bool = True,
+) -> ValuationResult:
+ """
+ Represents a truncated version of a permutation-based MC estimator. It iterates over
+ all subsets starting from the empty set to the full set of indices as specified by
+ `permutation`. For each subset, the marginal contribution is computed and added to
+ the result. The computation is interrupted if the truncation policy returns `True`.
+
+ !!! Todo
+ Reuse in [permutation_montecarlo_shapley()]
+ [pydvl.value.shapley.montecarlo.permutation_montecarlo_shapley]
+
+ Args:
+ u: Utility object containing model, data, and scoring function.
+ permutation: Permutation of indices to be considered.
+ truncation: Callable which decides whether to interrupt processing a
+ permutation and set all subsequent marginals to zero.
+ algorithm_name: For the results object. Used internally by different
+ variants of Shapley using this subroutine
+ additional_indices: Set of additional indices for data points which should be
+ always considered.
+ use_default_scorer_value: Use default scorer value even if additional_indices
+ is not None.
+
+ Returns:
+ ValuationResult object containing computed data values.
+ """
+ if (
+ additional_indices is not None
+ and len(np.intersect1d(permutation, additional_indices)) > 0
+ ):
+ raise ValueError(
+ "The class label set and the complement set have to be disjoint."
+ )
+
+ result = ValuationResult.zeros(
+ algorithm=algorithm_name, indices=u.data.indices, data_names=u.data.data_names
+ )
+
+ prev_score = (
+ u.default_score
+ if (
+ use_default_scorer_value
+ or additional_indices is None
+ or additional_indices is not None
+ and len(additional_indices) == 0
+ )
+ else u(additional_indices)
+ )
+
+ truncation_u = u
+ if additional_indices is not None:
+ # hack to calculate the correct value in reset.
+ truncation_indices = np.sort(np.concatenate((permutation, additional_indices)))
+ truncation_u = Utility(
+ u.model,
+ Dataset(
+ u.data.x_train[truncation_indices],
+ u.data.y_train[truncation_indices],
+ u.data.x_test,
+ u.data.y_test,
+ ),
+ u.scorer,
+ )
+ truncation.reset(truncation_u)
+
+ is_terminated = False
+ for i, idx in enumerate(permutation):
+ if is_terminated or (is_terminated := truncation(i, prev_score)):
+ score = prev_score
+ else:
+ score = u(
+ np.concatenate((permutation[: i + 1], additional_indices))
+ if additional_indices is not None and len(additional_indices) > 0
+ else permutation[: i + 1]
+ )
+
+ marginal = score - prev_score
+ result.update(idx, marginal)
+ prev_score = score
+
+ return result
+
+
+def _split_indices_by_label(
+ indices: NDArray[np.int_], labels: NDArray[np.int_], label: int
+) -> Tuple[NDArray[np.int_], NDArray[np.int_]]:
+ """
+ Splits the indices into two sets based on the value of `label`, e.g. those samples
+ with and without that label.
+
+ Args:
+ indices: The indices to be used for referring to the data.
+ labels: Corresponding labels for the indices.
+ label: Label to be used for splitting.
+
+ Returns:
+ Tuple with two sets of indices.
+ """
+ active_elements = labels == label
+ class_indices_set = np.where(active_elements)[0]
+ class_complement_indices_set = np.where(~active_elements)[0]
+ class_indices_set = indices[class_indices_set]
+ class_complement_indices_set = indices[class_complement_indices_set]
+ return class_indices_set, class_complement_indices_set
diff --git a/src/pydvl/value/shapley/gt.py b/src/pydvl/value/shapley/gt.py
index b193af6e5..2d3be7710 100644
--- a/src/pydvl/value/shapley/gt.py
+++ b/src/pydvl/value/shapley/gt.py
@@ -17,8 +17,10 @@
## References
[^1]: Jia, R. et al., 2019.
- [Towards Efficient Data Valuation Based on the Shapley Value](https://proceedings.mlr.press/v89/jia19a.html).
- In: Proceedings of the 22nd International Conference on Artificial Intelligence and Statistics, pp. 1167–1176. PMLR.
+ [Towards Efficient Data Valuation Based on the Shapley
+ Value](https://proceedings.mlr.press/v89/jia19a.html).
+ In: Proceedings of the 22nd International Conference on Artificial
+ Intelligence and Statistics, pp. 1167–1176. PMLR.
"""
import logging
from collections import namedtuple
diff --git a/src/pydvl/value/shapley/truncated.py b/src/pydvl/value/shapley/truncated.py
index 5e9f3f729..b54c8f6b3 100644
--- a/src/pydvl/value/shapley/truncated.py
+++ b/src/pydvl/value/shapley/truncated.py
@@ -8,7 +8,7 @@
"""
import abc
import logging
-from typing import cast
+from typing import Optional, cast
import numpy as np
from deprecate import deprecated
@@ -16,7 +16,7 @@
from pydvl.parallel.config import ParallelConfig
from pydvl.utils import Utility, running_moments
from pydvl.value import ValuationResult
-from pydvl.value.stopping import StoppingCriterion
+from pydvl.value.stopping import MaxChecks, StoppingCriterion
__all__ = [
"TruncationPolicy",
@@ -58,7 +58,7 @@ def _check(self, idx: int, score: float) -> bool:
...
@abc.abstractmethod
- def reset(self):
+ def reset(self, u: Optional[Utility] = None):
"""Reset the policy to a state ready for a new permutation."""
...
@@ -84,7 +84,7 @@ class NoTruncation(TruncationPolicy):
def _check(self, idx: int, score: float) -> bool:
return False
- def reset(self):
+ def reset(self, u: Optional[Utility] = None):
pass
@@ -115,7 +115,7 @@ def _check(self, idx: int, score: float) -> bool:
self.count += 1
return self.count >= self.max_marginals
- def reset(self):
+ def reset(self, u: Optional[Utility] = None):
self.count = 0
@@ -134,14 +134,18 @@ def __init__(self, u: Utility, rtol: float):
super().__init__()
self.rtol = rtol
logger.info("Computing total utility for permutation truncation.")
- self.total_utility = u(u.data.indices)
+ self.total_utility = self.reset(u)
+ self._u = u
def _check(self, idx: int, score: float) -> bool:
# Explicit cast for the benefit of mypy 🤷
return bool(np.allclose(score, self.total_utility, rtol=self.rtol))
- def reset(self):
- pass
+ def reset(self, u: Optional[Utility] = None):
+ if u is None:
+ u = self._u
+
+ self.total_utility = u(u.data.indices)
class BootstrapTruncation(TruncationPolicy):
@@ -179,7 +183,7 @@ def _check(self, idx: int, score: float) -> bool:
self.sigmas * np.sqrt(self.variance)
)
- def reset(self):
+ def reset(self, u: Optional[Utility] = None):
self.count = 0
self.variance = self.mean = 0
diff --git a/src/pydvl/value/stopping.py b/src/pydvl/value/stopping.py
index f2a236340..4ce4b27e8 100644
--- a/src/pydvl/value/stopping.py
+++ b/src/pydvl/value/stopping.py
@@ -226,6 +226,9 @@ def completion(self) -> float:
return 0.0
return float(np.mean(self.converged).item())
+ def reset(self):
+ pass
+
@property
def converged(self) -> NDArray[np.bool_]:
"""Returns a boolean array indicating whether the values have converged
@@ -413,7 +416,7 @@ def __init__(self, n_checks: Optional[int], modify_result: bool = True):
def _check(self, result: ValuationResult) -> Status:
if self.n_checks:
self._count += 1
- if self._count > self.n_checks:
+ if self._count >= self.n_checks:
self._converged = np.ones_like(result.values, dtype=bool)
return Status.Converged
return Status.Pending
@@ -423,6 +426,9 @@ def completion(self) -> float:
return min(1.0, self._count / self.n_checks)
return 0.0
+ def reset(self):
+ self._count = 0
+
def __str__(self):
return f"MaxChecks(n_checks={self.n_checks})"
@@ -546,6 +552,9 @@ def completion(self) -> float:
return 0.0
return (time() - self.start) / self.max_seconds
+ def reset(self):
+ self.start = time()
+
def __str__(self):
return f"MaxTime(seconds={self.max_seconds})"
@@ -622,7 +631,7 @@ def _check(self, r: ValuationResult) -> Status:
quots = np.divide(diffs, curr[ii], out=diffs, where=curr[ii] != 0)
# quots holds the quotients when the denominator is non-zero, and
# the absolute difference, which is just the memory, otherwise.
- if np.mean(quots) < self.rtol:
+ if len(quots) > 0 and np.mean(quots) < self.rtol:
self._converged = self.update_op(
self._converged, r.counts > self.n_steps
) # type: ignore
@@ -630,5 +639,8 @@ def _check(self, r: ValuationResult) -> Status:
return Status.Converged
return Status.Pending
+ def reset(self):
+ self._memory = None # type: ignore
+
def __str__(self):
return f"HistoryDeviation(n_steps={self.n_steps}, rtol={self.rtol})"
diff --git a/tests/utils/test_numeric.py b/tests/utils/test_numeric.py
index 13423b286..b722c24f8 100644
--- a/tests/utils/test_numeric.py
+++ b/tests/utils/test_numeric.py
@@ -5,6 +5,7 @@
powerset,
random_matrix_with_condition_number,
random_powerset,
+ random_powerset_label_min,
random_subset_of_size,
running_moments,
)
@@ -248,3 +249,27 @@ def test_running_moments():
true_variances = [np.var(vv) for vv in values]
assert np.allclose(means, true_means)
assert np.allclose(variances, true_variances)
+
+
+@pytest.mark.parametrize(
+ "min_elements_per_label,num_elements_per_label,num_labels,check_num_samples",
+ [(0, 10, 3, 1000), (1, 10, 3, 1000), (2, 10, 3, 1000)],
+)
+def test_random_powerset_label_min(
+ min_elements_per_label: int,
+ num_elements_per_label: int,
+ num_labels: int,
+ check_num_samples: int,
+):
+ s = np.arange(num_labels * num_elements_per_label)
+ labels = np.arange(num_labels).repeat(num_elements_per_label)
+
+ for idx, subset in enumerate(
+ random_powerset_label_min(s, labels, min_elements_per_label)
+ ):
+ assert np.all(np.isin(subset, s))
+ for group in np.unique(labels):
+ assert np.sum(group == labels[subset]) >= min_elements_per_label
+
+ if idx == check_num_samples:
+ break
diff --git a/tests/value/shapley/test_classwise.py b/tests/value/shapley/test_classwise.py
new file mode 100644
index 000000000..bd4f55a5d
--- /dev/null
+++ b/tests/value/shapley/test_classwise.py
@@ -0,0 +1,416 @@
+from typing import Dict, Tuple, cast
+
+import numpy as np
+import pandas as pd
+import pytest
+from numpy.typing import NDArray
+
+from pydvl.utils import Dataset, Utility, powerset
+from pydvl.value import MaxChecks, ValuationResult
+from pydvl.value.shapley.classwise import (
+ ClasswiseScorer,
+ compute_classwise_shapley_values,
+)
+from pydvl.value.shapley.truncated import NoTruncation
+from tests.value import check_values
+
+
+@pytest.fixture(scope="function")
+def classwise_shapley_exact_solution() -> Tuple[Dict, ValuationResult, Dict]:
+ """
+ See [classwise.py][pydvl.value.shapley.classwise] for details of the derivation.
+ """
+ return (
+ {
+ "normalize_values": False,
+ },
+ ValuationResult(
+ values=np.array(
+ [
+ 1 / 6 * np.exp(1 / 4),
+ 1 / 3 * np.exp(1 / 4),
+ 1 / 12 * np.exp(1 / 4) + 1 / 24 * np.exp(1 / 2),
+ 1 / 8 * np.exp(1 / 2),
+ ]
+ )
+ ),
+ {"atol": 0.05},
+ )
+
+
+@pytest.fixture(scope="function")
+def classwise_shapley_exact_solution_normalized(
+ classwise_shapley_exact_solution,
+) -> Tuple[Dict, ValuationResult, Dict]:
+ """
+ It additionally normalizes the values using the argument `normalize_values`. See
+ [classwise.py][pydvl.value.shapley.classwise] for details of the derivation.
+ """
+ values = classwise_shapley_exact_solution[1].values
+ label_zero_coefficient = 1 / np.exp(1 / 4)
+ label_one_coefficient = 1 / (1 / 3 * np.exp(1 / 4) + 2 / 3 * np.exp(1 / 2))
+
+ return (
+ {
+ "normalize_values": True,
+ },
+ ValuationResult(
+ values=np.array(
+ [
+ values[0] * label_zero_coefficient,
+ values[1] * label_zero_coefficient,
+ values[2] * label_one_coefficient,
+ values[3] * label_one_coefficient,
+ ]
+ )
+ ),
+ {"atol": 0.05},
+ )
+
+
+@pytest.fixture(scope="function")
+def classwise_shapley_exact_solution_no_default() -> Tuple[Dict, ValuationResult, Dict]:
+ """
+ Note that this special case doesn't set the utility to 0 if the permutation is
+ empty. See [classwise.py][pydvl.value.shapley.classwise] for details of the
+ derivation.
+ """
+ return (
+ {
+ "use_default_scorer_value": False,
+ "normalize_values": False,
+ },
+ ValuationResult(
+ values=np.array(
+ [
+ 1 / 24 * np.exp(1 / 4),
+ 5 / 24 * np.exp(1 / 4),
+ 1 / 12 * np.exp(1 / 4) + 1 / 24 * np.exp(1 / 2),
+ 1 / 8 * np.exp(1 / 2),
+ ]
+ )
+ ),
+ {"atol": 0.05},
+ )
+
+
+@pytest.fixture(scope="function")
+def classwise_shapley_exact_solution_no_default_allow_empty_set() -> (
+ Tuple[Dict, ValuationResult, Dict]
+):
+ r"""
+ Note that this special case doesn't set the utility to 0 if the permutation is
+ empty and additionally allows $S^{(k)} = \emptyset$. See
+ [classwise.py][pydvl.value.shapley.classwise] for details of the derivation.
+ """
+ return (
+ {
+ "use_default_scorer_value": False,
+ "min_elements_per_label": 0,
+ "normalize_values": False,
+ },
+ ValuationResult(
+ values=np.array(
+ [
+ 3 / 32 + 1 / 32 * np.exp(1 / 4),
+ 3 / 32 + 5 / 32 * np.exp(1 / 4),
+ 5 / 32 * np.exp(1 / 4) + 1 / 32 * np.exp(1 / 2),
+ 1 / 32 * np.exp(1 / 4) + 3 / 32 * np.exp(1 / 2),
+ ]
+ )
+ ),
+ {"atol": 0.05},
+ )
+
+
+@pytest.mark.parametrize("n_samples", [500], ids=lambda x: "n_samples={}".format(x))
+@pytest.mark.parametrize(
+ "n_resample_complement_sets",
+ [1],
+ ids=lambda x: "n_resample_complement_sets={}".format(x),
+)
+@pytest.mark.parametrize(
+ "exact_solution",
+ [
+ "classwise_shapley_exact_solution",
+ "classwise_shapley_exact_solution_normalized",
+ "classwise_shapley_exact_solution_no_default",
+ "classwise_shapley_exact_solution_no_default_allow_empty_set",
+ ],
+)
+def test_classwise_shapley(
+ classwise_shapley_utility: Utility,
+ exact_solution: Tuple[Dict, ValuationResult, Dict],
+ n_samples: int,
+ n_resample_complement_sets: int,
+ request,
+):
+ args, exact_solution, check_args = request.getfixturevalue(exact_solution)
+ values = compute_classwise_shapley_values(
+ classwise_shapley_utility,
+ done=MaxChecks(n_samples),
+ truncation=NoTruncation(),
+ done_sample_complements=MaxChecks(n_resample_complement_sets),
+ **args,
+ progress=True,
+ )
+ check_values(values, exact_solution, **check_args)
+ assert np.all(values.counts == n_samples * n_resample_complement_sets)
+
+
+def test_classwise_scorer_representation():
+ """
+ Tests the (string) representation of the ClassWiseScorer.
+ """
+
+ scorer = ClasswiseScorer("accuracy", initial_label=0)
+ assert str(scorer) == "classwise accuracy"
+ assert repr(scorer) == "ClasswiseAccuracy (scorer=make_scorer(accuracy_score))"
+
+
+@pytest.mark.parametrize("n_element, left_margin, right_margin", [(101, 0.3, 0.4)])
+def test_classwise_scorer_utility(dataset_left_right_margins):
+ """
+ Tests whether the ClassWiseScorer returns the expected utility value.
+ See [classwise.py][pydvl.value.shapley.classwise] for more details.
+ """
+ scorer = ClasswiseScorer("accuracy", initial_label=0)
+ x, y, info = dataset_left_right_margins
+ n_element = len(x)
+ target_in_cls_acc_0 = (info["left_margin"] * 100 + 1) / n_element
+ target_out_of_cls_acc_0 = (info["right_margin"] * 100 + 1) / n_element
+
+ model = ThresholdClassifier()
+ in_cls_acc_0, out_of_cls_acc_0 = scorer.estimate_in_class_and_out_of_class_score(
+ model, x, y
+ )
+ assert np.isclose(in_cls_acc_0, target_in_cls_acc_0)
+ assert np.isclose(out_of_cls_acc_0, target_out_of_cls_acc_0)
+
+ value = scorer(model, x, y)
+ assert np.isclose(value, in_cls_acc_0 * np.exp(out_of_cls_acc_0))
+
+ scorer.label = 1
+ value = scorer(model, x, y)
+ assert np.isclose(value, out_of_cls_acc_0 * np.exp(in_cls_acc_0))
+
+
+@pytest.mark.parametrize("n_element, left_margin, right_margin", [(101, 0.3, 0.4)])
+def test_classwise_scorer_is_symmetric(
+ dataset_left_right_margins,
+):
+ """
+ Tests whether the ClassWiseScorer is symmetric. For a two-class classification the
+ in-class accuracy for the first label needs to match the out-of-class accuracy for
+ the second label. See [classwise.py][pydvl.value.shapley.classwise] for more
+ details.
+ """
+ scorer = ClasswiseScorer("accuracy", initial_label=0)
+ x, y, info = dataset_left_right_margins
+ model = ThresholdClassifier()
+ in_cls_acc_0, out_of_cls_acc_0 = scorer.estimate_in_class_and_out_of_class_score(
+ model, x, y
+ )
+ scorer.label = 1
+ in_cls_acc_1, out_of_cls_acc_1 = scorer.estimate_in_class_and_out_of_class_score(
+ model, x, y
+ )
+ assert in_cls_acc_1 == out_of_cls_acc_0
+ assert in_cls_acc_0 == out_of_cls_acc_1
+
+
+def test_classwise_scorer_accuracies_manual_derivation(
+ classwise_shapley_utility: Utility,
+):
+ """
+ Tests whether the model of the scorer is fitted correctly and returns the expected
+ in-class and out-of-class accuracies. See
+ [classwise.py][pydvl.value.shapley.classwise] for more details.
+ """
+ subsets_zero = list(powerset(np.array((0, 1))))
+ subsets_one = list(powerset(np.array((2, 3))))
+ subsets_zero = [tuple(s) for s in subsets_zero]
+ subsets_one = [tuple(s) for s in subsets_one]
+ target_accuracies_zero = pd.DataFrame(
+ [
+ [0, 1 / 4, 1 / 4, 1 / 4],
+ [3 / 4, 1 / 4, 1 / 2, 1 / 4],
+ [3 / 4, 1 / 2, 1 / 2, 1 / 2],
+ [3 / 4, 1 / 2, 1 / 2, 1 / 2],
+ ],
+ index=subsets_zero,
+ columns=subsets_one,
+ )
+ target_accuracies_one = pd.DataFrame(
+ [
+ [0, 1 / 4, 1 / 4, 1 / 4],
+ [0, 1 / 4, 1 / 4, 1 / 4],
+ [0, 1 / 4, 1 / 4, 1 / 4],
+ [0, 1 / 4, 1 / 4, 1 / 4],
+ ],
+ index=subsets_zero,
+ columns=subsets_one,
+ )
+ model = classwise_shapley_utility.model
+ scorer = cast(ClasswiseScorer, classwise_shapley_utility.scorer)
+ scorer.label = 0
+
+ for set_zero_idx in range(len(subsets_zero)):
+ for set_one_idx in range(len(subsets_one)):
+ indices = list(subsets_zero[set_zero_idx] + subsets_one[set_one_idx])
+ (
+ x_train,
+ y_train,
+ ) = classwise_shapley_utility.data.get_training_data(indices)
+ classwise_shapley_utility.model.fit(x_train, y_train)
+
+ (
+ x_test,
+ y_test,
+ ) = classwise_shapley_utility.data.get_test_data()
+ (
+ in_cls_acc_0,
+ in_cls_acc_1,
+ ) = scorer.estimate_in_class_and_out_of_class_score(model, x_test, y_test)
+ assert (
+ in_cls_acc_0 == target_accuracies_zero.iloc[set_zero_idx, set_one_idx]
+ )
+ assert in_cls_acc_1 == target_accuracies_one.iloc[set_zero_idx, set_one_idx]
+
+
+@pytest.mark.parametrize("n_element, left_margin, right_margin", [(101, 0.3, 0.4)])
+def test_classwise_scorer_accuracies_left_right_margins(dataset_left_right_margins):
+ """
+ Tests whether the model of the scorer is fitted correctly and returns the expected
+ in-class and out-of-class accuracies. See
+ [classwise.py][pydvl.value.shapley.classwise] for more details.
+ """
+ scorer = ClasswiseScorer("accuracy", initial_label=0)
+ x, y, info = dataset_left_right_margins
+ n_element = len(x)
+
+ target_in_cls_acc_0 = (info["left_margin"] * 100 + 1) / n_element
+ target_out_of_cls_acc_0 = (info["right_margin"] * 100 + 1) / n_element
+
+ model = ThresholdClassifier()
+ in_cls_acc_0, out_of_cls_acc_0 = scorer.estimate_in_class_and_out_of_class_score(
+ model, x, y
+ )
+ assert np.isclose(in_cls_acc_0, target_in_cls_acc_0)
+ assert np.isclose(out_of_cls_acc_0, target_out_of_cls_acc_0)
+
+
+def test_closed_form_linear_classifier(
+ classwise_shapley_utility: Utility,
+):
+ """
+ Tests whether the model is fitted correctly and contains the right $\beta$
+ parameter. See [classwise.py][pydvl.value.shapley.classwise] for more details.
+ """
+ subsets_zero = list(powerset(np.array((0, 1))))
+ subsets_one = list(powerset(np.array((2, 3))))
+ subsets_zero = [tuple(s) for s in subsets_zero]
+ subsets_one = [tuple(s) for s in subsets_one]
+ target_betas = pd.DataFrame(
+ [
+ [np.nan, 1 / 3, 1 / 4, 7 / 25],
+ [0, 3 / 10, 4 / 17, 7 / 26],
+ [0, 3 / 13, 1 / 5, 7 / 29],
+ [0, 3 / 14, 4 / 21, 7 / 30],
+ ],
+ index=subsets_zero,
+ columns=subsets_one,
+ )
+ scorer = cast(ClasswiseScorer, classwise_shapley_utility.scorer)
+ scorer.label = 0
+
+ for set_zero_idx in range(len(subsets_zero)):
+ for set_one_idx in range(len(subsets_one)):
+ indices = list(subsets_zero[set_zero_idx] + subsets_one[set_one_idx])
+ (
+ x_train,
+ y_train,
+ ) = classwise_shapley_utility.data.get_training_data(indices)
+ classwise_shapley_utility.model.fit(x_train, y_train)
+ fitted_beta = classwise_shapley_utility.model._beta # noqa
+ target_beta = target_betas.iloc[set_zero_idx, set_one_idx]
+ assert (
+ np.isnan(fitted_beta)
+ if np.isnan(target_beta)
+ else fitted_beta == target_beta
+ )
+
+
+class ThresholdClassifier:
+ def fit(self, x: NDArray, y: NDArray) -> float:
+ raise NotImplementedError("Mock model")
+
+ def predict(self, x: NDArray) -> NDArray:
+ y = 0.5 < x
+ return y[:, 0].astype(int)
+
+ def score(self, x: NDArray, y: NDArray) -> float:
+ raise NotImplementedError("Mock model")
+
+
+class ClosedFormLinearClassifier:
+ def __init__(self):
+ self._beta = None
+
+ def fit(self, x: NDArray, y: NDArray) -> float:
+ v = x[:, 0]
+ self._beta = np.dot(v, y) / np.dot(v, v)
+ return -1
+
+ def predict(self, x: NDArray) -> NDArray:
+ if self._beta is None:
+ raise AttributeError("Model not fitted")
+
+ x = x[:, 0]
+ probs = self._beta * x
+ return np.clip(np.round(probs + 1e-10), 0, 1).astype(int)
+
+ def score(self, x: NDArray, y: NDArray) -> float:
+ pred_y = self.predict(x)
+ return np.sum(pred_y == y) / 4
+
+
+@pytest.fixture(scope="function")
+def classwise_shapley_utility(
+ dataset_manual_derivation: Dataset,
+) -> Utility:
+ return Utility(
+ ClosedFormLinearClassifier(),
+ dataset_manual_derivation,
+ ClasswiseScorer("accuracy"),
+ catch_errors=False,
+ )
+
+
+@pytest.fixture(scope="function")
+def dataset_manual_derivation() -> Dataset:
+ """
+ See [classwise.py][pydvl.value.shapley.classwise] for more details.
+ """
+ x_train = np.arange(1, 5).reshape([-1, 1])
+ y_train = np.array([0, 0, 1, 1])
+ x_test = x_train
+ y_test = np.array([0, 0, 0, 1])
+ return Dataset(x_train, y_train, x_test, y_test)
+
+
+@pytest.fixture(scope="function")
+def dataset_left_right_margins(
+ n_element: int, left_margin: float, right_margin: float
+) -> Tuple[NDArray[np.float_], NDArray[np.int_], Dict[str, float]]:
+ """
+ The label set is represented as 0000011100011111, with adjustable left and right
+ margins. The left margin denotes the percentage of zeros at the beginning, while the
+ right margin denotes the percentage of ones at the end. Accuracy can be efficiently
+ calculated using a closed-form solution.
+ """
+ x = np.linspace(0, 1, n_element)
+ y = ((left_margin <= x) & (x < 0.5)) | ((1 - right_margin) <= x)
+ y = y.astype(int)
+ x = np.expand_dims(x, -1)
+ return x, y, {"left_margin": left_margin, "right_margin": right_margin}
diff --git a/tests/value/test_stopping.py b/tests/value/test_stopping.py
index c57d5f56f..7399dc9c3 100644
--- a/tests/value/test_stopping.py
+++ b/tests/value/test_stopping.py
@@ -193,6 +193,6 @@ def test_max_checks():
assert not done(v)
done = MaxChecks(5)
- for _ in range(5):
+ for _ in range(4):
assert not done(v)
assert done(v)