Skip to content

Commit

Permalink
Refactor plots into benchmark repository.
Browse files Browse the repository at this point in the history
  • Loading branch information
kosmitive committed Feb 27, 2024
1 parent c1faa52 commit cbf67f3
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 71 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ experiments:
point_removal:
[...]
metrics:
weighted_accuracy_drop_logistic_regression:
accuracy_logistic_regression:
idx: weighted_metric_drop
metric: accuracy
eval_model: logistic_regression
Expand Down
94 changes: 65 additions & 29 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ settings:
plot_format: pdf

stages:
fetch_data: false
preprocess_data: false
sample_data: false
calculate_values: false
calculate_threshold_characteristics: false
fetch_data: true
preprocess_data: true
sample_data: true
calculate_values: true
calculate_threshold_characteristics: true
evaluate_curves: true
evaluate_metrics: true
render_plots: true
Expand Down Expand Up @@ -59,62 +59,84 @@ experiments:
point_removal:
sampler: default
curves:
weighted_accuracy_drop_logistic_regression:
accuracy_logistic_regression:
fn: metric
metric: accuracy
eval_model: logistic_regression
plots:
- accuracy
weighted_accuracy_drop_knn:
accuracy_knn:
fn: metric
metric: accuracy
eval_model: knn
plots:
- accuracy
weighted_accuracy_drop_gradient_boosting_classifier:
accuracy_gradient_boosting_classifier:
fn: metric
metric: accuracy
eval_model: gradient_boosting_classifier
plots:
- accuracy
weighted_accuracy_drop_svm:
accuracy_svm:
fn: metric
metric: accuracy
eval_model: svm
plots:
- accuracy
weighted_accuracy_drop_mlp:
accuracy_mlp:
fn: metric
metric: accuracy
eval_model: mlp
plots:
- accuracy

top_fraction:
fn: top_fraction
alpha_range:
from: 0.01
to: 0.05
step: 0.01
plots:
# - rank_stability

metrics:
geometric_weighted_drop:
weighted_relative_accuracy_difference_random:
curve:
- weighted_accuracy_drop_logistic_regression
- weighted_accuracy_drop_knn
- weighted_accuracy_drop_gradient_boosting_classifier
- weighted_accuracy_drop_svm
- weighted_accuracy_drop_mlp
fn: geometric_weighted_drop
input_perc: 1.0
- accuracy_logistic_regression
- accuracy_knn
- accuracy_gradient_boosting_classifier
- accuracy_svm
- accuracy_mlp
lamb: 0.1
fn: weighted_relative_accuracy_difference_random
plots:
- table_wad
- box_wad
- table
- box_wrad

geometric_weighted_drop_half:
curve:
- weighted_accuracy_drop_logistic_regression
- weighted_accuracy_drop_knn
- weighted_accuracy_drop_gradient_boosting_classifier
- weighted_accuracy_drop_svm
- weighted_accuracy_drop_mlp
- accuracy_logistic_regression
- accuracy_knn
- accuracy_gradient_boosting_classifier
- accuracy_svm
- accuracy_mlp
fn: geometric_weighted_drop
input_perc: 0.5
plots:
- table_wad
- table
- box_wad

geometric_weighted_drop:
curve:
- accuracy_logistic_regression
- accuracy_knn
- accuracy_gradient_boosting_classifier
- accuracy_svm
- accuracy_mlp
fn: geometric_weighted_drop
input_perc: 1.0
plots:
- table
- box_wad

noise_removal:
Expand All @@ -138,22 +160,36 @@ experiments:
- box_auc

plots:

density:

rank_stability:
type: line
mean_agg: intersect
x_label: "n"
y_label: "Accuracy"

accuracy:
type: line
mean_agg: mean
std_agg: bootstrap
plot_perc: 0.5
x_label: "n"
y_label: "Accuracy"

precision_recall:
type: line
mean_agg: mean
std_agg: bootstrap
x_label: "Recall"
y_label: "Precision"

table_wad:
table:
type: table

table_auc:
type: table
box_wrad:
type: boxplot
x_label: "WRAD"

box_wad:
type: boxplot
Expand Down
3 changes: 1 addition & 2 deletions scripts/evaluate_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from functools import partial, reduce

import click
import pandas as pd
from pydvl.parallel import ParallelConfig
from pydvl.utils.functional import maybe_add_argument

Expand Down Expand Up @@ -148,7 +147,7 @@ def _evaluate_curves(
cache=cache,
)

metric_curve.to_csv(output_dir / f"{curve_name}.curve.csv")
metric_curve.to_csv(output_dir / f"{curve_name}.csv")


if __name__ == "__main__":
Expand Down
29 changes: 27 additions & 2 deletions scripts/evaluate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

import click
import pandas as pd
from pydvl.utils.functional import free_arguments, maybe_add_argument

from re_classwise_shapley.io import Accessor
from re_classwise_shapley.log import setup_logger
from re_classwise_shapley.metric import MetricsRegistry
from re_classwise_shapley.requests import FunctionalCurveRequest
from re_classwise_shapley.utils import load_params_fast

