diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 1dc06652..aa1d1bb6 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -1,5 +1,10 @@ from deepsensor.data.loader import TaskLoader -from deepsensor.data.processor import DataProcessor +from deepsensor.data.processor import ( + DataProcessor, + process_X_mask_for_X, + xarray_to_coord_array_normalised, + mask_coord_array_normalised, +) from deepsensor.data.task import Task, flatten_X from typing import List, Union @@ -118,6 +123,7 @@ def predict( X_t: Union[ xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index, np.ndarray ], + X_t_mask: Union[xr.Dataset, xr.DataArray] = None, X_t_is_normalised: bool = False, resolution_factor=1, n_samples=0, @@ -131,13 +137,12 @@ def predict( ): """Predict on a regular grid or at off-grid locations. - TODO: - - Test with multiple targets model - Args: tasks: List of tasks containing context data. X_t: Target locations to predict at. Can be an xarray object containing on-grid locations or a pandas object containing off-grid locations. + X_t_mask: Optional 2D mask to apply to X_t (zero/False will be NaNs). Will be interpolated + to the same grid as X_t. Default None (no mask). X_t_is_normalised: Whether the `X_t` coords are normalised. If False, will normalise the coords before passing to model. Default False. resolution_factor: Optional factor to increase the resolution of the @@ -186,6 +191,10 @@ def predict( f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}." ) + if mode == "off-grid" and X_t_mask is not None: + # TODO: Unit test this + raise ValueError("X_t_mask can only be used with on-grid predictions.") + if type(tasks) is Task: tasks = [tasks] @@ -228,7 +237,7 @@ def predict( # Unnormalise coords to use for xarray/pandas objects for storing predictions X_t = self.data_processor.map_coords(X_t, unnorm=True) - else: + elif not X_t_is_normalised: # Normalise coords to use for model X_t_normalised = self.data_processor.map_coords(X_t) @@ -237,8 +246,15 @@ def predict( X_t_normalised = increase_spatial_resolution( X_t_normalised, resolution_factor ) - # TODO rename from _arr because not an array here - X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values) + + if X_t_mask is not None: + X_t_mask = process_X_mask_for_X(X_t_mask, X_t) + X_t_mask_normalised = self.data_processor.map_coords(X_t_mask) + X_t_arr = xarray_to_coord_array_normalised(X_t_normalised) + # Remove points that lie outside the mask + X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised) + else: + X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values) elif mode == "off-grid": X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T @@ -379,13 +395,22 @@ def unnormalise_pred_array(arr, **kwargs): ) if mode == "on-grid": - mean.loc[:, task["time"], :, :] = mean_arr - std.loc[:, task["time"], :, :] = std_arr - if n_samples >= 1: - for sample_i in range(n_samples): - samples.loc[:, sample_i, task["time"], :, :] = samples_arr[ - sample_i - ] + if X_t_mask is None: + mean.loc[:, task["time"], :, :] = mean_arr + std.loc[:, task["time"], :, :] = std_arr + if n_samples >= 1: + for sample_i in range(n_samples): + samples.loc[:, sample_i, task["time"], :, :] = samples_arr[ + sample_i + ] + else: + mean.loc[:, task["time"], :, :].data[:, X_t_mask.data] = mean_arr + std.loc[:, task["time"], :, :].data[:, X_t_mask.data] = std_arr + if n_samples >= 1: + for sample_i in range(n_samples): + samples.loc[:, sample_i, task["time"], :, :].data[ + :, X_t_mask.data + ] = samples_arr[sample_i] elif mode == "off-grid": # TODO multi-target case mean.loc[task["time"]] = mean_arr.T