Skip to content

Commit

Permalink
Eliminate random aspects.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Dec 19, 2023
1 parent b1d426b commit 0405ce3
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 274 deletions.
4 changes: 1 addition & 3 deletions dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ stages:
matrix:
experiment: ${active.experiments}
dataset: ${active.datasets}
repetition: ${active.repetitions}
cmd: >
python -m scripts.sample_data
--experiment-name ${item.experiment}
--dataset-name ${item.dataset}
--repetition-id ${item.repetition}
params:
- experiments.${item.experiment}.sampler
deps:
- scripts/sample_data.py
- output/preprocessed/${item.dataset}
outs:
- output/sampled/${item.experiment}/${item.dataset}/${item.repetition}:
- output/sampled/${item.experiment}/${item.dataset}:
persist: true

calculate-values:
Expand Down
19 changes: 18 additions & 1 deletion params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,30 @@ active:
- classwise_shapley
- beta_shapley
- tmc_shapley
- banzhaf_shapley
- owen_sampling_shapley
- least_core
repetitions:
- 1
- 2
- 3
- 4
- 5

- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
experiments:
point_removal:
sampler: default
Expand Down
487 changes: 241 additions & 246 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ torch = "^2.0.0"
torchvision = "^0.15.1"
openml = "^0.13.0"
click = "^8.1.3"
mlflow = "^2.6.0"
mlflow = "^2.9.2"
boto3 = "^1.28.36"
plotly = "^5.16.1"
dataframe_image = "*"
Expand Down
17 changes: 13 additions & 4 deletions scripts/calculate_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import os
import pickle
import time
from dataclasses import asdict

import click
import numpy as np
from pydvl.utils import Scorer, Utility
from pydvl.utils import MemcachedConfig, Scorer, Utility
from pymemcache.client import Client

from re_classwise_shapley.io import Accessor
from re_classwise_shapley.log import setup_logger
Expand Down Expand Up @@ -92,9 +94,13 @@ def _calculate_values(
f"Values for {valuation_method_name} exist in '{output_dir}'. Skipping..."
)

val_set = Accessor.datasets(experiment_name, dataset_name, repetition_id).loc[
0, "val_set"
]
mc_config = MemcachedConfig()
mc = Client(**asdict(mc_config)["client_config"])
if dataset_name != mc.get("last_dataset", None):
mc.flush_all()
mc.set("last_dataset", dataset_name)

val_set = Accessor.datasets(experiment_name, dataset_name).loc[0, "val_set"]

n_pipeline_step = 4
seed = pipeline_seed(repetition_id, n_pipeline_step)
Expand All @@ -107,11 +113,14 @@ def _calculate_values(

model_kwargs = params["models"][model_name]
model = instantiate_model(model_name, model_kwargs, seed=int(sub_seeds[0]))

u = Utility(
data=val_set,
model=model,
scorer=Scorer("accuracy", default=0.0),
catch_errors=True,
enable_cache=True,
cache_options=mc_config,
)

start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion scripts/evaluate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _evaluate_metrics(
values = Accessor.valuation_results(
experiment_name, model_name, dataset_name, repetition_id, valuation_method_name
).loc[0, "valuation"]
dataset = Accessor.datasets(experiment_name, dataset_name, repetition_id).loc[0]
dataset = Accessor.datasets(experiment_name, dataset_name).loc[0]
preprocess_info = dataset["preprocess_info"]

os.makedirs(output_dir, exist_ok=True)
Expand Down
13 changes: 3 additions & 10 deletions scripts/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@
@click.command()
@click.option("--experiment-name", type=str, required=True)
@click.option("--dataset-name", type=str, required=True)
@click.option("--repetition-id", type=int, required=True)
def sample_data(
experiment_name: str,
dataset_name: str,
repetition_id: int,
):
"""
Samples a dataset from a preprocessed dataset. It accepts `experiment_name` and
Expand All @@ -51,29 +49,24 @@ def sample_data(
`params.experiments` section.
dataset_name: The name of the dataset to preprocess. As specified in th
`params.datasets` section.
repetition_id: Repetition id of the experiment. It is used also as a seed for
all randomness.
"""
_sample_data(experiment_name, dataset_name, repetition_id)
_sample_data(experiment_name, dataset_name)


def _sample_data(
experiment_name: str,
dataset_name: str,
repetition_id: int,
):
params = load_params_fast()
input_folder = Accessor.PREPROCESSED_PATH / dataset_name
output_dir = (
Accessor.SAMPLED_PATH / experiment_name / dataset_name / str(repetition_id)
)
output_dir = Accessor.SAMPLED_PATH / experiment_name / dataset_name
if os.path.exists(output_dir / "val_set.pkl") and os.path.exists(
output_dir / "test_set.pkl"
):
return logger.info(f"Sampled data exists in '{output_dir}'. Skipping...")

n_pipeline_step = 3
seed = pipeline_seed(repetition_id, n_pipeline_step)
seed = pipeline_seed(42, n_pipeline_step)
seed_sequence = SeedSequence(seed).spawn(2)

experiment_config = params["experiments"][experiment_name]
Expand Down
11 changes: 3 additions & 8 deletions src/re_classwise_shapley/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,32 +220,27 @@ def metrics_and_curves(
def datasets(
experiment_name: str,
dataset_name: str,
repetition_id: int,
) -> Dict:
"""
Load datasets from the specified directory.
Args:
experiment_name: The name of the experiment.
dataset_name: The name of the dataset.
repetition_id: The repetition ID.
Returns:
A dictionary containing the loaded datasets and relevant information.
"""
base_path = (
Accessor.SAMPLED_PATH / experiment_name / dataset_name / str(repetition_id)
)
with open(base_path / f"val_set.pkl", "rb") as file:
base_path = Accessor.SAMPLED_PATH / experiment_name / dataset_name
with open(base_path / "val_set.pkl", "rb") as file:
val_set = pickle.load(file)

with open(base_path / f"test_set.pkl", "rb") as file:
with open(base_path / "test_set.pkl", "rb") as file:
test_set = pickle.load(file)

row = {
"experiment_name": experiment_name,
"dataset_name": dataset_name,
"repetition_id": repetition_id,
"val_set": val_set,
"test_set": test_set,
}
Expand Down

0 comments on commit 0405ce3

Please sign in to comment.