Skip to content

Commit

Permalink
Add X_t_mask kwarg to DeepSensorModel.predict
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Sep 10, 2023
1 parent 814ec86 commit ca4948c
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from deepsensor.data.loader import TaskLoader
from deepsensor.data.processor import DataProcessor
from deepsensor.data.processor import (
DataProcessor,
process_X_mask_for_X,
xarray_to_coord_array_normalised,
mask_coord_array_normalised,
)
from deepsensor.data.task import Task, flatten_X

from typing import List, Union
Expand Down Expand Up @@ -118,6 +123,7 @@ def predict(
X_t: Union[
xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index, np.ndarray
],
X_t_mask: Union[xr.Dataset, xr.DataArray] = None,
X_t_is_normalised: bool = False,
resolution_factor=1,
n_samples=0,
Expand All @@ -131,13 +137,12 @@ def predict(
):
"""Predict on a regular grid or at off-grid locations.
TODO:
- Test with multiple targets model
Args:
tasks: List of tasks containing context data.
X_t: Target locations to predict at. Can be an xarray object containing
on-grid locations or a pandas object containing off-grid locations.
X_t_mask: Optional 2D mask to apply to X_t (zero/False will be NaNs). Will be interpolated
to the same grid as X_t. Default None (no mask).
X_t_is_normalised: Whether the `X_t` coords are normalised.
If False, will normalise the coords before passing to model. Default False.
resolution_factor: Optional factor to increase the resolution of the
Expand Down Expand Up @@ -186,6 +191,10 @@ def predict(
f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
)

if mode == "off-grid" and X_t_mask is not None:
# TODO: Unit test this
raise ValueError("X_t_mask can only be used with on-grid predictions.")

if type(tasks) is Task:
tasks = [tasks]

Expand Down Expand Up @@ -228,7 +237,7 @@ def predict(

# Unnormalise coords to use for xarray/pandas objects for storing predictions
X_t = self.data_processor.map_coords(X_t, unnorm=True)
else:
elif not X_t_is_normalised:
# Normalise coords to use for model
X_t_normalised = self.data_processor.map_coords(X_t)

Expand All @@ -237,8 +246,15 @@ def predict(
X_t_normalised = increase_spatial_resolution(
X_t_normalised, resolution_factor
)
# TODO rename from _arr because not an array here
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)

if X_t_mask is not None:
X_t_mask = process_X_mask_for_X(X_t_mask, X_t)
X_t_mask_normalised = self.data_processor.map_coords(X_t_mask)
X_t_arr = xarray_to_coord_array_normalised(X_t_normalised)
# Remove points that lie outside the mask
X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised)
else:
X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values)
elif mode == "off-grid":
X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T

Expand Down Expand Up @@ -379,13 +395,22 @@ def unnormalise_pred_array(arr, **kwargs):
)

if mode == "on-grid":
mean.loc[:, task["time"], :, :] = mean_arr
std.loc[:, task["time"], :, :] = std_arr
if n_samples >= 1:
for sample_i in range(n_samples):
samples.loc[:, sample_i, task["time"], :, :] = samples_arr[
sample_i
]
if X_t_mask is None:
mean.loc[:, task["time"], :, :] = mean_arr
std.loc[:, task["time"], :, :] = std_arr
if n_samples >= 1:
for sample_i in range(n_samples):
samples.loc[:, sample_i, task["time"], :, :] = samples_arr[
sample_i
]
else:
mean.loc[:, task["time"], :, :].data[:, X_t_mask.data] = mean_arr
std.loc[:, task["time"], :, :].data[:, X_t_mask.data] = std_arr
if n_samples >= 1:
for sample_i in range(n_samples):
samples.loc[:, sample_i, task["time"], :, :].data[
:, X_t_mask.data
] = samples_arr[sample_i]
elif mode == "off-grid":
# TODO multi-target case
mean.loc[task["time"]] = mean_arr.T
Expand Down

0 comments on commit ca4948c

Please sign in to comment.