From 95e709a316c9081ddb69a6bfbba284af9cd919fb Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Tue, 19 Sep 2023 21:57:22 +0100 Subject: [PATCH] Make removing NaNs from Y_t part of `Task` data operations --- deepsensor/data/task.py | 65 +++++++++++++++++++++++++++++--------- deepsensor/model/convnp.py | 45 ++------------------------ 2 files changed, 52 insertions(+), 58 deletions(-) diff --git a/deepsensor/data/task.py b/deepsensor/data/task.py index e5c54b5e..6eac3556 100644 --- a/deepsensor/data/task.py +++ b/deepsensor/data/task.py @@ -127,6 +127,40 @@ def cast_to_float32(self): """ return self.op(lambda x: x.astype(np.float32), op_flag="float32") + def remove_nans_from_task_Y_t_if_present(self): + """If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"])""" + self["ops"].append("target_nans_removed") + + # First check whether there are any NaNs that we need to remove + nans_present = False + for Y_t in self["Y_t"]: + if B.any(B.isnan(Y_t)): + nans_present = True + + Y_t_nans_list = [] + if nans_present: + for i, (X, Y) in enumerate(zip(self["X_t"], self["Y_t"])): + Y = flatten_Y(Y) + Y_t_nans = B.any(B.isnan(Y), axis=0) # shape (n_targets,) + Y_t_nans_list.append(Y_t_nans) + + if not nans_present: + return self + + # NaNs present in self - remove NaNs + for i, (X, Y, Y_t_nans) in enumerate(zip(self["X_t"], self["Y_t"], Y_t_nans_list)): + if B.any(Y_t_nans): + if isinstance(X, tuple): + # Gridded data + X = flatten_X(X) + Y = flatten_Y(Y) + self["X_t"][i] = X[:, ~Y_t_nans] + self["Y_t"][i] = Y[:, ~Y_t_nans] + if "Y_t_aux" in self.keys(): + self["Y_t_aux"] = self["Y_t_aux"][:, ~Y_t_nans] + + return self + def mask_nans_numpy(self): """Replace NaNs with zeroes and set a mask to indicate where the NaNs were. @@ -291,6 +325,20 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task: if len(tasks) == 1: return tasks[0] + for i, task in enumerate(tasks): + if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]: + raise ValueError( + "Cannot concatenate tasks that have had NaNs masked. " + "Masking will be applied automatically after concatenation." + ) + if "target_nans_removed" not in task["ops"]: + task = task.remove_nans_from_task_Y_t_if_present() + if "batch_dim" not in task["ops"]: + task = task.add_batch_dim() + if "float32" not in task["ops"]: + task = task.cast_to_float32() + tasks[i] = task + # Assert number of target sets equal n_target_sets = [len(task["Y_t"]) for task in tasks] if not all([n == n_target_sets[0] for n in n_target_sets]): @@ -305,8 +353,8 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task: if not all([n == n_target_obs[0] for n in n_target_obs]): raise ValueError( f"All tasks must have the same number of targets to concatenate: got {n_target_sets}. " - "If you want to train using batches containing tasks with differing numbers of targets, " - "you can run the model individually over each task and average the losses." + "To train with Task batches containing differing numbers of targets, " + "run the model individually over each task and average the losses." ) # Raise error if target sets are different types (gridded/non-gridded) across tasks @@ -321,19 +369,6 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task: # For each task, store list of tuples of (x_c, y_c) (one tuple per context set) contexts = [] for i, task in enumerate(tasks): - if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]: - raise ValueError( - "Cannot concatenate tasks that have had NaNs masked. " - "Masking will be applied automatically after concatenation." - ) - # Ensure converted to tensors with batch dims - if "batch_dim" not in task["ops"]: - task = task.add_batch_dim() - if "float32" not in task["ops"]: - task = task.cast_to_float32() - - tasks[i] = task - contexts_i = list(zip(task["X_c"], task["Y_c"])) contexts.append(contexts_i) diff --git a/deepsensor/model/convnp.py b/deepsensor/model/convnp.py index b76d9b89..0af02481 100644 --- a/deepsensor/model/convnp.py +++ b/deepsensor/model/convnp.py @@ -238,6 +238,8 @@ def load(self, model_ID: str): def modify_task(cls, task): """Cast numpy arrays to TensorFlow or PyTorch tensors, add batch dim, and mask NaNs""" + if "target_nans_removed" not in task["ops"]: + task = task.remove_nans_from_task_Y_t_if_present() if "batch_dim" not in task["ops"]: task = task.add_batch_dim() if "float32" not in task["ops"]: @@ -397,15 +399,6 @@ def loss_fn(self, task: Task, fix_noise=None, num_lv_samples=8, normalise=False) ------- """ - # Remove NaNs from the target data if present - task, nans_present = remove_nans_from_task_Y_t_if_present(task) - # if nans_present: - # TODO raise error like: - # warnings.warn( - # "NaNs present in the target data. These will be removed before evaluating the loss.", - # - # ) - task = ConvNP.modify_task(task) context_data, xt, yt, model_kwargs = convert_task_to_nps_args(task) @@ -522,37 +515,3 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task: FutureWarning, ) return deepsensor.data.task.concat_tasks(tasks, multiple) - - -def remove_nans_from_task_Y_t_if_present(task): - """If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"])""" - # First check whether there are any NaNs that we need to remove - nans_present = False - for Y_t in task["Y_t"]: - if B.any(B.isnan(Y_t)): - nans_present = True - - Y_t_nans_list = [] - if nans_present: - for i, (X, Y) in enumerate(zip(task["X_t"], task["Y_t"])): - Y = flatten_Y(Y) - Y_t_nans = B.any(B.isnan(Y), axis=0) # shape (n_targets,) - Y_t_nans_list.append(Y_t_nans) - - if not nans_present: - return task, False - - # NaNs present in task - make deep copy and remove NaNs - task = copy.deepcopy(task) - for i, (X, Y, Y_t_nans) in enumerate(zip(task["X_t"], task["Y_t"], Y_t_nans_list)): - if B.any(Y_t_nans): - if isinstance(X, tuple): - # Gridded data - X = flatten_X(X) - Y = flatten_Y(Y) - task["X_t"][i] = X[:, ~Y_t_nans] - task["Y_t"][i] = Y[:, ~Y_t_nans] - if "Y_t_aux" in task.keys(): - task["Y_t_aux"] = task["Y_t_aux"][:, ~Y_t_nans] - - return task, True