Skip to content

Commit

Permalink
Make removing NaNs from Y_t part of Task data operations
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Sep 19, 2023
1 parent 5af8fc3 commit 95e709a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 58 deletions.
65 changes: 50 additions & 15 deletions deepsensor/data/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,40 @@ def cast_to_float32(self):
"""
return self.op(lambda x: x.astype(np.float32), op_flag="float32")

def remove_nans_from_task_Y_t_if_present(self):
"""If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"])"""
self["ops"].append("target_nans_removed")

# First check whether there are any NaNs that we need to remove
nans_present = False
for Y_t in self["Y_t"]:
if B.any(B.isnan(Y_t)):
nans_present = True

Y_t_nans_list = []
if nans_present:
for i, (X, Y) in enumerate(zip(self["X_t"], self["Y_t"])):
Y = flatten_Y(Y)
Y_t_nans = B.any(B.isnan(Y), axis=0) # shape (n_targets,)
Y_t_nans_list.append(Y_t_nans)

if not nans_present:
return self

# NaNs present in self - remove NaNs
for i, (X, Y, Y_t_nans) in enumerate(zip(self["X_t"], self["Y_t"], Y_t_nans_list)):
if B.any(Y_t_nans):
if isinstance(X, tuple):
# Gridded data
X = flatten_X(X)
Y = flatten_Y(Y)
self["X_t"][i] = X[:, ~Y_t_nans]
self["Y_t"][i] = Y[:, ~Y_t_nans]
if "Y_t_aux" in self.keys():
self["Y_t_aux"] = self["Y_t_aux"][:, ~Y_t_nans]

return self

def mask_nans_numpy(self):
"""Replace NaNs with zeroes and set a mask to indicate where the NaNs were.
Expand Down Expand Up @@ -291,6 +325,20 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
if len(tasks) == 1:
return tasks[0]

for i, task in enumerate(tasks):
if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]:
raise ValueError(
"Cannot concatenate tasks that have had NaNs masked. "
"Masking will be applied automatically after concatenation."
)
if "target_nans_removed" not in task["ops"]:
task = task.remove_nans_from_task_Y_t_if_present()
if "batch_dim" not in task["ops"]:
task = task.add_batch_dim()
if "float32" not in task["ops"]:
task = task.cast_to_float32()
tasks[i] = task

# Assert number of target sets equal
n_target_sets = [len(task["Y_t"]) for task in tasks]
if not all([n == n_target_sets[0] for n in n_target_sets]):
Expand All @@ -305,8 +353,8 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
if not all([n == n_target_obs[0] for n in n_target_obs]):
raise ValueError(
f"All tasks must have the same number of targets to concatenate: got {n_target_sets}. "
"If you want to train using batches containing tasks with differing numbers of targets, "
"you can run the model individually over each task and average the losses."
"To train with Task batches containing differing numbers of targets, "
"run the model individually over each task and average the losses."
)

# Raise error if target sets are different types (gridded/non-gridded) across tasks
Expand All @@ -321,19 +369,6 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
# For each task, store list of tuples of (x_c, y_c) (one tuple per context set)
contexts = []
for i, task in enumerate(tasks):
if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]:
raise ValueError(
"Cannot concatenate tasks that have had NaNs masked. "
"Masking will be applied automatically after concatenation."
)
# Ensure converted to tensors with batch dims
if "batch_dim" not in task["ops"]:
task = task.add_batch_dim()
if "float32" not in task["ops"]:
task = task.cast_to_float32()

tasks[i] = task

contexts_i = list(zip(task["X_c"], task["Y_c"]))
contexts.append(contexts_i)

Expand Down
45 changes: 2 additions & 43 deletions deepsensor/model/convnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def load(self, model_ID: str):
def modify_task(cls, task):
"""Cast numpy arrays to TensorFlow or PyTorch tensors, add batch dim, and mask NaNs"""

if "target_nans_removed" not in task["ops"]:
task = task.remove_nans_from_task_Y_t_if_present()
if "batch_dim" not in task["ops"]:
task = task.add_batch_dim()
if "float32" not in task["ops"]:
Expand Down Expand Up @@ -397,15 +399,6 @@ def loss_fn(self, task: Task, fix_noise=None, num_lv_samples=8, normalise=False)
-------
"""
# Remove NaNs from the target data if present
task, nans_present = remove_nans_from_task_Y_t_if_present(task)
# if nans_present:
# TODO raise error like:
# warnings.warn(
# "NaNs present in the target data. These will be removed before evaluating the loss.",
#
# )

task = ConvNP.modify_task(task)

context_data, xt, yt, model_kwargs = convert_task_to_nps_args(task)
Expand Down Expand Up @@ -522,37 +515,3 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
FutureWarning,
)
return deepsensor.data.task.concat_tasks(tasks, multiple)


def remove_nans_from_task_Y_t_if_present(task):
"""If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"])"""
# First check whether there are any NaNs that we need to remove
nans_present = False
for Y_t in task["Y_t"]:
if B.any(B.isnan(Y_t)):
nans_present = True

Y_t_nans_list = []
if nans_present:
for i, (X, Y) in enumerate(zip(task["X_t"], task["Y_t"])):
Y = flatten_Y(Y)
Y_t_nans = B.any(B.isnan(Y), axis=0) # shape (n_targets,)
Y_t_nans_list.append(Y_t_nans)

if not nans_present:
return task, False

# NaNs present in task - make deep copy and remove NaNs
task = copy.deepcopy(task)
for i, (X, Y, Y_t_nans) in enumerate(zip(task["X_t"], task["Y_t"], Y_t_nans_list)):
if B.any(Y_t_nans):
if isinstance(X, tuple):
# Gridded data
X = flatten_X(X)
Y = flatten_Y(Y)
task["X_t"][i] = X[:, ~Y_t_nans]
task["Y_t"][i] = Y[:, ~Y_t_nans]
if "Y_t_aux" in task.keys():
task["Y_t_aux"] = task["Y_t_aux"][:, ~Y_t_nans]

return task, True

0 comments on commit 95e709a

Please sign in to comment.