diff --git a/scripts/calculate_values.py b/scripts/calculate_values.py index 8a15a178..e60d40d7 100644 --- a/scripts/calculate_values.py +++ b/scripts/calculate_values.py @@ -98,7 +98,8 @@ def _calculate_values( mc = Client(**asdict(mc_config)["client_config"]) last_run = mc.get("last_run", None) if last_run is None or ( - experiment_name != last_run["experiment"] + "calculate_values" != last_run.get("stage", "") + or experiment_name != last_run["experiment"] or dataset_name != last_run["dataset"] or model_name != last_run["model"] or ( @@ -113,6 +114,7 @@ def _calculate_values( mc.set( "last_run", { + "stage": "calculate_values", "experiment": experiment_name, "dataset": dataset_name, "model": model_name, diff --git a/scripts/evaluate_metrics.py b/scripts/evaluate_metrics.py index 57cd20dd..ca330e76 100644 --- a/scripts/evaluate_metrics.py +++ b/scripts/evaluate_metrics.py @@ -125,6 +125,36 @@ def _evaluate_metrics( n_pipeline_step = 5 seed = pipeline_seed(repetition_id, n_pipeline_step) + mc_config = MemcachedConfig() + mc = Client(**asdict(mc_config)["client_config"]) + last_run = mc.get("last_run", None) + if last_run is None or ( + "evaluate_metrics" != last_run.get("stage", "") + or experiment_name != last_run["experiment"] + or dataset_name != last_run["dataset"] + or model_name != last_run["model"] + or metric_name != last_run["metric"] + or ( + valuation_method_name != last_run["method"] + and ( + valuation_method_name == "classwise_shapley" + or last_run["method"] == "classwise_shapley" + ) + ) + ): + mc.flush_all() + mc.set( + "last_run", + { + "stage": "evaluate_metrics", + "experiment": experiment_name, + "dataset": dataset_name, + "model": model_name, + "method": valuation_method_name, + "metric": metric_name, + }, + ) + logger.info("Evaluating metric...") with n_threaded(n_threads=1): metric_values, metric_curve = metric_fn( diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index a1af813b..1dfc4e81 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -93,21 +93,31 @@ def run_pipeline(): repetition_id, ) - for metric_name in params["experiments"][experiment_name]["metrics"].keys(): - logger.info( - f"Calculate metric {metric_name} for dataset {dataset_name}, " - f"valuation method {valuation_method_name} and seed " - f"{repetition_id}." - ) - logger.info(f"Evaluate metric {metric_name}.") - _evaluate_metrics( - experiment_name, - dataset_name, - model_name, - valuation_method_name, - repetition_id, - metric_name, - ) + for ( + dataset_name, + metric_name, + valuation_method_name, + repetition_id, + ) in product( + active_params["datasets"], + params["experiments"][experiment_name]["metrics"].keys(), + active_params["valuation_methods"], + active_params["repetitions"], + ): + logger.info( + f"Calculate metric {metric_name} for dataset {dataset_name}, " + f"valuation method {valuation_method_name} and seed " + f"{repetition_id}." + ) + logger.info(f"Evaluate metric {metric_name}.") + _evaluate_metrics( + experiment_name, + dataset_name, + model_name, + valuation_method_name, + repetition_id, + metric_name, + ) logger.info(f"Render plots for {experiment_name} and {model_name}.") _render_plots(experiment_name, model_name)