From 37f71f6a5d123778a41966ff374bf343fbc9ddf2 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Sun, 20 Oct 2024 16:24:58 +0100 Subject: [PATCH] Add pointwise error computation to method and add to forecasting unit test --- deepsensor/eval/__init__.py | 1 + deepsensor/eval/metrics.py | 24 ++++++++++++++++++++++++ tests/test_model.py | 13 ++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 deepsensor/eval/__init__.py create mode 100644 deepsensor/eval/metrics.py diff --git a/deepsensor/eval/__init__.py b/deepsensor/eval/__init__.py new file mode 100644 index 00000000..1761d1aa --- /dev/null +++ b/deepsensor/eval/__init__.py @@ -0,0 +1 @@ +from .metrics import * diff --git a/deepsensor/eval/metrics.py b/deepsensor/eval/metrics.py new file mode 100644 index 00000000..c9616430 --- /dev/null +++ b/deepsensor/eval/metrics.py @@ -0,0 +1,24 @@ +import xarray as xr +from deepsensor.model.pred import Prediction + + +def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset: + """ + Compute errors between predictions and targets. + + Args: + pred: Prediction object. + target: Target data. + + Returns: + xr.Dataset: Dataset of pointwise differences between predictions and targets + at the same valid time in the predictions. Note, the difference is positive + when the prediction is greater than the target. + """ + errors = {} + for var_ID, pred_var in pred.items(): + target_var = target[var_ID] + error = pred_var["mean"] - target_var.sel(time=pred_var.time) + error.name = f"{var_ID}" + errors[var_ID] = error + return xr.Dataset(errors) diff --git a/tests/test_model.py b/tests/test_model.py index cb643a95..6a0836ac 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,6 +18,7 @@ from deepsensor.data.loader import TaskLoader from deepsensor.model.convnp import ConvNP from deepsensor.train.train import Trainer +from deepsensor.eval.metrics import compute_errors from tests.utils import gen_random_data_xr, gen_random_data_pandas @@ -686,9 +687,15 @@ def test_forecasting_model_predict_return_valid_times(self): if isinstance(pred_var, xr.Dataset): # Check we can compute errors using the valid time coord ('time') - errors = pred_var["mean"] - self.da.sel(time=pred_var.time) - assert errors.dims == ("lead_time", "init_time", "x1", "x2") - assert errors.shape == pred_var["mean"].shape + errors = compute_errors(pred, self.da.to_dataset()) + for var_ID in errors.keys(): + assert tuple(errors[var_ID].dims) == ( + "lead_time", + "init_time", + "x1", + "x2", + ) + assert errors[var_ID].shape == pred[var_ID]["mean"].shape elif isinstance(pred_var, pd.DataFrame): # Makes coordinate checking easier by avoiding repeat values pred_var = pred_var.to_xarray().isel(x1=0, x2=0)