Skip to content

Commit

Permalink
Remove leftover uses of enable_cache argument
Browse files Browse the repository at this point in the history
  • Loading branch information
AnesBenmerzoug committed Nov 24, 2023
1 parent fbc96cf commit ab4578a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/pydvl/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def _utility(self, indices: FrozenSet) -> float:
"""Clones the model, fits it on a subset of the training data
and scores it on the test data.
If the object is constructed with `enable_cache = True`, results are
If an instance of [CacheBackend][pydvl.utils.caching.base.CacheBackend]
is passed during construction, results are
memoized to avoid duplicate computation. This is useful in particular
when computing utilities of permutations of indices or when randomly
sampling from the powerset of indices.
Expand Down
5 changes: 3 additions & 2 deletions tests/value/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def score(self, x: NDArray, y: NDArray) -> float:
score_range=(0, x.sum() / x.max()),
catch_errors=False,
show_warnings=True,
enable_cache=False,
)


Expand Down Expand Up @@ -122,7 +121,9 @@ def linear_shapley(cache, linear_dataset, scorer, n_jobs):

if u is None:
u = Utility(
LinearRegression(), data=linear_dataset, scorer=scorer, enable_cache=False
LinearRegression(),
data=linear_dataset,
scorer=scorer,
)
exact_values = combinatorial_exact_shapley(u, progress=False, n_jobs=n_jobs)
cache.set(u_cache_key, u)
Expand Down
5 changes: 4 additions & 1 deletion tests/value/shapley/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def knn_loss_function(labels, predictions, n_classes=3):
)

utility = Utility(
model, data=data, scorer=scorer, show_warnings=False, enable_cache=False
model,
data=data,
scorer=scorer,
show_warnings=False,
)
exact_values = combinatorial_exact_shapley(
utility, progress=False, n_jobs=min(len(data), available_cpus())
Expand Down

0 comments on commit ab4578a

Please sign in to comment.