From c462af8cc2879d10aa2d9c41f7357cae29e3abb8 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Fri, 18 Oct 2024 14:52:26 +0100 Subject: [PATCH 1/8] Add unit test for forecasting valid times --- tests/test_model.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 5193c480..85670227 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -55,8 +55,15 @@ class TestModel(unittest.TestCase): def setUpClass(cls): # super().__init__(*args, **kwargs) # It's safe to share data between tests because the TaskLoader does not modify data + cls.var_ID = "2m_temp" cls.da = _gen_data_xr() + cls.da.name = cls.var_ID cls.df = _gen_data_pandas() + cls.df.name = cls.var_ID + # Various tests assume we have a single target set with a single variable. + # If a test requires multiple target sets or variables, this is set up in the test. + assert isinstance(cls.da, xr.DataArray) + assert isinstance(cls.df, pd.Series) cls.dp = DataProcessor() _ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters @@ -417,10 +424,10 @@ def test_highlevel_predict_coords_align_with_X_t_ongrid(self): task = tl("2020-01-01") pred = model.predict(task, X_t=da_raw) - assert np.array_equal( + np.testing.assert_array_equal( pred["dummy_data"]["mean"]["latitude"], da_raw["latitude"] ) - assert np.array_equal( + np.testing.assert_array_equal( pred["dummy_data"]["mean"]["longitude"], da_raw["longitude"] ) @@ -493,14 +500,14 @@ def test_highlevel_predict_with_pred_params_pandas(self): # Check that nothing breaks and the correct parameters are returned pred = model.predict(task, X_t=X_t, pred_params=pred_params) for pred_param in pred_params: - assert pred_param in pred["var"] + assert pred_param in pred[self.var_ID] # Test mixture probs special case pred_params = ["mixture_probs"] pred = model.predict(task, X_t=self.da, pred_params=pred_params) for component in range(model.N_mixture_components): pred_param = f"mixture_probs_{component}" - assert pred_param in pred["var"] + assert pred_param in pred[self.var_ID] def test_highlevel_predict_with_pred_params_xarray(self): """ @@ -528,14 +535,14 @@ def test_highlevel_predict_with_pred_params_xarray(self): # Check that nothing breaks and the correct parameters are returned pred = model.predict(task, X_t=self.da, pred_params=pred_params) for pred_param in pred_params: - assert pred_param in pred["var"] + assert pred_param in pred[self.var_ID] # Test mixture probs special case pred_params = ["mixture_probs"] pred = model.predict(task, X_t=self.da, pred_params=pred_params) for component in range(model.N_mixture_components): pred_param = f"mixture_probs_{component}" - assert pred_param in pred["var"] + assert pred_param in pred[self.var_ID] def test_highlevel_predict_with_invalid_pred_params(self): """Test that passing ``pred_params`` to ``.predict`` works.""" @@ -640,6 +647,25 @@ def test_ar_sample(self): ar_sample=True, ) + def test_forecasting_model_predict_return_valid_times(self): + """Test that the times returned by a forecasting model are valid.""" + lead_times_days = [1, 2, 3] + init_date = "2020-01-01" + expected_valid_times = [ + pd.Timestamp(init_date) + pd.DateOffset(days=lt) for lt in lead_times_days + ] + + tl = TaskLoader( + context=self.da, + target=[self.da,] * len(lead_times_days), + target_delta_t=lead_times_days, + time_freq="D", + ) + model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False) + task = tl(init_date, context_sampling=10) + pred = model.predict(task, X_t=self.da) + np.testing.assert_array_equal(pred[self.var_ID]["mean"].time.values, expected_valid_times) + def assert_shape(x, shape: tuple): """Assert that the shape of ``x`` matches ``shape``.""" From 5ec03a7d39e18afd021b2c52393a728c3e00017b Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Fri, 18 Oct 2024 15:00:00 +0100 Subject: [PATCH 2/8] run black --- tests/test_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 85670227..6e7497f5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -657,14 +657,19 @@ def test_forecasting_model_predict_return_valid_times(self): tl = TaskLoader( context=self.da, - target=[self.da,] * len(lead_times_days), + target=[ + self.da, + ] + * len(lead_times_days), target_delta_t=lead_times_days, time_freq="D", ) model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False) task = tl(init_date, context_sampling=10) pred = model.predict(task, X_t=self.da) - np.testing.assert_array_equal(pred[self.var_ID]["mean"].time.values, expected_valid_times) + np.testing.assert_array_equal( + pred[self.var_ID]["mean"].time.values, expected_valid_times + ) def assert_shape(x, shape: tuple): From 1e0c4a06c3cc0deb54b3b1e097485e5ed71c3ca9 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Fri, 18 Oct 2024 15:04:43 +0100 Subject: [PATCH 3/8] Add lead_time coord to forecasting unit test --- tests/test_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 6e7497f5..a0847f1f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -651,8 +651,9 @@ def test_forecasting_model_predict_return_valid_times(self): """Test that the times returned by a forecasting model are valid.""" lead_times_days = [1, 2, 3] init_date = "2020-01-01" + expected_lead_times = [pd.Timedelta(days=lt) for lt in lead_times_days] expected_valid_times = [ - pd.Timestamp(init_date) + pd.DateOffset(days=lt) for lt in lead_times_days + pd.Timestamp(init_date) + lt for lt in expected_lead_times ] tl = TaskLoader( @@ -670,6 +671,9 @@ def test_forecasting_model_predict_return_valid_times(self): np.testing.assert_array_equal( pred[self.var_ID]["mean"].time.values, expected_valid_times ) + np.testing.assert_array_equal( + pred[self.var_ID]["mean"].lead_time.values, expected_lead_times + ) def assert_shape(x, shape: tuple): From edd477210fb7adb82cd25eb838ebbf560312fcc1 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Sun, 20 Oct 2024 12:52:00 +0100 Subject: [PATCH 4/8] Add & dims and valid coord for model.predict forecastin; Add unit test for on-/off-grid --- deepsensor/model/model.py | 77 ++++++++++++++++++++++++++-- deepsensor/model/pred.py | 103 +++++++++++++++++++++++++++++++------- tests/test_model.py | 47 ++++++++++++----- 3 files changed, 193 insertions(+), 34 deletions(-) diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 9c69049c..b1e318cb 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -348,6 +348,34 @@ def predict( if ar_sample and n_samples < 1: raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.") + target_delta_t = self.task_loader.target_delta_t + dts = [pd.Timedelta(dt) for dt in target_delta_t] + dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts]) + if target_delta_t is not None and dts_all_zero: + forecasting_mode = False + lead_times = None + elif target_delta_t is not None and not dts_all_zero: + target_var_IDs_set = set(self.task_loader.target_var_IDs) + msg = f""" + Got more than one set of target variables in target sets, + but predictions can only be made with one set of target variables + to simplify implementation. + Got {target_var_IDs_set}. + """ + assert len(target_var_IDs_set) == 1, msg + # Repeat lead_tim for each variable in each target set + lead_times = [] + for target_set_idx, dt in enumerate(target_delta_t): + target_set_dim = self.task_loader.target_dims[target_set_idx] + lead_times += [ + pd.Timedelta(dt, unit=self.task_loader.time_freq) + for _ in range(target_set_dim) + ] + forecasting_mode = True + else: + forecasting_mode = False + lead_times = None + if type(tasks) is Task: tasks = [tasks] @@ -355,12 +383,14 @@ def predict( B.set_random_seed(seed) np.random.seed(seed) - dates = [task["time"] for task in tasks] + init_dates = [task["time"] for task in tasks] # Flatten tuple of tuples to single list target_var_IDs = [ var_ID for set in self.task_loader.target_var_IDs for var_ID in set ] + if lead_times is not None: + assert len(lead_times) == len(target_var_IDs) # TODO consider removing this logic, can we just depend on the dim names in X_t? if not unnormalise: @@ -450,11 +480,13 @@ def predict( pred = Prediction( target_var_IDs, pred_params_to_store, - dates, + init_dates, X_t, X_t_mask, coord_names, n_samples=n_samples, + forecasting_mode=forecasting_mode, + lead_times=lead_times, ) def unnormalise_pred_array(arr, **kwargs): @@ -605,14 +637,22 @@ def unnormalise_pred_array(arr, **kwargs): # Assign predictions to Prediction object for param, arr in prediction_arrs.items(): if param != "mixture_probs": - pred.assign(param, task["time"], arr) + pred.assign(param, task["time"], arr, lead_times=lead_times) elif param == "mixture_probs": assert arr.shape[0] == self.N_mixture_components, ( f"Number of mixture components ({arr.shape[0]}) does not match " f"model attribute N_mixture_components ({self.N_mixture_components})." ) for component_i, probs in enumerate(arr): - pred.assign(f"{param}_{component_i}", task["time"], probs) + pred.assign( + f"{param}_{component_i}", + task["time"], + probs, + lead_times=lead_times, + ) + + if forecasting_mode: + pred = add_valid_time_coord_to_pred(pred) if verbose: dur = time.time() - tic @@ -621,6 +661,35 @@ def unnormalise_pred_array(arr, **kwargs): return pred +def add_valid_time_coord_to_pred(pred: Prediction) -> Prediction: + """ + Add a valid time coordinate "time" to a Prediction object based on the + initialisation times "init_time" and lead times "lead_time". + + Args: + pred (:class:`~.model.pred.Prediction`): + Prediction object to add valid time coordinate to. + + Returns: + :class:`~.model.pred.Prediction`: + Prediction object with valid time coordinate added. + """ + for var_ID in pred.keys(): + if isinstance(pred[var_ID], pd.DataFrame): + x = pred[var_ID].reset_index() + pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values + print(f"{x}") + print(f"{x.dtypes}") + elif isinstance(pred[var_ID], xr.Dataset): + x = pred[var_ID] + pred[var_ID] = pred[var_ID].assign_coords( + time=x["lead_time"] + x["init_time"] + ) + else: + raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.") + return pred + + def main(): # pragma: no cover import deepsensor.tensorflow from deepsensor.data.loader import TaskLoader diff --git a/deepsensor/model/pred.py b/deepsensor/model/pred.py index 00d49f81..3b6d814a 100644 --- a/deepsensor/model/pred.py +++ b/deepsensor/model/pred.py @@ -4,6 +4,8 @@ import pandas as pd import xarray as xr +Timestamp = Union[str, pd.Timestamp, np.datetime64] + class Prediction(dict): """ @@ -32,13 +34,20 @@ class Prediction(dict): n_samples (int) Number of joint samples to draw from the model. If 0, will not draw samples. Default 0. + forecasting_mode (bool) + If True, stored forecast predictions with an init_time and lead_time dimension, + and a valid_time coordinate. If False, stores prediction at t=0 only + (i.e. spatial interpolation), with only a single time dimension. Default False. + lead_times (List[pd.Timedelta], optional) + List of lead times to store in predictions. Must be provided if + forecasting_mode is True. Default None. """ def __init__( self, target_var_IDs: List[str], pred_params: List[str], - dates: List[Union[str, pd.Timestamp]], + dates: List[Timestamp], X_t: Union[ xr.Dataset, xr.DataArray, @@ -50,6 +59,8 @@ def __init__( X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, coord_names: dict = None, n_samples: int = 0, + forecasting_mode: bool = False, + lead_times: Optional[List[pd.Timedelta]] = None, ): self.target_var_IDs = target_var_IDs self.X_t_mask = X_t_mask @@ -58,6 +69,13 @@ def __init__( self.x1_name = coord_names["x1"] self.x2_name = coord_names["x2"] + self.forecasting_mode = forecasting_mode + if forecasting_mode: + assert ( + lead_times is not None + ), "If forecasting_mode is True, lead_times must be provided." + self.lead_times = lead_times + self.mode = infer_prediction_modality_from_X_t(X_t) self.pred_params = pred_params @@ -67,15 +85,25 @@ def __init__( *[f"sample_{i}" for i in range(n_samples)], ] + # Create empty xarray/pandas objects to store predictions if self.mode == "on-grid": for var_ID in self.target_var_IDs: - # Create empty xarray/pandas objects to store predictions + if self.forecasting_mode: + prepend_dims = ["lead_time"] + prepend_coords = {"lead_time": lead_times} + else: + prepend_dims = None + prepend_coords = None self[var_ID] = create_empty_spatiotemporal_xarray( X_t, dates, data_vars=self.pred_params, coord_names=coord_names, + prepend_dims=prepend_dims, + prepend_coords=prepend_coords, ) + if self.forecasting_mode: + self[var_ID] = self[var_ID].rename(time="init_time") if self.X_t_mask is None: # Create 2D boolean array of True values to simplify indexing self.X_t_mask = ( @@ -86,8 +114,18 @@ def __init__( ) elif self.mode == "off-grid": # Repeat target locs for each date to create multiindex - idxs = [(date, *idxs) for date in dates for idxs in X_t.index] - index = pd.MultiIndex.from_tuples(idxs, names=["time", *X_t.index.names]) + if self.forecasting_mode: + index_names = ["lead_time", "init_time", *X_t.index.names] + idxs = [ + (lt, date, *idxs) + for lt in lead_times + for date in dates + for idxs in X_t.index + ] + else: + index_names = ["time", *X_t.index.names] + idxs = [(date, *idxs) for date in dates for idxs in X_t.index] + index = pd.MultiIndex.from_tuples(idxs, names=index_names) for var_ID in self.target_var_IDs: self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params) @@ -106,6 +144,7 @@ def assign( prediction_parameter: str, date: Union[str, pd.Timestamp], data: np.ndarray, + lead_times: Optional[List[pd.Timedelta]] = None, ): """ @@ -117,11 +156,29 @@ def assign( data (np.ndarray) If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets). If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2). + lead_time (pd.Timedelta, optional) + Lead time of the forecast. Required if forecasting_mode is True. Default None. """ + if self.forecasting_mode: + assert ( + lead_times is not None + ), "If forecasting_mode is True, lead_times must be provided." + + msg = f""" + If forecasting_mode is True, lead_times must be of equal length to the number of + variables in the data (the first dimension). Got {lead_times=} of length + {len(lead_times)} lead times and data shape {data.shape}. + """ + assert len(lead_times) == data.shape[0], msg + if self.mode == "on-grid": if prediction_parameter != "samples": - for var_ID, pred in zip(self.target_var_IDs, data): - self[var_ID][prediction_parameter].loc[date].data[ + for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)): + if self.forecasting_mode: + index = (lead_times[i], date) + else: + index = date + self[var_ID][prediction_parameter].loc[index].data[ self.X_t_mask.data ] = pred.ravel() elif prediction_parameter == "samples": @@ -130,28 +187,44 @@ def assign( f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}." ) for sample_i, sample in enumerate(data): - for var_ID, pred in zip(self.target_var_IDs, sample): - self[var_ID][f"sample_{sample_i}"].loc[date].data[ + for i, (var_ID, pred) in enumerate( + zip(self.target_var_IDs, sample) + ): + if self.forecasting_mode: + index = (lead_times[i], date) + else: + index = date + self[var_ID][f"sample_{sample_i}"].loc[index].data[ self.X_t_mask.data ] = pred.ravel() elif self.mode == "off-grid": if prediction_parameter != "samples": - for var_ID, pred in zip(self.target_var_IDs, data): - self[var_ID][prediction_parameter].loc[date] = pred + for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)): + if self.forecasting_mode: + index = (lead_times[i], date) + else: + index = date + self[var_ID][prediction_parameter].loc[index] = pred elif prediction_parameter == "samples": assert len(data.shape) == 3, ( f"If prediction_parameter is 'samples', and mode is 'off-grid', data must" f"have shape (N_samples, N_var, N_targets). Got {data.shape}." ) for sample_i, sample in enumerate(data): - for var_ID, pred in zip(self.target_var_IDs, sample): - self[var_ID][f"sample_{sample_i}"].loc[date] = pred + for i, (var_ID, pred) in enumerate( + zip(self.target_var_IDs, sample) + ): + if self.forecasting_mode: + index = (lead_times[i], date) + else: + index = date + self[var_ID][f"sample_{sample_i}"].loc[index] = pred def create_empty_spatiotemporal_xarray( X: Union[xr.Dataset, xr.DataArray], - dates: List, + dates: List[Timestamp], coord_names: dict = None, data_vars: List[str] = None, prepend_dims: Optional[List[str]] = None, @@ -231,10 +304,6 @@ def create_empty_spatiotemporal_xarray( # Convert time coord to pandas timestamps pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values)) - # TODO: Convert init time to forecast time? - # pred_ds = pred_ds.assign_coords( - # time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0])) - return pred_ds diff --git a/tests/test_model.py b/tests/test_model.py index a0847f1f..f418f7c8 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -649,12 +649,17 @@ def test_ar_sample(self): def test_forecasting_model_predict_return_valid_times(self): """Test that the times returned by a forecasting model are valid.""" + init_dates = ["2020-01-01", "2020-01-02"] + expected_init_times = np.array(init_dates).astype(np.datetime64) + lead_times_days = [1, 2, 3] - init_date = "2020-01-01" - expected_lead_times = [pd.Timedelta(days=lt) for lt in lead_times_days] - expected_valid_times = [ - pd.Timestamp(init_date) + lt for lt in expected_lead_times - ] + expected_lead_times = np.array( + [np.timedelta64(lt, "D") for lt in lead_times_days] + ) + + expected_valid_times = np.array( + expected_lead_times[:, None] + expected_init_times[None, :] + ) tl = TaskLoader( context=self.da, @@ -666,14 +671,30 @@ def test_forecasting_model_predict_return_valid_times(self): time_freq="D", ) model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False) - task = tl(init_date, context_sampling=10) - pred = model.predict(task, X_t=self.da) - np.testing.assert_array_equal( - pred[self.var_ID]["mean"].time.values, expected_valid_times - ) - np.testing.assert_array_equal( - pred[self.var_ID]["mean"].lead_time.values, expected_lead_times - ) + tasks = tl(init_dates, context_sampling=10) + + X_ts = [ + # Gridded predictions (xarray) + self.da, + # Off-grid prediction (pandas) + np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]), + ] + for X_t in X_ts: + pred = model.predict(tasks, X_t=X_t) + + pred_var = pred[self.var_ID] + + if isinstance(pred_var, pd.DataFrame): + # Makes coordinate checking easier by avoiding repeat values + pred_var = pred_var.to_xarray().isel(x1=0, x2=0) + + np.testing.assert_array_equal( + pred_var.lead_time.values, expected_lead_times + ) + np.testing.assert_array_equal( + pred_var.init_time.values, expected_init_times + ) + np.testing.assert_array_equal(pred_var.time.values, expected_valid_times) def assert_shape(x, shape: tuple): From 30f311b83135ab8c922bd32e27de943d439e08c1 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Sun, 20 Oct 2024 12:54:35 +0100 Subject: [PATCH 5/8] Bump version --- CITATION.cff | 4 ++-- README.md | 2 +- setup.cfg | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 4e5a9beb..4dfafe1b 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -38,5 +38,5 @@ keywords: - neural processes - active learning license: MIT -version: 0.3.8 -date-released: '2024-07-28' +version: 0.4.0 +date-released: '2024-10-20' diff --git a/README.md b/README.md index 459a2a72..0e8f2b3b 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ data with neural processes

