diff --git a/scripts/calculate_values.py b/scripts/calculate_values.py index 7bf87d6f4..e4dfbd069 100644 --- a/scripts/calculate_values.py +++ b/scripts/calculate_values.py @@ -22,8 +22,9 @@ import click import numpy as np -from pydvl.utils import DiskCacheBackend, Scorer, Utility +from pydvl.utils import MemcachedClientConfig, Scorer, Utility +from re_classwise_shapley.cache import PrefixedMemcachedCacheBackend from re_classwise_shapley.io import Accessor from re_classwise_shapley.log import setup_logger from re_classwise_shapley.model import instantiate_model @@ -94,13 +95,9 @@ def _calculate_values( ) params = load_params_fast() - - cache = None - if "cache_group" in params["valuation_methods"][valuation_method_name]: - cache_group = params["valuation_methods"][valuation_method_name]["cache_group"] - cache = DiskCacheBackend( - Path(".cache") / experiment_name / dataset_name / model_name / cache_group - ) + cache = PrefixedMemcachedCacheBackend( + config=MemcachedClientConfig(), prefix=f"{experiment_name}/{dataset_name}" + ) val_set = Accessor.datasets(experiment_name, dataset_name).loc[0, "val_set"] diff --git a/scripts/evaluate_metrics.py b/scripts/evaluate_metrics.py index 58c8f45b5..af995d593 100644 --- a/scripts/evaluate_metrics.py +++ b/scripts/evaluate_metrics.py @@ -22,9 +22,10 @@ import click import pandas as pd from pydvl.parallel import ParallelConfig -from pydvl.utils import DiskCacheBackend +from pydvl.utils import DiskCacheBackend, MemcachedClientConfig from pydvl.utils.functional import maybe_add_argument +from re_classwise_shapley.cache import PrefixedMemcachedCacheBackend from re_classwise_shapley.io import Accessor from re_classwise_shapley.log import setup_logger from re_classwise_shapley.metric import MetricRegistry @@ -126,20 +127,9 @@ def _evaluate_metrics( n_pipeline_step = 5 seed = pipeline_seed(repetition_id, n_pipeline_step) - - cache = None - if ( - "eval_model" in metric_kwargs - and "cache_group" in params["valuation_methods"][valuation_method_name] - ): - cache_group = params["valuation_methods"][valuation_method_name]["cache_group"] - cache = DiskCacheBackend( - Path(".cache") - / experiment_name - / dataset_name - / metric_kwargs["eval_model"] - / cache_group - ) + cache = PrefixedMemcachedCacheBackend( + config=MemcachedClientConfig(), prefix=f"{experiment_name}/{dataset_name}" + ) logger.info("Evaluating metric...") with n_threaded(n_threads=1): diff --git a/src/re_classwise_shapley/cache.py b/src/re_classwise_shapley/cache.py new file mode 100644 index 000000000..a5d3db8bf --- /dev/null +++ b/src/re_classwise_shapley/cache.py @@ -0,0 +1,15 @@ +from typing import Any, Optional + +from pydvl.utils import MemcachedCacheBackend + + +class PrefixedMemcachedCacheBackend(MemcachedCacheBackend): + def __init__(self, *args, prefix: str, **kwargs): + super().__init__(*args, **kwargs) + self._prefix = prefix + + def get(self, key: str) -> Optional[Any]: + return super().get(f"{self._prefix}/{key}") + + def set(self, key: str, value: Any) -> None: + super().set(f"{self._prefix}/{key}", value)