From 81e8cb4d899f33ab2e9672bfa45e6775b52e1a8d Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Mon, 29 Jan 2024 15:05:46 +0100 Subject: [PATCH] Add cache group field to cache prefix. --- scripts/calculate_values.py | 11 +++++++---- scripts/evaluate_metrics.py | 17 +++++++++++------ src/re_classwise_shapley/cache.py | 4 ++-- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/scripts/calculate_values.py b/scripts/calculate_values.py index e4dfbd06..ffd0ed03 100644 --- a/scripts/calculate_values.py +++ b/scripts/calculate_values.py @@ -18,7 +18,6 @@ import os import pickle import time -from pathlib import Path import click import numpy as np @@ -95,9 +94,13 @@ def _calculate_values( ) params = load_params_fast() - cache = PrefixedMemcachedCacheBackend( - config=MemcachedClientConfig(), prefix=f"{experiment_name}/{dataset_name}" - ) + cache = None + if "cache_group" in params["valuation_methods"][valuation_method_name]: + cache_group = params["valuation_methods"][valuation_method_name]["cache_group"] + prefix = f"{experiment_name}/{dataset_name}/{cache_group}" + cache = PrefixedMemcachedCacheBackend( + config=MemcachedClientConfig(), prefix=prefix + ) val_set = Accessor.datasets(experiment_name, dataset_name).loc[0, "val_set"] diff --git a/scripts/evaluate_metrics.py b/scripts/evaluate_metrics.py index af995d59..d5e5be57 100644 --- a/scripts/evaluate_metrics.py +++ b/scripts/evaluate_metrics.py @@ -13,16 +13,14 @@ directory. The metrics are usually stored as `*.csv` files. Each metric consists of a single value and a curve. The curve is stored as `*.curve.csv` file. """ - import logging import os from functools import partial, reduce -from pathlib import Path import click import pandas as pd from pydvl.parallel import ParallelConfig -from pydvl.utils import DiskCacheBackend, MemcachedClientConfig +from pydvl.utils import MemcachedClientConfig from pydvl.utils.functional import maybe_add_argument from re_classwise_shapley.cache import PrefixedMemcachedCacheBackend @@ -127,9 +125,16 @@ def _evaluate_metrics( n_pipeline_step = 5 seed = pipeline_seed(repetition_id, n_pipeline_step) - cache = PrefixedMemcachedCacheBackend( - config=MemcachedClientConfig(), prefix=f"{experiment_name}/{dataset_name}" - ) + cache = None + if ( + "eval_model" in metric_kwargs + and "cache_group" in params["valuation_methods"][valuation_method_name] + ): + cache_group = params["valuation_methods"][valuation_method_name]["cache_group"] + prefix = f"{experiment_name}/{dataset_name}/{cache_group}" + cache = PrefixedMemcachedCacheBackend( + config=MemcachedClientConfig(), prefix=prefix + ) logger.info("Evaluating metric...") with n_threaded(n_threads=1): diff --git a/src/re_classwise_shapley/cache.py b/src/re_classwise_shapley/cache.py index a5d3db8b..ff83dd96 100644 --- a/src/re_classwise_shapley/cache.py +++ b/src/re_classwise_shapley/cache.py @@ -4,9 +4,9 @@ class PrefixedMemcachedCacheBackend(MemcachedCacheBackend): - def __init__(self, *args, prefix: str, **kwargs): + def __init__(self, *args, **kwargs): + self._prefix = kwargs["prefix"] super().__init__(*args, **kwargs) - self._prefix = prefix def get(self, key: str) -> Optional[Any]: return super().get(f"{self._prefix}/{key}")