Skip to content

Commit

Permalink
Move mask processing to deepsensor.data.processor
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Sep 10, 2023
1 parent 0768554 commit faa5a23
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
30 changes: 7 additions & 23 deletions deepsensor/active_learning/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
mask_coord_array_normalised,
da1_da2_same_grid,
interp_da1_to_da2,
process_X_mask_for_X,
)
from deepsensor.model.model import DeepSensorModel, create_empty_spatiotemporal_xarray
from deepsensor.data.task import Task, append_obs_to_task
Expand Down Expand Up @@ -69,9 +70,6 @@ def __init__(
self.task_loader = task_loader
self.pbar = None

self.X_s_mask = X_s_mask
self.X_t_mask = X_t_mask

self.x1_name = self.model.data_processor.config["coords"]["x1"]["name"]
self.x2_name = self.model.data_processor.config["coords"]["x2"]["name"]

Expand All @@ -86,9 +84,14 @@ def __init__(

self.X_s = X_s
self.X_t = X_t
self.X_s_mask = X_s_mask
self.X_t_mask = X_t_mask

# Interpolate masks onto search and target coords
self.X_s_mask, self.X_t_mask = self._process_masks(X_s_mask, X_t_mask, X_s, X_t)
if self.X_s_mask is not None:
self.X_s_mask = process_X_mask_for_X(self.X_s_mask, self.X_s)
if self.X_t_mask is not None:
self.X_t_mask = process_X_mask_for_X(self.X_t_mask, self.X_t)

# Interpolate overridden infill datasets at search points if necessary
if query_infill is not None and not da1_da2_same_grid(query_infill, X_s):
Expand Down Expand Up @@ -141,25 +144,6 @@ def _validate_n_new_context(
f"and less than the number of search points ({N_s})"
)

@classmethod
def _process_masks(cls, X_s_mask: xr.DataArray, X_t_mask: xr.DataArray, X_s, X_t):
"""Process masks by interpolating to X_s and X_t"""
# TODO avoid repeated code
if X_s_mask is not None:
X_s_mask = X_s_mask.astype(float).interp_like(
X_s, method="nearest", kwargs={"fill_value": 0}
)
X_s_mask.data = X_s_mask.data.astype(bool)
X_s_mask.load()
if X_t_mask is not None:
X_t_mask = X_t_mask.astype(float).interp_like(
X_t, method="nearest", kwargs={"fill_value": 0}
)
X_t_mask.data = X_t_mask.data.astype(bool)
X_t_mask.load()

return X_s_mask, X_t_mask

def _get_times_from_tasks(self):
"""Get times from tasks"""
times = [task["time"] for task in self.tasks]
Expand Down
22 changes: 21 additions & 1 deletion deepsensor/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,27 @@ def xarray_to_coord_array_normalised(da: Union[xr.Dataset, xr.DataArray]):
return np.stack([X1.ravel(), X2.ravel()], axis=0)


def mask_coord_array_normalised(coord_arr, mask_da):
def process_X_mask_for_X(X_mask: xr.DataArray, X: xr.DataArray):
"""Process X_mask by interpolating to X and converting to boolean.
Both X_mask and X are xarray DataArrays with the same spatial coords.
"""
X_mask = X_mask.astype(float).interp_like(
X, method="nearest", kwargs={"fill_value": 0}
)
X_mask.data = X_mask.data.astype(bool)
X_mask.load()
return X_mask


def mask_coord_array_normalised(
coord_arr: np.ndarray, mask_da: Union[xr.DataArray, xr.Dataset, None]
):
"""Remove points from (2, N) numpy array that are outside gridded xarray boolean mask.
If `coord_arr` is shape `(2, N)`, then `mask_da` is a shape `(N,)` boolean array
(True if point is inside mask, False if outside).
"""
if mask_da is None:
return coord_arr
mask_da = mask_da.astype(float) # Temporarily convert to float for interpolation
Expand Down

0 comments on commit faa5a23

Please sign in to comment.