Skip to content

Commit

Permalink
Fix caching problem due to execution order.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Dec 20, 2023
1 parent 244d8d6 commit a7cdf5c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
4 changes: 3 additions & 1 deletion scripts/calculate_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _calculate_values(
mc = Client(**asdict(mc_config)["client_config"])
last_run = mc.get("last_run", None)
if last_run is None or (
experiment_name != last_run["experiment"]
"calculate_values" != last_run.get("stage", "")
or experiment_name != last_run["experiment"]
or dataset_name != last_run["dataset"]
or model_name != last_run["model"]
or (
Expand All @@ -113,6 +114,7 @@ def _calculate_values(
mc.set(
"last_run",
{
"stage": "calculate_values",
"experiment": experiment_name,
"dataset": dataset_name,
"model": model_name,
Expand Down
30 changes: 30 additions & 0 deletions scripts/evaluate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,36 @@ def _evaluate_metrics(
n_pipeline_step = 5
seed = pipeline_seed(repetition_id, n_pipeline_step)

mc_config = MemcachedConfig()
mc = Client(**asdict(mc_config)["client_config"])
last_run = mc.get("last_run", None)
if last_run is None or (
"evaluate_metrics" != last_run.get("stage", "")
or experiment_name != last_run["experiment"]
or dataset_name != last_run["dataset"]
or model_name != last_run["model"]
or metric_name != last_run["metric"]
or (
valuation_method_name != last_run["method"]
and (
valuation_method_name == "classwise_shapley"
or last_run["method"] == "classwise_shapley"
)
)
):
mc.flush_all()
mc.set(
"last_run",
{
"stage": "evaluate_metrics",
"experiment": experiment_name,
"dataset": dataset_name,
"model": model_name,
"method": valuation_method_name,
"metric": metric_name,
},
)

logger.info("Evaluating metric...")
with n_threaded(n_threads=1):
metric_values, metric_curve = metric_fn(
Expand Down
40 changes: 25 additions & 15 deletions scripts/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,31 @@ def run_pipeline():
repetition_id,
)

for metric_name in params["experiments"][experiment_name]["metrics"].keys():
logger.info(
f"Calculate metric {metric_name} for dataset {dataset_name}, "
f"valuation method {valuation_method_name} and seed "
f"{repetition_id}."
)
logger.info(f"Evaluate metric {metric_name}.")
_evaluate_metrics(
experiment_name,
dataset_name,
model_name,
valuation_method_name,
repetition_id,
metric_name,
)
for (
dataset_name,
metric_name,
valuation_method_name,
repetition_id,
) in product(
active_params["datasets"],
params["experiments"][experiment_name]["metrics"].keys(),
active_params["valuation_methods"],
active_params["repetitions"],
):
logger.info(
f"Calculate metric {metric_name} for dataset {dataset_name}, "
f"valuation method {valuation_method_name} and seed "
f"{repetition_id}."
)
logger.info(f"Evaluate metric {metric_name}.")
_evaluate_metrics(
experiment_name,
dataset_name,
model_name,
valuation_method_name,
repetition_id,
metric_name,
)

logger.info(f"Render plots for {experiment_name} and {model_name}.")
_render_plots(experiment_name, model_name)
Expand Down

0 comments on commit a7cdf5c

Please sign in to comment.