Skip to content

Commit

Permalink
Relative imports are difficult
Browse files Browse the repository at this point in the history
  • Loading branch information
Kalle Westerling committed Oct 2, 2023
1 parent 1819cd9 commit 4a4af88
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions deepsensor/data/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import plum
import copy

from . import backend
from .. import backend
from ..errors import TaskSetIndexError, GriddedDataError


Expand Down Expand Up @@ -109,7 +109,9 @@ def recurse(k, v):
return [recurse(k, vi) for vi in v]
elif type(v) is tuple:
return (recurse(k, v[0]), recurse(k, v[1]))
elif isinstance(v, (np.ndarray, np.ma.MaskedArray, backend.nps.Masked)):
elif isinstance(
v, (np.ndarray, np.ma.MaskedArray, backend.nps.Masked)
):
return f(v)
else:
return v # covers metadata entries
Expand Down Expand Up @@ -183,7 +185,9 @@ def mask_nans_numpy(self):
task : dict. Task with NaNs set to zeros and a mask indicating where the missing values are.
"""
if "batch_dim" not in self["ops"]:
raise ValueError("Must call `add_batch_dim` before `mask_nans_numpy`")
raise ValueError(
"Must call `add_batch_dim` before `mask_nans_numpy`"
)

def f(arr):
if isinstance(arr, backend.nps.Masked):
Expand All @@ -205,9 +209,13 @@ def f(arr):

def mask_nans_nps(self):
if "batch_dim" not in self["ops"]:
raise ValueError("Must call `add_batch_dim` before `mask_nans_nps`")
raise ValueError(
"Must call `add_batch_dim` before `mask_nans_nps`"
)
if "numpy_mask" not in self["ops"]:
raise ValueError("Must call `mask_nans_numpy` before `mask_nans_nps`")
raise ValueError(
"Must call `mask_nans_numpy` before `mask_nans_nps`"
)

def f(arr):
if isinstance(arr, np.ma.MaskedArray):
Expand Down Expand Up @@ -286,7 +294,9 @@ def append_obs_to_task(
return task_with_new


def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
def flatten_X(
X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
) -> np.ndarray:
"""
Convert tuple of gridded coords to (2, N) array if necessary.
Expand All @@ -306,7 +316,9 @@ def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray
return X


def flatten_Y(Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
def flatten_Y(
Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
) -> np.ndarray:
"""
Convert gridded data of shape (N_dim, N_x1, N_x2) to (N_dim, N_x1 * N_x2)
array if necessary.
Expand Down Expand Up @@ -463,7 +475,9 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
)
else:
# Target set is off-the-grid with tensor for `X_t`
merged_task["X_t"][i] = B.concat(*[t["X_t"][i] for t in tasks], axis=0)
merged_task["X_t"][i] = B.concat(
*[t["X_t"][i] for t in tasks], axis=0
)
merged_task["Y_t"][i] = B.concat(*[t["Y_t"][i] for t in tasks], axis=0)

merged_task["time"] = [t["time"] for t in tasks]
Expand Down

0 comments on commit 4a4af88

Please sign in to comment.