From b9653beb2fb68115c695800b923372b7291f0fb8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 3 Mar 2024 19:45:08 +0100 Subject: [PATCH] Docs: fix trainer metric definitions (#1924) * Docs: fix trainer metric definitions * Link to torchmetrics docs * Teach Sphinx where docs live --- docs/conf.py | 1 + torchgeo/trainers/classification.py | 30 +++++++++++++++-------------- torchgeo/trainers/detection.py | 11 ++++++----- torchgeo/trainers/regression.py | 18 ++++++----------- torchgeo/trainers/segmentation.py | 9 +++++---- 5 files changed, 34 insertions(+), 35 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e62e91172d8..16c31ae6cb0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -119,6 +119,7 @@ "sklearn": ("https://scikit-learn.org/stable/", None), "timm": ("https://huggingface.co/docs/timm/main/en/", None), "torch": ("https://pytorch.org/docs/stable", None), + "torchmetrics": ("https://lightning.ai/docs/torchmetrics/stable/", None), "torchvision": ("https://pytorch.org/vision/stable", None), } diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 87c993b64be..2b2452c66a2 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -97,14 +97,15 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels. - Uses 'micro' averaging. Higher values are better. - * Multiclass Average Accuracy (AA): Ratio of correctly classified classes. - Uses 'macro' averaging. Higher values are better. - * Multiclass Jaccard Index (IoU): Per-class overlap between predicted and - actual classes. Uses 'macro' averaging. Higher valuers are better. - * Multiclass F1 Score: The harmonic mean of precision and recall. - Uses 'micro' averaging. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of + true positives divided by the dataset size. Both overall accuracy (OA) + using 'micro' averaging and average accuracy (AA) using 'macro' averaging + are reported. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection + over union (IoU). Uses 'macro' averaging. Higher valuers are better. + * :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score. + The harmonic mean of precision and recall. Uses 'micro' averaging. + Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect @@ -266,12 +267,13 @@ class MultiLabelClassificationTask(ClassificationTask): def configure_metrics(self) -> None: """Initialize the performance metrics. - * Multiclass Overall Accuracy (OA): Ratio of correctly classified pixels. - Uses 'micro' averaging. Higher values are better. - * Multiclass Average Accuracy (AA): Ratio of correctly classified classes. - Uses 'macro' averaging. Higher values are better. - * Multiclass F1 Score: The harmonic mean of precision and recall. - Uses 'micro' averaging. Higher values are better. + * :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of + true positives divided by the dataset size. Both overall accuracy (OA) + using 'micro' averaging and average accuracy (AA) using 'macro' averaging + are reported. Higher values are better. + * :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score. + The harmonic mean of precision and recall. Uses 'micro' averaging. + Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index d8f91665add..7838863a260 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -205,10 +205,11 @@ def configure_models(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Mean Average Precision (mAP): Computes the Mean-Average-Precision (mAP) and - Mean-Average-Recall (mAR) for object detection. Prediction is based on the - intersection over union (IoU) between the predicted bounding boxes and the - ground truth bounding boxes. Uses 'macro' averaging. Higher values are better. + * :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average + precision (mAP) and mean average recall (mAR). Precision is the number of + true positives divided by the number of true positives + false positives. + Recall is the number of true positives divived by the number of true positives + + false negatives. Uses 'macro' averaging. Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not @@ -216,7 +217,7 @@ def configure_metrics(self) -> None: * 'Macro' averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes. """ - metrics = MetricCollection([MeanAveragePrecision()]) + metrics = MetricCollection([MeanAveragePrecision(average="macro")]) self.val_metrics = metrics.clone(prefix="val_") self.test_metrics = metrics.clone(prefix="test_") diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 92bab1ed0ef..1dc4b66873e 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -98,18 +98,12 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Root Mean Squared Error (RMSE): The square root of the average of the squared - differences between the predicted and actual values. Lower values are better. - * Mean Squared Error (MSE): The average of the squared differences between the - predicted and actual values. Lower values are better. - * Mean Absolute Error (MAE): The average of the absolute differences between the - predicted and actual values. Lower values are better. - - .. note:: - * 'Micro' averaging suits overall performance evaluation but may not reflect - minority class accuracy. - * 'Macro' averaging gives equal weight to each class, and is useful for - balanced performance assessment across imbalanced classes. + * :class:`~torchmetrics.MeanSquaredError`: The average of the squared + differences between the predicted and actual values (MSE) and its + square root (RMSE). Lower values are better. + * :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute + differences between the predicted and actual values (MAE). + Lower values are better. """ metrics = MetricCollection( { diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index a83ababa764..d16b63e86db 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -126,10 +126,11 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * Multiclass Pixel Accuracy: Ratio of correctly classified pixels. - Uses 'micro' averaging. Higher values are better. - * Multiclass Jaccard Index (IoU): Per-pixel overlap between predicted and - actual segments. Uses 'macro' averaging. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy + (OA) using 'micro' averaging. The number of true positives divided by the + dataset size. Higher values are better. + * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection + over union (IoU). Uses 'micro' averaging. Higher valuers are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect