Skip to content

Commit

Permalink
Fix doctests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Jul 25, 2023
1 parent 495211e commit 3720cd7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/pydvl/utils/parallel/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def init_parallel_backend(
>>> config = ParallelConfig(backend="ray")
>>> parallel_backend = init_parallel_backend(config)
>>> parallel_backend
<RayParallelBackend: {'address': None, 'logging_level': 30, 'num_cpus': None}>
<RayParallelBackend: {'address': None, 'logging_level': 30, '_temp_dir': None, \
'num_cpus': None}>
"""
try:
Expand Down
6 changes: 5 additions & 1 deletion src/pydvl/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _utility(self, indices: FrozenSet) -> float:
model = self.model

# Special case for classification with only one class
if y_train.dtype == int and len(np.unique(y_train)) == 1:
if (
y_train.dtype == int
and len(np.unique(y_train.reshape(-1))) == 1
and self.scorer._name == "accuracy"
):
unique_cls = y_train[0]
all_cls, counts = np.unique(y_test, return_counts=True)
cls_idx = np.argwhere(all_cls == unique_cls)[0, 0]
Expand Down

0 comments on commit 3720cd7

Please sign in to comment.