Skip to content

Commit

Permalink
Merge pull request #437 from aai-institute/feature/filter-converged
Browse files Browse the repository at this point in the history
Stop updating indices as soon as they converge in semivalue computations
  • Loading branch information
mdbenito authored Oct 8, 2023
2 parents 60d8aef + 491dbd1 commit 4c10cc3
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 36 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- No longer using docker within tests to start a memcached server
[PR #444](https://github.com/aai-institute/pyDVL/pull/444)
- Faster semi-value computation with per-index check of stopping criteria (optional)
[PR #437](https://github.com/aai-institute/pyDVL/pull/437)
- Improvements and fixes to notebooks
[PR #436](https://github.com/aai-institute/pyDVL/pull/436)
- Fix initialization of `data_names` in `ValuationResult.zeros()`
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/reporting/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def plot_ci_array(
means = np.mean(data, axis=0)
variances = np.var(data, axis=0, ddof=1)

dummy: ValuationResult[np.int_, np.object_] = ValuationResult(
dummy = ValuationResult[np.int_, np.object_](
algorithm="dummy",
values=means,
variances=variances,
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/value/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __len__(self) -> int:
return len(self._outer_indices)

def __str__(self):
return f"{self.__class__.__name__}"
return self.__class__.__name__

def __repr__(self):
return f"{self.__class__.__name__}({self._indices}, {self._outer_indices})"
Expand Down
41 changes: 34 additions & 7 deletions src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
import warnings
from enum import Enum
from itertools import islice
from typing import Collection, List, Optional, Protocol, Tuple, Type, TypeVar, cast
from typing import Iterable, List, Optional, Protocol, Tuple, Type, cast

import scipy as sp
from deprecate import deprecated
Expand Down Expand Up @@ -143,7 +143,7 @@ def __call__(self, n: int, k: int) -> float:


def _marginal(
u: Utility, coefficient: SVCoefficient, samples: Collection[SampleT]
u: Utility, coefficient: SVCoefficient, samples: Iterable[SampleT]
) -> Tuple[MarginalT, ...]:
"""Computation of marginal utility. This is a helper function for
[compute_generic_semivalues][pydvl.value.semivalues.compute_generic_semivalues].
Expand Down Expand Up @@ -186,6 +186,7 @@ def compute_generic_semivalues(
done: StoppingCriterion,
*,
batch_size: int = 1,
skip_converged: bool = False,
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
Expand All @@ -198,6 +199,15 @@ def compute_generic_semivalues(
coefficient: The semi-value coefficient
done: Stopping criterion.
batch_size: Number of marginal evaluations per single parallel job.
skip_converged: Whether to skip marginal evaluations for indices that
have already converged. **CAUTION**: This is only entirely safe if
the stopping criterion is [MaxUpdates][pydvl.value.stopping.MaxUpdates].
For any other stopping criterion, the convergence status of indices
may change during the computation, or they may be marked as having
converged even though in fact the estimated values are far from the
true values (e.g. for
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError],
you will probably have to carefully adjust the threshold).
n_jobs: Number of parallel jobs to use.
config: Object configuring parallel computation, with cluster
address, number of cpus, etc.
Expand Down Expand Up @@ -262,16 +272,33 @@ def compute_generic_semivalues(

# Ensure that we always have n_submitted_jobs running
try:
for _ in range(n_submitted_jobs - len(pending)):
while len(pending) < n_submitted_jobs:
samples = tuple(islice(sampler_it, batch_size))
if len(samples) == 0:
raise StopIteration

pending.add(
executor.submit(
_marginal, u=u, coefficient=correction, samples=samples
# Filter out samples for indices that have already converged
filtered_samples = samples
if skip_converged and len(done.converged) > 0:
# cloudpickle can't pickle this on python 3.8:
# filtered_samples = filter(
# lambda t: not done.converged[t[0]], samples
# )
filtered_samples = tuple(
(idx, sample)
for idx, sample in samples
if not done.converged[idx]
)

if filtered_samples:
pending.add(
executor.submit(
_marginal,
u=u,
coefficient=correction,
samples=filtered_samples,
)
)
)
except StopIteration:
if len(pending) == 0:
return result
Expand Down
145 changes: 125 additions & 20 deletions src/pydvl/value/stopping.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,94 @@
"""
r"""
Stopping criteria for value computations.
This module provides a basic set of stopping criteria, like [MaxUpdates][pydvl.value.stopping.MaxUpdates],
[MaxTime][pydvl.value.stopping.MaxTime], or [HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others.
These can behave in different ways depending on the context.
For example, [MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
This module provides a basic set of stopping criteria, like
[MaxUpdates][pydvl.value.stopping.MaxUpdates],
[MaxTime][pydvl.value.stopping.MaxTime], or
[HistoryDeviation][pydvl.value.stopping.HistoryDeviation] among others. These
can behave in different ways depending on the context. For example,
[MaxUpdates][pydvl.value.stopping.MaxUpdates] limits
the number of updates to values, which depending on the algorithm may mean a
different number of utility evaluations or imply other computations like solving
a linear or quadratic program.
# Creating stopping criteria
Stopping criteria are callables that are evaluated on a
[ValuationResult][pydvl.value.result.ValuationResult] and return a
[Status][pydvl.utils.status.Status] object. They can be combined using boolean
operators.
## How convergence is determined
Most stopping criteria keep track of the convergence of each index separately
but make global decisions based on the overall convergence of some fraction of
all indices. For example, if we have a stopping criterion that checks whether
the standard error of 90% of values is below a threshold, then methods will keep
updating **all** indices until 90% of them have converged, irrespective of the
quality of the individual estimates, and *without freezing updates* for indices
along the way as values individually attain low standard error.
This has some practical implications, because some values do tend to converge
sooner than others. For example, assume we use the criterion
`AbsoluteStandardError(0.02) | MaxUpdates(1000)`. Then values close to 0 might
be marked as "converged" rather quickly because they fulfill the first
criterion, say after 20 iterations, despite being poor estimates. Because other
indices take much longer to have low standard error and the criterion is a
global check, the "converged" ones keep being updated and end up being good
estimates. In this case, this has been beneficial, but one might not wish for
converged values to be updated, if one is sure that the criterion is adequate
for individual values.
[Semi-value methods][pydvl.value.semivalues] include a parameter
`skip_converged` that allows to skip the computation of values that have
converged. The way to avoid doing this too early is to use a more stringent
check, e.g. `AbsoluteStandardError(1e-3) | MaxUpdates(1000)`. With
`skip_converged=True` this check can still take less time than the first one,
despite requiring more iterations for some indices.
## Choosing a stopping criterion
The choice of a stopping criterion greatly depends on the algorithm and the
context. A safe bet is to combine a [MaxUpdates][pydvl.value.stopping.MaxUpdates]
or a [MaxTime][pydvl.value.stopping.MaxTime] with a
[HistoryDeviation][pydvl.value.stopping.HistoryDeviation] or an
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError]. The former
will ensure that the computation does not run for too long, while the latter
will try to achieve results that are stable enough. Note however that if the
threshold is too strict, one will always end up running until a maximum number
of iterations or time. Also keep in mind that different values converge at
different times, so you might want to use tight thresholds and `skip_converged`
as described above for semi-values.
??? Example
```python
from pydvl.value import AbsoluteStandardError, MaxUpdates, compute_banzhaf_semivalues
utility = ... # some utility object
criterion = AbsoluteStandardError(threshold=1e-3, burn_in=32) | MaxUpdates(1000)
values = compute_banzhaf_semivalues(
utility,
criterion,
skip_converged=True, # skip values that have converged (CAREFUL!)
)
```
This will compute the Banzhaf semivalues for `utility` until either the
absolute standard error is below `1e-3` or `1000` updates have been
performed. The `burn_in` parameter is used to discard the first `32` updates
from the computation of the standard error. The `skip_converged` parameter
is used to avoid computing more marginals for indices that have converged,
which is useful if
[AbsoluteStandardError][pydvl.value.stopping.AbsoluteStandardError] is met
before [MaxUpdates][pydvl.value.stopping.MaxUpdates] for some indices.
!!! Warning
Be careful not to reuse the same stopping criterion for different
computations. The object has state and will not be reset between calls to
value computation methods. If you need to reuse the same criterion, you
should create a new instance.
## Creating stopping criteria
The easiest way is to declare a function implementing the interface
[StoppingCriterionCallable][pydvl.value.stopping.StoppingCriterionCallable] and
Expand All @@ -18,19 +97,19 @@
that can be composed with other stopping criteria.
Alternatively, and in particular if reporting of completion is required, one can
inherit from this class and implement the abstract methods
[_check][pydvl.value.stopping.StoppingCriterion._check] and
inherit from this class and implement the abstract methods `_check` and
[completion][pydvl.value.stopping.StoppingCriterion.completion].
# Composing stopping criteria
## Combining stopping criteria
Objects of type [StoppingCriterion][pydvl.value.stopping.StoppingCriterion] can
be composed with the binary operators `&` (*and*), and `|` (*or*), following the
be combined with the binary operators `&` (*and*), and `|` (*or*), following the
truth tables of [Status][pydvl.utils.status.Status]. The unary operator `~`
(*not*) is also supported. See
[StoppingCriterion][pydvl.value.stopping.StoppingCriterion] for details on how
these operations affect the behavior of the stopping criteria.
## References
[^1]: <a name="ghorbani_data_2019"></a>Ghorbani, A., Zou, J., 2019.
Expand Down Expand Up @@ -163,6 +242,15 @@ def converged(self) -> NDArray[np.bool_]:

@property
def name(self):
log = logging.getLogger(__name__)
# This string for the benefit of deprecation searches:
# remove_in="0.8.0"
log.warning(
"The `name` attribute of `StoppingCriterion` is deprecated and will be removed in 0.8.0. "
)
return getattr(self, "_name", type(self).__name__)

def __str__(self):
return type(self).__name__

def __call__(self, result: ValuationResult) -> Status:
Expand All @@ -182,23 +270,23 @@ def __and__(self, other: "StoppingCriterion") -> "StoppingCriterion":
fun=lambda result: self._check(result) & other._check(result),
converged=lambda: self.converged & other.converged,
completion=lambda: min(self.completion(), other.completion()),
name=f"Composite StoppingCriterion: {self.name} AND {other.name}",
name=f"Composite StoppingCriterion: {str(self)} AND {str(other)}",
)(modify_result=self.modify_result or other.modify_result)

def __or__(self, other: "StoppingCriterion") -> "StoppingCriterion":
return make_criterion(
fun=lambda result: self._check(result) | other._check(result),
converged=lambda: self.converged | other.converged,
completion=lambda: max(self.completion(), other.completion()),
name=f"Composite StoppingCriterion: {self.name} OR {other.name}",
name=f"Composite StoppingCriterion: {str(self)} OR {str(other)}",
)(modify_result=self.modify_result or other.modify_result)

def __invert__(self) -> "StoppingCriterion":
return make_criterion(
fun=lambda result: ~self._check(result),
converged=lambda: ~self.converged,
completion=lambda: 1 - self.completion(),
name=f"Composite StoppingCriterion: NOT {self.name}",
name=f"Composite StoppingCriterion: NOT {str(self)}",
)(modify_result=self.modify_result)


Expand Down Expand Up @@ -239,8 +327,7 @@ def converged(self) -> NDArray[np.bool_]:
return super().converged
return converged()

@property
def name(self):
def __str__(self):
return self._name

def completion(self) -> float:
Expand All @@ -254,13 +341,13 @@ def completion(self) -> float:
class AbsoluteStandardError(StoppingCriterion):
r"""Determine convergence based on the standard error of the values.
If $s_i$ is the standard error for datum $i$ and $v_i$ its value, then this
criterion returns [Converged][pydvl.utils.status.Status] if
$s_i < \epsilon$ for all $i$ and a threshold value $\epsilon \gt 0$.
If $s_i$ is the standard error for datum $i$, then this criterion returns
[Converged][pydvl.utils.status.Status] if $s_i < \epsilon$ for all $i$ and a
threshold value $\epsilon \gt 0$.
Args:
threshold: A value is considered to have converged if the standard
error is below this value. A way of choosing it is to pick some
error is below this threshold. A way of choosing it is to pick some
percentage of the range of the values. For Shapley values this is
the difference between the maximum and minimum of the utility
function (to see this substitute the maximum and minimum values of
Expand All @@ -270,7 +357,7 @@ class AbsoluteStandardError(StoppingCriterion):
burn_in: The number of iterations to ignore before checking for
convergence. This is required because computations typically start
with zero variance, as a result of using
[empty()][pydvl.value.result.ValuationResult.empty]. The default is
[zeros()][pydvl.value.result.ValuationResult.zeros]. The default is
set to an arbitrary minimum which is usually enough but may need to
be increased.
"""
Expand All @@ -295,6 +382,9 @@ def _check(self, result: ValuationResult) -> Status:
return Status.Converged
return Status.Pending

def __str__(self):
return f"AbsoluteStandardError(threshold={self.threshold}, fraction={self.fraction}, burn_in={self.burn_in})"


class StandardError(AbsoluteStandardError):
@deprecated(target=AbsoluteStandardError, deprecated_in="0.6.0", remove_in="0.8.0")
Expand Down Expand Up @@ -333,6 +423,9 @@ def completion(self) -> float:
return min(1.0, self._count / self.n_checks)
return 0.0

def __str__(self):
return f"MaxChecks(n_checks={self.n_checks})"


class MaxUpdates(StoppingCriterion):
"""Terminate if any number of value updates exceeds or equals the given
Expand Down Expand Up @@ -377,6 +470,9 @@ def completion(self) -> float:
return self.last_max / self.n_updates
return 0.0

def __str__(self):
return f"MaxUpdates(n_updates={self.n_updates})"


class MinUpdates(StoppingCriterion):
"""Terminate as soon as all value updates exceed or equal the given threshold.
Expand Down Expand Up @@ -414,6 +510,9 @@ def completion(self) -> float:
return self.last_min / self.n_updates
return 0.0

def __str__(self):
return f"MinUpdates(n_updates={self.n_updates})"


class MaxTime(StoppingCriterion):
"""Terminate if the computation time exceeds the given number of seconds.
Expand Down Expand Up @@ -447,6 +546,9 @@ def completion(self) -> float:
return 0.0
return (time() - self.start) / self.max_seconds

def __str__(self):
return f"MaxTime(seconds={self.max_seconds})"


class HistoryDeviation(StoppingCriterion):
r"""A simple check for relative distance to a previous step in the
Expand Down Expand Up @@ -527,3 +629,6 @@ def _check(self, r: ValuationResult) -> Status:
if np.all(self._converged):
return Status.Converged
return Status.Pending

def __str__(self):
return f"HistoryDeviation(n_steps={self.n_steps}, rtol={self.rtol})"
Loading

0 comments on commit 4c10cc3

Please sign in to comment.