Skip to content

Commit

Permalink
Optimize caching.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Dec 19, 2023
1 parent 375bc93 commit 037e28e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ stages:
experiment: ${active.experiments}
dataset: ${active.datasets}
model: ${active.models}
repetition: ${active.repetitions}
method: ${active.valuation_methods}
repetition: ${active.repetitions}
cmd: >
python -m scripts.calculate_values
--experiment-name ${item.experiment}
Expand Down
5 changes: 1 addition & 4 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ active:
- phoneme
valuation_methods:
- loo
- classwise_shapley
- beta_shapley
- tmc_shapley
- banzhaf_shapley
- owen_sampling_shapley
- least_core
- classwise_shapley
repetitions:
- 1
- 2
Expand Down
24 changes: 22 additions & 2 deletions scripts/calculate_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,29 @@ def _calculate_values(

mc_config = MemcachedConfig()
mc = Client(**asdict(mc_config)["client_config"])
if dataset_name != mc.get("last_dataset", None):
last_run = mc.get("last_run", None)
if last_run is None or (
experiment_name != last_run["experiment"]
or dataset_name != last_run["dataset"]
or model_name != last_run["model"]
or (
valuation_method_name != last_run["method"]
and (
valuation_method_name == "classwise_shapley"
or last_run["last_method"] == "classwise_shapley"
)
)
):
mc.flush_all()
mc.set("last_dataset", dataset_name)
mc.set(
"last_run",
{
"experiment": experiment_name,
"dataset": dataset_name,
"model": model_name,
"method": valuation_method_name,
},
)

val_set = Accessor.datasets(experiment_name, dataset_name).loc[0, "val_set"]

Expand Down

0 comments on commit 037e28e

Please sign in to comment.