diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index c782f3ca..f3927165 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -37,7 +37,7 @@ def compute_f1(pred, true): ) -def get_metrics(labels, preds, threshold: float = 0.01): +def get_metrics(labels, preds, threshold: float = 0.5): # Compute precision-recall curve and AUC precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, preds) pr_auc = sklearn.metrics.auc(recall, precision)