----------- -[![release](https://img.shields.io/badge/release-v0.3.8-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases) +[![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases) [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/) ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg) [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main) diff --git a/setup.cfg b/setup.cfg index eb671489..ccb2ad1e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = deepsensor -version = 0.3.8 +version = 0.4.0 author = Tom R. Andersson author_email = tomandersson3@gmail.com description = A Python package for modelling xarray and pandas data with neural processes. From 0e9202a9defe168bac978f1f301a073e57630373 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Sun, 20 Oct 2024 15:52:38 +0100 Subject: [PATCH 6/8] Add forecast error to unit test --- tests/test_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index f418f7c8..37bb0ff4 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -684,7 +684,12 @@ def test_forecasting_model_predict_return_valid_times(self): pred_var = pred[self.var_ID] - if isinstance(pred_var, pd.DataFrame): + if isinstance(pred_var, xr.Dataset): + # Check we can compute errors using the valid time coord ('time') + errors = pred_var["mean"] - self.da.sel(time=pred_var.time) + assert errors.dims == ("lead_time", "init_time", "x1", "x2") + assert errors.shape == pred_var.shape + elif isinstance(pred_var, pd.DataFrame): # Makes coordinate checking easier by avoiding repeat values pred_var = pred_var.to_xarray().isel(x1=0, x2=0) From 0dc761403cac21ae90bcdbca957162e29071dc1f Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Sun, 20 Oct 2024 16:03:57 +0100 Subject: [PATCH 7/8] Fix .shape on dataset in unit test --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 37bb0ff4..cb643a95 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -688,7 +688,7 @@ def test_forecasting_model_predict_return_valid_times(self): # Check we can compute errors using the valid time coord ('time') errors = pred_var["mean"] - self.da.sel(time=pred_var.time) assert errors.dims == ("lead_time", "init_time", "x1", "x2") - assert errors.shape == pred_var.shape + assert errors.shape == pred_var["mean"].shape elif isinstance(pred_var, pd.DataFrame): # Makes coordinate checking easier by avoiding repeat values pred_var = pred_var.to_xarray().isel(x1=0, x2=0) From 37f71f6a5d123778a41966ff374bf343fbc9ddf2 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Sun, 20 Oct 2024 16:24:58 +0100 Subject: [PATCH 8/8] Add pointwise error computation to method and add to forecasting unit test --- deepsensor/eval/__init__.py | 1 + deepsensor/eval/metrics.py | 24 ++++++++++++++++++++++++ tests/test_model.py | 13 ++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 deepsensor/eval/__init__.py create mode 100644 deepsensor/eval/metrics.py diff --git a/deepsensor/eval/__init__.py b/deepsensor/eval/__init__.py new file mode 100644 index 00000000..1761d1aa --- /dev/null +++ b/deepsensor/eval/__init__.py @@ -0,0 +1 @@ +from .metrics import * diff --git a/deepsensor/eval/metrics.py b/deepsensor/eval/metrics.py new file mode 100644 index 00000000..c9616430 --- /dev/null +++ b/deepsensor/eval/metrics.py @@ -0,0 +1,24 @@ +import xarray as xr +from deepsensor.model.pred import Prediction + + +def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset: + """ + Compute errors between predictions and targets. + + Args: + pred: Prediction object. + target: Target data. + + Returns: + xr.Dataset: Dataset of pointwise differences between predictions and targets + at the same valid time in the predictions. Note, the difference is positive + when the prediction is greater than the target. + """ + errors = {} + for var_ID, pred_var in pred.items(): + target_var = target[var_ID] + error = pred_var["mean"] - target_var.sel(time=pred_var.time) + error.name = f"{var_ID}" + errors[var_ID] = error + return xr.Dataset(errors) diff --git a/tests/test_model.py b/tests/test_model.py index cb643a95..6a0836ac 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,6 +18,7 @@ from deepsensor.data.loader import TaskLoader from deepsensor.model.convnp import ConvNP from deepsensor.train.train import Trainer +from deepsensor.eval.metrics import compute_errors from tests.utils import gen_random_data_xr, gen_random_data_pandas @@ -686,9 +687,15 @@ def test_forecasting_model_predict_return_valid_times(self): if isinstance(pred_var, xr.Dataset): # Check we can compute errors using the valid time coord ('time') - errors = pred_var["mean"] - self.da.sel(time=pred_var.time) - assert errors.dims == ("lead_time", "init_time", "x1", "x2") - assert errors.shape == pred_var["mean"].shape + errors = compute_errors(pred, self.da.to_dataset()) + for var_ID in errors.keys(): + assert tuple(errors[var_ID].dims) == ( + "lead_time", + "init_time", + "x1", + "x2", + ) + assert errors[var_ID].shape == pred[var_ID]["mean"].shape elif isinstance(pred_var, pd.DataFrame): # Makes coordinate checking easier by avoiding repeat values pred_var = pred_var.to_xarray().isel(x1=0, x2=0)