From e6ae5681faa04e7e17c12c6a44ffc2f86a435e0e Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Thu, 15 Feb 2024 02:17:06 +0100 Subject: [PATCH] Commit last changes. --- params.yaml | 6 +++--- src/re_classwise_shapley/metric.py | 12 ++++++------ src/re_classwise_shapley/plotting.py | 5 +++-- src/re_classwise_shapley/utils.py | 5 ++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/params.yaml b/params.yaml index 44082f6c..31e8e24c 100644 --- a/params.yaml +++ b/params.yaml @@ -15,7 +15,7 @@ settings: time: - active: false + active: true threshold_characteristics: active: false @@ -47,9 +47,9 @@ active: - classwise_shapley - tmc_shapley - beta_shapley - - banzhaf_shapley + - loo repetitions: - from: 1 + from: 6 to: 20 experiments: diff --git a/src/re_classwise_shapley/metric.py b/src/re_classwise_shapley/metric.py index 9d734d6a..64251e75 100644 --- a/src/re_classwise_shapley/metric.py +++ b/src/re_classwise_shapley/metric.py @@ -136,15 +136,15 @@ def _curve_precision_recall_ranking( the ranking. The index of the series is the recall and the values are the corresponding precision values. """ - precision, recall = np.zeros(len(ranked_list) - 2), np.zeros(len(ranked_list) - 2) - for idx in range(2, len(ranked_list)): + precision, recall = np.zeros(len(ranked_list)), np.zeros(len(ranked_list)) + for idx in range(len(ranked_list)): partial_list = ranked_list[: idx + 1] intersection = list(set(target_list) & set(partial_list)) - precision[idx - 2] = float(len(intersection) / len(partial_list)) - recall[idx - 2] = float(len(intersection) / len(target_list)) + precision[idx] = float(len(intersection) / np.maximum(1, len(partial_list))) + recall[idx] = float(len(intersection) / np.maximum(1, len(target_list))) - graph = pd.Series(precision, index=recall) - graph = graph[~graph.index.duplicated(keep="first")] + graph = pd.DataFrame({"precision": precision, "recall": recall}) + graph = graph.groupby("recall")["precision"].mean() graph = graph.sort_index(ascending=True) graph.index.name = "recall" graph.name = "precision" diff --git a/src/re_classwise_shapley/plotting.py b/src/re_classwise_shapley/plotting.py index 4f4b798c..121edd5f 100644 --- a/src/re_classwise_shapley/plotting.py +++ b/src/re_classwise_shapley/plotting.py @@ -286,7 +286,7 @@ def plot_histogram_func( @contextmanager def plot_time( data: pd.DataFrame, - patch_size: Tuple[float, float] = (5, 5), + patch_size: Tuple[float, float] = (5, 4), n_cols: int = 5, ) -> plt.Figure: """ @@ -355,6 +355,7 @@ def plot_curves_func(data: pd.DataFrame, ax: plt.Axes, **kwargs): mean_color, shade_color = COLORS[color_name] results = pd.concat(method_data["curve"].tolist(), axis=1) + results = results.iloc[1:-1] if plot_perc is not None: results = results.iloc[: int(m.ceil(plot_perc * results.shape[0])), :] @@ -400,7 +401,7 @@ def plot_metric_table( @contextmanager def plot_metric_boxplot( data: pd.DataFrame, - patch_size: Tuple[float, float] = (5, 5), + patch_size: Tuple[float, float] = (5, 4), n_cols: int = 5, x_label: str = None, ) -> plt.Figure: diff --git a/src/re_classwise_shapley/utils.py b/src/re_classwise_shapley/utils.py index 3fe8fdbf..6b69e701 100644 --- a/src/re_classwise_shapley/utils.py +++ b/src/re_classwise_shapley/utils.py @@ -150,7 +150,7 @@ def calculate_threshold_characteristic_curves( Returns: A pd.DataFrame with all four characteristic curves. """ - characteristics = pd.DataFrame(index=x_range, columns=["<,>", ">,<"]) + characteristics = pd.DataFrame(index=x_range, columns=["<,>", ">,>"]) n_data = len(in_cls_mar_acc) for i, threshold in enumerate(characteristics.index): @@ -159,8 +159,7 @@ def calculate_threshold_characteristic_curves( / n_data ) characteristics.iloc[i, 1] = ( - np.sum((in_cls_mar_acc > threshold) & (global_mar_acc < -threshold)) - / n_data + np.sum((in_cls_mar_acc > threshold) & (global_mar_acc > threshold)) / n_data ) return characteristics