logger = setup_logger("evaluate_metrics")
Expand Down Expand Up @@ -95,7 +97,18 @@ def _evaluate_metrics(
metric_fn = metrics_kwargs.pop("fn")
metrics_kwargs.pop("plots", None)
curve_names = metrics_kwargs.pop("curve")
metrics_fn = partial(MetricsRegistry[metric_fn], **metrics_kwargs)
repetitions = get_active_repetitions(params)

fn = MetricsRegistry[metric_fn]
requests = []
if "random_base_line" in free_arguments(fn):
requests.append(
FunctionalCurveRequest("random_base_line", "random", repetitions)
)

fn = maybe_add_argument(fn, "random_base_line")
metrics_fn = partial(fn, **metrics_kwargs)

os.makedirs(output_dir, exist_ok=True)
curves = list(
Accessor.curves(
Expand All @@ -114,12 +127,24 @@ def _evaluate_metrics(
if os.path.exists(output_dir / f"{metric_name}.{curve_name}.csv"):
continue

metric = metrics_fn(curve)
extra_kwargs = {
request.arg_name: request.request(
experiment_name, model_name, dataset_name, curve_names
)
for request in requests
}
metric = metrics_fn(curve, **extra_kwargs)
evaluated_metrics = pd.Series([metric])
evaluated_metrics.name = "value"
evaluated_metrics.index.name = "metric"
evaluated_metrics.to_csv(output_dir / f"{metric_name}.{curve_name}.csv")


def get_active_repetitions(params):
active_params = params["active"]
repetitions = active_params["repetitions"]
return list(range(repetitions["from"], repetitions["to"] + 1))


if __name__ == "__main__":
evaluate_metrics()
41 changes: 26 additions & 15 deletions scripts/render_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
plot_metric_table,
plot_threshold_characteristics,
plot_time,
plot_value_decay,
)
from re_classwise_shapley.utils import (
flatten_dict,
Expand Down Expand Up @@ -103,6 +104,26 @@ def _render_plots(experiment_name: str, model_name: str):

params = load_params_fast()
plot_format = params["settings"]["plot_format"]

logger.info(f"Load valuations results.")
valuation_results = Accessor.valuation_results(
experiment_name,
model_name,
dataset_names,
repetitions,
method_names,
)
logger.info(f"Plotting value decay for all methods.")
with plot_value_decay(valuation_results, method_names) as fig:
log_figure(fig, output_folder, f"decay.{plot_format}", "values")

for method_name in method_names:
logger.info(f"Plot histogram for values of method `{method_name}`.")
with plot_histogram(valuation_results, [method_name]) as fig:
log_figure(
fig, output_folder, f"density.{method_name}.{plot_format}", "values"
)

threshold_characteristics_settings = params["settings"][
"threshold_characteristics"
]
Expand All @@ -126,20 +147,6 @@ def _render_plots(experiment_name: str, model_name: str):
f"threshold_characteristics.{plot_format}",
"threshold_characteristics",
)
logger.info(f"Load valuations results.")
valuation_results = Accessor.valuation_results(
experiment_name,
model_name,
dataset_names,
repetitions,
method_names,
)
for method_name in method_names:
logger.info(f"Plot histogram for values of method `{method_name}`.")
with plot_histogram(valuation_results, [method_name]) as fig:
log_figure(
fig, output_folder, f"density.{method_name}.{plot_format}", "densities"
)

params = load_params_fast()
time_settings = params["settings"]["time"]
Expand Down Expand Up @@ -173,14 +180,18 @@ def _render_plots(experiment_name: str, model_name: str):
plot_perc = plot_settings.get("plot_perc", 1.0)
x_label = plot_settings.get("x_label", None)
y_label = plot_settings.get("y_label", None)
agg = plot_settings.get("agg", "mean")
with plot_curves(
selected_loaded_curves,
plot_perc=plot_perc,
x_label=x_label,
y_label=y_label,
) as fig:
log_figure(
fig, output_folder, f"{curve_name}.{plot_format}", "curves"
fig,
output_folder,
f"{curve_name}.{plot_format}",
"curves",
)
case _:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions src/re_classwise_shapley/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class Accessor:
THRESHOLD_CHARACTERISTICS_PATH = OUTPUT_PATH / "threshold_characteristics"
SAMPLED_PATH = OUTPUT_PATH / "sampled"
VALUES_PATH = OUTPUT_PATH / "values"
CURVES_PATH = OUTPUT_PATH / "results"
CURVES_PATH = OUTPUT_PATH / "curves"
METRICS_PATH = OUTPUT_PATH / "metrics"
PLOT_PATH = OUTPUT_PATH / "plots"

Expand Down Expand Up @@ -340,7 +340,7 @@ def curves(
/ str(repetition_id)
/ method_name
)
curve = pd.read_csv(base_path / f"{curve_name}.curve.csv")
curve = pd.read_csv(base_path / f"{curve_name}.csv")
curve.index = curve[curve.columns[0]]
curve = curve.drop(columns=[curve.columns[0]]).iloc[:, -1]

Expand Down
Loading

0 comments on commit cbf67f3

Please sign in to comment.