From 22aa8d27c0c89d8be19c292182db4c0b10fd6df9 Mon Sep 17 00:00:00 2001 From: Kalle Westerling Date: Fri, 13 Oct 2023 17:49:47 +0100 Subject: [PATCH] Work commences! --- deepsensor/active_learning/acquisition_fns.py | 388 ++++++------ deepsensor/active_learning/algorithms.py | 165 +++--- deepsensor/data/loader.py | 559 ++++++++++-------- deepsensor/data/processor.py | 326 +++++----- 4 files changed, 726 insertions(+), 712 deletions(-) diff --git a/deepsensor/active_learning/acquisition_fns.py b/deepsensor/active_learning/acquisition_fns.py index c379a2b0..72f1e7dd 100644 --- a/deepsensor/active_learning/acquisition_fns.py +++ b/deepsensor/active_learning/acquisition_fns.py @@ -15,13 +15,12 @@ class AcquisitionFunction: def __init__(self, model: ProbabilisticModel): """ - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... - context_set_idx : int - Index of context set to add new observations to when computing the - acquisition function. + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] + context_set_idx (int): + Index of context set to add new observations to when computing + the acquisition function. """ self.model = model self.min_or_max = -1 @@ -30,21 +29,18 @@ def __call__(self, task: Task) -> np.ndarray: """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - :class:`numpy:numpy.ndarray` - Acquisition function value/s. Shape (). + Returns: + :class:`numpy:numpy.ndarray`: + Acquisition function value/s. Shape (). - Raises - ------ - NotImplementedError - Because this is an abstract method, it must be implemented by the - subclass. + Raises: + NotImplementedError: + Because this is an abstract method, it must be implemented by + the subclass. """ raise NotImplementedError @@ -66,23 +62,20 @@ def __call__(self, task: Task, X_s: np.ndarray) -> np.ndarray: """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. - X_s : :class:`numpy:numpy.ndarray` - Search points. Shape (2, N_search). + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + X_s (:class:`numpy:numpy.ndarray`): + Search points. Shape (2, N_search). - Returns - ------- - :class:`numpy:numpy.ndarray` - Should return acquisition function value/s. Shape (N_search,). + Returns: + :class:`numpy:numpy.ndarray`: + Should return acquisition function value/s. Shape (N_search,). - Raises - ------ - NotImplementedError - Because this is an abstract method, it must be implemented by the - subclass. + Raises: + NotImplementedError: + Because this is an abstract method, it must be implemented by + the subclass. """ raise NotImplementedError @@ -94,10 +87,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -106,17 +98,15 @@ def __call__(self, task: Task, target_set_idx: int = 0): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - target_set_idx : int, optional - ..., by default 0 + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + target_set_idx (int, optional): + [Description of the target_set_idx parameter.], by default 0 - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return np.mean(self.model.stddev(task)[target_set_idx]) @@ -128,10 +118,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -140,17 +129,15 @@ def __call__(self, task: Task, target_set_idx: int = 0): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - target_set_idx : int, optional - ..., by default 0 + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + target_set_idx (int, optional): + [Description of the target_set_idx parameter.], default is 0 - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return np.mean(self.model.variance(task)[target_set_idx]) @@ -162,10 +149,9 @@ def __init__(self, *args, p: int = 1, **kwargs): """ ... - Parameters - ---------- - p : int, optional - ..., by default 1 + Args: + p (int, optional): + [Description of the parameter p.], default is 1 """ super().__init__(*args, **kwargs) self.p = p @@ -175,17 +161,15 @@ def __call__(self, task: Task, target_set_idx: int = 0): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - target_set_idx : int, optional - ..., by default 0 + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + target_set_idx (int, optional): + [Description of the target_set_idx parameter.], defaults to 0 - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return np.linalg.norm( self.model.stddev(task)[target_set_idx].ravel(), ord=self.p @@ -199,10 +183,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -211,15 +194,13 @@ def __call__(self, task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ marginal_entropy = self.model.mean_marginal_entropy(task) return marginal_entropy @@ -232,10 +213,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -244,15 +224,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return self.model.joint_entropy(task) @@ -264,10 +242,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -276,15 +253,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ pred = self.model.mean(task) true = task["Y_t"] @@ -298,10 +273,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -310,15 +284,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ pred = self.model.mean(task) true = task["Y_t"] @@ -332,10 +304,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -344,19 +315,19 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ pred = self.model.mean(task) true = task["Y_t"] - return -np.mean(norm.logpdf(true, loc=pred, scale=self.model.stddev(task))) + return -np.mean( + norm.logpdf(true, loc=pred, scale=self.model.stddev(task)) + ) class OracleJointNLL(AcquisitionFunctionOracle): @@ -366,10 +337,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -378,15 +348,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return -self.model.logpdf(task) @@ -398,10 +366,9 @@ def __init__(self, seed: int = 42): """ ... - Parameters - ---------- - seed : int, optional - Random seed, by default 42. + Args: + seed (int, optional): + Random seed, defaults to 42. """ self.rng = np.random.default_rng(seed) self.min_or_max = "max" @@ -410,17 +377,15 @@ def __call__(self, task: Task, X_s: np.ndarray): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - X_s : :class:`numpy:numpy.ndarray` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + X_s (:class:`numpy:numpy.ndarray`): + [Description of the X_s parameter.] - Returns - ------- - float - A random acquisition function value. + Returns: + float: + A random acquisition function value. """ return self.rng.random(X_s.shape[1]) @@ -432,10 +397,9 @@ def __init__(self, context_set_idx: int = 0): """ ... - Parameters - ---------- - context_set_idx : int, optional - ..., by default 0 + Args: + context_set_idx (int, optional): + [Description of the context_set_idx parameter.], defaults to 0 """ self.context_set_idx = context_set_idx self.min_or_max = "max" @@ -444,17 +408,15 @@ def __call__(self, task: Task, X_s: np.ndarray): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - X_s : :class:`numpy:numpy.ndarray` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + X_s (:class:`numpy:numpy.ndarray`): + [Description of the X_s parameter.] - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ X_c = task["X_c"][self.context_set_idx] @@ -482,10 +444,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "max" @@ -494,19 +455,17 @@ def __call__(self, task: Task, X_s: np.ndarray, target_set_idx: int = 0): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - X_s : :class:`numpy:numpy.ndarray` - ... - target_set_idx : int, optional - ..., by default 0 - - Returns - ------- - ... - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + X_s (:class:`numpy:numpy.ndarray`): + [Description of the X_s parameter.] + target_set_idx (int, optional): + [Description of the target_set_idx parameter.], defaults to 0 + + Returns: + [Type of the return value]: + [Description of the return value.] """ # Set the target points to the search points task = copy.deepcopy(task) @@ -527,13 +486,12 @@ class ExpectedImprovement(AcquisitionFunctionParallel): def __init__(self, model: ProbabilisticModel, context_set_idx: int = 0): """ - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... - context_set_idx : int - Index of context set to add new observations to when computing the - acquisition function. + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] + context_set_idx (int): + Index of context set to add new observations to when computing the + acquisition function. """ super().__init__(model) self.context_set_idx = context_set_idx @@ -543,19 +501,17 @@ def __call__( self, task: Task, X_s: np.ndarray, target_set_idx: int = 0 ) -> np.ndarray: """ - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. - X_s : :class:`numpy:numpy.ndarray` - Search points. Shape (2, N_search). - target_set_idx : int - Index of target set to compute acquisition function for. + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + X_s (:class:`numpy:numpy.ndarray`): + Search points. Shape (2, N_search). + target_set_idx (int): + Index of target set to compute acquisition function for. - Returns - ------- - :class:`numpy:numpy.ndarray` - Acquisition function value/s. Shape (N_search,). + Returns: + :class:`numpy:numpy.ndarray`: + Acquisition function value/s. Shape (N_search,). """ # Set the target points to the search points task = copy.deepcopy(task) @@ -577,6 +533,8 @@ def __call__( # Compute the expected improvement Z = (mean - best_target_value) / stddev - ei = stddev * (mean - best_target_value) * norm.cdf(Z) + stddev * norm.pdf(Z) + ei = stddev * (mean - best_target_value) * norm.cdf( + Z + ) + stddev * norm.pdf(Z) return ei diff --git a/deepsensor/active_learning/algorithms.py b/deepsensor/active_learning/algorithms.py index 51f474d3..18ec3534 100644 --- a/deepsensor/active_learning/algorithms.py +++ b/deepsensor/active_learning/algorithms.py @@ -33,8 +33,12 @@ class GreedyAlgorithm: def __init__( self, model: DeepSensorModel, - X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index], - X_t: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index], + X_s: Union[ + xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index + ], + X_t: Union[ + xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index + ], X_s_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, N_new_context: int = 1, @@ -53,47 +57,45 @@ def __init__( """ ... - Parameters - ---------- - model : :class:`~.model.model.DeepSensorModel` - Trained model to use for proposing new context points. - X_s : :class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` - Search coordinates. - X_t : :class:`xarray.Dataset` | :class:`xarray.DataArray` - Target coordinates. - X_s_mask : :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional - Mask for search coordinates. If provided, only points where mask - is True will be considered. Defaults to None. - X_t_mask : :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional - ..., by default None. - N_new_context : int, optional - ..., by default 1. - X_normalised : bool, optional - ..., by default False. - model_infill_method : str, optional - ..., by default "mean". - query_infill : :class:`xarray.DataArray`, optional - ..., by default None. - proposed_infill : :class:`xarray.DataArray`, optional - ..., by default None. - context_set_idx : int, optional - ..., by default 0. - target_set_idx : int, optional - ..., by default 0. - progress_bar : bool, optional - ..., by default False. - min_or_max : str, optional - ..., by default "min". - task_loader : :class:`~.data.loader.TaskLoader`, optional - ..., by default None. - verbose : bool, optional - ..., by default False. - - Raises - ------ - ValueError - If the ``model`` passed does not inherit from - :class:`~.model.model.DeepSensorModel`. + Args: + model (:class:`~.model.model.DeepSensorModel`): + Trained model to use for proposing new context points. + X_s (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index`): + Search coordinates. + X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray`): + Target coordinates. + X_s_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional): + Mask for search coordinates. If provided, only points where mask + is True will be considered. Defaults to None. + X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional): + [Description of the X_t_mask parameter.], defaults to None. + N_new_context (int, optional): + [Description of the N_new_context parameter.], defaults to 1. + X_normalised (bool, optional): + [Description of the X_normalised parameter.], defaults to False. + model_infill_method (str, optional): + [Description of the model_infill_method parameter.], defaults to "mean". + query_infill (:class:`xarray.DataArray`, optional): + [Description of the query_infill parameter.], defaults to None. + proposed_infill (:class:`xarray.DataArray`, optional): + [Description of the proposed_infill parameter.], defaults to None. + context_set_idx (int, optional): + [Description of the context_set_idx parameter.], defaults to 0. + target_set_idx (int, optional): + [Description of the target_set_idx parameter.], defaults to 0. + progress_bar (bool, optional): + [Description of the progress_bar parameter.], defaults to False. + min_or_max (str, optional): + [Description of the min_or_max parameter.], defaults to "min". + task_loader (:class:`~.data.loader.TaskLoader`, optional): + [Description of the task_loader parameter.], defaults to None. + verbose (bool, optional): + [Description of the verbose parameter.], defaults to False. + + Raises: + ValueError: + If the ``model`` passed does not inherit from + :class:`~.model.model.DeepSensorModel`. """ if not isinstance(model, DeepSensorModel): raise ValueError( @@ -136,11 +138,15 @@ def __init__( self.X_t_mask = process_X_mask_for_X(self.X_t_mask, self.X_t) # Interpolate overridden infill datasets at search points if necessary - if query_infill is not None and not da1_da2_same_grid(query_infill, X_s): + if query_infill is not None and not da1_da2_same_grid( + query_infill, X_s + ): if verbose: print("query_infill not on search grid, interpolating.") query_infill = interp_da1_to_da2(query_infill, self.X_s) - if proposed_infill is not None and not da1_da2_same_grid(proposed_infill, X_s): + if proposed_infill is not None and not da1_da2_same_grid( + proposed_infill, X_s + ): if verbose: print("proposed_infill not on search grid, interpolating.") proposed_infill = interp_da1_to_da2(proposed_infill, self.X_s) @@ -153,7 +159,9 @@ def __init__( self.X_t_arr = xarray_to_coord_array_normalised(X_t) if self.X_t_mask is not None: # Remove points that lie outside the mask - self.X_t_arr = mask_coord_array_normalised(self.X_t_arr, self.X_t_mask) + self.X_t_arr = mask_coord_array_normalised( + self.X_t_arr, self.X_t_mask + ) elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)): # Targets off-grid self.X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T @@ -200,7 +208,9 @@ def _get_times_from_tasks(self): def _model_infill_at_search_points( self, - X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index], + X_s: Union[ + xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index + ], ): """ Computes and sets the model infill y-values over whole search grid @@ -345,7 +355,9 @@ def _search(self, acquisition_fn: AcquisitionFunction): and self.task_loader.aux_at_contexts ): # Add auxiliary variable sampled at context set as a new context variable - X_c = task_with_new["X_c"][self.task_loader.aux_at_contexts[0]] + X_c = task_with_new["X_c"][ + self.task_loader.aux_at_contexts[0] + ] Y_c_aux = self.task_loader.sample_offgrid_aux( X_c, self.task_loader.aux_at_contexts[1] ) @@ -426,32 +438,25 @@ def __call__( """ Iteratively... docstring TODO - Returns a tensor of proposed new sensor locations (in greedy - iteration/priority order) and their corresponding list of indexes in - the search space. - - Parameters - ---------- - acquisition_fn: :class:`~.active_learning.acquisition_fns.AcquisitionFunction` - ... - tasks: List[:class:`~.data.task.Task`] | :class:`~.data.task.Task` - ... - - Returns - ------- - X_new_df, acquisition_fn_ds: Tuple[:class:`pandas.DataFrame`, :class:`xarray.Dataset`] - ... - - Raises - ------ - ValueError - If ``acquisition_fn`` is an - :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle` - and ``task_loader`` is None. - ValueError - If ``min_or_max`` is not ``"min"`` or ``"max"``. - ValueError - If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None. + Args: + acquisition_fn (:class:`~.active_learning.acquisition_fns.AcquisitionFunction`): + [Description of the acquisition_fn parameter.] + tasks (List[:class:`~.data.task.Task`] | :class:`~.data.task.Task`): + [Description of the tasks parameter.] + + Returns: + Tuple[:class:`pandas.DataFrame`, :class:`xarray.Dataset`]: + X_new_df, acquisition_fn_ds - [Description of the return values.] + + Raises: + ValueError: + If ``acquisition_fn`` is an + :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle` + and ``task_loader`` is None. + ValueError: + If ``min_or_max`` is not ``"min"`` or ``"max"``. + ValueError: + If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None. """ if ( isinstance(acquisition_fn, AcquisitionFunctionOracle) @@ -465,7 +470,8 @@ def __call__( self.min_or_max = acquisition_fn.min_or_max if self.min_or_max not in ["min", "max"]: raise ValueError( - f"min_or_max must be either 'min' or 'max', got " f"{self.min_or_max}." + f"min_or_max must be either 'min' or 'max', got " + f"{self.min_or_max}." ) if diff and isinstance(acquisition_fn, AcquisitionFunctionParallel): @@ -496,7 +502,10 @@ def __call__( "Model expects Y_t_aux data but a TaskLoader isn't " "provided to GreedyAlgorithm." ) - if self.task_loader is not None and self.task_loader.aux_at_target_dims > 0: + if ( + self.task_loader is not None + and self.task_loader.aux_at_target_dims > 0 + ): tasks[i]["Y_t_aux"] = self.task_loader.sample_offgrid_aux( self.X_t_arr, self.task_loader.aux_at_targets ) @@ -532,7 +541,9 @@ def __call__( if self.model_infill_method == "sample": total_iterations *= self.n_samples - with tqdm(total=total_iterations, disable=not self.progress_bar) as self.pbar: + with tqdm( + total=total_iterations, disable=not self.progress_bar + ) as self.pbar: for iteration in range(self.N_new_context): self.iteration = iteration x_new = self._single_greedy_iteration(acquisition_fn) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 08c66456..1bb0003a 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -33,7 +33,9 @@ def __init__( str, List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]], ] = None, - aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None, + aux_at_contexts: Optional[ + Tuple[int, Union[xr.DataArray, xr.Dataset]] + ] = None, aux_at_targets: Optional[ Union[ xr.DataArray, @@ -56,56 +58,55 @@ def __init__( - Either all data is passed as paths, or all data is passed as loaded data (else ValueError) - If all data passed as paths, the TaskLoader can be saved with the `save` method (using config) - Parameters - ---------- - task_loader_ID : ... - If loading a TaskLoader from a config file, this is the folder the - TaskLoader was saved in (using `.save`). If this argument is passed, all other - arguments are ignored. - context : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`] - Context data. Can be a single :class:`xarray.DataArray`, - :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a - list/tuple of these. - target : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`] - Target data. Can be a single :class:`xarray.DataArray`, - :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a - list/tuple of these. - aux_at_contexts : Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional - Auxiliary data at context locations. Tuple of two elements, where - the first element is the index of the context set for which the - auxiliary data will be sampled at, and the second element is the - auxiliary data, which can be a single :class:`xarray.DataArray` or - :class:`xarray.Dataset`. Default: None. - aux_at_targets : :class:`xarray.DataArray` | :class:`xarray.Dataset`, optional - Auxiliary data at target locations. Can be a single - :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default: - None. - links : Tuple[int, int] | List[Tuple[int, int]], optional - Specifies links between context and target data. Each link is a - tuple of two integers, where the first integer is the index of the - context data and the second integer is the index of the target - data. Can be a single tuple in the case of a single link. If None, - no links are specified. Default: None. - context_delta_t : int | List[int], optional - Time difference between context data and t=0 (task init time). Can - be a single int (same for all context data) or a list/tuple of - ints. Default is 0. - target_delta_t : int | List[int], optional - Time difference between target data and t=0 (task init time). Can - be a single int (same for all target data) or a list/tuple of ints. - Default is 0. - time_freq : str, optional - Time frequency of the data. Default: ``'D'`` (daily). - xarray_interp_method : str, optional - Interpolation method to use when interpolating - :class:`xarray.DataArray`. Default is ``'linear'``. - discrete_xarray_sampling : bool, optional - When randomly sampling xarray variables, whether to sample at - discrete points defined at grid cell centres, or at continuous - points within the grid. Default is ``False``. - dtype : object, optional - Data type of the data. Used to cast the data to the specified - dtype. Default: ``'float32'``. + Args: + task_loader_ID: + If loading a TaskLoader from a config file, this is the folder the + TaskLoader was saved in (using `.save`). If this argument is passed, all other + arguments are ignored. + context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]): + Context data. Can be a single :class:`xarray.DataArray`, + :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a + list/tuple of these. + target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]): + Target data. Can be a single :class:`xarray.DataArray`, + :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a + list/tuple of these. + aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional): + Auxiliary data at context locations. Tuple of two elements, where + the first element is the index of the context set for which the + auxiliary data will be sampled at, and the second element is the + auxiliary data, which can be a single :class:`xarray.DataArray` or + :class:`xarray.Dataset`. Default: None. + aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional): + Auxiliary data at target locations. Can be a single + :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default: + None. + links (Tuple[int, int] | List[Tuple[int, int]], optional): + Specifies links between context and target data. Each link is a + tuple of two integers, where the first integer is the index of the + context data and the second integer is the index of the target + data. Can be a single tuple in the case of a single link. If None, + no links are specified. Default: None. + context_delta_t (int | List[int], optional): + Time difference between context data and t=0 (task init time). Can + be a single int (same for all context data) or a list/tuple of + ints. Default is 0. + target_delta_t (int | List[int], optional): + Time difference between target data and t=0 (task init time). Can + be a single int (same for all target data) or a list/tuple of ints. + Default is 0. + time_freq (str, optional): + Time frequency of the data. Default: ``'D'`` (daily). + xarray_interp_method (str, optional): + Interpolation method to use when interpolating + :class:`xarray.DataArray`. Default is ``'linear'``. + discrete_xarray_sampling (bool, optional): + When randomly sampling xarray variables, whether to sample at + discrete points defined at grid cell centres, or at continuous + points within the grid. Default is ``False``. + dtype (object, optional): + Data type of the data. Used to cast the data to the specified + dtype. Default: ``'float32'``. """ if task_loader_ID is not None: self.task_loader_ID = task_loader_ID @@ -125,7 +126,9 @@ def __init__( self.target_delta_t = self.config["target_delta_t"] self.time_freq = self.config["time_freq"] self.xarray_interp_method = self.config["xarray_interp_method"] - self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"] + self.discrete_xarray_sampling = self.config[ + "discrete_xarray_sampling" + ] self.dtype = self.config["dtype"] else: self.context = context @@ -278,7 +281,9 @@ def _load_pandas_or_xarray(path): def _load_data(data): if isinstance(data, (tuple, list)): - data = tuple([_load_pandas_or_xarray(data_i) for data_i in data]) + data = tuple( + [_load_pandas_or_xarray(data_i) for data_i in data] + ) else: data = _load_pandas_or_xarray(data) return data @@ -340,7 +345,9 @@ def cast_to_dtype(var): # Note: Numeric pandas indexes are always cast to float64, so we have to cast # x1/x2 coord dtypes during task sampling else: - raise ValueError(f"Unknown type {type(var)} for context set {var}") + raise ValueError( + f"Unknown type {type(var)} for context set {var}" + ) return var if var is None: @@ -354,11 +361,13 @@ def cast_to_dtype(var): def load_dask(self) -> None: """ - Load any ``dask`` data into memory. + Load any `dask` data into memory. - Returns - ------- - None. + This function triggers the computation and loading of any data that + is represented as dask arrays or datasets into memory. + + Returns: + None """ def load(datasets): @@ -381,17 +390,14 @@ def count_context_and_target_data_dims(self): """ Count the number of data dimensions in the context and target data. - Returns - ------- - context_dims : tuple. Tuple of data dimensions in the context data. - target_dims : tuple. Tuple of data dimensions in the target data. + Returns: + tuple: context_dims, Tuple of data dimensions in the context data. + tuple: target_dims, Tuple of data dimensions in the target data. - Raises - ------ - ValueError - If the context/target data is not a tuple/list of - :class:`xarray.DataArray`, :class:`xarray.Dataset` or - :class:`pandas.DataFrame`. + Raises: + ValueError: If the context/target data is not a tuple/list of + :class:`xarray.DataArray`, :class:`xarray.Dataset` or + :class:`pandas.DataFrame`. """ def count_data_dims_of_tuple_of_sets(datasets): @@ -406,22 +412,28 @@ def count_data_dims_of_tuple_of_sets(datasets): elif isinstance(var, xr.DataArray): dim = 1 # Single data variable elif isinstance(var, pd.DataFrame): - dim = len(var.columns) # Assumes all columns are data variables + dim = len( + var.columns + ) # Assumes all columns are data variables elif isinstance(var, pd.Series): dim = 1 # Single data variable else: - raise ValueError(f"Unknown type {type(var)} for context set {var}") + raise ValueError( + f"Unknown type {type(var)} for context set {var}" + ) dims.append(dim) return dims context_dims = count_data_dims_of_tuple_of_sets(self.context) target_dims = count_data_dims_of_tuple_of_sets(self.target) if self.aux_at_contexts is not None: - context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts) + context_dims += count_data_dims_of_tuple_of_sets( + self.aux_at_contexts + ) if self.aux_at_targets is not None: - aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[ - 0 - ] + aux_at_target_dims = count_data_dims_of_tuple_of_sets( + self.aux_at_targets + )[0] else: aux_at_target_dims = 0 @@ -431,17 +443,14 @@ def infer_context_and_target_var_IDs(self): """ Infer the variable IDs of the context and target data. - Returns - ------- - context_var_IDs : tuple. Tuple of variable IDs in the context data. - target_var_IDs : tuple. Tuple of variable IDs in the target data. + Returns: + tuple: context_var_IDs, Tuple of variable IDs in the context data. + tuple: target_var_IDs, Tuple of variable IDs in the target data. - Raises - ------ - ValueError - If the context/target data is not a tuple/list of - :class:`xarray.DataArray`, :class:`xarray.Dataset` or - :class:`pandas.DataFrame`. + Raises: + ValueError: If the context/target data is not a tuple/list of + :class:`xarray.DataArray`, :class:`xarray.Dataset` or + :class:`pandas.DataFrame`. """ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): @@ -455,13 +464,17 @@ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): if isinstance(var, xr.DataArray): var_ID = (var.name,) # Single data variable elif isinstance(var, xr.Dataset): - var_ID = tuple(var.data_vars.keys()) # Multiple data variables + var_ID = tuple( + var.data_vars.keys() + ) # Multiple data variables elif isinstance(var, pd.DataFrame): var_ID = tuple(var.columns) elif isinstance(var, pd.Series): var_ID = (var.name,) else: - raise ValueError(f"Unknown type {type(var)} for context set {var}") + raise ValueError( + f"Unknown type {type(var)} for context set {var}" + ) if delta_ts is not None: # Add delta_t to the variable ID @@ -485,15 +498,17 @@ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): ) if self.aux_at_contexts is not None: - context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts) + context_var_IDs += infer_var_IDs_of_tuple_of_sets( + self.aux_at_contexts + ) context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets( self.aux_at_contexts, [0] ) if self.aux_at_targets is not None: - aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[ - 0 - ] + aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets( + self.aux_at_targets + )[0] else: aux_at_target_var_IDs = None @@ -505,7 +520,9 @@ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): aux_at_target_var_IDs, ) - def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]): + def _check_links( + self, links: Union[Tuple[int, int], List[Tuple[int, int]]] + ): """ Check that the context-target links are valid. @@ -534,7 +551,9 @@ def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]): assert isinstance( links, list ), f"Links must be a list of length-2 tuples, but got {type(links)}" - assert len(links) > 0, "If links is not None, it must be a non-empty list" + assert ( + len(links) > 0 + ), "If links is not None, it must be a non-empty list" assert all( isinstance(link, tuple) for link in links ), f"Links must be a list of tuples, but got {[type(link) for link in links]}" @@ -592,28 +611,24 @@ def sample_da( """ Sample a DataArray according to a given strategy. - Parameters - ---------- - da : :class:`xarray.DataArray` | :class:`xarray.Dataset` - DataArray to sample, assumed to be sliced for the task already. - sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` - Sampling strategy, either "all" or an integer for random grid cell - sampling. - seed : int, optional - Seed for random sampling. Default: None. - - Returns - ------- - Data : Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Tuple of sampled target data and sampled context data. - - Raises - ------ - InvalidSamplingStrategyError - If the sampling strategy is not valid. - InvalidSamplingStrategyError - If a numpy coordinate array is passed to sample an xarray object, - but the coordinates are out of bounds. + Args: + da (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + DataArray to sample, assumed to be sliced for the task already. + sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`): + Sampling strategy, either "all" or an integer for random grid + cell sampling. + seed (int, optional): + Seed for random sampling. Default is None. + + Returns: + Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]: + Tuple of sampled target data and sampled context data. + + Raises: + InvalidSamplingStrategyError: + If the sampling strategy is not valid or if a numpy coordinate + array is passed to sample an xarray object, but the coordinates + are out of bounds. """ da = da.load() # Converts dask -> numpy if not already loaded if isinstance(da, xr.Dataset): @@ -639,9 +654,15 @@ def sample_da( dim = 1 # Single data variable Y_c = np.zeros((dim, 0), dtype=self.dtype) return X_c, Y_c - x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N) - x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N) - Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest") + x1 = rng.uniform( + da.coords["x1"].min(), da.coords["x1"].max(), N + ) + x2 = rng.uniform( + da.coords["x2"].min(), da.coords["x2"].max(), N + ) + Y_c = da.sel( + x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest" + ) Y_c = np.array(Y_c, dtype=self.dtype) X_c = np.array([x1, x2], dtype=self.dtype) if Y_c.ndim == 1: @@ -692,29 +713,25 @@ def sample_df( """ Sample a DataArray according to a given strategy. - Parameters - ---------- - df : :class:`pandas.DataFrame` | :class:`pandas.Series` - DataArray to sample, assumed to be time-sliced for the task - already. - sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` - Sampling strategy, either "all" or an integer for random grid cell - sampling. - seed : int, optional - Seed for random sampling. Default: None. - - Returns - ------- - Data : Tuple[X_c, Y_c] - Tuple of sampled target data and sampled context data. - - Raises - ------ - InvalidSamplingStrategyError - If the sampling strategy is not valid. - InvalidSamplingStrategyError - If a numpy coordinate array is passed to sample a pandas object, - but the DataFrame does not contain all the requested samples. + Args: + df (:class:`pandas.DataFrame` | :class:`pandas.Series`): + DataArray to sample, assumed to be time-sliced for the task + already. + sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`): + Sampling strategy, either "all" or an integer for random grid + cell sampling. + seed (int, optional): + Seed for random sampling. Default is None. + + Returns: + Tuple[X_c, Y_c]: + Tuple of sampled target data and sampled context data. + + Raises: + InvalidSamplingStrategyError: + If the sampling strategy is not valid or if a numpy coordinate + array is passed to sample a pandas object, but the DataFrame + does not contain all the requested samples. """ df = df.dropna(how="any") # If any obs are NaN, drop them @@ -725,9 +742,16 @@ def sample_df( N = sampling_strat rng = np.random.default_rng(seed) idx = rng.choice(df.index, N) - X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype) + X_c = ( + df.loc[idx] + .reset_index()[["x1", "x2"]] + .values.T.astype(self.dtype) + ) Y_c = df.loc[idx].values.T - elif isinstance(sampling_strat, str) and sampling_strat in ["all", "split"]: + elif isinstance(sampling_strat, str) and sampling_strat in [ + "all", + "split", + ]: # NOTE if "split", we assume that the context-target split has already been applied to the df # in an earlier scope with access to both the context and target data. This is maybe risky! X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype) @@ -762,18 +786,20 @@ def sample_offgrid_aux( """ Sample auxiliary data at off-grid locations. - Parameters - ---------- - X_t : :class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Off-grid locations at which to sample the auxiliary data. Can be a - tuple of two numpy arrays, or a single numpy array. - offgrid_aux : :class:`xarray.DataArray` | :class:`xarray.Dataset` - Auxiliary data at off-grid locations. + Args: + X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]): + Off-grid locations at which to sample the auxiliary data. Can + be a tuple of two numpy arrays, or a single numpy array. + offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + Auxiliary data at off-grid locations. - Returns - ------- - :class:`numpy:numpy.ndarray` - ... + Returns: + :class:`numpy:numpy.ndarray`: + [Description of the returned numpy ndarray] + + Raises: + [ExceptionType]: + [Description of under what conditions this function raises an exception] """ if "time" in offgrid_aux.dims: raise ValueError( @@ -861,75 +887,65 @@ def task_generation( - int: Sample N observations uniformly at random. - float: Sample a fraction of observations uniformly at random. - :class:`numpy:numpy.ndarray`, shape (2, N): Sample N observations - at the given x1, x2 coordinates. Coords are assumed to be - unnormalised. - - Parameters - ---------- - date : :class:`pandas.Timestamp` - Date for which to generate the task. - context_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`] - Sampling strategy for the context data, either a list of sampling - strategies for each context set, or a single strategy applied to - all context sets. Default is ``"all"``. - target_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`] - Sampling strategy for the target data, either a list of sampling - strategies for each target set, or a single strategy applied to all - target sets. Default is ``"all"``. - split_frac : float - The fraction of observations to use for the context set with the - "split" sampling strategy for linked context and target set pairs. - The remaining observations are used for the target set. Default is - 0.5. - datewise_deterministic : bool - Whether random sampling is datewise_deterministic based on the - date. Default is ``False``. - seed_override : Optional[int] - Override the seed for random sampling. This can be used to use the - same random sampling at different ``date``. Default is None. - - Returns - ------- - task : :class:`~.data.task.Task` - Task object containing the context and target data. + at the given x1, x2 coordinates. Coords are assumed to be + unnormalized. + + Args: + date (:class:`pandas.Timestamp`): + Date for which to generate the task. + context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional): + Sampling strategy for the context data, either a list of + sampling strategies for each context set, or a single strategy + applied to all context sets. Default is ``"all"``. + target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional): + Sampling strategy for the target data, either a list of + sampling strategies for each target set, or a single strategy + applied to all target sets. Default is ``"all"``. + split_frac (float, optional): + The fraction of observations to use for the context set with + the "split" sampling strategy for linked context and target set + pairs. The remaining observations are used for the target set. + Default is 0.5. + datewise_deterministic (bool, optional): + Whether random sampling is datewise deterministic based on the + date. Default is ``False``. + seed_override (Optional[int], optional): + Override the seed for random sampling. This can be used to use + the same random sampling at different ``date``. Default is + None. + + Returns: + :class:`~.data.task.Task`: + Task object containing the context and target data. """ def check_sampling_strat(sampling_strat, set): """ - Check the sampling strategy - - Ensure ``sampling_strat`` is either a single strategy (broadcast to - all sets) or a list of length equal to the number of sets. Convert - to a tuple of length equal to the number of sets and return. - - Parameters - ---------- - sampling_strat : ... - Sampling strategy to check. - set : ... - Context or target set to check. - - Returns - ------- - sampling_strat : tuple - Tuple of sampling strategies, one for each set. - - Raises - ------ - InvalidSamplingStrategyError - If the sampling strategy is invalid. - InvalidSamplingStrategyError - If the length of the sampling strategy does not match the - number of sets. - InvalidSamplingStrategyError - If the sampling strategy is not a valid type. - InvalidSamplingStrategyError - If the sampling strategy is a float but not in [0, 1]. - InvalidSamplingStrategyError - If the sampling strategy is an int but not positive. - InvalidSamplingStrategyError - If the sampling strategy is a numpy array but not of shape - (2, N). + Check the sampling strategy. + + Ensure ``sampling_strat`` is either a single strategy (broadcast + to all sets) or a list of length equal to the number of sets. + Convert to a tuple of length equal to the number of sets and + return. + + Args: + sampling_strat: + Sampling strategy to check. + set: + Context or target set to check. + + Returns: + tuple: + Tuple of sampling strategies, one for each set. + + Raises: + InvalidSamplingStrategyError: + - If the sampling strategy is invalid. + - If the length of the sampling strategy does not match the number of sets. + - If the sampling strategy is not a valid type. + - If the sampling strategy is a float but not in [0, 1]. + - If the sampling strategy is an int but not positive. + - If the sampling strategy is a numpy array but not of shape (2, N). """ if not isinstance(sampling_strat, (list, tuple)): sampling_strat = tuple([sampling_strat] * len(set)) @@ -942,7 +958,9 @@ def check_sampling_strat(sampling_strat, set): ) for strat in sampling_strat: - if not isinstance(strat, (str, int, np.integer, float, np.ndarray)): + if not isinstance( + strat, (str, int, np.integer, float, np.ndarray) + ): raise InvalidSamplingStrategyError( f"Unknown sampling strategy {strat} of type {type(strat)}" ) @@ -976,31 +994,30 @@ def sample_variable(var, sampling_strat, seed): Sample a variable by a given sampling strategy to get input and output data. - Parameters - ---------- - var : ... - Variable to sample. - sampling_strat : ... - Sampling strategy to use. - seed : ... - Seed for random sampling. - - Returns - ------- - ... : Tuple[X, Y] - Tuple of input and output data. - - Raises - ------ - ValueError - If the variable is of an unknown type. + Args: + var: + Variable to sample. + sampling_strat: + Sampling strategy to use. + seed: + Seed for random sampling. + + Returns: + Tuple[X, Y]: + Tuple of input and output data. + + Raises: + ValueError: + If the variable is of an unknown type. """ if isinstance(var, (xr.Dataset, xr.DataArray)): X, Y = self.sample_da(var, sampling_strat, seed) elif isinstance(var, (pd.DataFrame, pd.Series)): X, Y = self.sample_df(var, sampling_strat, seed) else: - raise ValueError(f"Unknown type {type(var)} for context set " f"{var}") + raise ValueError( + f"Unknown type {type(var)} for context set " f"{var}" + ) return X, Y # Check that the sampling strategies are valid @@ -1008,7 +1025,9 @@ def sample_variable(var, sampling_strat, seed): target_sampling = check_sampling_strat(target_sampling, self.target) # Check `split_frac if split_frac < 0 or split_frac > 1: - raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}") + raise ValueError( + f"split_frac must be between 0 and 1, got {split_frac}" + ) if self.links is None: b1 = any( [ @@ -1092,8 +1111,12 @@ def sample_variable(var, sampling_strat, seed): # Perform the split sampling strategy for linked context and target sets at this point # while we have the full context and target data in scope - context_split_idxs = np.where(np.array(context_sampling) == "split")[0] - target_split_idxs = np.where(np.array(target_sampling) == "split")[0] + context_split_idxs = np.where( + np.array(context_sampling) == "split" + )[0] + target_split_idxs = np.where(np.array(target_sampling) == "split")[ + 0 + ] assert len(context_split_idxs) == len(target_split_idxs), ( f"Number of context sets with 'split' sampling strategy " f"({len(context_split_idxs)}) must match number of target sets " @@ -1146,8 +1169,12 @@ def sample_variable(var, sampling_strat, seed): # Perform the gapfill sampling strategy for linked context and target sets at this point # while we have the full context and target data in scope - context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0] - target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0] + context_gapfill_idxs = np.where( + np.array(context_sampling) == "gapfill" + )[0] + target_gapfill_idxs = np.where( + np.array(target_sampling) == "gapfill" + )[0] assert len(context_gapfill_idxs) == len(target_gapfill_idxs), ( f"Number of context sets with 'gapfill' sampling strategy " f"({len(context_gapfill_idxs)}) must match number of target sets " @@ -1177,9 +1204,13 @@ def sample_variable(var, sampling_strat, seed): # Keep trying until we get a target set with at least one target point keep_searching = True while keep_searching: - added_mask_date = rng.choice(self.context[context_idx].time) + added_mask_date = rng.choice( + self.context[context_idx].time + ) added_mask = ( - self.context[context_idx].sel(time=added_mask_date).isnull() + self.context[context_idx] + .sel(time=added_mask_date) + .isnull() ) curr_mask = context_var.isnull() @@ -1190,7 +1221,9 @@ def sample_variable(var, sampling_strat, seed): # when we could just slice the target values here target_mask = added_mask & ~curr_mask if isinstance(target_var, xr.Dataset): - keep_searching = np.all(target_mask.to_array().data == False) + keep_searching = np.all( + target_mask.to_array().data == False + ) else: keep_searching = np.all(target_mask.data == False) if keep_searching: @@ -1210,7 +1243,9 @@ def sample_variable(var, sampling_strat, seed): X_c, Y_c = sample_variable(var, sampling_strat, context_seed) task[f"X_c"].append(X_c) task[f"Y_c"].append(Y_c) - for j, (var, sampling_strat) in enumerate(zip(target_slices, target_sampling)): + for j, (var, sampling_strat) in enumerate( + zip(target_slices, target_sampling) + ): target_seed = seed + i + j if seed is not None else None X_t, Y_t = sample_variable(var, sampling_strat, target_seed) task[f"X_t"].append(X_t) @@ -1218,7 +1253,9 @@ def sample_variable(var, sampling_strat, seed): if self.aux_at_contexts is not None: # Add auxiliary variable sampled at context set as a new context variable - X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)] + X_c_offgrid = [ + X_c for X_c in task["X_c"] if not isinstance(X_c, tuple) + ] if len(X_c_offgrid) == 0: # No offgrid context sets X_c_offrid_all = np.empty((2, 0), dtype=self.dtype) @@ -1226,7 +1263,8 @@ def sample_variable(var, sampling_strat, seed): X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1) Y_c_aux = ( self.sample_offgrid_aux( - X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date) + X_c_offrid_all, + self.time_slice_variable(self.aux_at_contexts, date), ), ) task["X_c"].append(X_c_offrid_all) @@ -1248,20 +1286,21 @@ def sample_variable(var, sampling_strat, seed): def __call__(self, date, *args, **kwargs): """ - Generate a task for a given date (or a list of task for an iterable of dates). + Generate a task for a given date (or a list of task for an iterable of + dates). - Parameters - ---------- - date : ... - Date for which to generate the task. + Args: + date: + Date for which to generate the task. - Returns - ------- - task: Task | List[Task] - Task object or list of task objects for each date containing the - context and target data. + Returns: + Task | List[Task]: + Task object or list of task objects for each date containing + the context and target data. """ - if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + if isinstance( + date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex) + ): return [self.task_generation(d, *args, **kwargs) for d in date] else: return self.task_generation(date, *args, **kwargs) diff --git a/deepsensor/data/processor.py b/deepsensor/data/processor.py index a59aca83..6e08034a 100644 --- a/deepsensor/data/processor.py +++ b/deepsensor/data/processor.py @@ -32,25 +32,24 @@ def __init__( """ Initialise a DataProcessor object. - Parameters - ---------- - folder : str, optional - Folder to load normalisation params from. Defaults to None. - x1_name : str, optional - Name of first spatial coord (e.g. "lat"). Defaults to "x1". - x2_name : str, optional - Name of second spatial coord (e.g. "lon"). Defaults to "x2". - x1_map : tuple, optional - 2-tuple of raw x1 coords to linearly map to (0, 1), respectively. - Defaults to (0, 1) (i.e. no normalisation). - x2_map : tuple, optional - 2-tuple of raw x2 coords to linearly map to (0, 1), respectively. - Defaults to (0, 1) (i.e. no normalisation). - deepcopy : bool, optional - Whether to make a deepcopy of raw data to ensure it is not changed - by reference when normalising. Defaults to True. - verbose : bool, optional - Whether to print verbose output. Defaults to False. + Args: + folder (str, optional): + Folder to load normalisation params from. Defaults to None. + x1_name (str, optional): + Name of first spatial coord (e.g. "lat"). Defaults to "x1". + x2_name (str, optional): + Name of second spatial coord (e.g. "lon"). Defaults to "x2". + x1_map (tuple, optional): + 2-tuple of raw x1 coords to linearly map to (0, 1), + respectively. Defaults to (0, 1) (i.e. no normalisation). + x2_map (tuple, optional): + 2-tuple of raw x2 coords to linearly map to (0, 1), + respectively. Defaults to (0, 1) (i.e. no normalisation). + deepcopy (bool, optional): + Whether to make a deepcopy of raw data to ensure it is not + changed by reference when normalising. Defaults to True. + verbose (bool, optional): + Whether to print verbose output. Defaults to False. """ if folder is not None: fpath = os.path.join(folder, self.config_fname) @@ -77,13 +76,17 @@ def __init__( if (self.x1_none and not self.x2_none) or ( not self.x1_none and self.x2_none ): - raise ValueError("Must provide both x1_map and x2_map, or neither.") + raise ValueError( + "Must provide both x1_map and x2_map, or neither." + ) elif not self.x1_none and not self.x2_none: x1_map, x2_map = self._validate_coord_mappings(x1_map, x2_map) if "coords" not in self.config: # Add coordinate normalisation info to config - self.set_coord_params(time_name, x1_name, x1_map, x2_name, x2_map) + self.set_coord_params( + time_name, x1_name, x1_map, x2_name, x2_map + ) self.raw_spatial_coord_names = [ self.config["coords"][coord]["name"] for coord in ["x1", "x2"] @@ -127,7 +130,8 @@ def _validate_coord_mappings(self, x1_map, x2_map): def _validate_xr(self, data: Union[xr.DataArray, xr.Dataset]): def _validate_da(da: xr.DataArray): coord_names = [ - self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"] + self.config["coords"][coord]["name"] + for coord in ["time", "x1", "x2"] ] if coord_names[0] not in da.dims: # We don't have a time dimension. @@ -146,7 +150,8 @@ def _validate_da(da: xr.DataArray): def _validate_pandas(self, df: Union[pd.DataFrame, pd.Series]): coord_names = [ - self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"] + self.config["coords"][coord]["name"] + for coord in ["time", "x1", "x2"] ] if coord_names[0] not in df.index.names: @@ -172,15 +177,12 @@ def load_dask(cls, data: Union[xr.DataArray, xr.Dataset]): """ Load dask data into memory. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` - ... + Args: + data (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + Description of the parameter. - Returns - ------- - ... - ... + Returns: + [Type and description of the returned value(s) needed.] """ if isinstance(data, xr.DataArray): data.load() @@ -188,26 +190,26 @@ def load_dask(cls, data: Union[xr.DataArray, xr.Dataset]): data.load() return data - def set_coord_params(self, time_name, x1_name, x1_map, x2_name, x2_map) -> None: + def set_coord_params( + self, time_name, x1_name, x1_map, x2_name, x2_map + ) -> None: """ Set coordinate normalisation params. - Parameters - ---------- - time_name : ... - ... - x1_name : ... - ... - x1_map : ... - ... - x2_name : ... - ... - x2_map : ... - ... - - Returns - ------- - None. + Args: + time_name: + [Type] Description needed. + x1_name: + [Type] Description needed. + x1_map: + [Type] Description needed. + x2_name: + [Type] Description needed. + x2_map: + [Type] Description needed. + + Returns: + None. """ self.config["coords"] = {} self.config["coords"]["time"] = {"name": time_name} @@ -222,17 +224,15 @@ def check_params_computed(self, var_ID, method) -> bool: """ Check if normalisation params computed for a given variable. - Parameters - ---------- - var_ID : ... - ... - method : ... - ... + Args: + var_ID: + [Type] Description needed. + method: + [Type] Description needed. - Returns - ------- - bool - Whether normalisation params are computed for a given variable. + Returns: + bool: + Whether normalisation params are computed for a given variable. """ if ( var_ID in self.config @@ -252,23 +252,20 @@ def get_config(self, var_ID, data, method=None): Get pre-computed normalisation params or compute them for variable ``var_ID``. - .. note: + .. note:: + TODO do we need to pass var_ID? Can we just use the name of data? - TODO do we need to pass var_ID? Can we just use name of data? + Args: + var_ID: + [Type] Description needed. + data: + [Type] Description needed. + method (optional): + [Type] Description needed. Defaults to None. - Parameters - ---------- - var_ID : ... - ... - data : ... - ... - method : ..., optional - ..., by default None. - - Returns - ------- - ... - ... + Returns: + [Type]: + Description of the returned value(s) needed. """ if method not in self.valid_methods: raise ValueError( @@ -303,39 +300,39 @@ def map_coord_array(self, coord_array: np.ndarray, unnorm: bool = False): """ Normalise or unnormalise a coordinate array. - Parameters - ---------- - coord_array : :class:`numpy:numpy.ndarray` - Array of shape ``(2, N)`` containing coords. - unnorm : bool, optional - Whether to unnormalise. Defaults to ``False``. + Args: + coord_array (:class:`numpy:numpy.ndarray`): + Array of shape ``(2, N)`` containing coords. + unnorm (bool, optional): + Whether to unnormalise. Defaults to ``False``. - Returns - ------- - ... - ... + Returns: + [Type]: + Description of the returned value(s) needed. """ - x1, x2 = self.map_x1_and_x2(coord_array[0], coord_array[1], unnorm=unnorm) + x1, x2 = self.map_x1_and_x2( + coord_array[0], coord_array[1], unnorm=unnorm + ) new_coords = np.stack([x1, x2], axis=0) return new_coords - def map_x1_and_x2(self, x1: np.ndarray, x2: np.ndarray, unnorm: bool = False): + def map_x1_and_x2( + self, x1: np.ndarray, x2: np.ndarray, unnorm: bool = False + ): """ - Normalise or unnormalise spatial coords in a array. - - Parameters - ---------- - x1 : :class:`numpy:numpy.ndarray` - Array of shape ``(N_x1,)`` containing spatial coords of x1. - x2 : :class:`numpy:numpy.ndarray` - Array of shape ``(N_x2,)`` containing spatial coords of x2. - unnorm : bool, optional - Whether to unnormalise. Defaults to ``False``. - - Returns - ------- - Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Normalised or unnormalised spatial coords of x1 and x2. + Normalise or unnormalise spatial coords in an array. + + Args: + x1 (:class:`numpy:numpy.ndarray`): + Array of shape ``(N_x1,)`` containing spatial coords of x1. + x2 (:class:`numpy:numpy.ndarray`): + Array of shape ``(N_x2,)`` containing spatial coords of x2. + unnorm (bool, optional): + Whether to unnormalise. Defaults to ``False``. + + Returns: + Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]: + Normalised or unnormalised spatial coords of x1 and x2. """ x11, x12 = self.config["coords"]["x1"]["map"] x21, x22 = self.config["coords"]["x2"]["map"] @@ -357,17 +354,15 @@ def map_coords( """ Normalise spatial coords in a pandas or xarray object. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | :class:`pandas.Series` - ... - unnorm : bool, optional - ... + Args: + data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`): + [Description Needed] + unnorm (bool, optional): + [Description Needed]. Defaults to [Default Value]. - Returns - ------- - ... - ... + Returns: + [Type]: + [Description Needed] """ if isinstance(data, (pd.DataFrame, pd.Series)): # Reset index to get coords as columns @@ -438,7 +433,9 @@ def map_coords( # Rename all dimensions. rename = { - old: new for old, new in zip(old_coord_IDs, new_coord_IDs) if old != new + old: new + for old, new in zip(old_coord_IDs, new_coord_IDs) + if old != new } data = data.rename(rename) @@ -452,7 +449,9 @@ def map_coords( def map_array( self, - data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray], + data: Union[ + xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray + ], var_ID: str, method: Optional[str] = None, unnorm: bool = False, @@ -462,27 +461,29 @@ def map_array( Normalise or unnormalise the data values in an xarray, pandas, or numpy object. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`numpy:numpy.ndarray` - ... - var_ID : str - ... - method : str, optional - ..., by default None. - unnorm : bool, optional - ..., by default False. - add_offset : bool, optional - ..., by default True. - - Returns - ------- - ... - ... + Args: + data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, :class:`pandas.Series`, or :class:`numpy:numpy.ndarray`): + [Description Needed] + var_ID (str): + [Description Needed] + method (str, optional): + [Description Needed]. Defaults to None. + unnorm (bool, optional): + [Description Needed]. Defaults to False. + add_offset (bool, optional): + [Description Needed]. Defaults to True. + + Returns: + [Type]: + [Description Needed] """ if not unnorm and method is None: raise ValueError("Must provide `method` if normalising data.") - elif unnorm and method is not None and self.config[var_ID]["method"] != method: + elif ( + unnorm + and method is not None + and self.config[var_ID]["method"] != method + ): # User has provided a different method to the one used for normalising raise ValueError( f"Variable '{var_ID}' has been normalised with method '{self.config[var_ID]['method']}', " @@ -539,21 +540,19 @@ def map( Normalise or unnormalise the data values and coords in an xarray or pandas object. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | :class:`pandas.Series` - ... - method : str, optional - ..., by default ``None``. - add_offset : bool, optional - ..., by default ``True``. - unnorm : bool, optional - ..., by default ``False``. - - Returns - ------- - ... - ... + Args: + data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`): + [Description Needed] + method (str, optional): + [Description Needed]. Defaults to None. + add_offset (bool, optional): + [Description Needed]. Defaults to True. + unnorm (bool, optional): + [Description Needed]. Defaults to False. + + Returns: + [Type]: + [Description Needed] """ if self.deepcopy: data = deepcopy(data) @@ -605,27 +604,30 @@ def __call__( """ Normalise data. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`] - Data to normalise. - method : str, optional - Normalisation method. Defaults to "mean_std". Options: - - "mean_std": Normalise to mean=0 and std=1 + Args: + data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]): + Data to be normalised. Can be an xarray DataArray, xarray + Dataset, pandas DataFrame, or a list containing objects of + these types. + method (str, optional): Normalisation method. Options include: + - "mean_std": Normalise to mean=0 and std=1 (default) - "min_max": Normalise to min=-1 and max=1 - Returns - ------- - :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`] - Normalised data. + Returns: + :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]: + Normalised data. Type or structure depends on the input. """ if isinstance(data, list): return [ - self.map(d, method, unnorm=False, assert_computed=assert_computed) + self.map( + d, method, unnorm=False, assert_computed=assert_computed + ) for d in data ] else: - return self.map(data, method, unnorm=False, assert_computed=assert_computed) + return self.map( + data, method, unnorm=False, assert_computed=assert_computed + ) def unnormalise( self, @@ -660,12 +662,16 @@ def unnormalise( Unnormalised data. """ if isinstance(data, list): - return [self.map(d, add_offset=add_offset, unnorm=True) for d in data] + return [ + self.map(d, add_offset=add_offset, unnorm=True) for d in data + ] else: return self.map(data, add_offset=add_offset, unnorm=True) -def xarray_to_coord_array_normalised(da: Union[xr.Dataset, xr.DataArray]) -> np.ndarray: +def xarray_to_coord_array_normalised( + da: Union[xr.Dataset, xr.DataArray] +) -> np.ndarray: """ Convert xarray to normalised coordinate array.