Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CrossValidationReporter - add a roc function #977

Open
MarieS-WiMLDS opened this issue Dec 17, 2024 · 0 comments
Open

CrossValidationReporter - add a roc function #977

MarieS-WiMLDS opened this issue Dec 17, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@MarieS-WiMLDS
Copy link
Contributor

Is your feature request related to a problem? Please describe.

When I do a cross validation, for a classification problem, I'd like to know and see how it goes in terms of roc curves.

Describe the solution you'd like

cv_report = skore.CrossValidationReporter(est, X, y)
cv_report.plots.roc

example in scikit-learn doc with matplotlib, you might want to rewrite it with plotly.

import matplotlib.pyplot as plt

from sklearn import svm
from sklearn.metrics import RocCurveDisplay, auc
from sklearn.model_selection import StratifiedKFold

n_splits = 6
cv = StratifiedKFold(n_splits=n_splits)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)

tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)

fig, ax = plt.subplots(figsize=(6, 6))
for fold, (train, test) in enumerate(cv.split(X, y)):
    classifier.fit(X[train], y[train])
    viz = RocCurveDisplay.from_estimator(
        classifier,
        X[test],
        y[test],
        name=f"ROC fold {fold}",
        alpha=0.3,
        lw=1,
        ax=ax,
        plot_chance_level=(fold == n_splits - 1),
    )
    interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
    interp_tpr[0] = 0.0
    tprs.append(interp_tpr)
    aucs.append(viz.roc_auc)

ax.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
)
ax.legend(loc="lower right")
plt.show()

Describe alternatives you've considered, if relevant

ploting the "median" roc and the +/- 1 STD. but it's weird and unclear if it has scientific value (cc @glemaitre for confirmation).

Additional context

No response

@MarieS-WiMLDS MarieS-WiMLDS added enhancement New feature or request needs-triage This has been recently submitted and needs attention labels Dec 17, 2024
@MarieS-WiMLDS MarieS-WiMLDS removed the needs-triage This has been recently submitted and needs attention label Dec 17, 2024
@augustebaum augustebaum self-assigned this Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants