diff --git a/deepsensor/active_learning/acquisition_fns.py b/deepsensor/active_learning/acquisition_fns.py index 9bd71319..b069aabb 100644 --- a/deepsensor/active_learning/acquisition_fns.py +++ b/deepsensor/active_learning/acquisition_fns.py @@ -20,15 +20,14 @@ def __init__( target_set_idx: int = 0, ): """ - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... - context_set_idx : int - Index of context set to add new observations to when computing the - acquisition function. - target_set_idx : int - Index of target set to compute acquisition function for. + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] + context_set_idx (int): + Index of context set to add new observations to when computing + the acquisition function. + target_set_idx (int): + Index of target set to compute acquisition function for. """ self.model = model self.context_set_idx = context_set_idx @@ -39,21 +38,18 @@ def __call__(self, task: Task) -> np.ndarray: """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - :class:`numpy:numpy.ndarray` - Acquisition function value/s. Shape (). + Returns: + :class:`numpy:numpy.ndarray`: + Acquisition function value/s. Shape (). - Raises - ------ - NotImplementedError - Because this is an abstract method, it must be implemented by the - subclass. + Raises: + NotImplementedError: + Because this is an abstract method, it must be implemented by + the subclass. """ raise NotImplementedError @@ -75,23 +71,20 @@ def __call__(self, task: Task, X_s: np.ndarray) -> np.ndarray: """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. - X_s : :class:`numpy:numpy.ndarray` - Search points. Shape (2, N_search). + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + X_s (:class:`numpy:numpy.ndarray`): + Search points. Shape (2, N_search). - Returns - ------- - :class:`numpy:numpy.ndarray` - Should return acquisition function value/s. Shape (N_search,). + Returns: + :class:`numpy:numpy.ndarray`: + Should return acquisition function value/s. Shape (N_search,). - Raises - ------ - NotImplementedError - Because this is an abstract method, it must be implemented by the - subclass. + Raises: + NotImplementedError: + Because this is an abstract method, it must be implemented by + the subclass. """ raise NotImplementedError @@ -103,10 +96,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -115,15 +107,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return np.mean(self.model.stddev(task)[self.target_set_idx]) @@ -135,10 +125,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -146,16 +135,14 @@ def __init__(self, model: ProbabilisticModel): def __call__(self, task: Task): """ ... + + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return np.mean(self.model.variance(task)[self.target_set_idx]) @@ -167,10 +154,9 @@ def __init__(self, *args, p: int = 1, **kwargs): """ ... - Parameters - ---------- - p : int, optional - ..., by default 1 + Args: + p (int, optional): + [Description of the parameter p.], default is 1 """ super().__init__(*args, **kwargs) self.p = p @@ -180,15 +166,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return np.linalg.norm( self.model.stddev(task)[self.target_set_idx].ravel(), ord=self.p @@ -202,10 +186,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -214,15 +197,13 @@ def __call__(self, task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ marginal_entropy = self.model.mean_marginal_entropy(task) return marginal_entropy @@ -235,10 +216,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -247,15 +227,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return self.model.joint_entropy(task) @@ -267,10 +245,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -279,15 +256,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ pred = self.model.mean(task) if isinstance(pred, list): @@ -303,10 +278,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -315,15 +289,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ pred = self.model.mean(task) if isinstance(pred, list): @@ -339,10 +311,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -351,15 +322,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ pred = self.model.mean(task) if isinstance(pred, list): @@ -375,10 +344,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "min" @@ -387,15 +355,13 @@ def __call__(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ return -self.model.logpdf(task) @@ -407,10 +373,9 @@ def __init__(self, seed: int = 42): """ ... - Parameters - ---------- - seed : int, optional - Random seed, by default 42. + Args: + seed (int, optional): + Random seed, defaults to 42. """ self.rng = np.random.default_rng(seed) self.min_or_max = "max" @@ -419,17 +384,15 @@ def __call__(self, task: Task, X_s: np.ndarray): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - X_s : :class:`numpy:numpy.ndarray` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + X_s (:class:`numpy:numpy.ndarray`): + [Description of the X_s parameter.] - Returns - ------- - float - A random acquisition function value. + Returns: + float: + A random acquisition function value. """ return self.rng.random(X_s.shape[1]) @@ -438,30 +401,21 @@ class ContextDist(AcquisitionFunctionParallel): """Distance to closest context point.""" def __init__(self): - """ - ... - - Parameters - ---------- - ... - """ self.min_or_max = "max" def __call__(self, task: Task, X_s: np.ndarray): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - X_s : :class:`numpy:numpy.ndarray` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + X_s (:class:`numpy:numpy.ndarray`): + [Description of the X_s parameter.] - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ X_c = task["X_c"][self.context_set_idx] @@ -489,10 +443,9 @@ def __init__(self, model: ProbabilisticModel): """ ... - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "max" @@ -501,17 +454,15 @@ def __call__(self, task: Task, X_s: np.ndarray): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - X_s : :class:`numpy:numpy.ndarray` - ... + Args: + task (:class:`~.data.task.Task`): + [Description of the task parameter.] + X_s (:class:`numpy:numpy.ndarray`): + [Description of the X_s parameter.] - Returns - ------- - ... - ... + Returns: + [Type of the return value]: + [Description of the return value.] """ # Set the target points to the search points task = copy.deepcopy(task) @@ -532,10 +483,9 @@ class ExpectedImprovement(AcquisitionFunctionParallel): def __init__(self, model: ProbabilisticModel): """ - Parameters - ---------- - model : :class:`~.model.model.ProbabilisticModel` - ... + Args: + model (:class:`~.model.model.ProbabilisticModel`): + [Description of the model parameter.] """ super().__init__(model) self.min_or_max = "max" @@ -546,17 +496,15 @@ def __call__( X_s: np.ndarray, ) -> np.ndarray: """ - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. - X_s : :class:`numpy:numpy.ndarray` - Search points. Shape (2, N_search). + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + X_s (:class:`numpy:numpy.ndarray`): + Search points. Shape (2, N_search). - Returns - ------- - :class:`numpy:numpy.ndarray` - Acquisition function value/s. Shape (N_search,). + Returns: + :class:`numpy:numpy.ndarray`: + Acquisition function value/s. Shape (N_search,). """ # Set the target points to the search points task = copy.deepcopy(task) @@ -578,6 +526,8 @@ def __call__( # Compute the expected improvement Z = (mean - best_target_value) / stddev - ei = stddev * (mean - best_target_value) * norm.cdf(Z) + stddev * norm.pdf(Z) + ei = stddev * (mean - best_target_value) * norm.cdf( + Z + ) + stddev * norm.pdf(Z) return ei diff --git a/deepsensor/active_learning/algorithms.py b/deepsensor/active_learning/algorithms.py index 51f474d3..18ec3534 100644 --- a/deepsensor/active_learning/algorithms.py +++ b/deepsensor/active_learning/algorithms.py @@ -33,8 +33,12 @@ class GreedyAlgorithm: def __init__( self, model: DeepSensorModel, - X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index], - X_t: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index], + X_s: Union[ + xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index + ], + X_t: Union[ + xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index + ], X_s_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, N_new_context: int = 1, @@ -53,47 +57,45 @@ def __init__( """ ... - Parameters - ---------- - model : :class:`~.model.model.DeepSensorModel` - Trained model to use for proposing new context points. - X_s : :class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` - Search coordinates. - X_t : :class:`xarray.Dataset` | :class:`xarray.DataArray` - Target coordinates. - X_s_mask : :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional - Mask for search coordinates. If provided, only points where mask - is True will be considered. Defaults to None. - X_t_mask : :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional - ..., by default None. - N_new_context : int, optional - ..., by default 1. - X_normalised : bool, optional - ..., by default False. - model_infill_method : str, optional - ..., by default "mean". - query_infill : :class:`xarray.DataArray`, optional - ..., by default None. - proposed_infill : :class:`xarray.DataArray`, optional - ..., by default None. - context_set_idx : int, optional - ..., by default 0. - target_set_idx : int, optional - ..., by default 0. - progress_bar : bool, optional - ..., by default False. - min_or_max : str, optional - ..., by default "min". - task_loader : :class:`~.data.loader.TaskLoader`, optional - ..., by default None. - verbose : bool, optional - ..., by default False. - - Raises - ------ - ValueError - If the ``model`` passed does not inherit from - :class:`~.model.model.DeepSensorModel`. + Args: + model (:class:`~.model.model.DeepSensorModel`): + Trained model to use for proposing new context points. + X_s (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index`): + Search coordinates. + X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray`): + Target coordinates. + X_s_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional): + Mask for search coordinates. If provided, only points where mask + is True will be considered. Defaults to None. + X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional): + [Description of the X_t_mask parameter.], defaults to None. + N_new_context (int, optional): + [Description of the N_new_context parameter.], defaults to 1. + X_normalised (bool, optional): + [Description of the X_normalised parameter.], defaults to False. + model_infill_method (str, optional): + [Description of the model_infill_method parameter.], defaults to "mean". + query_infill (:class:`xarray.DataArray`, optional): + [Description of the query_infill parameter.], defaults to None. + proposed_infill (:class:`xarray.DataArray`, optional): + [Description of the proposed_infill parameter.], defaults to None. + context_set_idx (int, optional): + [Description of the context_set_idx parameter.], defaults to 0. + target_set_idx (int, optional): + [Description of the target_set_idx parameter.], defaults to 0. + progress_bar (bool, optional): + [Description of the progress_bar parameter.], defaults to False. + min_or_max (str, optional): + [Description of the min_or_max parameter.], defaults to "min". + task_loader (:class:`~.data.loader.TaskLoader`, optional): + [Description of the task_loader parameter.], defaults to None. + verbose (bool, optional): + [Description of the verbose parameter.], defaults to False. + + Raises: + ValueError: + If the ``model`` passed does not inherit from + :class:`~.model.model.DeepSensorModel`. """ if not isinstance(model, DeepSensorModel): raise ValueError( @@ -136,11 +138,15 @@ def __init__( self.X_t_mask = process_X_mask_for_X(self.X_t_mask, self.X_t) # Interpolate overridden infill datasets at search points if necessary - if query_infill is not None and not da1_da2_same_grid(query_infill, X_s): + if query_infill is not None and not da1_da2_same_grid( + query_infill, X_s + ): if verbose: print("query_infill not on search grid, interpolating.") query_infill = interp_da1_to_da2(query_infill, self.X_s) - if proposed_infill is not None and not da1_da2_same_grid(proposed_infill, X_s): + if proposed_infill is not None and not da1_da2_same_grid( + proposed_infill, X_s + ): if verbose: print("proposed_infill not on search grid, interpolating.") proposed_infill = interp_da1_to_da2(proposed_infill, self.X_s) @@ -153,7 +159,9 @@ def __init__( self.X_t_arr = xarray_to_coord_array_normalised(X_t) if self.X_t_mask is not None: # Remove points that lie outside the mask - self.X_t_arr = mask_coord_array_normalised(self.X_t_arr, self.X_t_mask) + self.X_t_arr = mask_coord_array_normalised( + self.X_t_arr, self.X_t_mask + ) elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)): # Targets off-grid self.X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T @@ -200,7 +208,9 @@ def _get_times_from_tasks(self): def _model_infill_at_search_points( self, - X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index], + X_s: Union[ + xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index + ], ): """ Computes and sets the model infill y-values over whole search grid @@ -345,7 +355,9 @@ def _search(self, acquisition_fn: AcquisitionFunction): and self.task_loader.aux_at_contexts ): # Add auxiliary variable sampled at context set as a new context variable - X_c = task_with_new["X_c"][self.task_loader.aux_at_contexts[0]] + X_c = task_with_new["X_c"][ + self.task_loader.aux_at_contexts[0] + ] Y_c_aux = self.task_loader.sample_offgrid_aux( X_c, self.task_loader.aux_at_contexts[1] ) @@ -426,32 +438,25 @@ def __call__( """ Iteratively... docstring TODO - Returns a tensor of proposed new sensor locations (in greedy - iteration/priority order) and their corresponding list of indexes in - the search space. - - Parameters - ---------- - acquisition_fn: :class:`~.active_learning.acquisition_fns.AcquisitionFunction` - ... - tasks: List[:class:`~.data.task.Task`] | :class:`~.data.task.Task` - ... - - Returns - ------- - X_new_df, acquisition_fn_ds: Tuple[:class:`pandas.DataFrame`, :class:`xarray.Dataset`] - ... - - Raises - ------ - ValueError - If ``acquisition_fn`` is an - :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle` - and ``task_loader`` is None. - ValueError - If ``min_or_max`` is not ``"min"`` or ``"max"``. - ValueError - If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None. + Args: + acquisition_fn (:class:`~.active_learning.acquisition_fns.AcquisitionFunction`): + [Description of the acquisition_fn parameter.] + tasks (List[:class:`~.data.task.Task`] | :class:`~.data.task.Task`): + [Description of the tasks parameter.] + + Returns: + Tuple[:class:`pandas.DataFrame`, :class:`xarray.Dataset`]: + X_new_df, acquisition_fn_ds - [Description of the return values.] + + Raises: + ValueError: + If ``acquisition_fn`` is an + :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle` + and ``task_loader`` is None. + ValueError: + If ``min_or_max`` is not ``"min"`` or ``"max"``. + ValueError: + If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None. """ if ( isinstance(acquisition_fn, AcquisitionFunctionOracle) @@ -465,7 +470,8 @@ def __call__( self.min_or_max = acquisition_fn.min_or_max if self.min_or_max not in ["min", "max"]: raise ValueError( - f"min_or_max must be either 'min' or 'max', got " f"{self.min_or_max}." + f"min_or_max must be either 'min' or 'max', got " + f"{self.min_or_max}." ) if diff and isinstance(acquisition_fn, AcquisitionFunctionParallel): @@ -496,7 +502,10 @@ def __call__( "Model expects Y_t_aux data but a TaskLoader isn't " "provided to GreedyAlgorithm." ) - if self.task_loader is not None and self.task_loader.aux_at_target_dims > 0: + if ( + self.task_loader is not None + and self.task_loader.aux_at_target_dims > 0 + ): tasks[i]["Y_t_aux"] = self.task_loader.sample_offgrid_aux( self.X_t_arr, self.task_loader.aux_at_targets ) @@ -532,7 +541,9 @@ def __call__( if self.model_infill_method == "sample": total_iterations *= self.n_samples - with tqdm(total=total_iterations, disable=not self.progress_bar) as self.pbar: + with tqdm( + total=total_iterations, disable=not self.progress_bar + ) as self.pbar: for iteration in range(self.N_new_context): self.iteration = iteration x_new = self._single_greedy_iteration(acquisition_fn) diff --git a/deepsensor/data/loader.py b/deepsensor/data/loader.py index 08c66456..71a362a0 100644 --- a/deepsensor/data/loader.py +++ b/deepsensor/data/loader.py @@ -33,7 +33,9 @@ def __init__( str, List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]], ] = None, - aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None, + aux_at_contexts: Optional[ + Tuple[int, Union[xr.DataArray, xr.Dataset]] + ] = None, aux_at_targets: Optional[ Union[ xr.DataArray, @@ -56,56 +58,55 @@ def __init__( - Either all data is passed as paths, or all data is passed as loaded data (else ValueError) - If all data passed as paths, the TaskLoader can be saved with the `save` method (using config) - Parameters - ---------- - task_loader_ID : ... - If loading a TaskLoader from a config file, this is the folder the - TaskLoader was saved in (using `.save`). If this argument is passed, all other - arguments are ignored. - context : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`] - Context data. Can be a single :class:`xarray.DataArray`, - :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a - list/tuple of these. - target : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`] - Target data. Can be a single :class:`xarray.DataArray`, - :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a - list/tuple of these. - aux_at_contexts : Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional - Auxiliary data at context locations. Tuple of two elements, where - the first element is the index of the context set for which the - auxiliary data will be sampled at, and the second element is the - auxiliary data, which can be a single :class:`xarray.DataArray` or - :class:`xarray.Dataset`. Default: None. - aux_at_targets : :class:`xarray.DataArray` | :class:`xarray.Dataset`, optional - Auxiliary data at target locations. Can be a single - :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default: - None. - links : Tuple[int, int] | List[Tuple[int, int]], optional - Specifies links between context and target data. Each link is a - tuple of two integers, where the first integer is the index of the - context data and the second integer is the index of the target - data. Can be a single tuple in the case of a single link. If None, - no links are specified. Default: None. - context_delta_t : int | List[int], optional - Time difference between context data and t=0 (task init time). Can - be a single int (same for all context data) or a list/tuple of - ints. Default is 0. - target_delta_t : int | List[int], optional - Time difference between target data and t=0 (task init time). Can - be a single int (same for all target data) or a list/tuple of ints. - Default is 0. - time_freq : str, optional - Time frequency of the data. Default: ``'D'`` (daily). - xarray_interp_method : str, optional - Interpolation method to use when interpolating - :class:`xarray.DataArray`. Default is ``'linear'``. - discrete_xarray_sampling : bool, optional - When randomly sampling xarray variables, whether to sample at - discrete points defined at grid cell centres, or at continuous - points within the grid. Default is ``False``. - dtype : object, optional - Data type of the data. Used to cast the data to the specified - dtype. Default: ``'float32'``. + Args: + task_loader_ID: + If loading a TaskLoader from a config file, this is the folder the + TaskLoader was saved in (using `.save`). If this argument is passed, all other + arguments are ignored. + context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]): + Context data. Can be a single :class:`xarray.DataArray`, + :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a + list/tuple of these. + target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]): + Target data. Can be a single :class:`xarray.DataArray`, + :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a + list/tuple of these. + aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional): + Auxiliary data at context locations. Tuple of two elements, where + the first element is the index of the context set for which the + auxiliary data will be sampled at, and the second element is the + auxiliary data, which can be a single :class:`xarray.DataArray` or + :class:`xarray.Dataset`. Default: None. + aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional): + Auxiliary data at target locations. Can be a single + :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default: + None. + links (Tuple[int, int] | List[Tuple[int, int]], optional): + Specifies links between context and target data. Each link is a + tuple of two integers, where the first integer is the index of the + context data and the second integer is the index of the target + data. Can be a single tuple in the case of a single link. If None, + no links are specified. Default: None. + context_delta_t (int | List[int], optional): + Time difference between context data and t=0 (task init time). Can + be a single int (same for all context data) or a list/tuple of + ints. Default is 0. + target_delta_t (int | List[int], optional): + Time difference between target data and t=0 (task init time). Can + be a single int (same for all target data) or a list/tuple of ints. + Default is 0. + time_freq (str, optional): + Time frequency of the data. Default: ``'D'`` (daily). + xarray_interp_method (str, optional): + Interpolation method to use when interpolating + :class:`xarray.DataArray`. Default is ``'linear'``. + discrete_xarray_sampling (bool, optional): + When randomly sampling xarray variables, whether to sample at + discrete points defined at grid cell centres, or at continuous + points within the grid. Default is ``False``. + dtype (object, optional): + Data type of the data. Used to cast the data to the specified + dtype. Default: ``'float32'``. """ if task_loader_ID is not None: self.task_loader_ID = task_loader_ID @@ -125,7 +126,9 @@ def __init__( self.target_delta_t = self.config["target_delta_t"] self.time_freq = self.config["time_freq"] self.xarray_interp_method = self.config["xarray_interp_method"] - self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"] + self.discrete_xarray_sampling = self.config[ + "discrete_xarray_sampling" + ] self.dtype = self.config["dtype"] else: self.context = context @@ -278,7 +281,9 @@ def _load_pandas_or_xarray(path): def _load_data(data): if isinstance(data, (tuple, list)): - data = tuple([_load_pandas_or_xarray(data_i) for data_i in data]) + data = tuple( + [_load_pandas_or_xarray(data_i) for data_i in data] + ) else: data = _load_pandas_or_xarray(data) return data @@ -312,18 +317,17 @@ def _cast_to_dtype( """ Cast context and target data to the default dtype. - Parameters - ---------- - var : ... - ... + .. + TODO unit test this by passing in a variety of data types and + checking that they are cast correctly. - TODO unit test this by passing in a variety of data types and checking that they are - cast correctly. + Args: + var : ... + ... - Returns - ------- - context : tuple. Tuple of context data with specified dtype. - target : tuple. Tuple of target data with specified dtype. + Returns: + tuple: Tuple of context data with specified dtype. + tuple: Tuple of target data with specified dtype. """ def cast_to_dtype(var): @@ -340,7 +344,9 @@ def cast_to_dtype(var): # Note: Numeric pandas indexes are always cast to float64, so we have to cast # x1/x2 coord dtypes during task sampling else: - raise ValueError(f"Unknown type {type(var)} for context set {var}") + raise ValueError( + f"Unknown type {type(var)} for context set {var}" + ) return var if var is None: @@ -354,11 +360,13 @@ def cast_to_dtype(var): def load_dask(self) -> None: """ - Load any ``dask`` data into memory. + Load any `dask` data into memory. - Returns - ------- - None. + This function triggers the computation and loading of any data that + is represented as dask arrays or datasets into memory. + + Returns: + None """ def load(datasets): @@ -381,17 +389,14 @@ def count_context_and_target_data_dims(self): """ Count the number of data dimensions in the context and target data. - Returns - ------- - context_dims : tuple. Tuple of data dimensions in the context data. - target_dims : tuple. Tuple of data dimensions in the target data. - - Raises - ------ - ValueError - If the context/target data is not a tuple/list of - :class:`xarray.DataArray`, :class:`xarray.Dataset` or - :class:`pandas.DataFrame`. + Returns: + tuple: context_dims, Tuple of data dimensions in the context data. + tuple: target_dims, Tuple of data dimensions in the target data. + + Raises: + ValueError: If the context/target data is not a tuple/list of + :class:`xarray.DataArray`, :class:`xarray.Dataset` or + :class:`pandas.DataFrame`. """ def count_data_dims_of_tuple_of_sets(datasets): @@ -406,22 +411,28 @@ def count_data_dims_of_tuple_of_sets(datasets): elif isinstance(var, xr.DataArray): dim = 1 # Single data variable elif isinstance(var, pd.DataFrame): - dim = len(var.columns) # Assumes all columns are data variables + dim = len( + var.columns + ) # Assumes all columns are data variables elif isinstance(var, pd.Series): dim = 1 # Single data variable else: - raise ValueError(f"Unknown type {type(var)} for context set {var}") + raise ValueError( + f"Unknown type {type(var)} for context set {var}" + ) dims.append(dim) return dims context_dims = count_data_dims_of_tuple_of_sets(self.context) target_dims = count_data_dims_of_tuple_of_sets(self.target) if self.aux_at_contexts is not None: - context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts) + context_dims += count_data_dims_of_tuple_of_sets( + self.aux_at_contexts + ) if self.aux_at_targets is not None: - aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[ - 0 - ] + aux_at_target_dims = count_data_dims_of_tuple_of_sets( + self.aux_at_targets + )[0] else: aux_at_target_dims = 0 @@ -431,17 +442,14 @@ def infer_context_and_target_var_IDs(self): """ Infer the variable IDs of the context and target data. - Returns - ------- - context_var_IDs : tuple. Tuple of variable IDs in the context data. - target_var_IDs : tuple. Tuple of variable IDs in the target data. - - Raises - ------ - ValueError - If the context/target data is not a tuple/list of - :class:`xarray.DataArray`, :class:`xarray.Dataset` or - :class:`pandas.DataFrame`. + Returns: + tuple: context_var_IDs, Tuple of variable IDs in the context data. + tuple: target_var_IDs, Tuple of variable IDs in the target data. + + Raises: + ValueError: If the context/target data is not a tuple/list of + :class:`xarray.DataArray`, :class:`xarray.Dataset` or + :class:`pandas.DataFrame`. """ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): @@ -455,13 +463,17 @@ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): if isinstance(var, xr.DataArray): var_ID = (var.name,) # Single data variable elif isinstance(var, xr.Dataset): - var_ID = tuple(var.data_vars.keys()) # Multiple data variables + var_ID = tuple( + var.data_vars.keys() + ) # Multiple data variables elif isinstance(var, pd.DataFrame): var_ID = tuple(var.columns) elif isinstance(var, pd.Series): var_ID = (var.name,) else: - raise ValueError(f"Unknown type {type(var)} for context set {var}") + raise ValueError( + f"Unknown type {type(var)} for context set {var}" + ) if delta_ts is not None: # Add delta_t to the variable ID @@ -485,15 +497,17 @@ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): ) if self.aux_at_contexts is not None: - context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts) + context_var_IDs += infer_var_IDs_of_tuple_of_sets( + self.aux_at_contexts + ) context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets( self.aux_at_contexts, [0] ) if self.aux_at_targets is not None: - aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[ - 0 - ] + aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets( + self.aux_at_targets + )[0] else: aux_at_target_var_IDs = None @@ -505,28 +519,27 @@ def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None): aux_at_target_var_IDs, ) - def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]): + def _check_links( + self, links: Union[Tuple[int, int], List[Tuple[int, int]]] + ): """ Check that the context-target links are valid. - Parameters - ---------- - links : Tuple[int, int] | List[Tuple[int, int]] - Specifies links between context and target data. Each link is a - tuple of two integers, where the first integer is the index of the - context data and the second integer is the index of the target - data. Can be a single tuple in the case of a single link. If None, - no links are specified. Default: None. - - Returns - ------- - links : Tuple[int, int] | List[Tuple[int, int]] - The input links, if valid. - - Raises - ------ - ValueError - If the links are not valid. + Args: + links (Tuple[int, int] | List[Tuple[int, int]]): + Specifies links between context and target data. Each link is a + tuple of two integers, where the first integer is the index of + the context data and the second integer is the index of the + target data. Can be a single tuple in the case of a single + link. If None, no links are specified. Default: None. + + Returns: + Tuple[int, int] | List[Tuple[int, int]] + The input links, if valid. + + Raises: + ValueError + If the links are not valid. """ if links is None: return None @@ -534,7 +547,9 @@ def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]): assert isinstance( links, list ), f"Links must be a list of length-2 tuples, but got {type(links)}" - assert len(links) > 0, "If links is not None, it must be a non-empty list" + assert ( + len(links) > 0 + ), "If links is not None, it must be a non-empty list" assert all( isinstance(link, tuple) for link in links ), f"Links must be a list of tuples, but got {[type(link) for link in links]}" @@ -592,28 +607,24 @@ def sample_da( """ Sample a DataArray according to a given strategy. - Parameters - ---------- - da : :class:`xarray.DataArray` | :class:`xarray.Dataset` - DataArray to sample, assumed to be sliced for the task already. - sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` - Sampling strategy, either "all" or an integer for random grid cell - sampling. - seed : int, optional - Seed for random sampling. Default: None. - - Returns - ------- - Data : Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Tuple of sampled target data and sampled context data. - - Raises - ------ - InvalidSamplingStrategyError - If the sampling strategy is not valid. - InvalidSamplingStrategyError - If a numpy coordinate array is passed to sample an xarray object, - but the coordinates are out of bounds. + Args: + da (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + DataArray to sample, assumed to be sliced for the task already. + sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`): + Sampling strategy, either "all" or an integer for random grid + cell sampling. + seed (int, optional): + Seed for random sampling. Default is None. + + Returns: + Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]: + Tuple of sampled target data and sampled context data. + + Raises: + InvalidSamplingStrategyError: + If the sampling strategy is not valid or if a numpy coordinate + array is passed to sample an xarray object, but the coordinates + are out of bounds. """ da = da.load() # Converts dask -> numpy if not already loaded if isinstance(da, xr.Dataset): @@ -639,9 +650,15 @@ def sample_da( dim = 1 # Single data variable Y_c = np.zeros((dim, 0), dtype=self.dtype) return X_c, Y_c - x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N) - x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N) - Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest") + x1 = rng.uniform( + da.coords["x1"].min(), da.coords["x1"].max(), N + ) + x2 = rng.uniform( + da.coords["x2"].min(), da.coords["x2"].max(), N + ) + Y_c = da.sel( + x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest" + ) Y_c = np.array(Y_c, dtype=self.dtype) X_c = np.array([x1, x2], dtype=self.dtype) if Y_c.ndim == 1: @@ -692,29 +709,25 @@ def sample_df( """ Sample a DataArray according to a given strategy. - Parameters - ---------- - df : :class:`pandas.DataFrame` | :class:`pandas.Series` - DataArray to sample, assumed to be time-sliced for the task - already. - sampling_strat : str | int | float | :class:`numpy:numpy.ndarray` - Sampling strategy, either "all" or an integer for random grid cell - sampling. - seed : int, optional - Seed for random sampling. Default: None. - - Returns - ------- - Data : Tuple[X_c, Y_c] - Tuple of sampled target data and sampled context data. - - Raises - ------ - InvalidSamplingStrategyError - If the sampling strategy is not valid. - InvalidSamplingStrategyError - If a numpy coordinate array is passed to sample a pandas object, - but the DataFrame does not contain all the requested samples. + Args: + df (:class:`pandas.DataFrame` | :class:`pandas.Series`): + DataArray to sample, assumed to be time-sliced for the task + already. + sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`): + Sampling strategy, either "all" or an integer for random grid + cell sampling. + seed (int, optional): + Seed for random sampling. Default is None. + + Returns: + Tuple[X_c, Y_c]: + Tuple of sampled target data and sampled context data. + + Raises: + InvalidSamplingStrategyError: + If the sampling strategy is not valid or if a numpy coordinate + array is passed to sample a pandas object, but the DataFrame + does not contain all the requested samples. """ df = df.dropna(how="any") # If any obs are NaN, drop them @@ -725,9 +738,16 @@ def sample_df( N = sampling_strat rng = np.random.default_rng(seed) idx = rng.choice(df.index, N) - X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype) + X_c = ( + df.loc[idx] + .reset_index()[["x1", "x2"]] + .values.T.astype(self.dtype) + ) Y_c = df.loc[idx].values.T - elif isinstance(sampling_strat, str) and sampling_strat in ["all", "split"]: + elif isinstance(sampling_strat, str) and sampling_strat in [ + "all", + "split", + ]: # NOTE if "split", we assume that the context-target split has already been applied to the df # in an earlier scope with access to both the context and target data. This is maybe risky! X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype) @@ -762,18 +782,20 @@ def sample_offgrid_aux( """ Sample auxiliary data at off-grid locations. - Parameters - ---------- - X_t : :class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Off-grid locations at which to sample the auxiliary data. Can be a - tuple of two numpy arrays, or a single numpy array. - offgrid_aux : :class:`xarray.DataArray` | :class:`xarray.Dataset` - Auxiliary data at off-grid locations. - - Returns - ------- - :class:`numpy:numpy.ndarray` - ... + Args: + X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]): + Off-grid locations at which to sample the auxiliary data. Can + be a tuple of two numpy arrays, or a single numpy array. + offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + Auxiliary data at off-grid locations. + + Returns: + :class:`numpy:numpy.ndarray`: + [Description of the returned numpy ndarray] + + Raises: + [ExceptionType]: + [Description of under what conditions this function raises an exception] """ if "time" in offgrid_aux.dims: raise ValueError( @@ -801,22 +823,19 @@ def time_slice_variable(self, var, date, delta_t=0): """ Slice a variable by a given time delta. - Parameters - ---------- - var : ... - Variable to slice. - delta_t : ... - Time delta to slice by. - - Returns - ------- - var : ... - Sliced variable. - - Raises - ------ - ValueError - If the variable is of an unknown type. + Args: + var (...): + Variable to slice. + delta_t (...): + Time delta to slice by. + + Returns: + var (...) + Sliced variable. + + Raises: + ValueError + If the variable is of an unknown type. """ # TODO: Does this work with instantaneous time? delta_t = pd.Timedelta(delta_t, unit=self.time_freq) @@ -861,75 +880,65 @@ def task_generation( - int: Sample N observations uniformly at random. - float: Sample a fraction of observations uniformly at random. - :class:`numpy:numpy.ndarray`, shape (2, N): Sample N observations - at the given x1, x2 coordinates. Coords are assumed to be - unnormalised. - - Parameters - ---------- - date : :class:`pandas.Timestamp` - Date for which to generate the task. - context_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`] - Sampling strategy for the context data, either a list of sampling - strategies for each context set, or a single strategy applied to - all context sets. Default is ``"all"``. - target_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`] - Sampling strategy for the target data, either a list of sampling - strategies for each target set, or a single strategy applied to all - target sets. Default is ``"all"``. - split_frac : float - The fraction of observations to use for the context set with the - "split" sampling strategy for linked context and target set pairs. - The remaining observations are used for the target set. Default is - 0.5. - datewise_deterministic : bool - Whether random sampling is datewise_deterministic based on the - date. Default is ``False``. - seed_override : Optional[int] - Override the seed for random sampling. This can be used to use the - same random sampling at different ``date``. Default is None. - - Returns - ------- - task : :class:`~.data.task.Task` - Task object containing the context and target data. + at the given x1, x2 coordinates. Coords are assumed to be + unnormalized. + + Args: + date (:class:`pandas.Timestamp`): + Date for which to generate the task. + context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional): + Sampling strategy for the context data, either a list of + sampling strategies for each context set, or a single strategy + applied to all context sets. Default is ``"all"``. + target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional): + Sampling strategy for the target data, either a list of + sampling strategies for each target set, or a single strategy + applied to all target sets. Default is ``"all"``. + split_frac (float, optional): + The fraction of observations to use for the context set with + the "split" sampling strategy for linked context and target set + pairs. The remaining observations are used for the target set. + Default is 0.5. + datewise_deterministic (bool, optional): + Whether random sampling is datewise deterministic based on the + date. Default is ``False``. + seed_override (Optional[int], optional): + Override the seed for random sampling. This can be used to use + the same random sampling at different ``date``. Default is + None. + + Returns: + :class:`~.data.task.Task`: + Task object containing the context and target data. """ def check_sampling_strat(sampling_strat, set): """ - Check the sampling strategy - - Ensure ``sampling_strat`` is either a single strategy (broadcast to - all sets) or a list of length equal to the number of sets. Convert - to a tuple of length equal to the number of sets and return. - - Parameters - ---------- - sampling_strat : ... - Sampling strategy to check. - set : ... - Context or target set to check. - - Returns - ------- - sampling_strat : tuple - Tuple of sampling strategies, one for each set. - - Raises - ------ - InvalidSamplingStrategyError - If the sampling strategy is invalid. - InvalidSamplingStrategyError - If the length of the sampling strategy does not match the - number of sets. - InvalidSamplingStrategyError - If the sampling strategy is not a valid type. - InvalidSamplingStrategyError - If the sampling strategy is a float but not in [0, 1]. - InvalidSamplingStrategyError - If the sampling strategy is an int but not positive. - InvalidSamplingStrategyError - If the sampling strategy is a numpy array but not of shape - (2, N). + Check the sampling strategy. + + Ensure ``sampling_strat`` is either a single strategy (broadcast + to all sets) or a list of length equal to the number of sets. + Convert to a tuple of length equal to the number of sets and + return. + + Args: + sampling_strat: + Sampling strategy to check. + set: + Context or target set to check. + + Returns: + tuple: + Tuple of sampling strategies, one for each set. + + Raises: + InvalidSamplingStrategyError: + - If the sampling strategy is invalid. + - If the length of the sampling strategy does not match the number of sets. + - If the sampling strategy is not a valid type. + - If the sampling strategy is a float but not in [0, 1]. + - If the sampling strategy is an int but not positive. + - If the sampling strategy is a numpy array but not of shape (2, N). """ if not isinstance(sampling_strat, (list, tuple)): sampling_strat = tuple([sampling_strat] * len(set)) @@ -942,7 +951,9 @@ def check_sampling_strat(sampling_strat, set): ) for strat in sampling_strat: - if not isinstance(strat, (str, int, np.integer, float, np.ndarray)): + if not isinstance( + strat, (str, int, np.integer, float, np.ndarray) + ): raise InvalidSamplingStrategyError( f"Unknown sampling strategy {strat} of type {type(strat)}" ) @@ -976,31 +987,30 @@ def sample_variable(var, sampling_strat, seed): Sample a variable by a given sampling strategy to get input and output data. - Parameters - ---------- - var : ... - Variable to sample. - sampling_strat : ... - Sampling strategy to use. - seed : ... - Seed for random sampling. - - Returns - ------- - ... : Tuple[X, Y] - Tuple of input and output data. - - Raises - ------ - ValueError - If the variable is of an unknown type. + Args: + var: + Variable to sample. + sampling_strat: + Sampling strategy to use. + seed: + Seed for random sampling. + + Returns: + Tuple[X, Y]: + Tuple of input and output data. + + Raises: + ValueError: + If the variable is of an unknown type. """ if isinstance(var, (xr.Dataset, xr.DataArray)): X, Y = self.sample_da(var, sampling_strat, seed) elif isinstance(var, (pd.DataFrame, pd.Series)): X, Y = self.sample_df(var, sampling_strat, seed) else: - raise ValueError(f"Unknown type {type(var)} for context set " f"{var}") + raise ValueError( + f"Unknown type {type(var)} for context set " f"{var}" + ) return X, Y # Check that the sampling strategies are valid @@ -1008,7 +1018,9 @@ def sample_variable(var, sampling_strat, seed): target_sampling = check_sampling_strat(target_sampling, self.target) # Check `split_frac if split_frac < 0 or split_frac > 1: - raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}") + raise ValueError( + f"split_frac must be between 0 and 1, got {split_frac}" + ) if self.links is None: b1 = any( [ @@ -1092,8 +1104,12 @@ def sample_variable(var, sampling_strat, seed): # Perform the split sampling strategy for linked context and target sets at this point # while we have the full context and target data in scope - context_split_idxs = np.where(np.array(context_sampling) == "split")[0] - target_split_idxs = np.where(np.array(target_sampling) == "split")[0] + context_split_idxs = np.where( + np.array(context_sampling) == "split" + )[0] + target_split_idxs = np.where(np.array(target_sampling) == "split")[ + 0 + ] assert len(context_split_idxs) == len(target_split_idxs), ( f"Number of context sets with 'split' sampling strategy " f"({len(context_split_idxs)}) must match number of target sets " @@ -1146,8 +1162,12 @@ def sample_variable(var, sampling_strat, seed): # Perform the gapfill sampling strategy for linked context and target sets at this point # while we have the full context and target data in scope - context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0] - target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0] + context_gapfill_idxs = np.where( + np.array(context_sampling) == "gapfill" + )[0] + target_gapfill_idxs = np.where( + np.array(target_sampling) == "gapfill" + )[0] assert len(context_gapfill_idxs) == len(target_gapfill_idxs), ( f"Number of context sets with 'gapfill' sampling strategy " f"({len(context_gapfill_idxs)}) must match number of target sets " @@ -1177,9 +1197,13 @@ def sample_variable(var, sampling_strat, seed): # Keep trying until we get a target set with at least one target point keep_searching = True while keep_searching: - added_mask_date = rng.choice(self.context[context_idx].time) + added_mask_date = rng.choice( + self.context[context_idx].time + ) added_mask = ( - self.context[context_idx].sel(time=added_mask_date).isnull() + self.context[context_idx] + .sel(time=added_mask_date) + .isnull() ) curr_mask = context_var.isnull() @@ -1190,7 +1214,9 @@ def sample_variable(var, sampling_strat, seed): # when we could just slice the target values here target_mask = added_mask & ~curr_mask if isinstance(target_var, xr.Dataset): - keep_searching = np.all(target_mask.to_array().data == False) + keep_searching = np.all( + target_mask.to_array().data == False + ) else: keep_searching = np.all(target_mask.data == False) if keep_searching: @@ -1210,7 +1236,9 @@ def sample_variable(var, sampling_strat, seed): X_c, Y_c = sample_variable(var, sampling_strat, context_seed) task[f"X_c"].append(X_c) task[f"Y_c"].append(Y_c) - for j, (var, sampling_strat) in enumerate(zip(target_slices, target_sampling)): + for j, (var, sampling_strat) in enumerate( + zip(target_slices, target_sampling) + ): target_seed = seed + i + j if seed is not None else None X_t, Y_t = sample_variable(var, sampling_strat, target_seed) task[f"X_t"].append(X_t) @@ -1218,7 +1246,9 @@ def sample_variable(var, sampling_strat, seed): if self.aux_at_contexts is not None: # Add auxiliary variable sampled at context set as a new context variable - X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)] + X_c_offgrid = [ + X_c for X_c in task["X_c"] if not isinstance(X_c, tuple) + ] if len(X_c_offgrid) == 0: # No offgrid context sets X_c_offrid_all = np.empty((2, 0), dtype=self.dtype) @@ -1226,7 +1256,8 @@ def sample_variable(var, sampling_strat, seed): X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1) Y_c_aux = ( self.sample_offgrid_aux( - X_c_offrid_all, self.time_slice_variable(self.aux_at_contexts, date) + X_c_offrid_all, + self.time_slice_variable(self.aux_at_contexts, date), ), ) task["X_c"].append(X_c_offrid_all) @@ -1248,20 +1279,21 @@ def sample_variable(var, sampling_strat, seed): def __call__(self, date, *args, **kwargs): """ - Generate a task for a given date (or a list of task for an iterable of dates). - - Parameters - ---------- - date : ... - Date for which to generate the task. - - Returns - ------- - task: Task | List[Task] - Task object or list of task objects for each date containing the - context and target data. + Generate a task for a given date (or a list of task for an iterable of + dates). + + Args: + date: + Date for which to generate the task. + + Returns: + Task | List[Task]: + Task object or list of task objects for each date containing + the context and target data. """ - if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)): + if isinstance( + date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex) + ): return [self.task_generation(d, *args, **kwargs) for d in date] else: return self.task_generation(date, *args, **kwargs) diff --git a/deepsensor/data/processor.py b/deepsensor/data/processor.py index dc310d55..e817db46 100644 --- a/deepsensor/data/processor.py +++ b/deepsensor/data/processor.py @@ -33,25 +33,24 @@ def __init__( """ Initialise a DataProcessor object. - Parameters - ---------- - folder : str, optional - Folder to load normalisation params from. Defaults to None. - x1_name : str, optional - Name of first spatial coord (e.g. "lat"). Defaults to "x1". - x2_name : str, optional - Name of second spatial coord (e.g. "lon"). Defaults to "x2". - x1_map : tuple, optional - 2-tuple of raw x1 coords to linearly map to (0, 1), respectively. - Defaults to (0, 1) (i.e. no normalisation). - x2_map : tuple, optional - 2-tuple of raw x2 coords to linearly map to (0, 1), respectively. - Defaults to (0, 1) (i.e. no normalisation). - deepcopy : bool, optional - Whether to make a deepcopy of raw data to ensure it is not changed - by reference when normalising. Defaults to True. - verbose : bool, optional - Whether to print verbose output. Defaults to False. + Args: + folder (str, optional): + Folder to load normalisation params from. Defaults to None. + x1_name (str, optional): + Name of first spatial coord (e.g. "lat"). Defaults to "x1". + x2_name (str, optional): + Name of second spatial coord (e.g. "lon"). Defaults to "x2". + x1_map (tuple, optional): + 2-tuple of raw x1 coords to linearly map to (0, 1), + respectively. Defaults to (0, 1) (i.e. no normalisation). + x2_map (tuple, optional): + 2-tuple of raw x2 coords to linearly map to (0, 1), + respectively. Defaults to (0, 1) (i.e. no normalisation). + deepcopy (bool, optional): + Whether to make a deepcopy of raw data to ensure it is not + changed by reference when normalising. Defaults to True. + verbose (bool, optional): + Whether to print verbose output. Defaults to False. """ if folder is not None: fpath = os.path.join(folder, self.config_fname) @@ -89,7 +88,9 @@ def __init__( if "coords" not in self.config: # Add coordinate normalisation info to config - self.set_coord_params(time_name, x1_name, x1_map, x2_name, x2_map) + self.set_coord_params( + time_name, x1_name, x1_map, x2_name, x2_map + ) self.raw_spatial_coord_names = [ self.config["coords"][coord]["name"] for coord in ["x1", "x2"] @@ -141,7 +142,8 @@ def _validate_coord_mappings(self, x1_map, x2_map): def _validate_xr(self, data: Union[xr.DataArray, xr.Dataset]): def _validate_da(da: xr.DataArray): coord_names = [ - self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"] + self.config["coords"][coord]["name"] + for coord in ["time", "x1", "x2"] ] if coord_names[0] not in da.dims: # We don't have a time dimension. @@ -160,7 +162,8 @@ def _validate_da(da: xr.DataArray): def _validate_pandas(self, df: Union[pd.DataFrame, pd.Series]): coord_names = [ - self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"] + self.config["coords"][coord]["name"] + for coord in ["time", "x1", "x2"] ] if coord_names[0] not in df.index.names: @@ -186,15 +189,12 @@ def load_dask(cls, data: Union[xr.DataArray, xr.Dataset]): """ Load dask data into memory. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` - ... + Args: + data (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + Description of the parameter. - Returns - ------- - ... - ... + Returns: + [Type and description of the returned value(s) needed.] """ if isinstance(data, xr.DataArray): data.load() @@ -202,26 +202,26 @@ def load_dask(cls, data: Union[xr.DataArray, xr.Dataset]): data.load() return data - def set_coord_params(self, time_name, x1_name, x1_map, x2_name, x2_map) -> None: + def set_coord_params( + self, time_name, x1_name, x1_map, x2_name, x2_map + ) -> None: """ Set coordinate normalisation params. - Parameters - ---------- - time_name : ... - ... - x1_name : ... - ... - x1_map : ... - ... - x2_name : ... - ... - x2_map : ... - ... - - Returns - ------- - None. + Args: + time_name: + [Type] Description needed. + x1_name: + [Type] Description needed. + x1_map: + [Type] Description needed. + x2_name: + [Type] Description needed. + x2_map: + [Type] Description needed. + + Returns: + None. """ self.config["coords"] = {} self.config["coords"]["time"] = {"name": time_name} @@ -236,17 +236,15 @@ def check_params_computed(self, var_ID, method) -> bool: """ Check if normalisation params computed for a given variable. - Parameters - ---------- - var_ID : ... - ... - method : ... - ... + Args: + var_ID: + [Type] Description needed. + method: + [Type] Description needed. - Returns - ------- - bool - Whether normalisation params are computed for a given variable. + Returns: + bool: + Whether normalisation params are computed for a given variable. """ if ( var_ID in self.config @@ -266,23 +264,20 @@ def get_config(self, var_ID, data, method=None): Get pre-computed normalisation params or compute them for variable ``var_ID``. - .. note: + .. note:: + TODO do we need to pass var_ID? Can we just use the name of data? - TODO do we need to pass var_ID? Can we just use name of data? + Args: + var_ID: + [Type] Description needed. + data: + [Type] Description needed. + method (optional): + [Type] Description needed. Defaults to None. - Parameters - ---------- - var_ID : ... - ... - data : ... - ... - method : ..., optional - ..., by default None. - - Returns - ------- - ... - ... + Returns: + [Type]: + Description of the returned value(s) needed. """ if method not in self.valid_methods: raise ValueError( @@ -317,39 +312,39 @@ def map_coord_array(self, coord_array: np.ndarray, unnorm: bool = False): """ Normalise or unnormalise a coordinate array. - Parameters - ---------- - coord_array : :class:`numpy:numpy.ndarray` - Array of shape ``(2, N)`` containing coords. - unnorm : bool, optional - Whether to unnormalise. Defaults to ``False``. + Args: + coord_array (:class:`numpy:numpy.ndarray`): + Array of shape ``(2, N)`` containing coords. + unnorm (bool, optional): + Whether to unnormalise. Defaults to ``False``. - Returns - ------- - ... - ... + Returns: + [Type]: + Description of the returned value(s) needed. """ - x1, x2 = self.map_x1_and_x2(coord_array[0], coord_array[1], unnorm=unnorm) + x1, x2 = self.map_x1_and_x2( + coord_array[0], coord_array[1], unnorm=unnorm + ) new_coords = np.stack([x1, x2], axis=0) return new_coords - def map_x1_and_x2(self, x1: np.ndarray, x2: np.ndarray, unnorm: bool = False): + def map_x1_and_x2( + self, x1: np.ndarray, x2: np.ndarray, unnorm: bool = False + ): """ - Normalise or unnormalise spatial coords in a array. - - Parameters - ---------- - x1 : :class:`numpy:numpy.ndarray` - Array of shape ``(N_x1,)`` containing spatial coords of x1. - x2 : :class:`numpy:numpy.ndarray` - Array of shape ``(N_x2,)`` containing spatial coords of x2. - unnorm : bool, optional - Whether to unnormalise. Defaults to ``False``. - - Returns - ------- - Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Normalised or unnormalised spatial coords of x1 and x2. + Normalise or unnormalise spatial coords in an array. + + Args: + x1 (:class:`numpy:numpy.ndarray`): + Array of shape ``(N_x1,)`` containing spatial coords of x1. + x2 (:class:`numpy:numpy.ndarray`): + Array of shape ``(N_x2,)`` containing spatial coords of x2. + unnorm (bool, optional): + Whether to unnormalise. Defaults to ``False``. + + Returns: + Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]: + Normalised or unnormalised spatial coords of x1 and x2. """ x11, x12 = self.config["coords"]["x1"]["map"] x21, x22 = self.config["coords"]["x2"]["map"] @@ -371,17 +366,15 @@ def map_coords( """ Normalise spatial coords in a pandas or xarray object. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | :class:`pandas.Series` - ... - unnorm : bool, optional - ... + Args: + data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`): + [Description Needed] + unnorm (bool, optional): + [Description Needed]. Defaults to [Default Value]. - Returns - ------- - ... - ... + Returns: + [Type]: + [Description Needed] """ if isinstance(data, (pd.DataFrame, pd.Series)): # Reset index to get coords as columns @@ -457,7 +450,9 @@ def map_coords( # Rename all dimensions. rename = { - old: new for old, new in zip(old_coord_IDs, new_coord_IDs) if old != new + old: new + for old, new in zip(old_coord_IDs, new_coord_IDs) + if old != new } data = data.rename(rename) @@ -471,7 +466,9 @@ def map_coords( def map_array( self, - data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray], + data: Union[ + xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray + ], var_ID: str, method: Optional[str] = None, unnorm: bool = False, @@ -481,27 +478,29 @@ def map_array( Normalise or unnormalise the data values in an xarray, pandas, or numpy object. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`numpy:numpy.ndarray` - ... - var_ID : str - ... - method : str, optional - ..., by default None. - unnorm : bool, optional - ..., by default False. - add_offset : bool, optional - ..., by default True. - - Returns - ------- - ... - ... + Args: + data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, :class:`pandas.Series`, or :class:`numpy:numpy.ndarray`): + [Description Needed] + var_ID (str): + [Description Needed] + method (str, optional): + [Description Needed]. Defaults to None. + unnorm (bool, optional): + [Description Needed]. Defaults to False. + add_offset (bool, optional): + [Description Needed]. Defaults to True. + + Returns: + [Type]: + [Description Needed] """ if not unnorm and method is None: raise ValueError("Must provide `method` if normalising data.") - elif unnorm and method is not None and self.config[var_ID]["method"] != method: + elif ( + unnorm + and method is not None + and self.config[var_ID]["method"] != method + ): # User has provided a different method to the one used for normalising raise ValueError( f"Variable '{var_ID}' has been normalised with method '{self.config[var_ID]['method']}', " @@ -558,21 +557,19 @@ def map( Normalise or unnormalise the data values and coords in an xarray or pandas object. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | :class:`pandas.Series` - ... - method : str, optional - ..., by default ``None``. - add_offset : bool, optional - ..., by default ``True``. - unnorm : bool, optional - ..., by default ``False``. - - Returns - ------- - ... - ... + Args: + data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`): + [Description Needed] + method (str, optional): + [Description Needed]. Defaults to None. + add_offset (bool, optional): + [Description Needed]. Defaults to True. + unnorm (bool, optional): + [Description Needed]. Defaults to False. + + Returns: + [Type]: + [Description Needed] """ if self.deepcopy: data = deepcopy(data) @@ -624,27 +621,30 @@ def __call__( """ Normalise data. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`] - Data to normalise. - method : str, optional - Normalisation method. Defaults to "mean_std". Options: - - "mean_std": Normalise to mean=0 and std=1 + Args: + data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]): + Data to be normalised. Can be an xarray DataArray, xarray + Dataset, pandas DataFrame, or a list containing objects of + these types. + method (str, optional): Normalisation method. Options include: + - "mean_std": Normalise to mean=0 and std=1 (default) - "min_max": Normalise to min=-1 and max=1 - Returns - ------- - :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`] - Normalised data. + Returns: + :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]: + Normalised data. Type or structure depends on the input. """ if isinstance(data, list): return [ - self.map(d, method, unnorm=False, assert_computed=assert_computed) + self.map( + d, method, unnorm=False, assert_computed=assert_computed + ) for d in data ] else: - return self.map(data, method, unnorm=False, assert_computed=assert_computed) + return self.map( + data, method, unnorm=False, assert_computed=assert_computed + ) def unnormalise( self, @@ -664,49 +664,61 @@ def unnormalise( """ Unnormalise data. - Parameters - ---------- - data : :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`] - Data to unnormalise. - add_offset : bool, optional - Whether to add the offset to the data when unnormalising. Set to - False to unnormalise uncertainty values (e.g. std dev). Defaults to - True. - - Returns - ------- - :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`] - Unnormalised data. + Args: + data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]): + Data to unnormalise. + add_offset (bool, optional): + Whether to add the offset to the data when unnormalising. Set + to False to unnormalise uncertainty values (e.g. std dev). + Defaults to True. + + Returns: + :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]: + Unnormalised data. """ if isinstance(data, list): - return [self.map(d, add_offset=add_offset, unnorm=True) for d in data] + return [ + self.map(d, add_offset=add_offset, unnorm=True) for d in data + ] else: return self.map(data, add_offset=add_offset, unnorm=True) -def xarray_to_coord_array_normalised(da: Union[xr.Dataset, xr.DataArray]) -> np.ndarray: +def xarray_to_coord_array_normalised( + da: Union[xr.Dataset, xr.DataArray] +) -> np.ndarray: """ Convert xarray to normalised coordinate array. - Parameters - ---------- - da : :class:`xarray.Dataset` | :class:`xarray.DataArray` - ... + Args: + da (:class:`xarray.Dataset` | :class:`xarray.DataArray`) + ... - Returns - ------- - :class:`numpy:numpy.ndarray` - A normalised coordinate array of shape ``(2, N)``. + Returns: + :class:`numpy:numpy.ndarray` + A normalised coordinate array of shape ``(2, N)``. """ x1, x2 = da["x1"].values, da["x2"].values X1, X2 = np.meshgrid(x1, x2, indexing="ij") return np.stack([X1.ravel(), X2.ravel()], axis=0) -def process_X_mask_for_X(X_mask: xr.DataArray, X: xr.DataArray): +def process_X_mask_for_X( + X_mask: xr.DataArray, X: xr.DataArray +) -> xr.DataArray: """Process X_mask by interpolating to X and converting to boolean. Both X_mask and X are xarray DataArrays with the same spatial coords. + + Args: + X_mask (:class:`xarray.DataArray`): + ... + X (:class:`xarray.DataArray`): + ... + + Returns: + :class:`xarray.DataArray` + ... """ X_mask = X_mask.astype(float).interp_like( X, method="nearest", kwargs={"fill_value": 0} @@ -718,24 +730,23 @@ def process_X_mask_for_X(X_mask: xr.DataArray, X: xr.DataArray): def mask_coord_array_normalised( coord_arr: np.ndarray, mask_da: Union[xr.DataArray, xr.Dataset, None] -): +) -> np.ndarray: """ - Remove points from (2, N) numpy array that are outside gridded xarray boolean mask. - - If `coord_arr` is shape `(2, N)`, then `mask_da` is a shape `(N,)` boolean array - (True if point is inside mask, False if outside). - - Parameters - ---------- - coord_arr : ... - ... - mask_da : ... - ... - - Returns - ------- - ... - ... + Remove points from (2, N) numpy array that are outside gridded xarray + boolean mask. + + If `coord_arr` is shape `(2, N)`, then `mask_da` is a shape `(N,)` boolean + array (True if point is inside mask, False if outside). + + Args: + coord_arr (:class:`numpy:numpy.ndarray`): + ... + mask_da (:class:`xarray.Dataset` | :class:`xarray.DataArray`): + ... + + Returns: + :class:`numpy:numpy.ndarray` + ... """ if mask_da is None: return coord_arr @@ -752,17 +763,15 @@ def da1_da2_same_grid(da1: xr.DataArray, da2: xr.DataArray) -> bool: .. note:: ``da1`` and ``da2`` are assumed normalised by ``DataProcessor``. - Parameters - ---------- - da1 : :class:`xarray.DataArray` - ... - da2 : :class:`xarray.DataArray` - ... - - Returns - ------- - bool - Whether ``da1`` and ``da2`` are on the same grid. + Args: + da1 (:class:`xarray.DataArray`): + ... + da2 (:class:`xarray.DataArray`): + ... + + Returns: + bool + Whether ``da1`` and ``da2`` are on the same grid. """ x1equal = np.array_equal(da1["x1"].values, da2["x1"].values) x2equal = np.array_equal(da1["x2"].values, da2["x2"].values) @@ -776,16 +785,14 @@ def interp_da1_to_da2(da1: xr.DataArray, da2: xr.DataArray) -> xr.DataArray: .. note:: ``da1`` and ``da2`` are assumed normalised by ``DataProcessor``. - Parameters - ---------- - da1 : :class:`xarray.DataArray` - ... - da2 : :class:`xarray.DataArray` - ... - - Returns - ------- - :class:`xarray.DataArray` - Interpolated xarray. + Args: + da1 (:class:`xarray.DataArray`): + ... + da2 (:class:`xarray.DataArray`): + ... + + Returns: + :class:`xarray.DataArray` + Interpolated xarray. """ return da1.interp(x1=da2["x1"], x2=da2["x2"], method="nearest") diff --git a/deepsensor/data/task.py b/deepsensor/data/task.py index 1b5d53f5..21998c70 100644 --- a/deepsensor/data/task.py +++ b/deepsensor/data/task.py @@ -1,6 +1,6 @@ import deepsensor -from typing import Union, Tuple, List +from typing import Callable, Union, Tuple, List, Optional import numpy as np import lab as B import plum @@ -21,10 +21,9 @@ def __init__(self, task_dict: dict) -> None: """ Initialise a Task object. - Parameters - ---------- - task_dict : dict - Dictionary containing the task. + Args: + task_dict (dict): + Dictionary containing the task. """ super().__init__(task_dict) @@ -45,7 +44,22 @@ def summarise_str(cls, k, v): return v @classmethod - def summarise_repr(cls, k, v): + def summarise_repr(cls, k, v) -> str: + """ + Summarise the task in a representation that can be printed. + + Args: + cls (:class:`deepsensor.data.task.Task`:): + Task class. + k (str): + Key of the task dictionary. + v (object): + Value of the task dictionary. + + Returns: + str: + String representation of the task. + """ if plum.isinstance(v, B.Numeric): return f"{type(v).__name__}/{v.dtype}/{v.shape}" if plum.isinstance(v, deepsensor.backend.nps.mask.Masked): @@ -58,7 +72,7 @@ def summarise_repr(cls, k, v): else: return f"{type(v).__name__}/{v}" - def __str__(self): + def __str__(self) -> str: """ Print a convenient summary of the task dictionary. @@ -69,7 +83,7 @@ def __str__(self): s += f"{k}: {Task.summarise_str(k, v)}\n" return s - def __repr__(self): + def __repr__(self) -> str: """ Print a convenient summary of the task dictionary. @@ -81,26 +95,23 @@ def __repr__(self): s += f"{k}: {Task.summarise_repr(k, v)}\n" return s - def op(self, f, op_flag=None): - """Apply function f to the array elements of a task dictionary. + def op(self, f: Callable, op_flag: Optional[str] = None): + """ + Apply function f to the array elements of a task dictionary. Useful for recasting to a different dtype or reshaping (e.g. adding a batch dimension). - Parameters - ---------- - f : function - Function to apply to the array elements of the task. - task : dict - Task dictionary. - op_flag : str - Flag to set in the task dictionary's `ops` key. - - Returns - ------- - task : dict. - Task dictionary with f applied to the array elements and - op_flag set in the ``ops`` key. + Args: + f (callable): + Function to apply to the array elements of the task. + op_flag (str): + Flag to set in the task dictionary's `ops` key. + + Returns: + :class:`deepsensor.data.task.Task`: + Task with f applied to the array elements and op_flag set in + the ``ops`` key. """ def recurse(k, v): @@ -109,7 +120,8 @@ def recurse(k, v): elif type(v) is tuple: return (recurse(k, v[0]), recurse(k, v[1])) elif isinstance( - v, (np.ndarray, np.ma.MaskedArray, deepsensor.backend.nps.Masked) + v, + (np.ndarray, np.ma.MaskedArray, deepsensor.backend.nps.Masked), ): return f(v) else: @@ -123,25 +135,33 @@ def recurse(k, v): return self # altered by reference, but return anyway def add_batch_dim(self): - """Add a batch dimension to the arrays in the task dictionary. + """ + Add a batch dimension to the arrays in the task dictionary. - Returns - ------- - task : dict. Task dictionary with batch dimension added to the array elements. + Returns: + :class:`deepsensor.data.task.Task`: + Task with batch dimension added to the array elements. """ return self.op(lambda x: x[None, ...], op_flag="batch_dim") def cast_to_float32(self): - """Cast the arrays in the task dictionary to float32. + """ + Cast the arrays in the task dictionary to float32. - Returns - ------- - task : dict. Task dictionary with arrays cast to float32. + Returns: + :class:`deepsensor.data.task.Task`: + Task with arrays cast to float32. """ return self.op(lambda x: x.astype(np.float32), op_flag="float32") def remove_any_nans_from_Y_t(self): - """If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"])""" + """ + If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"]) + + Returns: + :class:`deepsensor.data.task.Task`: + ... + """ if "batch_dim" in self["ops"]: raise ValueError( "Cannot remove NaNs from task if a batch dim has been added." @@ -179,14 +199,19 @@ def remove_any_nans_from_Y_t(self): return self def mask_nans_numpy(self): - """Replace NaNs with zeroes and set a mask to indicate where the NaNs were. + """ + Replace NaNs with zeroes and set a mask to indicate where the NaNs + were. - Returns - ------- - task : dict. Task with NaNs set to zeros and a mask indicating where the missing values are. + Returns: + :class:`deepsensor.data.task.Task`: + Task with NaNs set to zeros and a mask indicating where the + missing values are. """ if "batch_dim" not in self["ops"]: - raise ValueError("Must call `add_batch_dim` before `mask_nans_numpy`") + raise ValueError( + "Must call `add_batch_dim` before `mask_nans_numpy`" + ) def f(arr): if isinstance(arr, deepsensor.backend.nps.Masked): @@ -207,10 +232,21 @@ def f(arr): return self.op(lambda x: f(x), op_flag="numpy_mask") def mask_nans_nps(self): + """ + ... + + Returns: + :class:`deepsensor.data.task.Task`: + ... + """ if "batch_dim" not in self["ops"]: - raise ValueError("Must call `add_batch_dim` before `mask_nans_nps`") + raise ValueError( + "Must call `add_batch_dim` before `mask_nans_nps`" + ) if "numpy_mask" not in self["ops"]: - raise ValueError("Must call `mask_nans_numpy` before `mask_nans_nps`") + raise ValueError( + "Must call `mask_nans_numpy` before `mask_nans_nps`" + ) def f(arr): if isinstance(arr, np.ma.MaskedArray): @@ -223,10 +259,12 @@ def f(arr): return self.op(lambda x: f(x), op_flag="nps_mask") def convert_to_tensor(self): - """Convert to tensor object based on deep learning backend + """ + Convert to tensor object based on deep learning backend. - Returns - task: dict. Task dictionary with arrays converted to deep learning tensor objects + Returns: + :class:`deepsensor.data.task.Task`: + Task with arrays converted to deep learning tensor objects. """ def f(arr): @@ -257,6 +295,20 @@ def append_obs_to_task( .. TODO: for speed during active learning algs, consider a shallow copy option plus ability to remove observations. + + Args: + task (:class:`deepsensor.data.task.Task`:): + The task to modify. + X_new (array-like): + New observation coordinates. + Y_new (array-like): + New observation values. + context_set_idx (int): + Index of the context set to append to. + + Returns: + :class:`deepsensor.data.task.Task`: + Task with new observation appended to the context set. """ if not 0 <= context_set_idx <= len(task["X_c"]) - 1: raise TaskSetIndexError(context_set_idx, len(task["X_c"]), "context") @@ -289,19 +341,19 @@ def append_obs_to_task( return task_with_new -def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray: +def flatten_X( + X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] +) -> np.ndarray: """ Convert tuple of gridded coords to (2, N) array if necessary. - Parameters - ---------- - X : :class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - ... + Args: + X (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]): + ... - Returns - ---------- - :class:`numpy:numpy.ndarray` - ... + Returns: + :class:`numpy:numpy.ndarray` + ... """ if type(X) is tuple: X1, X2 = np.meshgrid(X[0], X[1], indexing="ij") @@ -309,20 +361,20 @@ def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray return X -def flatten_Y(Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray: +def flatten_Y( + Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] +) -> np.ndarray: """ Convert gridded data of shape (N_dim, N_x1, N_x2) to (N_dim, N_x1 * N_x2) array if necessary. - Parameters - ---------- - Y : :class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - ... + Args: + Y (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]): + ... - Returns - ------- - :class:`numpy:numpy.ndarray` - ... + Returns: + :class:`numpy:numpy.ndarray` + ... """ if Y.ndim == 3: Y = Y.reshape(*Y.shape[:-2], -1) @@ -335,15 +387,13 @@ def flatten_gridded_data_in_task(task: Task) -> Task: Necessary for AR sampling, which doesn't yet permit gridded context sets. - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task : :class:`~.data.task.Task` + ... - Returns - ------- - Task - ... + Returns: + :class:`deepsensor.data.task.Task`: + ... """ task_flattened = copy.deepcopy(task) @@ -365,31 +415,29 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task: functionality. - Raise error if ``aux_t`` values passed (not supported I don't think) - Parameters - ---------- - tasks : List[Task] - List of tasks to concatenate into a single task. - multiple : int, optional - Contexts are padded to the smallest multiple of this number that is - greater than the number of contexts in each task. Defaults to 1 - (padded to the largest number of contexts in the tasks). Setting to a - larger number will increase the amount of padding but decrease the - range of tensor shapes presented to the model, which simplifies the - computational graph in graph mode. - - Returns - ------- - merged_task : :class:`~.data.task.Task` - Task containing multiple batches. - - Raises - ------ - ValueError - If the tasks have different numbers of target sets. - ValueError - If the tasks have different numbers of targets. - ValueError - If the tasks have different types of target sets (gridded/non-gridded). + Args: + tasks (List[:class:`deepsensor.data.task.Task`:]): + List of tasks to concatenate into a single task. + multiple (int, optional): + Contexts are padded to the smallest multiple of this number that is + greater than the number of contexts in each task. Defaults to 1 + (padded to the largest number of contexts in the tasks). Setting + to a larger number will increase the amount of padding but decrease + the range of tensor shapes presented to the model, which simplifies + the computational graph in graph mode. + + Returns: + :class:`~.data.task.Task` + Task containing multiple batches. + + Raises: + ValueError + If the tasks have different numbers of target sets. + ValueError + If the tasks have different numbers of targets. + ValueError + If the tasks have different types of target sets (gridded/ + non-gridded). """ if len(tasks) == 1: return tasks[0] @@ -466,7 +514,9 @@ def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task: ) else: # Target set is off-the-grid with tensor for `X_t` - merged_task["X_t"][i] = B.concat(*[t["X_t"][i] for t in tasks], axis=0) + merged_task["X_t"][i] = B.concat( + *[t["X_t"][i] for t in tasks], axis=0 + ) merged_task["Y_t"][i] = B.concat(*[t["Y_t"][i] for t in tasks], axis=0) merged_task["time"] = [t["time"] for t in tasks] diff --git a/deepsensor/data/utils.py b/deepsensor/data/utils.py index bd869561..08ac27a7 100644 --- a/deepsensor/data/utils.py +++ b/deepsensor/data/utils.py @@ -12,15 +12,13 @@ def construct_x1x2_ds(gridded_ds): a 2D gridded channel whose values contain the x_1 and x_2 coordinate values, respectively. - Parameters - ---------- - gridded_ds : :class:`xarray.Dataset` - ... - - Returns - ------- - :class:`xarray.Dataset` - ... + Args: + gridded_ds (:class:`xarray.Dataset`): + ... + + Returns: + :class:`xarray.Dataset` + ... """ X1, X2 = np.meshgrid(gridded_ds.x1, gridded_ds.x2, indexing="ij") ds = xr.Dataset( @@ -40,17 +38,15 @@ def construct_circ_time_ds(dates, freq): - ``'D'``: cycles once per year at daily intervals - ``'M'``: cycles once per year at monthly intervals - Parameters - ---------- - dates: ... - ... - freq : ... - ... - - Returns - ------- - :class:`xarray.Dataset` - ... + Args: + dates (...): + ... + freq (...): + ... + + Returns: + :class:`xarray.Dataset` + ... """ if freq == "D": time_var = dates.dayofyear @@ -79,7 +75,9 @@ def construct_circ_time_ds(dates, freq): return ds -def compute_xarray_data_resolution(ds: Union[xr.DataArray, xr.Dataset]) -> float: +def compute_xarray_data_resolution( + ds: Union[xr.DataArray, xr.Dataset] +) -> float: """ Computes the resolution of an xarray object with coordinates x1 and x2. @@ -88,15 +86,12 @@ def compute_xarray_data_resolution(ds: Union[xr.DataArray, xr.Dataset]) -> float resolution of 0.2 degrees, the data resolution returned will be 0.1 degrees. - Parameters - ---------- - ds : :class:`xarray.DataArray` | :class:`xarray.Dataset` - Xarray object with coordinates x1 and x2. + Args: + ds (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + Xarray object with coordinates x1 and x2. - Returns - ------- - data_resolution : float - Resolution of the data (in spatial units, e.g. 0.1 degrees). + Returns: + float: Resolution of the data (in spatial units, e.g. 0.1 degrees). """ x1_res = np.abs(np.mean(np.diff(ds["x1"]))) x2_res = np.abs(np.mean(np.diff(ds["x2"]))) @@ -119,21 +114,18 @@ def compute_pandas_data_resolution( than 1000) and to use the 5th percentile. This means that the resolution is the distance between the closest 5% of neighbouring observations. - Parameters - ---------- - df : :class:`pandas.DataFrame` | :class:`pandas.Series` - Dataframe or series with indexes time, x1, and x2. - n_times : int, optional - Number of dates to sample. Defaults to 1000. If "all", all dates are - used. - percentile : int, optional - Percentile of pairwise distances for computing the resolution. - Defaults to 5. - - Returns - ------- - data_resolution : float - Resolution of the data (in spatial units, e.g. 0.1 degrees). + Args: + df (:class:`pandas.DataFrame` | :class:`pandas.Series`): + Dataframe or series with indexes time, x1, and x2. + n_times (int, optional): + Number of dates to sample. Defaults to 1000. If "all", all dates + are used. + percentile (int, optional): + Percentile of pairwise distances for computing the resolution. + Defaults to 5. + + Returns: + float: Resolution of the data (in spatial units, e.g. 0.1 degrees). """ dates = df.index.get_level_values("time").unique() @@ -149,10 +141,14 @@ def compute_pandas_data_resolution( if X.shape[0] < 2: # Skip this time if there are fewer than 2 stationS continue - X_unique = np.unique(X, axis=0) # (N_unique, 2) array of unique coordinates + X_unique = np.unique( + X, axis=0 + ) # (N_unique, 2) array of unique coordinates pairwise_distances = scipy.spatial.distance.cdist(X_unique, X_unique) - percentile_distances_without_self = np.ma.masked_equal(pairwise_distances, 0) + percentile_distances_without_self = np.ma.masked_equal( + pairwise_distances, 0 + ) # Compute the closest distance from each station to each other station closest_distances_t = np.min(percentile_distances_without_self, axis=1) diff --git a/deepsensor/model/convnp.py b/deepsensor/model/convnp.py index 9e4bf5a9..f191347f 100644 --- a/deepsensor/model/convnp.py +++ b/deepsensor/model/convnp.py @@ -62,75 +62,74 @@ class ConvNP(DeepSensorModel): customise the model, which will override any defaults inferred from a ``TaskLoader``. - Parameters - ---------- - points_per_unit : int, optional - Density of the internal discretisation. Defaults to 100. - likelihood : str, optional - Likelihood. Must be one of ``"cnp"`` (equivalently ``"het"``), - ``"gnp"`` (equivalently ``"lowrank"``), or ``"cnp-spikes-beta"`` - (equivalently ``"spikes-beta"``). Defaults to ``"cnp"``. - dim_x : int, optional - Dimensionality of the inputs. Defaults to 1. - dim_y : int, optional - Dimensionality of the outputs. Defaults to 1. - dim_yc : int or tuple[int], optional - Dimensionality of the outputs of the context set. You should set this - if the dimensionality of the outputs of the context set is not equal - to the dimensionality of the outputs of the target set. You should - also set this if you want to use multiple context sets. In that case, - set this equal to a tuple of integers indicating the respective output - dimensionalities. - dim_yt : int, optional - Dimensionality of the outputs of the target set. You should set this - if the dimensionality of the outputs of the target set is not equal to - the dimensionality of the outputs of the context set. - dim_aux_t : int, optional - Dimensionality of target-specific auxiliary variables. - conv_arch : str, optional - Convolutional architecture to use. Must be one of - ``"unet[-res][-sep]"`` or ``"conv[-res][-sep]"``. Defaults to - ``"unet"``. - unet_channels : tuple[int], optional - Channels of every layer of the UNet. Defaults to six layers each with - 64 channels. - unet_kernels : int or tuple[int], optional - Sizes of the kernels in the UNet. Defaults to 5. - unet_resize_convs : bool, optional - Use resize convolutions rather than transposed convolutions in the - UNet. Defaults to ``False``. - unet_resize_conv_interp_method : str, optional - Interpolation method for the resize convolutions in the UNet. Can be - set to ``"bilinear"``. Defaults to "bilinear". - num_basis_functions : int, optional - Number of basis functions for the low-rank likelihood. Defaults to - 64. - dim_lv : int, optional - Dimensionality of the latent variable. Setting to >0 constructs a - latent neural process. Defaults to 0. - encoder_scales : float or tuple[float], optional - Initial value for the length scales of the set convolutions for the - context sets embeddings. Set to a tuple equal to the number of context - sets to use different values for each set. Set to a single value to use - the same value for all context sets. Defaults to - ``1 / points_per_unit``. - encoder_scales_learnable : bool, optional - Whether the encoder SetConv length scale(s) are learnable. Defaults to - ``False``. - decoder_scale : float, optional - Initial value for the length scale of the set convolution in the - decoder. Defaults to ``1 / points_per_unit``. - decoder_scale_learnable : bool, optional - Whether the decoder SetConv length scale(s) are learnable. Defaults to - ``False``. - aux_t_mlp_layers : tuple[int], optional - Widths of the layers of the MLP for the target-specific auxiliary - variable. Defaults to three layers of width 128. - epsilon : float, optional - Epsilon added by the set convolutions before dividing by the density - channel. Defaults to ``1e-2``. - dtype : dtype, optional - Data type. + Args: + points_per_unit (int, optional): + Density of the internal discretisation. Defaults to 100. + likelihood (str, optional): + Likelihood. Must be one of ``"cnp"`` (equivalently ``"het"``), + ``"gnp"`` (equivalently ``"lowrank"``), or ``"cnp-spikes-beta"`` + (equivalently ``"spikes-beta"``). Defaults to ``"cnp"``. + dim_x (int, optional): + Dimensionality of the inputs. Defaults to 1. + dim_y (int, optional): + Dimensionality of the outputs. Defaults to 1. + dim_yc (int or tuple[int], optional): + Dimensionality of the outputs of the context set. You should set this + if the dimensionality of the outputs of the context set is not equal + to the dimensionality of the outputs of the target set. You should + also set this if you want to use multiple context sets. In that case, + set this equal to a tuple of integers indicating the respective output + dimensionalities. + dim_yt (int, optional): + Dimensionality of the outputs of the target set. You should set this + if the dimensionality of the outputs of the target set is not equal to + the dimensionality of the outputs of the context set. + dim_aux_t (int, optional): + Dimensionality of target-specific auxiliary variables. + conv_arch (str, optional): + Convolutional architecture to use. Must be one of + ``"unet[-res][-sep]"`` or ``"conv[-res][-sep]"``. Defaults to + ``"unet"``. + unet_channels (tuple[int], optional): + Channels of every layer of the UNet. Defaults to six layers each with + 64 channels. + unet_kernels (int or tuple[int], optional): + Sizes of the kernels in the UNet. Defaults to 5. + unet_resize_convs (bool, optional): + Use resize convolutions rather than transposed convolutions in the + UNet. Defaults to ``False``. + unet_resize_conv_interp_method (str, optional): + Interpolation method for the resize convolutions in the UNet. Can be + set to ``"bilinear"``. Defaults to "bilinear". + num_basis_functions (int, optional): + Number of basis functions for the low-rank likelihood. Defaults to + 64. + dim_lv (int, optional): + Dimensionality of the latent variable. Setting to >0 constructs a + latent neural process. Defaults to 0. + encoder_scales (float or tuple[float], optional): + Initial value for the length scales of the set convolutions for the + context sets embeddings. Set to a tuple equal to the number of context + sets to use different values for each set. Set to a single value to use + the same value for all context sets. Defaults to + ``1 / points_per_unit``. + encoder_scales_learnable (bool, optional): + Whether the encoder SetConv length scale(s) are learnable. Defaults to + ``False``. + decoder_scale (float, optional): + Initial value for the length scale of the set convolution in the + decoder. Defaults to ``1 / points_per_unit``. + decoder_scale_learnable (bool, optional): + Whether the decoder SetConv length scale(s) are learnable. Defaults to + ``False``. + aux_t_mlp_layers (tuple[int], optional): + Widths of the layers of the MLP for the target-specific auxiliary + variable. Defaults to three layers of width 128. + epsilon (float, optional): + Epsilon added by the set convolutions before dividing by the density + channel. Defaults to ``1e-2``. + dtype (dtype, optional): + Data type. """ @dispatch @@ -162,14 +161,13 @@ def __init__( Instantiate model from TaskLoader, using data to infer model parameters (unless overridden). - Parameters - ---------- - data_processor : :class:`~.data.processor.DataProcessor` - DataProcessor object. - task_loader : :class:`~.data.loader.TaskLoader` - TaskLoader object. - verbose : bool, optional - Whether to print inferred model parameters, by default True. + Args: + data_processor (:class:`~.data.processor.DataProcessor`): + DataProcessor object. + task_loader (:class:`~.data.loader.TaskLoader`): + TaskLoader object. + verbose (bool, optional): + Whether to print inferred model parameters, by default True. """ super().__init__(data_processor, task_loader) @@ -191,21 +189,29 @@ def __init__( if "aux_t_mlp_layers" not in kwargs and kwargs["dim_aux_t"] > 0: kwargs["aux_t_mlp_layers"] = (64,) * 3 if verbose: - print(f"Setting aux_t_mlp_layers: {kwargs['aux_t_mlp_layers']}") + print( + f"Setting aux_t_mlp_layers: {kwargs['aux_t_mlp_layers']}" + ) if "points_per_unit" not in kwargs: ppu = gen_ppu(task_loader) if verbose: print(f"points_per_unit inferred from TaskLoader: {ppu}") kwargs["points_per_unit"] = ppu if "encoder_scales" not in kwargs: - encoder_scales = gen_encoder_scales(kwargs["points_per_unit"], task_loader) + encoder_scales = gen_encoder_scales( + kwargs["points_per_unit"], task_loader + ) if verbose: - print(f"encoder_scales inferred from TaskLoader: {encoder_scales}") + print( + f"encoder_scales inferred from TaskLoader: {encoder_scales}" + ) kwargs["encoder_scales"] = encoder_scales if "decoder_scale" not in kwargs: decoder_scale = gen_decoder_scale(kwargs["points_per_unit"]) if verbose: - print(f"decoder_scale inferred from TaskLoader: {decoder_scale}") + print( + f"decoder_scale inferred from TaskLoader: {decoder_scale}" + ) kwargs["decoder_scale"] = decoder_scale self.model, self.config = construct_neural_process(*args, **kwargs) @@ -220,14 +226,13 @@ def __init__( """ Instantiate with a pre-defined neural process model. - Parameters - ---------- - data_processor : :class:`~.data.processor.DataProcessor` - DataProcessor object. - task_loader : :class:`~.data.loader.TaskLoader` - TaskLoader object. - neural_process : TFModel | TorchModel - Pre-defined neural process model. + Args: + data_processor (:class:`~.data.processor.DataProcessor`): + DataProcessor object. + task_loader (:class:`~.data.loader.TaskLoader`): + TaskLoader object. + neural_process (TFModel | TorchModel): + Pre-defined neural process model. """ super().__init__(data_processor, task_loader) @@ -253,13 +258,24 @@ def __init__( self.load(model_ID) def save(self, model_ID: str): - """Save the model weights and config to a folder.""" + """ + Save the model weights and config to a folder. + + Args: + model_ID (str): + Folder to save the model to. + + Returns: + None. + """ os.makedirs(model_ID, exist_ok=True) if backend.str == "torch": import torch - torch.save(self.model.state_dict(), os.path.join(model_ID, "model.pt")) + torch.save( + self.model.state_dict(), os.path.join(model_ID, "model.pt") + ) elif backend.str == "tf": self.model.save_weights(os.path.join(model_ID, "model")) else: @@ -270,7 +286,16 @@ def save(self, model_ID: str): json.dump(self.config, f, indent=4, sort_keys=False) def load(self, model_ID: str): - """Load a model from a folder containing model weights and config.""" + """ + Load a model from a folder containing model weights and config. + + Args: + model_ID (str): + Folder to load the model from. + + Returns: + None. + """ config_fpath = os.path.join(model_ID, "model_config.json") with open(config_fpath, "r") as f: self.config = json.load(f) @@ -280,7 +305,9 @@ def load(self, model_ID: str): if backend.str == "torch": import torch - self.model.load_state_dict(torch.load(os.path.join(model_ID, "model.pt"))) + self.model.load_state_dict( + torch.load(os.path.join(model_ID, "model.pt")) + ) elif backend.str == "tf": self.model.load_weights(os.path.join(model_ID, "model")) else: @@ -292,15 +319,12 @@ def modify_task(cls, task: Task): Cast numpy arrays to TensorFlow or PyTorch tensors, add batch dim, and mask NaNs. - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ if "batch_dim" not in task["ops"]: @@ -320,19 +344,16 @@ def __call__(self, task, n_samples=10, requires_grad=False): """ Compute ConvNP distribution. - Parameters - ---------- - task : :class:`~.data.task.Task` - ... - n_samples : int, optional - Number of samples to draw from the distribution, by default 10. - requires_grad : bool, optional - Whether to compute gradients, by default False. - - Returns - ------- - ... - The ConvNP distribution. + Args: + task (:class:`~.data.task.Task`): + ... + n_samples (int, optional): + Number of samples to draw from the distribution, by default 10. + requires_grad (bool, optional): + Whether to compute gradients, by default False. + + Returns: + ...: The ConvNP distribution. """ task = ConvNP.modify_task(task) dist = run_nps_model(self.model, task, n_samples, requires_grad) @@ -343,15 +364,12 @@ def mean(self, dist: AbstractMultiOutputDistribution): """ ... - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - ... + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ mean = dist.mean if isinstance(mean, backend.nps.Aggregate): @@ -364,15 +382,12 @@ def mean(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ dist = self(task) return self.mean(dist) @@ -382,15 +397,12 @@ def variance(self, dist: AbstractMultiOutputDistribution): """ ... - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - ... + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ variance = dist.var if isinstance(variance, backend.nps.Aggregate): @@ -403,15 +415,12 @@ def variance(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ dist = self(task) return self.variance(dist) @@ -421,15 +430,12 @@ def stddev(self, dist: AbstractMultiOutputDistribution): """ ... - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - ... + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ variance = self.variance(dist) if isinstance(variance, (list, tuple)): @@ -442,15 +448,12 @@ def stddev(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ dist = self(task) return self.stddev(dist) @@ -460,15 +463,12 @@ def covariance(self, dist: AbstractMultiOutputDistribution): """ ... - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - ... + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ return B.to_numpy(B.dense(dist.vectorised_normal.var))[0, 0] @@ -477,15 +477,12 @@ def covariance(self, task: Task): """ ... - Parameters - ---------- - task : :class:`~.data.task.Task` - ... + Args: + task (:class:`~.data.task.Task`): + ... - Returns - ------- - ... - ... + Returns: + ...: ... """ dist = self(task) return self.covariance(dist) @@ -500,19 +497,19 @@ def sample( """ Create samples from a ConvNP distribution. - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - The distribution to sample from. - n_samples : int, optional - The number of samples to draw from the distribution, by default 1. - noiseless : bool, optional - Whether to sample from the noiseless distribution, by default True. + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + The distribution to sample from. + n_samples (int, optional): + The number of samples to draw from the distribution, by + default 1. + noiseless (bool, optional): + Whether to sample from the noiseless distribution, by default + True. - Returns - ------- - :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`] - The samples as an array or list of arrays. + Returns: + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + The samples as an array or list of arrays. """ if noiseless: samples = dist.noiseless.sample(n_samples) @@ -529,19 +526,19 @@ def sample(self, task: Task, n_samples: int = 1, noiseless: bool = True): """ Create samples from a ConvNP distribution. - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to sample from. - n_samples : int, optional - The number of samples to draw from the distribution, by default 1. - noiseless : bool, optional - Whether to sample from the noiseless distribution, by default True. + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + The distribution to sample from. + n_samples (int, optional): + The number of samples to draw from the distribution, by + default 1. + noiseless (bool, optional): + Whether to sample from the noiseless distribution, by default + True. - Returns - ------- - :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`] - The samples as an array or list of arrays. + Returns: + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + The samples as an array or list of arrays. """ dist = self(task) return self.sample(dist, n_samples, noiseless) @@ -551,15 +548,12 @@ def slice_diag(self, task: Task): """ Slice out the ConvCNP part of the ConvNP distribution. - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to slice. + Args: + task (:class:`~.data.task.Task`): + The task to slice. - Returns - ------- - ... - ... + Returns: + ...: ... """ dist = self(task) dist_diag = backend.nps.MultiOutputNormal( @@ -575,15 +569,12 @@ def slice_diag(self, dist: AbstractMultiOutputDistribution): """ Slice out the ConvCNP part of the ConvNP distribution. - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - The distribution to slice. + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + The distribution to slice. - Returns - ------- - ... - ... + Returns: + ...: ... """ dist_diag = backend.nps.MultiOutputNormal( dist._mean, @@ -598,15 +589,12 @@ def mean_marginal_entropy(self, dist: AbstractMultiOutputDistribution): """ Mean marginal entropy over target points given context points. - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - The distribution to compute the entropy of. + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + The distribution to compute the entropy of. - Returns - ------- - float - The mean marginal entropy. + Returns: + float: The mean marginal entropy. """ dist_diag = self.slice_diag(dist) return B.mean(B.to_numpy(dist_diag.entropy())[0, 0]) @@ -616,15 +604,12 @@ def mean_marginal_entropy(self, task: Task): """ Mean marginal entropy over target points given context points. - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to compute the entropy of. + Args: + task (:class:`~.data.task.Task`): + The task to compute the entropy of. - Returns - ------- - float - The mean marginal entropy. + Returns: + float: The mean marginal entropy. """ dist_diag = self.slice_diag(task) return B.mean(B.to_numpy(dist_diag.entropy())[0, 0]) @@ -634,15 +619,12 @@ def joint_entropy(self, dist: AbstractMultiOutputDistribution): """ Model entropy over target points given context points. - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - The distribution to compute the entropy of. + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + The distribution to compute the entropy of. - Returns - ------- - float - The model entropy. + Returns: + float: The model entropy. """ return B.to_numpy(dist.entropy())[0, 0] @@ -651,15 +633,12 @@ def joint_entropy(self, task: Task): """ Model entropy over target points given context points. - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to compute the entropy of. + Args: + task (:class:`~.data.task.Task`): + The task to compute the entropy of. - Returns - ------- - float - The model entropy. + Returns: + float: The model entropy. """ return B.to_numpy(self(task).entropy())[0, 0] @@ -669,17 +648,14 @@ def logpdf(self, dist: AbstractMultiOutputDistribution, task: Task): Model outputs joint distribution over all targets: Concat targets along observation dimension. - Parameters - ---------- - dist : neuralprocesses.dist.AbstractMultiOutputDistribution - The distribution to compute the logpdf of. - task : :class:`~.data.task.Task` - The task to compute the logpdf of. + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + The distribution to compute the logpdf of. + task (:class:`~.data.task.Task`): + The task to compute the logpdf of. - Returns - ------- - float - The logpdf. + Returns: + float: The logpdf. """ Y_t = B.concat(*task["Y_t"], axis=-1) return B.to_numpy(dist.logpdf(Y_t)).mean() @@ -690,15 +666,12 @@ def logpdf(self, task: Task): Model outputs joint distribution over all targets: Concat targets along observation dimension. - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to compute the logpdf of. + Args: + task (:class:`~.data.task.Task`): + The task to compute the logpdf of. - Returns - ------- - float - The logpdf. + Returns: + float: The logpdf. """ dist = self(task) return self.logpdf(dist, task) @@ -713,24 +686,21 @@ def loss_fn( """ Compute the loss of a task. - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to compute the loss of. - fix_noise : ... - Whether to fix the noise to the value specified in the model - config. - num_lv_samples : int, optional - If latent variable model, number of lv samples for evaluating the - loss, by default 8. - normalise : bool, optional - Whether to normalise the loss by the number of target points, by - default False. - - Returns - ------- - float - The loss. + Args: + task (:class:`~.data.task.Task`): + The task to compute the loss of. + fix_noise (...): + Whether to fix the noise to the value specified in the model + config. + num_lv_samples (int, optional): + If latent variable model, number of lv samples for evaluating + the loss, by default 8. + normalise (bool, optional): + Whether to normalise the loss by the number of target points, + by default False. + + Returns: + float: The loss. """ task = ConvNP.modify_task(task) @@ -769,25 +739,25 @@ def ar_sample( .. note:: AR sampling only works for 0th context/target set - Parameters - ---------- - task : :class:`~.data.task.Task` - The task to sample from. - n_samples : int, optional - The number of samples to draw from the distribution, by default 1. - X_target_AR : :class:`numpy:numpy.ndarray`, optional - Locations to draw AR samples over. If None, AR samples will be - drawn over the target locations in the task. Defaults to None. - ar_subsample_factor : int, optional - Subsample target locations to draw AR samples over. Defaults to 1. - fill_type : Literal["mean", "sample"], optional - How to infill the rest of the sample. Must be one of "mean" or - "sample". Defaults to "mean". - - Returns - ------- - :class:`numpy:numpy.ndarray` - The samples. + Args: + task (:class:`~.data.task.Task`): + The task to sample from. + n_samples (int, optional): + The number of samples to draw from the distribution, by + default 1. + X_target_AR (:class:`numpy:numpy.ndarray`, optional): + Locations to draw AR samples over. If None, AR samples will be + drawn over the target locations in the task. Defaults to None. + ar_subsample_factor (int, optional): + Subsample target locations to draw AR samples over. Defaults + to 1. + fill_type (Literal["mean", "sample"], optional): + How to infill the rest of the sample. Must be one of "mean" or + "sample". Defaults to "mean". + + Returns: + :class:`numpy:numpy.ndarray` + The samples. """ # AR sampling requires gridded data to be flattened, not coordinate tuples @@ -828,14 +798,18 @@ def ar_sample( variance, noiseless_samples, noisy_samples, - ) = run_nps_model_ar(self.model, task_arsample, num_samples=n_samples) + ) = run_nps_model_ar( + self.model, task_arsample, num_samples=n_samples + ) else: ( mean, variance, noiseless_samples, noisy_samples, - ) = run_nps_model_ar(self.model, task_arsample, num_samples=n_samples) + ) = run_nps_model_ar( + self.model, task_arsample, num_samples=n_samples + ) # Slice out first (and assumed only) target entry in nps.Aggregate object noiseless_samples = B.to_numpy(noiseless_samples) @@ -849,7 +823,9 @@ def ar_sample( task_with_sample["X_c"][0] = B.concat( task["X_c"][0], task_arsample["X_t"][0], axis=-1 ) - task_with_sample["Y_c"][0] = B.concat(task["Y_c"][0], sample, axis=-1) + task_with_sample["Y_c"][0] = B.concat( + task["Y_c"][0], sample, axis=-1 + ) if fill_type == "mean": # Compute the mean conditioned on the AR samples diff --git a/deepsensor/model/defaults.py b/deepsensor/model/defaults.py index cacf54af..47993829 100644 --- a/deepsensor/model/defaults.py +++ b/deepsensor/model/defaults.py @@ -21,16 +21,14 @@ def gen_ppu(task_loader: TaskLoader) -> int: computes the data resolution for each. The model ppu is then set to the maximum data ppu. - Parameters - ---------- - task_loader : :class:`~.data.loader.TaskLoader` - TaskLoader object containing context and target sets. - - Returns - ------- - model_ppu : int - Model ppu (points per unit), i.e. the number of points per unit of - input space. + Args: + task_loader (:class:`~.data.loader.TaskLoader`): + TaskLoader object containing context and target sets. + + Returns: + int: + Model ppu (points per unit), i.e. the number of points per unit of + input space. """ # List of data resolutions for each context/target variable (in points-per-unit) data_ppus = [] @@ -63,16 +61,13 @@ def gen_decoder_scale(model_ppu: int) -> float: internal grid. The value chosen is 1 / model_ppu (i.e. the length scale is equal to the model's internal grid spacing). - Parameters - ---------- - model_ppu : int - Model ppu (points per unit), i.e. the number of points per unit of - input space. + Args: + model_ppu (int): + Model ppu (points per unit), i.e. the number of points per unit of + input space. - Returns - ------- - decoder_scale : float - Decoder scale. + Returns: + float: Decoder scale. """ return 1 / model_ppu @@ -95,18 +90,15 @@ def gen_encoder_scales(model_ppu: int, task_loader: TaskLoader) -> List[float]: points) for each context variable. The encoder scale is then set to 0.5 * data_resolution. - Parameters - ---------- - model_ppu : int - Model ppu (points per unit), i.e. the number of points per unit of - input space. - task_loader : :class:`~.data.loader.TaskLoader` - TaskLoader object containing context and target sets. - - Returns - ------- - encoder_scales : list[float] - List of encoder scales for each context set. + Args: + model_ppu (int): + Model ppu (points per unit), i.e. the number of points per unit of + input space. + task_loader (:class:`~.data.loader.TaskLoader`): + TaskLoader object containing context and target sets. + + Returns: + list[float]: List of encoder scales for each context set. """ encoder_scales = [] for var in task_loader.context: diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index 7187cafb..80a06e7d 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -34,36 +34,33 @@ def create_empty_spatiotemporal_xarray( """ ... - Parameters - ---------- - X : :class:`xarray.Dataset` | :class:`xarray.DataArray` - ... - dates : List[...] - ... - coord_names : dict, optional - ..., by default {"x1": "x1", "x2": "x2"} - data_vars : List[str], optional - ..., by default ["var"] - prepend_dims : List[str], optional - ..., by default None - prepend_coords : dict, optional - ..., by default None - - Returns - ------- - ... + Args: + X (:class:`xarray.Dataset` | :class:`xarray.DataArray`): + ... + dates (List[...]): + ... + coord_names (dict, optional): + ..., by default {"x1": "x1", "x2": "x2"} + data_vars (List[str], optional): + ..., by default ["var"] + prepend_dims (List[str], optional): + ..., by default None + prepend_coords (dict, optional): + ..., by default None + + Returns: ... + ... - Raises - ------ - ValueError - If ``data_vars`` contains duplicate values. - ValueError - If ``coord_names["x1"]`` is not uniformly spaced. - ValueError - If ``coord_names["x2"]`` is not uniformly spaced. - ValueError - If ``prepend_dims`` and ``prepend_coords`` are not the same length. + Raises: + ValueError + If ``data_vars`` contains duplicate values. + ValueError + If ``coord_names["x1"]`` is not uniformly spaced. + ValueError + If ``coord_names["x2"]`` is not uniformly spaced. + ValueError + If ``prepend_dims`` and ``prepend_coords`` are not the same length. """ if prepend_dims is None: prepend_dims = [] @@ -82,9 +79,13 @@ def create_empty_spatiotemporal_xarray( # Assert uniform spacing if not np.allclose(np.diff(x1_predict), np.diff(x1_predict)[0]): - raise ValueError(f"Coordinate {coord_names['x1']} must be uniformly spaced.") + raise ValueError( + f"Coordinate {coord_names['x1']} must be uniformly spaced." + ) if not np.allclose(np.diff(x2_predict), np.diff(x2_predict)[0]): - raise ValueError(f"Coordinate {coord_names['x2']} must be uniformly spaced.") + raise ValueError( + f"Coordinate {coord_names['x2']} must be uniformly spaced." + ) if len(prepend_dims) != len(set(prepend_dims)): # TODO unit test @@ -102,7 +103,10 @@ def create_empty_spatiotemporal_xarray( } pred_ds = xr.Dataset( - {data_var: xr.DataArray(dims=dims, coords=coords) for data_var in data_vars} + { + data_var: xr.DataArray(dims=dims, coords=coords) + for data_var in data_vars + } ).astype("float32") # Convert time coord to pandas timestamps @@ -126,26 +130,28 @@ def increase_spatial_resolution( .. # TODO wasteful to interpolate X_t_normalised - Parameters - ---------- - X_t_normalised : ... - ... - resolution_factor : ... - ... - coord_names : dict, optional - ..., by default {"x1": "x1", "x2": "x2"} + Args: + X_t_normalised (...): + ... + resolution_factor (...): + ... + coord_names (dict, optional): + ..., by default {"x1": "x1", "x2": "x2"} - Returns - ------- - ... + Returns: ... + ... """ assert isinstance(resolution_factor, (float, int)) assert isinstance(X_t_normalised, (xr.DataArray, xr.Dataset)) x1_name, x2_name = coord_names["x1"], coord_names["x2"] x1, x2 = X_t_normalised.coords[x1_name], X_t_normalised.coords[x2_name] - x1 = np.linspace(x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64") - x2 = np.linspace(x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64") + x1 = np.linspace( + x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64" + ) + x2 = np.linspace( + x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64" + ) X_t_normalised = X_t_normalised.interp( **{x1_name: x1, x2_name: x2}, method="nearest" ) @@ -164,20 +170,16 @@ def mean(self, task: Task, *args, **kwargs): Computes the model mean prediction over target points based on given context data. - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Returns - ------- - mean : :class:`numpy:numpy.ndarray` - Should return mean prediction over target points. + Returns: + :class:`numpy:numpy.ndarray`: Mean prediction over target points. - Raises - ------ - NotImplementedError - If not implemented by child class. + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -186,20 +188,16 @@ def variance(self, task: Task, *args, **kwargs): Model marginal variance over target points given context points. Shape (N,). - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Returns - ------- - var : :class:`numpy:numpy.ndarray` - Should return marginal variance over target points. + Returns: + :class:`numpy:numpy.ndarray`: Marginal variance over target points. - Raises - ------ - NotImplementedError - If not implemented by child class. + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -208,15 +206,12 @@ def stddev(self, task: Task): Model marginal standard deviation over target points given context points. Shape (N,). - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Returns - ------- - std : :class:`numpy:numpy.ndarray` - Should return marginal standard deviation over target points. + Returns: + :class:`numpy:numpy.ndarray`: Marginal standard deviation over target points. """ var = self.variance(task) return var**0.5 @@ -226,20 +221,16 @@ def covariance(self, task: Task, *args, **kwargs): Computes the model covariance matrix over target points based on given context data. Shape (N, N). - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Returns - ------- - cov : :class:`numpy:numpy.ndarray` - Should return covariance matrix over target points. + Returns: + :class:`numpy:numpy.ndarray`: Covariance matrix over target points. - Raises - ------ - NotImplementedError - If not implemented by child class. + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -251,20 +242,17 @@ def mean_marginal_entropy(self, task: Task, *args, **kwargs): .. note:: Note: Getting a vector of marginal entropies would be useful too. - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. - Returns - ------- - mean_marginal_entropy : float - Should return mean marginal entropy over target points. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Raises - ------ - NotImplementedError - If not implemented by child class. + Returns: + float: Mean marginal entropy over target points. + + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -273,20 +261,17 @@ def joint_entropy(self, task: Task, *args, **kwargs): Computes the model joint entropy over target points based on given context data. - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. - Returns - ------- - joint_entropy : float - Should return joint entropy over target points. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. + + Returns: + float: Joint entropy over target points. - Raises - ------ - NotImplementedError - If not implemented by child class. + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -295,20 +280,16 @@ def logpdf(self, task: Task, *args, **kwargs): Computes the joint model logpdf over target points based on given context data. - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Returns - ------- - logpdf : float - Should return joint logpdf over target points. + Returns: + float: Joint logpdf over target points. - Raises - ------ - NotImplementedError - If not implemented by child class. + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -316,20 +297,16 @@ def loss(self, task: Task, *args, **kwargs): """ Computes the model loss over target points based on given context data. - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. + Args: + task (:class:`~.data.task.Task`): + Task containing context data. - Returns - ------- - loss : float - Should return loss over target points. + Returns: + float: Loss over target points. - Raises - ------ - NotImplementedError - If not implemented by child class. + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -338,22 +315,19 @@ def sample(self, task: Task, n_samples=1, *args, **kwargs): Draws ``n_samples`` joint samples over target points based on given context data. Returned shape is ``(n_samples, n_target)``. - Parameters - ---------- - task : :class:`~.data.task.Task` - Task containing context data. - n_samples : int - Number of samples to draw. - - Returns - ------- - samples : Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`] - Should return joint samples over target points. - - Raises - ------ - NotImplementedError - If not implemented by child class. + + Args: + task (:class:`~.data.task.Task`): + Task containing context data. + n_samples (int, optional): + Number of samples to draw. Defaults to 1. + + Returns: + tuple[:class:`numpy:numpy.ndarray`]: Joint samples over target points. + + Raises: + NotImplementedError + If not implemented by child class. """ raise NotImplementedError() @@ -373,13 +347,12 @@ def __init__( """ Initialise DeepSensorModel. - Parameters - ---------- - data_processor : :class:`~.data.processor.DataProcessor` - DataProcessor object, used to unnormalise predictions. - task_loader : :class:`~.data.loader.TaskLoader` - TaskLoader object, used to determine target variables for - unnormalising. + Args: + data_processor (:class:`~.data.processor.DataProcessor`): + DataProcessor object, used to unnormalise predictions. + task_loader (:class:`~.data.loader.TaskLoader`): + TaskLoader object, used to determine target variables for + unnormalising. """ self.task_loader = task_loader self.data_processor = data_processor @@ -416,73 +389,77 @@ def predict( TODO: - Test with multiple targets model - Parameters - ---------- - tasks : List[Task] | Task - List of tasks containing context data. - X_t : :class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray` - Target locations to predict at. Can be an xarray object containing - on-grid locations or a pandas object containing off-grid locations. - X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray` - Optional 2D mask to apply to X_t (zero/False will be NaNs). Will be interpolated - to the same grid as X_t. Default None (no mask). - X_t_is_normalised : bool - Whether the ``X_t`` coords are normalised. If False, will normalise - the coords before passing to model. Default ``False``. - aux_at_targets_override : :class:`xarray.Dataset` | :class:`xarray.DataArray` - Optional auxiliary xarray data to override from the task_loader. - aux_at_targets_override_is_normalised : bool - Whether the `aux_at_targets_override` coords are normalised. - If False, the DataProcessor will normalise the coords before passing to model. - Default False. - resolution_factor : float - Optional factor to increase the resolution of the target grid by. - E.g. 2 will double the target resolution, 0.5 will halve it. - Applies to on-grid predictions only. Default 1. - n_samples : int - Number of joint samples to draw from the model. If 0, will not - draw samples. Default 0. - ar_sample : bool - Whether to use autoregressive sampling. Default ``False``. - unnormalise : bool - Whether to unnormalise the predictions. Only works if ``self`` has - a ``data_processor`` and ``task_loader`` attribute. Default - ``True``. - seed : int - Random seed for deterministic sampling. Default 0. - append_indexes : dict - Dictionary of index metadata to append to pandas indexes in the - off-grid case. Default ``None``. - progress_bar : int - Whether to display a progress bar over tasks. Default 0. - verbose : bool - Whether to print time taken for prediction. Default ``False``. - - Returns - ------- - predictions : :class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` - If ``X_t`` is a pandas object, returns pandas objects containing - off-grid predictions. - - If ``X_t`` is an xarray object, returns xarray object containing - on-grid predictions. - - If ``n_samples`` == 0, returns only mean and std predictions. - - If ``n_samples`` > 0, returns mean, std and samples predictions. - - Raises - ------ - ValueError - If ``X_t`` is not an xarray object and - ``resolution_factor`` is not 1 or ``ar_subsample_factor`` is not 1. - ValueError - If ``X_t`` is not a pandas object and ``append_indexes`` is not - ``None``. - ValueError - If ``X_t`` is not an xarray, pandas or numpy object. - ValueError - If ``append_indexes`` are not all the same length as ``X_t``. + Args: + tasks (List[Task] | Task): + List of tasks containing context data. + X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`): + Target locations to predict at. Can be an xarray object + containing on-grid locations or a pandas object containing + off-grid locations. + X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray` + Optional 2D mask to apply to X_t (zero/False will be NaNs). + Will be interpolated to the same grid as X_t. Default None (no + mask). + X_t_is_normalised (bool): + Whether the ``X_t`` coords are normalised. If False, will + normalise + the coords before passing to model. Default ``False``. + aux_at_targets_override (:class:`xarray.Dataset` | :class:`xarray.DataArray`): + Optional auxiliary xarray data to override from the + task_loader. + aux_at_targets_override_is_normalised (bool): + Whether the `aux_at_targets_override` coords are normalised. + If False, the DataProcessor will normalise the coords before + passing to model. + Default False. + resolution_factor (float): + Optional factor to increase the resolution of the target grid + by. E.g. 2 will double the target resolution, 0.5 will halve + it. Applies to on-grid predictions only. Default 1. + n_samples (int): + Number of joint samples to draw from the model. If 0, will not + draw samples. Default 0. + ar_sample (bool): + Whether to use autoregressive sampling. Default ``False``. + unnormalise (bool): + Whether to unnormalise the predictions. Only works if ``self`` + has a ``data_processor`` and ``task_loader`` attribute. Default + ``True``. + seed (int): + Random seed for deterministic sampling. Default 0. + append_indexes (dict): + Dictionary of index metadata to append to pandas indexes in the + off-grid case. Default ``None``. + progress_bar (int): + Whether to display a progress bar over tasks. Default 0. + verbose (bool): + Whether to print time taken for prediction. Default ``False``. + + Returns:= + :class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` + If ``X_t`` is a pandas object, returns pandas objects + containing off-grid predictions. + + If ``X_t`` is an xarray object, returns xarray object + containing on-grid predictions. + + If ``n_samples`` == 0, returns only mean and std predictions. + + If ``n_samples`` > 0, returns mean, std and samples + predictions. + + Raises: + ValueError + If ``X_t`` is not an xarray object and + ``resolution_factor`` is not 1 or ``ar_subsample_factor`` is + not 1. + ValueError + If ``X_t`` is not a pandas object and ``append_indexes`` is not + ``None``. + ValueError + If ``X_t`` is not an xarray, pandas or numpy object. + ValueError + If ``append_indexes`` are not all the same length as ``X_t``. """ tic = time.time() @@ -495,7 +472,9 @@ def predict( raise ValueError( "ar_subsample_factor can only be used with on-grid predictions." ) - if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)): + if not isinstance( + X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray) + ): if append_indexes is not None: raise ValueError( "append_indexes can only be used with off-grid predictions." @@ -512,7 +491,9 @@ def predict( if mode == "off-grid" and X_t_mask is not None: # TODO: Unit test this - raise ValueError("X_t_mask can only be used with on-grid predictions.") + raise ValueError( + "X_t_mask can only be used with on-grid predictions." + ) if type(tasks) is Task: tasks = [tasks] @@ -571,9 +552,14 @@ def predict( X_t_mask_normalised = self.data_processor.map_coords(X_t_mask) X_t_arr = xarray_to_coord_array_normalised(X_t_normalised) # Remove points that lie outside the mask - X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised) + X_t_arr = mask_coord_array_normalised( + X_t_arr, X_t_mask_normalised + ) else: - X_t_arr = (X_t_normalised["x1"].values, X_t_normalised["x2"].values) + X_t_arr = ( + X_t_normalised["x1"].values, + X_t_normalised["x2"].values, + ) elif mode == "off-grid": X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T @@ -617,7 +603,9 @@ def predict( elif mode == "off-grid": # Repeat target locs for each date to create multiindex idxs = [(date, *idxs) for date in dates for idxs in X_t.index] - index = pd.MultiIndex.from_tuples(idxs, names=["time", *X_t.index.names]) + index = pd.MultiIndex.from_tuples( + idxs, names=["time", *X_t.index.names] + ) mean = pd.DataFrame(index=index, columns=target_var_IDs) std = pd.DataFrame(index=index, columns=target_var_IDs) if n_samples >= 1: @@ -630,7 +618,9 @@ def predict( index_samples = pd.MultiIndex.from_tuples( idxs_samples, names=["sample", "time", *X_t.index.names] ) - samples = pd.DataFrame(index=index_samples, columns=target_var_IDs) + samples = pd.DataFrame( + index=index_samples, columns=target_var_IDs + ) def unnormalise_pred_array(arr, **kwargs): var_IDs_flattened = [ @@ -663,7 +653,9 @@ def unnormalise_pred_array(arr, **kwargs): else: aux_at_targets = self.task_loader.aux_at_targets - for task in tqdm(tasks, position=0, disable=progress_bar < 1, leave=True): + for task in tqdm( + tasks, position=0, disable=progress_bar < 1, leave=True + ): task["X_t"] = [X_t_arr for _ in range(len(task["X_t"]))] # If passing auxiliary data, need to sample it at target locations @@ -692,7 +684,9 @@ def unnormalise_pred_array(arr, **kwargs): n_samples=n_samples, ar_subsample_factor=ar_subsample_factor, ) - samples_arr = samples_arr.reshape((n_samples, *mean_arr.shape)) + samples_arr = samples_arr.reshape( + (n_samples, *mean_arr.shape) + ) else: samples_arr = self.sample(dist, n_samples=n_samples) else: @@ -708,7 +702,9 @@ def unnormalise_pred_array(arr, **kwargs): n_samples=n_samples, ar_subsample_factor=ar_subsample_factor, ) - samples_arr = samples_arr.reshape((n_samples, *mean_arr.shape)) + samples_arr = samples_arr.reshape( + (n_samples, *mean_arr.shape) + ) else: samples_arr = self.sample(task, n_samples=n_samples) @@ -734,12 +730,16 @@ def unnormalise_pred_array(arr, **kwargs): std.loc[:, task["time"], :, :] = std_arr if n_samples >= 1: for sample_i in range(n_samples): - samples.loc[:, sample_i, task["time"], :, :] = samples_arr[ - sample_i - ] + samples.loc[ + :, sample_i, task["time"], :, : + ] = samples_arr[sample_i] else: - mean.loc[:, task["time"], :, :].data[:, X_t_mask.data] = mean_arr - std.loc[:, task["time"], :, :].data[:, X_t_mask.data] = std_arr + mean.loc[:, task["time"], :, :].data[ + :, X_t_mask.data + ] = mean_arr + std.loc[:, task["time"], :, :].data[ + :, X_t_mask.data + ] = std_arr if n_samples >= 1: for sample_i in range(n_samples): samples.loc[:, sample_i, task["time"], :, :].data[ @@ -751,7 +751,9 @@ def unnormalise_pred_array(arr, **kwargs): std.loc[task["time"]] = std_arr.T if n_samples >= 1: for sample_i in range(n_samples): - samples.loc[sample_i, task["time"]] = samples_arr[sample_i].T + samples.loc[sample_i, task["time"]] = samples_arr[ + sample_i + ].T if mode == "on-grid": mean = mean.to_dataset(dim="data_var") @@ -777,6 +779,36 @@ def create_empty_spatiotemporal_xarray( prepend_dims: List[str] = None, prepend_coords: dict = None, ): + """ + ... + + Args: + X (xr.Dataset | xr.DataArray): + _description_ + dates (List): + _description_ + coord_names (..., optional): + _description_, by default {"x1": "x1", "x2": "x2"} + data_vars (List, optional): + _description_, by default ["var"] + prepend_dims (List[str], optional): + _description_, by default None + prepend_coords (dict, optional): + _description_, by default None + + Returns: + ...: ... + + Raises: + ValueError + ... + ValueError + ... + ValueError + ... + ValueError + ... + """ if prepend_dims is None: prepend_dims = [] if prepend_coords is None: @@ -794,9 +826,13 @@ def create_empty_spatiotemporal_xarray( # Assert uniform spacing if not np.allclose(np.diff(x1_predict), np.diff(x1_predict)[0]): - raise ValueError(f"Coordinate {coord_names['x1']} must be uniformly spaced.") + raise ValueError( + f"Coordinate {coord_names['x1']} must be uniformly spaced." + ) if not np.allclose(np.diff(x2_predict), np.diff(x2_predict)[0]): - raise ValueError(f"Coordinate {coord_names['x2']} must be uniformly spaced.") + raise ValueError( + f"Coordinate {coord_names['x2']} must be uniformly spaced." + ) if len(prepend_dims) != len(set(prepend_dims)): # TODO unit test @@ -814,7 +850,10 @@ def create_empty_spatiotemporal_xarray( } pred_ds = xr.Dataset( - {data_var: xr.DataArray(dims=dims, coords=coords) for data_var in data_vars} + { + data_var: xr.DataArray(dims=dims, coords=coords) + for data_var in data_vars + } ).astype("float32") # Convert time coord to pandas timestamps @@ -828,15 +867,35 @@ def create_empty_spatiotemporal_xarray( def increase_spatial_resolution( - X_t_normalised, resolution_factor, coord_names: dict = {"x1": "x1", "x2": "x2"} + X_t_normalised, + resolution_factor, + coord_names: dict = {"x1": "x1", "x2": "x2"}, ): + """ + ... + + Args: + X_t_normalised (...): + ... + resolution_factor (...): + ... + coord_names (..., optional): + ..., by default {"x1": "x1", "x2": "x2"} + + Returns: + ...: ... + """ # TODO wasteful to interpolate X_t_normalised assert isinstance(resolution_factor, (float, int)) assert isinstance(X_t_normalised, (xr.DataArray, xr.Dataset)) x1_name, x2_name = coord_names["x1"], coord_names["x2"] x1, x2 = X_t_normalised.coords[x1_name], X_t_normalised.coords[x2_name] - x1 = np.linspace(x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64") - x2 = np.linspace(x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64") + x1 = np.linspace( + x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64" + ) + x2 = np.linspace( + x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64" + ) X_t_normalised = X_t_normalised.interp( **{x1_name: x1, x2_name: x2}, method="nearest" ) diff --git a/deepsensor/model/nps.py b/deepsensor/model/nps.py index 2fb993e8..f99dc3c2 100644 --- a/deepsensor/model/nps.py +++ b/deepsensor/model/nps.py @@ -12,15 +12,13 @@ def convert_task_to_nps_args(task: Task): .. TODO move to ConvNP class? - Parameters - ---------- - task : :class:`~.data.task.Task` - Task object containing context and target sets. - - Returns - ------- - ... - ... + Args: + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + + Returns: + tuple[list[tuple[numpy.ndarray, numpy.ndarray]], numpy.ndarray, numpy.ndarray, dict]: + ... """ context_data = list(zip(task["X_c"], task["Y_c"])) @@ -31,7 +29,9 @@ def convert_task_to_nps_args(task: Task): yt = task["Y_t"][0] elif len(task["X_t"]) > 1 and len(task["Y_t"]) > 1: # Multiple target sets, different target locations - xt = backend.nps.AggregateInput(*[(xt, i) for i, xt in enumerate(task["X_t"])]) + xt = backend.nps.AggregateInput( + *[(xt, i) for i, xt in enumerate(task["X_t"])] + ) yt = backend.nps.Aggregate(*task["Y_t"]) elif len(task["X_t"]) == 1 and len(task["Y_t"]) > 1: # Multiple target sets, same target locations @@ -58,22 +58,20 @@ def run_nps_model( """ Run ``neuralprocesses`` model. - Parameters - ---------- - neural_process : neuralprocesses.Model - Neural process model. - task : :class:`~.data.task.Task` - Task object containing context and target sets. - n_samples : int, optional - Number of samples to draw from the model. Defaults to ``None`` (single - sample). - requires_grad : bool, optional - Whether to require gradients. Defaults to ``False``. - - Returns - ------- - dist : neuralprocesses.distributions.Distribution - Distribution object containing the model's predictions. + Args: + neural_process (neuralprocesses.Model): + Neural process model. + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + n_samples (int, optional): + Number of samples to draw from the model. Defaults to ``None`` + (single sample). + requires_grad (bool, optional): + Whether to require gradients. Defaults to ``False``. + + Returns: + neuralprocesses.distributions.Distribution: + Distribution object containing the model's predictions. """ context_data, xt, _, model_kwargs = convert_task_to_nps_args(task) if backend.str == "torch" and not requires_grad: @@ -85,7 +83,9 @@ def run_nps_model( context_data, xt, **model_kwargs, num_samples=n_samples ) else: - dist = neural_process(context_data, xt, **model_kwargs, num_samples=n_samples) + dist = neural_process( + context_data, xt, **model_kwargs, num_samples=n_samples + ) return dist @@ -93,19 +93,17 @@ def run_nps_model_ar(neural_process, task: Task, num_samples: int = 1): """ Run ``neural_process`` in AR mode. - Parameters - ---------- - neural_process : neuralprocesses.Model - Neural process model. - task : :class:`~.data.task.Task` - Task object containing context and target sets. - num_samples : int, optional - Number of samples to draw from the model. Defaults to 1. - - Returns - ------- - Tuple[..., ..., ..., ...] - Tuple of mean, variance, noiseless samples, and noisy samples. + Args: + neural_process (neuralprocesses.Model): + Neural process model. + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + num_samples (int, optional): + Number of samples to draw from the model. Defaults to 1. + + Returns: + tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]: + Tuple of mean, variance, noiseless samples, and noisy samples. """ context_data, xt, _, _ = convert_task_to_nps_args(task) @@ -150,83 +148,80 @@ def construct_neural_process( needed, they must be explicitly passed to ``neuralprocesses`` constructor (not currently safe to use `**kwargs` here). - Parameters - ---------- - dim_x : int, optional - Dimensionality of the inputs. Defaults to 1. - dim_y : int, optional - Dimensionality of the outputs. Defaults to 1. - dim_yc : int or tuple[int], optional - Dimensionality of the outputs of the context set. You should set this - if the dimensionality of the outputs of the context set is not equal to - the dimensionality of the outputs of the target set. You should also - set this if you want to use multiple context sets. In that case, set - this equal to a tuple of integers indicating the respective output - dimensionalities. - dim_yt : int, optional - Dimensionality of the outputs of the target set. You should set this if - the dimensionality of the outputs of the target set is not equal to the - dimensionality of the outputs of the context set. - dim_aux_t : int, optional - Dimensionality of target-specific auxiliary variables. - points_per_unit : int, optional - Density of the internal discretisation. Defaults to 100. - likelihood : str, optional - Likelihood. Must be one of ``"cnp"`` (equivalently ``"het"``), - ``"gnp"`` (equivalently ``"lowrank"``), or ``"cnp-spikes-beta"`` - (equivalently ``"spikes-beta"``). Defaults to ``"cnp"``. - conv_arch : str, optional - Convolutional architecture to use. Must be one of - ``"unet[-res][-sep]"`` or ``"conv[-res][-sep]"``. Defaults to - ``"unet"``. - unet_channels: tuple[int], optional - Channels of every layer of the UNet. Defaults to six layers each with - 64 channels. - unet_kernels : int or tuple[int], optional - Sizes of the kernels in the UNet. Defaults to 5. - unet_resize_convs : bool, optional - Use resize convolutions rather than transposed convolutions in the - UNet. Defaults to ``False``. - unet_resize_conv_interp_method : str, optional - Interpolation method for the resize convolutions in the UNet. Can be - set to ``"bilinear"``. Defaults to "bilinear". - num_basis_functions : int, optional - Number of basis functions for the low-rank likelihood. Defaults to - 64. - dim_lv : int, optional - Dimensionality of the latent variable. Setting to >0 constructs a - latent neural process. Defaults to 0. - encoder_scales : float or tuple[float], optional - Initial value for the length scales of the set convolutions for the - context sets embeddings. Set to a tuple equal to the number of context - sets to use different values for each set. Set to a single value to use - the same value for all context sets. Defaults to - ``1 / points_per_unit``. - encoder_scales_learnable : bool, optional - Whether the encoder SetConv length scale(s) are learnable. Defaults to - ``False``. - decoder_scale : float, optional - Initial value for the length scale of the set convolution in the - decoder. Defaults to ``1 / points_per_unit``. - decoder_scale_learnable : bool, optional - Whether the decoder SetConv length scale(s) are learnable. Defaults to - ``False``. - aux_t_mlp_layers : tuple[int], optional - Widths of the layers of the MLP for the target-specific auxiliary - variable. Defaults to three layers of width 128. - epsilon : float, optional - Epsilon added by the set convolutions before dividing by the density - channel. Defaults to ``1e-2``. - - Returns - ------- - :class:`.model.Model`: - ConvNP model. - - Raises - ------ - NotImplementedError - If specified backend has no default dtype. + Args: + dim_x (int, optional): + Dimensionality of the inputs. Defaults to 1. + dim_y (int, optional): + Dimensionality of the outputs. Defaults to 1. + dim_yc (int or tuple[int], optional): + Dimensionality of the outputs of the context set. You should set + this if the dimensionality of the outputs of the context set is not + equal to the dimensionality of the outputs of the target set. You + should also set this if you want to use multiple context sets. In + that case, set this equal to a tuple of integers indicating the + respective output dimensionalities. + dim_yt (int, optional): + Dimensionality of the outputs of the target set. You should set + this if the dimensionality of the outputs of the target set is not + equal to the dimensionality of the outputs of the context set. + dim_aux_t (int, optional): + Dimensionality of target-specific auxiliary variables. + points_per_unit (int, optional): + Density of the internal discretisation. Defaults to 100. + likelihood (str, optional): + Likelihood. Must be one of ``"cnp"`` (equivalently ``"het"``), + ``"gnp"`` (equivalently ``"lowrank"``), or ``"cnp-spikes-beta"`` + (equivalently ``"spikes-beta"``). Defaults to ``"cnp"``. + conv_arch (str, optional): + Convolutional architecture to use. Must be one of + ``"unet[-res][-sep]"`` or ``"conv[-res][-sep]"``. Defaults to + ``"unet"``. + unet_channels (tuple[int], optional): + Channels of every layer of the UNet. Defaults to six layers each + with 64 channels. + unet_kernels (int or tuple[int], optional): + Sizes of the kernels in the UNet. Defaults to 5. + unet_resize_convs (bool, optional): + Use resize convolutions rather than transposed convolutions in the + UNet. Defaults to ``False``. + unet_resize_conv_interp_method (str, optional): + Interpolation method for the resize convolutions in the UNet. Can + be set to ``"bilinear"``. Defaults to "bilinear". + num_basis_functions (int, optional): + Number of basis functions for the low-rank likelihood. Defaults to + 64. + dim_lv (int, optional): + Dimensionality of the latent variable. Setting to >0 constructs a + latent neural process. Defaults to 0. + encoder_scales (float or tuple[float], optional): + Initial value for the length scales of the set convolutions for the + context sets embeddings. Set to a tuple equal to the number of + context sets to use different values for each set. Set to a single + value to use the same value for all context sets. Defaults to + ``1 / points_per_unit``. + encoder_scales_learnable (bool, optional): + Whether the encoder SetConv length scale(s) are learnable. + Defaults to ``False``. + decoder_scale (float, optional): + Initial value for the length scale of the set convolution in the + decoder. Defaults to ``1 / points_per_unit``. + decoder_scale_learnable (bool, optional): + Whether the decoder SetConv length scale(s) are learnable. Defaults + to ``False``. + aux_t_mlp_layers (tuple[int], optional): + Widths of the layers of the MLP for the target-specific auxiliary + variable. Defaults to three layers of width 128. + epsilon (float, optional): + Epsilon added by the set convolutions before dividing by the + density channel. Defaults to ``1e-2``. + + Returns: + :class:`.model.Model`: + ConvNP model. + + Raises: + NotImplementedError + If specified backend has no default dtype. """ if likelihood == "cnp": likelihood = "het" @@ -247,7 +242,9 @@ def construct_neural_process( dtype = tf.float32 else: - raise NotImplementedError(f"Backend {backend.str} has no default dtype.") + raise NotImplementedError( + f"Backend {backend.str} has no default dtype." + ) neural_process = backend.nps.construct_convgnp( dim_x=dim_x, @@ -281,19 +278,19 @@ def compute_encoding_tensor(model, task: Task): """ Compute the encoding tensor for a given task. - Parameters - ---------- - model : ... - Model object. - task : :class:`~.data.task.Task` - Task object containing context and target sets. - - Returns - ------- - encoding : :class:`numpy:numpy.ndarray` - Encoding tensor? #TODO + Args: + model (...): + Model object. + task (:class:`~.data.task.Task`): + Task object containing context and target sets. + + Returns: + encoding : :class:`numpy:numpy.ndarray` + Encoding tensor? #TODO """ - neural_process_encoder = backend.nps.Model(model.model.encoder, lambda x: x) + neural_process_encoder = backend.nps.Model( + model.model.encoder, lambda x: x + ) task = model.modify_task(task) encoding = B.to_numpy(run_nps_model(neural_process_encoder, task)) return encoding diff --git a/deepsensor/train/train.py b/deepsensor/train/train.py index 366ec87a..aa60bbb5 100644 --- a/deepsensor/train/train.py +++ b/deepsensor/train/train.py @@ -13,18 +13,16 @@ def set_gpu_default_device() -> None: """ Set default GPU device for the backend. - Raises - ------ - RuntimeError - If no GPU is available. - RuntimeError - If backend is not supported. - NotImplementedError - If backend is not supported. - - Returns - ------- - None. + Raises: + RuntimeError + If no GPU is available. + RuntimeError + If backend is not supported. + NotImplementedError + If backend is not supported. + + Returns: + None. """ if deepsensor.backend.str == "torch": # Run on GPU if available @@ -35,7 +33,9 @@ def set_gpu_default_device() -> None: torch.set_default_device("cuda") B.set_global_device("cuda:0") else: - raise RuntimeError("No GPU available: torch.cuda.is_available() == False") + raise RuntimeError( + "No GPU available: torch.cuda.is_available() == False" + ) elif deepsensor.backend.str == "tf": # Run on GPU if available import tensorflow as tf @@ -47,10 +47,14 @@ def set_gpu_default_device() -> None: ) B.set_global_device("GPU:0") else: - raise RuntimeError("No GPU available: tf.test.is_gpu_available() == False") + raise RuntimeError( + "No GPU available: tf.test.is_gpu_available() == False" + ) else: - raise NotImplementedError(f"Backend {deepsensor.backend.str} not implemented") + raise NotImplementedError( + f"Backend {deepsensor.backend.str} not implemented" + ) def train_epoch( @@ -65,28 +69,25 @@ def train_epoch( """ Train model for one epoch. - Parameters - ---------- - model : :class:`~.model.convnp.ConvNP` - Model to train. - tasks : List[:class:`~.data.task.Task`] - List of tasks to train on. - lr : float, optional - Learning rate, by default 5e-5. - batch_size : int, optional - Batch size. Defaults to None. If None, no batching is performed. - opt : Optimizer, optional - TF or Torch optimizer. Defaults to None. If None, - :class:`tensorflow:tensorflow.keras.optimizer.Adam` is used. - progress_bar : bool, optional - Whether to display a progress bar. Defaults to False. - tqdm_notebook : bool, optional - Whether to use a notebook progress bar. Defaults to False. - - Returns - ------- - List[float] - List of losses for each task/batch. + Args: + model (:class:`~.model.convnp.ConvNP`): + Model to train. + tasks (List[:class:`~.data.task.Task`]): + List of tasks to train on. + lr (float, optional): + Learning rate, by default 5e-5. + batch_size (int, optional): + Batch size. Defaults to None. If None, no batching is performed. + opt (Optimizer, optional): + TF or Torch optimizer. Defaults to None. If None, + :class:`tensorflow:tensorflow.keras.optimizer.Adam` is used. + progress_bar (bool, optional): + Whether to display a progress bar. Defaults to False. + tqdm_notebook (bool, optional): + Whether to use a notebook progress bar. Defaults to False. + + Returns: + List[float]: List of losses for each task/batch. """ if deepsensor.backend.str == "tf": import tensorflow as tf @@ -102,7 +103,9 @@ def train_step(tasks): for task in tasks: task_losses.append(model.loss_fn(task, normalise=True)) mean_batch_loss = B.mean(B.stack(*task_losses)) - grads = tape.gradient(mean_batch_loss, model.model.trainable_weights) + grads = tape.gradient( + mean_batch_loss, model.model.trainable_weights + ) opt.apply_gradients(zip(grads, model.model.trainable_weights)) return mean_batch_loss @@ -125,12 +128,16 @@ def train_step(tasks): return mean_batch_loss.detach().cpu().numpy() else: - raise NotImplementedError(f"Backend {deepsensor.backend.str} not implemented") + raise NotImplementedError( + f"Backend {deepsensor.backend.str} not implemented" + ) tasks = np.random.permutation(tasks) if batch_size is not None: - n_batches = len(tasks) // batch_size # Note that this will drop the remainder + n_batches = ( + len(tasks) // batch_size + ) # Note that this will drop the remainder else: n_batches = len(tasks) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index c329b206..53142345 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -74,7 +74,9 @@ def _gen_task_loader_call_args(self, n_context_sets, n_target_sets): "all", np.zeros((2, 1)), ]: - yield [sampling_method] * n_context_sets, [sampling_method] * n_target_sets + yield [sampling_method] * n_context_sets, [ + sampling_method + ] * n_target_sets def test_load_dask(self): """Test loading dask data""" @@ -107,16 +109,6 @@ def data_type_ID_to_data(set_list): E.g. ["xr", "pd", "xr"] -> [self.da, self.df, self.da] E.g. "xr" -> self.da - - Parameters - ---------- - set_list : list[str] | str - List of data type IDs or single data type ID. - - Returns - ------- - list[xr.DataArray] | list[pd.DataFrame] | xr.DataArray | pd.DataFrame - List of data objects or single data object. """ if set_list == "xr": return self.da @@ -198,7 +190,9 @@ def test_invalid_sampling_strat(self): target=self.df, ), ]: - for invalid_sampling_strategy in invalid_context_sampling_strategies: + for ( + invalid_sampling_strategy + ) in invalid_context_sampling_strategies: with self.assertRaises(InvalidSamplingStrategyError): task = tl("2020-01-01", invalid_sampling_strategy) @@ -212,7 +206,9 @@ def test_links_gapfill_da(self) -> None: da_with_nans = copy.deepcopy(self.da) nan_idxs = np.random.randint(0, da_with_nans.size, size=10_000) da_with_nans.data.ravel()[nan_idxs] = np.nan - tl = TaskLoader(context=da_with_nans, target=da_with_nans, links=[(0, 0)]) + tl = TaskLoader( + context=da_with_nans, target=da_with_nans, links=[(0, 0)] + ) # This should not raise an error task = tl("2020-01-01", "gapfill", "gapfill") diff --git a/tests/utils.py b/tests/utils.py index bdc653b0..99fcb470 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,23 +11,21 @@ def gen_random_data_xr( """ Generate random xarray data. - Parameters - ---------- - coords : dict - Coordinates of the data. - dims : list, optional - Dimensions of the data. Defaults to None. If None, dims is inferred - from coords. This arg can be used to change the order of the - dimensions. - data_vars : list, optional - Data variables. Defaults to None. If None, variable is an - :class:`xarray.DataArray`. If not None, variable is an - :class:`xarray.Dataset` containing the data_vars. + Args: + coords (dict): + Coordinates of the data. + dims (list, optional): + Dimensions of the data. Defaults to None. If None, dims is inferred + from coords. This arg can be used to change the order of the + dimensions. + data_vars (list, optional): + Data variables. Defaults to None. If None, variable is an + :class:`xarray.DataArray`. If not None, variable is an + :class:`xarray.Dataset` containing the data_vars. - Returns - ------- - da : :class:`xarray.DataArray` | :class:`xarray.Dataset` - Random xarray data. + Returns: + da (:class:`xarray.DataArray` | :class:`xarray.Dataset`): + Random xarray data. """ if dims is None: shape = tuple([len(coords[dim]) for dim in coords]) @@ -47,24 +45,22 @@ def gen_random_data_pandas(coords: dict, dims: list = None, cols: list = None): """ Generate random pandas data. - Parameters - ---------- - coords : dict - Coordinates of the data. This will be used to construct a MultiIndex - using pandas.MultiIndex.from_product. - dims : list, optional - Dimensions of the data. Defaults to None. If None, dims is inferred - from coords. This arg can be used to change the order of the - MultiIndex. - cols : list, optional - Columns of the data. Defaults to None. If None, generate a - :class:`pandas.Series` with an arbitrary name. If not None, cols is - used to construct a :class:`pandas.DataFrame`. + Args: + coords (dict): + Coordinates of the data. This will be used to construct a + MultiIndex using pandas.MultiIndex.from_product. + dims (list, optional): + Dimensions of the data. Defaults to None. If None, dims is inferred + from coords. This arg can be used to change the order of the + MultiIndex. + cols (list, optional): + Columns of the data. Defaults to None. If None, generate a + :class:`pandas.Series` with an arbitrary name. If not None, cols is + used to construct a :class:`pandas.DataFrame`. - Returns - ------- - df : :class:`pandas.Series` | :class:`pandas.DataFrame` - Random pandas data. + Returns: + :class:`pandas.Series` | :class:`pandas.DataFrame` + Random pandas data. """ if dims is None: dims = list(coords.keys())