Skip to content

Commit

Permalink
Add a re-nameable cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Jan 29, 2024
1 parent 89fe76f commit 8c7e587
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
13 changes: 5 additions & 8 deletions scripts/calculate_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

import click
import numpy as np
from pydvl.utils import DiskCacheBackend, Scorer, Utility
from pydvl.utils import MemcachedClientConfig, Scorer, Utility

from re_classwise_shapley.cache import PrefixedMemcachedCacheBackend
from re_classwise_shapley.io import Accessor
from re_classwise_shapley.log import setup_logger
from re_classwise_shapley.model import instantiate_model
Expand Down Expand Up @@ -94,13 +95,9 @@ def _calculate_values(
)

params = load_params_fast()

cache = None
if "cache_group" in params["valuation_methods"][valuation_method_name]:
cache_group = params["valuation_methods"][valuation_method_name]["cache_group"]
cache = DiskCacheBackend(
Path(".cache") / experiment_name / dataset_name / model_name / cache_group
)
cache = PrefixedMemcachedCacheBackend(
config=MemcachedClientConfig(), prefix=f"{experiment_name}/{dataset_name}"
)

val_set = Accessor.datasets(experiment_name, dataset_name).loc[0, "val_set"]

Expand Down
20 changes: 5 additions & 15 deletions scripts/evaluate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
import click
import pandas as pd
from pydvl.parallel import ParallelConfig
from pydvl.utils import DiskCacheBackend
from pydvl.utils import DiskCacheBackend, MemcachedClientConfig
from pydvl.utils.functional import maybe_add_argument

from re_classwise_shapley.cache import PrefixedMemcachedCacheBackend
from re_classwise_shapley.io import Accessor
from re_classwise_shapley.log import setup_logger
from re_classwise_shapley.metric import MetricRegistry
Expand Down Expand Up @@ -126,20 +127,9 @@ def _evaluate_metrics(

n_pipeline_step = 5
seed = pipeline_seed(repetition_id, n_pipeline_step)

cache = None
if (
"eval_model" in metric_kwargs
and "cache_group" in params["valuation_methods"][valuation_method_name]
):
cache_group = params["valuation_methods"][valuation_method_name]["cache_group"]
cache = DiskCacheBackend(
Path(".cache")
/ experiment_name
/ dataset_name
/ metric_kwargs["eval_model"]
/ cache_group
)
cache = PrefixedMemcachedCacheBackend(
config=MemcachedClientConfig(), prefix=f"{experiment_name}/{dataset_name}"
)

logger.info("Evaluating metric...")
with n_threaded(n_threads=1):
Expand Down
15 changes: 15 additions & 0 deletions src/re_classwise_shapley/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any, Optional

from pydvl.utils import MemcachedCacheBackend


class PrefixedMemcachedCacheBackend(MemcachedCacheBackend):
def __init__(self, *args, prefix: str, **kwargs):
super().__init__(*args, **kwargs)
self._prefix = prefix

def get(self, key: str) -> Optional[Any]:
return super().get(f"{self._prefix}/{key}")

def set(self, key: str, value: Any) -> None:
super().set(f"{self._prefix}/{key}", value)

0 comments on commit 8c7e587

Please sign in to comment.