From d783385a88cb70d8a135a6524ccf71eb319e5e28 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Wed, 3 Jan 2024 21:29:08 +0100 Subject: [PATCH] Refactor `determine_in_out_of_cls_marginal_accuracies`. --- src/re_classwise_shapley/plotting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/re_classwise_shapley/plotting.py b/src/re_classwise_shapley/plotting.py index 06c9cf5b..a5a937e7 100644 --- a/src/re_classwise_shapley/plotting.py +++ b/src/re_classwise_shapley/plotting.py @@ -443,7 +443,9 @@ def plot_threshold_characteristics( ax = ax.flatten() for dataset_idx, dataset_name in enumerate(dataset_names): dataset_df = results[dataset_name]["threshold_characteristics"] - idx = np.argwhere(np.max(dataset_df, axis=1) >= max_plotting_percentage)[-1, 0] + idx = np.argwhere(np.max(dataset_df.values, axis=1) >= max_plotting_percentage)[ + -1, 0 + ] dataset_df.iloc[:idx].plot_threshold_characteristics(ax=ax[dataset_idx]) ax[dataset_idx].set_xlim(0, dataset_df.index[idx]) ax[dataset_idx].set_title(f"({chr(97 + dataset_idx)}) {dataset_name}")