diff --git a/dvc.yaml b/dvc.yaml index 5bb7a802..8017ed1a 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -72,8 +72,8 @@ stages: experiment: ${active.experiments} dataset: ${active.datasets} model: ${active.models} - repetition: ${active.repetitions} method: ${active.valuation_methods} + repetition: ${active.repetitions} cmd: > python -m scripts.calculate_values --experiment-name ${item.experiment} diff --git a/params.yaml b/params.yaml index 5b8e4307..b8f9c98c 100644 --- a/params.yaml +++ b/params.yaml @@ -24,12 +24,9 @@ active: - phoneme valuation_methods: - loo - - classwise_shapley - beta_shapley - tmc_shapley - - banzhaf_shapley - - owen_sampling_shapley - - least_core + - classwise_shapley repetitions: - 1 - 2 diff --git a/scripts/calculate_values.py b/scripts/calculate_values.py index a27406f2..6f11d8c3 100644 --- a/scripts/calculate_values.py +++ b/scripts/calculate_values.py @@ -96,9 +96,29 @@ def _calculate_values( mc_config = MemcachedConfig() mc = Client(**asdict(mc_config)["client_config"]) - if dataset_name != mc.get("last_dataset", None): + last_run = mc.get("last_run", None) + if last_run is None or ( + experiment_name != last_run["experiment"] + or dataset_name != last_run["dataset"] + or model_name != last_run["model"] + or ( + valuation_method_name != last_run["method"] + and ( + valuation_method_name == "classwise_shapley" + or last_run["last_method"] == "classwise_shapley" + ) + ) + ): mc.flush_all() - mc.set("last_dataset", dataset_name) + mc.set( + "last_run", + { + "experiment": experiment_name, + "dataset": dataset_name, + "model": model_name, + "method": valuation_method_name, + }, + ) val_set = Accessor.datasets(experiment_name, dataset_name).loc[0, "val_set"]