Skip to content

Commit

Permalink
Experimental: Store history.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Oct 16, 2023
1 parent 239c3f3 commit ad71a0a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from itertools import islice
from typing import Iterable, List, Optional, Protocol, Tuple, Type, cast

import numpy as np
import scipy as sp
from deprecate import deprecated
from tqdm import tqdm
Expand Down Expand Up @@ -182,6 +183,7 @@ def compute_generic_semivalues(
n_jobs: int = 1,
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
log_folder: Optional[Path] = None,
) -> ValuationResult:
"""Computes semi-values for a given utility function and subset sampler.
Expand All @@ -204,6 +206,8 @@ def compute_generic_semivalues(
config: Object configuring parallel computation, with cluster
address, number of cpus, etc.
progress: Whether to display a progress bar.
log_folder: If set history of each valuation result is stored in the
specified folder.
Returns:
Object with the results.
Expand Down Expand Up @@ -246,6 +250,7 @@ def compute_generic_semivalues(

sampler_it = iter(sampler)
pbar = tqdm(disable=not progress, total=100, unit="%")
history = [] if log_folder else None

with init_executor(
max_workers=max_workers, config=config, cancel_futures=True
Expand All @@ -259,7 +264,13 @@ def compute_generic_semivalues(
for future in completed:
for idx, marginal in future.result():
result.update(idx, marginal)
if log_folder is not None:
history.append(result.values[result.indices])

if done(result):
if log_folder:
np.savetxt(log_folder / "history.txt", np.array(history))

return result

# Ensure that we always have n_submitted_jobs running
Expand Down
14 changes: 14 additions & 0 deletions src/pydvl/value/shapley/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@
"""
import logging
import numbers
import os
from concurrent.futures import FIRST_COMPLETED, Future, wait
from copy import copy
from pathlib import Path
from typing import Callable, Optional, Set, Tuple, Union, cast

import numpy as np
Expand Down Expand Up @@ -252,6 +254,7 @@ def compute_classwise_shapley_values(
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
seed: Optional[Seed] = None,
log_folder: Optional[Path] = None,
) -> ValuationResult:
r"""
Computes an approximate Class-wise Shapley value by sampling independent
Expand Down Expand Up @@ -291,6 +294,8 @@ def compute_classwise_shapley_values(
config: Parallel configuration.
progress: Whether to display a progress bar.
seed: Either an instance of a numpy random number generator or a seed for it.
log_folder: If set history of each valuation result is stored in the
specified folder.
Returns:
ValuationResult object containing computed data values.
Expand Down Expand Up @@ -326,6 +331,7 @@ def compute_classwise_shapley_values(
)
terminate_exec = False
seed_sequence = ensure_seed_sequence(seed)
history = [] if log_folder else None

with init_executor(max_workers=n_jobs, config=config) as executor:
pending: Set[Future] = set()
Expand All @@ -335,6 +341,11 @@ def compute_classwise_shapley_values(
)
for future in completed_futures:
accumulated_result += future.result()
if log_folder is not None:
history.append(
accumulated_result.values[accumulated_result.indices]
)

if done(accumulated_result):
terminate_exec = True
break
Expand Down Expand Up @@ -363,6 +374,9 @@ def compute_classwise_shapley_values(
if normalize_values:
result = _normalize_classwise_shapley_values(result, u)

if log_folder:
np.savetxt(log_folder / "history.txt", np.array(history))

return result


Expand Down
12 changes: 12 additions & 0 deletions src/pydvl/value/shapley/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from concurrent.futures import FIRST_COMPLETED, Future, wait
from functools import reduce
from itertools import cycle, takewhile
from pathlib import Path
from typing import Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -142,6 +143,7 @@ def permutation_montecarlo_shapley(
config: ParallelConfig = ParallelConfig(),
progress: bool = False,
seed: Seed = None,
log_folder: Optional[Path] = None,
) -> ValuationResult:
r"""Computes an approximate Shapley value by sampling independent
permutations of the index set, approximating the sum:
Expand Down Expand Up @@ -189,6 +191,8 @@ def permutation_montecarlo_shapley(
number of cpus, etc.
progress: Whether to display a progress bar.
seed: Either an instance of a numpy random number generator or a seed for it.
log_folder: If set history of each valuation result is stored in the
specified folder.
Returns:
Object with the data values.
Expand All @@ -206,6 +210,7 @@ def permutation_montecarlo_shapley(
)

pbar = tqdm(disable=not progress, total=100, unit="%")
history = [] if log_folder else None

with init_executor(
max_workers=max_workers, config=config, cancel_futures=CancellationPolicy.ALL
Expand All @@ -222,7 +227,14 @@ def permutation_montecarlo_shapley(
result += future.result()
# we could check outside the loop, but that means more
# submissions if the stopping criterion is unstable

if log_folder is not None:
history.append(result.values[result.indices])

if done(result):
if log_folder:
np.savetxt(log_folder / "history.txt", np.array(history))

return result

# Ensure that we always have n_submitted_jobs in the queue or running
Expand Down

0 comments on commit ad71a0a

Please sign in to comment.