From c462af8cc2879d10aa2d9c41f7357cae29e3abb8 Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Fri, 18 Oct 2024 14:52:26 +0100
Subject: [PATCH 1/8] Add unit test for forecasting valid times
---
tests/test_model.py | 38 ++++++++++++++++++++++++++++++++------
1 file changed, 32 insertions(+), 6 deletions(-)
diff --git a/tests/test_model.py b/tests/test_model.py
index 5193c480..85670227 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -55,8 +55,15 @@ class TestModel(unittest.TestCase):
def setUpClass(cls):
# super().__init__(*args, **kwargs)
# It's safe to share data between tests because the TaskLoader does not modify data
+ cls.var_ID = "2m_temp"
cls.da = _gen_data_xr()
+ cls.da.name = cls.var_ID
cls.df = _gen_data_pandas()
+ cls.df.name = cls.var_ID
+ # Various tests assume we have a single target set with a single variable.
+ # If a test requires multiple target sets or variables, this is set up in the test.
+ assert isinstance(cls.da, xr.DataArray)
+ assert isinstance(cls.df, pd.Series)
cls.dp = DataProcessor()
_ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters
@@ -417,10 +424,10 @@ def test_highlevel_predict_coords_align_with_X_t_ongrid(self):
task = tl("2020-01-01")
pred = model.predict(task, X_t=da_raw)
- assert np.array_equal(
+ np.testing.assert_array_equal(
pred["dummy_data"]["mean"]["latitude"], da_raw["latitude"]
)
- assert np.array_equal(
+ np.testing.assert_array_equal(
pred["dummy_data"]["mean"]["longitude"], da_raw["longitude"]
)
@@ -493,14 +500,14 @@ def test_highlevel_predict_with_pred_params_pandas(self):
# Check that nothing breaks and the correct parameters are returned
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
for pred_param in pred_params:
- assert pred_param in pred["var"]
+ assert pred_param in pred[self.var_ID]
# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
- assert pred_param in pred["var"]
+ assert pred_param in pred[self.var_ID]
def test_highlevel_predict_with_pred_params_xarray(self):
"""
@@ -528,14 +535,14 @@ def test_highlevel_predict_with_pred_params_xarray(self):
# Check that nothing breaks and the correct parameters are returned
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for pred_param in pred_params:
- assert pred_param in pred["var"]
+ assert pred_param in pred[self.var_ID]
# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
- assert pred_param in pred["var"]
+ assert pred_param in pred[self.var_ID]
def test_highlevel_predict_with_invalid_pred_params(self):
"""Test that passing ``pred_params`` to ``.predict`` works."""
@@ -640,6 +647,25 @@ def test_ar_sample(self):
ar_sample=True,
)
+ def test_forecasting_model_predict_return_valid_times(self):
+ """Test that the times returned by a forecasting model are valid."""
+ lead_times_days = [1, 2, 3]
+ init_date = "2020-01-01"
+ expected_valid_times = [
+ pd.Timestamp(init_date) + pd.DateOffset(days=lt) for lt in lead_times_days
+ ]
+
+ tl = TaskLoader(
+ context=self.da,
+ target=[self.da,] * len(lead_times_days),
+ target_delta_t=lead_times_days,
+ time_freq="D",
+ )
+ model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
+ task = tl(init_date, context_sampling=10)
+ pred = model.predict(task, X_t=self.da)
+ np.testing.assert_array_equal(pred[self.var_ID]["mean"].time.values, expected_valid_times)
+
def assert_shape(x, shape: tuple):
"""Assert that the shape of ``x`` matches ``shape``."""
From 5ec03a7d39e18afd021b2c52393a728c3e00017b Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Fri, 18 Oct 2024 15:00:00 +0100
Subject: [PATCH 2/8] run black
---
tests/test_model.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/tests/test_model.py b/tests/test_model.py
index 85670227..6e7497f5 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -657,14 +657,19 @@ def test_forecasting_model_predict_return_valid_times(self):
tl = TaskLoader(
context=self.da,
- target=[self.da,] * len(lead_times_days),
+ target=[
+ self.da,
+ ]
+ * len(lead_times_days),
target_delta_t=lead_times_days,
time_freq="D",
)
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
task = tl(init_date, context_sampling=10)
pred = model.predict(task, X_t=self.da)
- np.testing.assert_array_equal(pred[self.var_ID]["mean"].time.values, expected_valid_times)
+ np.testing.assert_array_equal(
+ pred[self.var_ID]["mean"].time.values, expected_valid_times
+ )
def assert_shape(x, shape: tuple):
From 1e0c4a06c3cc0deb54b3b1e097485e5ed71c3ca9 Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Fri, 18 Oct 2024 15:04:43 +0100
Subject: [PATCH 3/8] Add lead_time coord to forecasting unit test
---
tests/test_model.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/tests/test_model.py b/tests/test_model.py
index 6e7497f5..a0847f1f 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -651,8 +651,9 @@ def test_forecasting_model_predict_return_valid_times(self):
"""Test that the times returned by a forecasting model are valid."""
lead_times_days = [1, 2, 3]
init_date = "2020-01-01"
+ expected_lead_times = [pd.Timedelta(days=lt) for lt in lead_times_days]
expected_valid_times = [
- pd.Timestamp(init_date) + pd.DateOffset(days=lt) for lt in lead_times_days
+ pd.Timestamp(init_date) + lt for lt in expected_lead_times
]
tl = TaskLoader(
@@ -670,6 +671,9 @@ def test_forecasting_model_predict_return_valid_times(self):
np.testing.assert_array_equal(
pred[self.var_ID]["mean"].time.values, expected_valid_times
)
+ np.testing.assert_array_equal(
+ pred[self.var_ID]["mean"].lead_time.values, expected_lead_times
+ )
def assert_shape(x, shape: tuple):
From edd477210fb7adb82cd25eb838ebbf560312fcc1 Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Sun, 20 Oct 2024 12:52:00 +0100
Subject: [PATCH 4/8] Add & dims and valid coord for model.predict
forecastin; Add unit test for on-/off-grid
---
deepsensor/model/model.py | 77 ++++++++++++++++++++++++++--
deepsensor/model/pred.py | 103 +++++++++++++++++++++++++++++++-------
tests/test_model.py | 47 ++++++++++++-----
3 files changed, 193 insertions(+), 34 deletions(-)
diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py
index 9c69049c..b1e318cb 100644
--- a/deepsensor/model/model.py
+++ b/deepsensor/model/model.py
@@ -348,6 +348,34 @@ def predict(
if ar_sample and n_samples < 1:
raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
+ target_delta_t = self.task_loader.target_delta_t
+ dts = [pd.Timedelta(dt) for dt in target_delta_t]
+ dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
+ if target_delta_t is not None and dts_all_zero:
+ forecasting_mode = False
+ lead_times = None
+ elif target_delta_t is not None and not dts_all_zero:
+ target_var_IDs_set = set(self.task_loader.target_var_IDs)
+ msg = f"""
+ Got more than one set of target variables in target sets,
+ but predictions can only be made with one set of target variables
+ to simplify implementation.
+ Got {target_var_IDs_set}.
+ """
+ assert len(target_var_IDs_set) == 1, msg
+ # Repeat lead_tim for each variable in each target set
+ lead_times = []
+ for target_set_idx, dt in enumerate(target_delta_t):
+ target_set_dim = self.task_loader.target_dims[target_set_idx]
+ lead_times += [
+ pd.Timedelta(dt, unit=self.task_loader.time_freq)
+ for _ in range(target_set_dim)
+ ]
+ forecasting_mode = True
+ else:
+ forecasting_mode = False
+ lead_times = None
+
if type(tasks) is Task:
tasks = [tasks]
@@ -355,12 +383,14 @@ def predict(
B.set_random_seed(seed)
np.random.seed(seed)
- dates = [task["time"] for task in tasks]
+ init_dates = [task["time"] for task in tasks]
# Flatten tuple of tuples to single list
target_var_IDs = [
var_ID for set in self.task_loader.target_var_IDs for var_ID in set
]
+ if lead_times is not None:
+ assert len(lead_times) == len(target_var_IDs)
# TODO consider removing this logic, can we just depend on the dim names in X_t?
if not unnormalise:
@@ -450,11 +480,13 @@ def predict(
pred = Prediction(
target_var_IDs,
pred_params_to_store,
- dates,
+ init_dates,
X_t,
X_t_mask,
coord_names,
n_samples=n_samples,
+ forecasting_mode=forecasting_mode,
+ lead_times=lead_times,
)
def unnormalise_pred_array(arr, **kwargs):
@@ -605,14 +637,22 @@ def unnormalise_pred_array(arr, **kwargs):
# Assign predictions to Prediction object
for param, arr in prediction_arrs.items():
if param != "mixture_probs":
- pred.assign(param, task["time"], arr)
+ pred.assign(param, task["time"], arr, lead_times=lead_times)
elif param == "mixture_probs":
assert arr.shape[0] == self.N_mixture_components, (
f"Number of mixture components ({arr.shape[0]}) does not match "
f"model attribute N_mixture_components ({self.N_mixture_components})."
)
for component_i, probs in enumerate(arr):
- pred.assign(f"{param}_{component_i}", task["time"], probs)
+ pred.assign(
+ f"{param}_{component_i}",
+ task["time"],
+ probs,
+ lead_times=lead_times,
+ )
+
+ if forecasting_mode:
+ pred = add_valid_time_coord_to_pred(pred)
if verbose:
dur = time.time() - tic
@@ -621,6 +661,35 @@ def unnormalise_pred_array(arr, **kwargs):
return pred
+def add_valid_time_coord_to_pred(pred: Prediction) -> Prediction:
+ """
+ Add a valid time coordinate "time" to a Prediction object based on the
+ initialisation times "init_time" and lead times "lead_time".
+
+ Args:
+ pred (:class:`~.model.pred.Prediction`):
+ Prediction object to add valid time coordinate to.
+
+ Returns:
+ :class:`~.model.pred.Prediction`:
+ Prediction object with valid time coordinate added.
+ """
+ for var_ID in pred.keys():
+ if isinstance(pred[var_ID], pd.DataFrame):
+ x = pred[var_ID].reset_index()
+ pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
+ print(f"{x}")
+ print(f"{x.dtypes}")
+ elif isinstance(pred[var_ID], xr.Dataset):
+ x = pred[var_ID]
+ pred[var_ID] = pred[var_ID].assign_coords(
+ time=x["lead_time"] + x["init_time"]
+ )
+ else:
+ raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
+ return pred
+
+
def main(): # pragma: no cover
import deepsensor.tensorflow
from deepsensor.data.loader import TaskLoader
diff --git a/deepsensor/model/pred.py b/deepsensor/model/pred.py
index 00d49f81..3b6d814a 100644
--- a/deepsensor/model/pred.py
+++ b/deepsensor/model/pred.py
@@ -4,6 +4,8 @@
import pandas as pd
import xarray as xr
+Timestamp = Union[str, pd.Timestamp, np.datetime64]
+
class Prediction(dict):
"""
@@ -32,13 +34,20 @@ class Prediction(dict):
n_samples (int)
Number of joint samples to draw from the model. If 0, will not
draw samples. Default 0.
+ forecasting_mode (bool)
+ If True, stored forecast predictions with an init_time and lead_time dimension,
+ and a valid_time coordinate. If False, stores prediction at t=0 only
+ (i.e. spatial interpolation), with only a single time dimension. Default False.
+ lead_times (List[pd.Timedelta], optional)
+ List of lead times to store in predictions. Must be provided if
+ forecasting_mode is True. Default None.
"""
def __init__(
self,
target_var_IDs: List[str],
pred_params: List[str],
- dates: List[Union[str, pd.Timestamp]],
+ dates: List[Timestamp],
X_t: Union[
xr.Dataset,
xr.DataArray,
@@ -50,6 +59,8 @@ def __init__(
X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
coord_names: dict = None,
n_samples: int = 0,
+ forecasting_mode: bool = False,
+ lead_times: Optional[List[pd.Timedelta]] = None,
):
self.target_var_IDs = target_var_IDs
self.X_t_mask = X_t_mask
@@ -58,6 +69,13 @@ def __init__(
self.x1_name = coord_names["x1"]
self.x2_name = coord_names["x2"]
+ self.forecasting_mode = forecasting_mode
+ if forecasting_mode:
+ assert (
+ lead_times is not None
+ ), "If forecasting_mode is True, lead_times must be provided."
+ self.lead_times = lead_times
+
self.mode = infer_prediction_modality_from_X_t(X_t)
self.pred_params = pred_params
@@ -67,15 +85,25 @@ def __init__(
*[f"sample_{i}" for i in range(n_samples)],
]
+ # Create empty xarray/pandas objects to store predictions
if self.mode == "on-grid":
for var_ID in self.target_var_IDs:
- # Create empty xarray/pandas objects to store predictions
+ if self.forecasting_mode:
+ prepend_dims = ["lead_time"]
+ prepend_coords = {"lead_time": lead_times}
+ else:
+ prepend_dims = None
+ prepend_coords = None
self[var_ID] = create_empty_spatiotemporal_xarray(
X_t,
dates,
data_vars=self.pred_params,
coord_names=coord_names,
+ prepend_dims=prepend_dims,
+ prepend_coords=prepend_coords,
)
+ if self.forecasting_mode:
+ self[var_ID] = self[var_ID].rename(time="init_time")
if self.X_t_mask is None:
# Create 2D boolean array of True values to simplify indexing
self.X_t_mask = (
@@ -86,8 +114,18 @@ def __init__(
)
elif self.mode == "off-grid":
# Repeat target locs for each date to create multiindex
- idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
- index = pd.MultiIndex.from_tuples(idxs, names=["time", *X_t.index.names])
+ if self.forecasting_mode:
+ index_names = ["lead_time", "init_time", *X_t.index.names]
+ idxs = [
+ (lt, date, *idxs)
+ for lt in lead_times
+ for date in dates
+ for idxs in X_t.index
+ ]
+ else:
+ index_names = ["time", *X_t.index.names]
+ idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
+ index = pd.MultiIndex.from_tuples(idxs, names=index_names)
for var_ID in self.target_var_IDs:
self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
@@ -106,6 +144,7 @@ def assign(
prediction_parameter: str,
date: Union[str, pd.Timestamp],
data: np.ndarray,
+ lead_times: Optional[List[pd.Timedelta]] = None,
):
"""
@@ -117,11 +156,29 @@ def assign(
data (np.ndarray)
If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
+ lead_time (pd.Timedelta, optional)
+ Lead time of the forecast. Required if forecasting_mode is True. Default None.
"""
+ if self.forecasting_mode:
+ assert (
+ lead_times is not None
+ ), "If forecasting_mode is True, lead_times must be provided."
+
+ msg = f"""
+ If forecasting_mode is True, lead_times must be of equal length to the number of
+ variables in the data (the first dimension). Got {lead_times=} of length
+ {len(lead_times)} lead times and data shape {data.shape}.
+ """
+ assert len(lead_times) == data.shape[0], msg
+
if self.mode == "on-grid":
if prediction_parameter != "samples":
- for var_ID, pred in zip(self.target_var_IDs, data):
- self[var_ID][prediction_parameter].loc[date].data[
+ for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
+ if self.forecasting_mode:
+ index = (lead_times[i], date)
+ else:
+ index = date
+ self[var_ID][prediction_parameter].loc[index].data[
self.X_t_mask.data
] = pred.ravel()
elif prediction_parameter == "samples":
@@ -130,28 +187,44 @@ def assign(
f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
)
for sample_i, sample in enumerate(data):
- for var_ID, pred in zip(self.target_var_IDs, sample):
- self[var_ID][f"sample_{sample_i}"].loc[date].data[
+ for i, (var_ID, pred) in enumerate(
+ zip(self.target_var_IDs, sample)
+ ):
+ if self.forecasting_mode:
+ index = (lead_times[i], date)
+ else:
+ index = date
+ self[var_ID][f"sample_{sample_i}"].loc[index].data[
self.X_t_mask.data
] = pred.ravel()
elif self.mode == "off-grid":
if prediction_parameter != "samples":
- for var_ID, pred in zip(self.target_var_IDs, data):
- self[var_ID][prediction_parameter].loc[date] = pred
+ for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
+ if self.forecasting_mode:
+ index = (lead_times[i], date)
+ else:
+ index = date
+ self[var_ID][prediction_parameter].loc[index] = pred
elif prediction_parameter == "samples":
assert len(data.shape) == 3, (
f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
)
for sample_i, sample in enumerate(data):
- for var_ID, pred in zip(self.target_var_IDs, sample):
- self[var_ID][f"sample_{sample_i}"].loc[date] = pred
+ for i, (var_ID, pred) in enumerate(
+ zip(self.target_var_IDs, sample)
+ ):
+ if self.forecasting_mode:
+ index = (lead_times[i], date)
+ else:
+ index = date
+ self[var_ID][f"sample_{sample_i}"].loc[index] = pred
def create_empty_spatiotemporal_xarray(
X: Union[xr.Dataset, xr.DataArray],
- dates: List,
+ dates: List[Timestamp],
coord_names: dict = None,
data_vars: List[str] = None,
prepend_dims: Optional[List[str]] = None,
@@ -231,10 +304,6 @@ def create_empty_spatiotemporal_xarray(
# Convert time coord to pandas timestamps
pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
- # TODO: Convert init time to forecast time?
- # pred_ds = pred_ds.assign_coords(
- # time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0]))
-
return pred_ds
diff --git a/tests/test_model.py b/tests/test_model.py
index a0847f1f..f418f7c8 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -649,12 +649,17 @@ def test_ar_sample(self):
def test_forecasting_model_predict_return_valid_times(self):
"""Test that the times returned by a forecasting model are valid."""
+ init_dates = ["2020-01-01", "2020-01-02"]
+ expected_init_times = np.array(init_dates).astype(np.datetime64)
+
lead_times_days = [1, 2, 3]
- init_date = "2020-01-01"
- expected_lead_times = [pd.Timedelta(days=lt) for lt in lead_times_days]
- expected_valid_times = [
- pd.Timestamp(init_date) + lt for lt in expected_lead_times
- ]
+ expected_lead_times = np.array(
+ [np.timedelta64(lt, "D") for lt in lead_times_days]
+ )
+
+ expected_valid_times = np.array(
+ expected_lead_times[:, None] + expected_init_times[None, :]
+ )
tl = TaskLoader(
context=self.da,
@@ -666,14 +671,30 @@ def test_forecasting_model_predict_return_valid_times(self):
time_freq="D",
)
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
- task = tl(init_date, context_sampling=10)
- pred = model.predict(task, X_t=self.da)
- np.testing.assert_array_equal(
- pred[self.var_ID]["mean"].time.values, expected_valid_times
- )
- np.testing.assert_array_equal(
- pred[self.var_ID]["mean"].lead_time.values, expected_lead_times
- )
+ tasks = tl(init_dates, context_sampling=10)
+
+ X_ts = [
+ # Gridded predictions (xarray)
+ self.da,
+ # Off-grid prediction (pandas)
+ np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]),
+ ]
+ for X_t in X_ts:
+ pred = model.predict(tasks, X_t=X_t)
+
+ pred_var = pred[self.var_ID]
+
+ if isinstance(pred_var, pd.DataFrame):
+ # Makes coordinate checking easier by avoiding repeat values
+ pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
+
+ np.testing.assert_array_equal(
+ pred_var.lead_time.values, expected_lead_times
+ )
+ np.testing.assert_array_equal(
+ pred_var.init_time.values, expected_init_times
+ )
+ np.testing.assert_array_equal(pred_var.time.values, expected_valid_times)
def assert_shape(x, shape: tuple):
From 30f311b83135ab8c922bd32e27de943d439e08c1 Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Sun, 20 Oct 2024 12:54:35 +0100
Subject: [PATCH 5/8] Bump version
---
CITATION.cff | 4 ++--
README.md | 2 +-
setup.cfg | 2 +-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/CITATION.cff b/CITATION.cff
index 4e5a9beb..4dfafe1b 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -38,5 +38,5 @@ keywords:
- neural processes
- active learning
license: MIT
-version: 0.3.8
-date-released: '2024-07-28'
+version: 0.4.0
+date-released: '2024-10-20'
diff --git a/README.md b/README.md
index 459a2a72..0e8f2b3b 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@ data with neural processes
-----------
-[![release](https://img.shields.io/badge/release-v0.3.8-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
+[![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
diff --git a/setup.cfg b/setup.cfg
index eb671489..ccb2ad1e 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name = deepsensor
-version = 0.3.8
+version = 0.4.0
author = Tom R. Andersson
author_email = tomandersson3@gmail.com
description = A Python package for modelling xarray and pandas data with neural processes.
From 0e9202a9defe168bac978f1f301a073e57630373 Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Sun, 20 Oct 2024 15:52:38 +0100
Subject: [PATCH 6/8] Add forecast error to unit test
---
tests/test_model.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/tests/test_model.py b/tests/test_model.py
index f418f7c8..37bb0ff4 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -684,7 +684,12 @@ def test_forecasting_model_predict_return_valid_times(self):
pred_var = pred[self.var_ID]
- if isinstance(pred_var, pd.DataFrame):
+ if isinstance(pred_var, xr.Dataset):
+ # Check we can compute errors using the valid time coord ('time')
+ errors = pred_var["mean"] - self.da.sel(time=pred_var.time)
+ assert errors.dims == ("lead_time", "init_time", "x1", "x2")
+ assert errors.shape == pred_var.shape
+ elif isinstance(pred_var, pd.DataFrame):
# Makes coordinate checking easier by avoiding repeat values
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
From 0dc761403cac21ae90bcdbca957162e29071dc1f Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Sun, 20 Oct 2024 16:03:57 +0100
Subject: [PATCH 7/8] Fix .shape on dataset in unit test
---
tests/test_model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_model.py b/tests/test_model.py
index 37bb0ff4..cb643a95 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -688,7 +688,7 @@ def test_forecasting_model_predict_return_valid_times(self):
# Check we can compute errors using the valid time coord ('time')
errors = pred_var["mean"] - self.da.sel(time=pred_var.time)
assert errors.dims == ("lead_time", "init_time", "x1", "x2")
- assert errors.shape == pred_var.shape
+ assert errors.shape == pred_var["mean"].shape
elif isinstance(pred_var, pd.DataFrame):
# Makes coordinate checking easier by avoiding repeat values
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
From 37f71f6a5d123778a41966ff374bf343fbc9ddf2 Mon Sep 17 00:00:00 2001
From: Tom Andersson
Date: Sun, 20 Oct 2024 16:24:58 +0100
Subject: [PATCH 8/8] Add pointwise error computation to method and add to
forecasting unit test
---
deepsensor/eval/__init__.py | 1 +
deepsensor/eval/metrics.py | 24 ++++++++++++++++++++++++
tests/test_model.py | 13 ++++++++++---
3 files changed, 35 insertions(+), 3 deletions(-)
create mode 100644 deepsensor/eval/__init__.py
create mode 100644 deepsensor/eval/metrics.py
diff --git a/deepsensor/eval/__init__.py b/deepsensor/eval/__init__.py
new file mode 100644
index 00000000..1761d1aa
--- /dev/null
+++ b/deepsensor/eval/__init__.py
@@ -0,0 +1 @@
+from .metrics import *
diff --git a/deepsensor/eval/metrics.py b/deepsensor/eval/metrics.py
new file mode 100644
index 00000000..c9616430
--- /dev/null
+++ b/deepsensor/eval/metrics.py
@@ -0,0 +1,24 @@
+import xarray as xr
+from deepsensor.model.pred import Prediction
+
+
+def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
+ """
+ Compute errors between predictions and targets.
+
+ Args:
+ pred: Prediction object.
+ target: Target data.
+
+ Returns:
+ xr.Dataset: Dataset of pointwise differences between predictions and targets
+ at the same valid time in the predictions. Note, the difference is positive
+ when the prediction is greater than the target.
+ """
+ errors = {}
+ for var_ID, pred_var in pred.items():
+ target_var = target[var_ID]
+ error = pred_var["mean"] - target_var.sel(time=pred_var.time)
+ error.name = f"{var_ID}"
+ errors[var_ID] = error
+ return xr.Dataset(errors)
diff --git a/tests/test_model.py b/tests/test_model.py
index cb643a95..6a0836ac 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -18,6 +18,7 @@
from deepsensor.data.loader import TaskLoader
from deepsensor.model.convnp import ConvNP
from deepsensor.train.train import Trainer
+from deepsensor.eval.metrics import compute_errors
from tests.utils import gen_random_data_xr, gen_random_data_pandas
@@ -686,9 +687,15 @@ def test_forecasting_model_predict_return_valid_times(self):
if isinstance(pred_var, xr.Dataset):
# Check we can compute errors using the valid time coord ('time')
- errors = pred_var["mean"] - self.da.sel(time=pred_var.time)
- assert errors.dims == ("lead_time", "init_time", "x1", "x2")
- assert errors.shape == pred_var["mean"].shape
+ errors = compute_errors(pred, self.da.to_dataset())
+ for var_ID in errors.keys():
+ assert tuple(errors[var_ID].dims) == (
+ "lead_time",
+ "init_time",
+ "x1",
+ "x2",
+ )
+ assert errors[var_ID].shape == pred[var_ID]["mean"].shape
elif isinstance(pred_var, pd.DataFrame):
# Makes coordinate checking easier by avoiding repeat values
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)