Skip to content

Commit

Permalink
Commit last changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Feb 15, 2024
1 parent 4b7ac54 commit e6ae568
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
6 changes: 3 additions & 3 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ settings:


time:
active: false
active: true

threshold_characteristics:
active: false
Expand Down Expand Up @@ -47,9 +47,9 @@ active:
- classwise_shapley
- tmc_shapley
- beta_shapley
- banzhaf_shapley
- loo
repetitions:
from: 1
from: 6
to: 20

experiments:
Expand Down
12 changes: 6 additions & 6 deletions src/re_classwise_shapley/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ def _curve_precision_recall_ranking(
the ranking. The index of the series is the recall and the values are the
corresponding precision values.
"""
precision, recall = np.zeros(len(ranked_list) - 2), np.zeros(len(ranked_list) - 2)
for idx in range(2, len(ranked_list)):
precision, recall = np.zeros(len(ranked_list)), np.zeros(len(ranked_list))
for idx in range(len(ranked_list)):
partial_list = ranked_list[: idx + 1]
intersection = list(set(target_list) & set(partial_list))
precision[idx - 2] = float(len(intersection) / len(partial_list))
recall[idx - 2] = float(len(intersection) / len(target_list))
precision[idx] = float(len(intersection) / np.maximum(1, len(partial_list)))
recall[idx] = float(len(intersection) / np.maximum(1, len(target_list)))

graph = pd.Series(precision, index=recall)
graph = graph[~graph.index.duplicated(keep="first")]
graph = pd.DataFrame({"precision": precision, "recall": recall})
graph = graph.groupby("recall")["precision"].mean()
graph = graph.sort_index(ascending=True)
graph.index.name = "recall"
graph.name = "precision"
Expand Down
5 changes: 3 additions & 2 deletions src/re_classwise_shapley/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def plot_histogram_func(
@contextmanager
def plot_time(
data: pd.DataFrame,
patch_size: Tuple[float, float] = (5, 5),
patch_size: Tuple[float, float] = (5, 4),
n_cols: int = 5,
) -> plt.Figure:
"""
Expand Down Expand Up @@ -355,6 +355,7 @@ def plot_curves_func(data: pd.DataFrame, ax: plt.Axes, **kwargs):
mean_color, shade_color = COLORS[color_name]

results = pd.concat(method_data["curve"].tolist(), axis=1)
results = results.iloc[1:-1]
if plot_perc is not None:
results = results.iloc[: int(m.ceil(plot_perc * results.shape[0])), :]

Expand Down Expand Up @@ -400,7 +401,7 @@ def plot_metric_table(
@contextmanager
def plot_metric_boxplot(
data: pd.DataFrame,
patch_size: Tuple[float, float] = (5, 5),
patch_size: Tuple[float, float] = (5, 4),
n_cols: int = 5,
x_label: str = None,
) -> plt.Figure:
Expand Down
5 changes: 2 additions & 3 deletions src/re_classwise_shapley/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def calculate_threshold_characteristic_curves(
Returns:
A pd.DataFrame with all four characteristic curves.
"""
characteristics = pd.DataFrame(index=x_range, columns=["<,>", ">,<"])
characteristics = pd.DataFrame(index=x_range, columns=["<,>", ">,>"])
n_data = len(in_cls_mar_acc)

for i, threshold in enumerate(characteristics.index):
Expand All @@ -159,8 +159,7 @@ def calculate_threshold_characteristic_curves(
/ n_data
)
characteristics.iloc[i, 1] = (
np.sum((in_cls_mar_acc > threshold) & (global_mar_acc < -threshold))
/ n_data
np.sum((in_cls_mar_acc > threshold) & (global_mar_acc > threshold)) / n_data
)

return characteristics

0 comments on commit e6ae568

Please sign in to comment.