Skip to content

Commit

Permalink
Adapt params for rank stability.
Browse files Browse the repository at this point in the history
  • Loading branch information
kosmitive committed May 1, 2024
1 parent 5a9aa5c commit 33e1e97
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ settings:
sample_data: false
calculate_values: false
calculate_threshold_characteristics: false
evaluate_curves: false
evaluate_curves: true
evaluate_metrics: true
render_plots: true

Expand Down Expand Up @@ -84,7 +84,7 @@ experiments:
fn: top_fraction
alpha_range:
from: 0.01
to: 0.05
to: 0.5
step: 0.01
plots:
- rank_stability
Expand Down
9 changes: 6 additions & 3 deletions src/re_classwise_shapley/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def curve_top_fraction(
alpha_range: A dictionary containing from, to and step keys.
Returns:
A pd.Series contianing the alpha value on the x-axis and a unfolded list on the
A pd.Series containing the alpha value on the x-axis and a unfolded list on the
y-axis.
"""
assert -1 <= alpha_range["to"] <= 1.0
assert -1 <= alpha_range["from"] <= 1.0
n = int((alpha_range["to"] - alpha_range["from"]) / alpha_range["step"]) + 1
alpha_range = np.arange(alpha_range["from"], alpha_range["to"], alpha_range["step"])
alpha_range = np.arange(
alpha_range["from"],
alpha_range["to"] + alpha_range["step"],
alpha_range["step"],
)
values.sort(reverse=np.all(alpha_range >= 0))

alpha_values = []
Expand Down

0 comments on commit 33e1e97

Please sign in to comment.