diff --git a/deepsensor/model/convnp.py b/deepsensor/model/convnp.py index c741f0a1..1b3649d1 100644 --- a/deepsensor/model/convnp.py +++ b/deepsensor/model/convnp.py @@ -222,6 +222,7 @@ def __init__( kwargs["decoder_scale"] = decoder_scale self.model, self.config = construct_neural_process(*args, **kwargs) + self._set_num_mixture_components() @dispatch def __init__( @@ -254,6 +255,7 @@ def __init__(self, model_ID: str): super().__init__() self.load(model_ID) + self._set_num_mixture_components() @dispatch def __init__( @@ -277,6 +279,18 @@ def __init__( super().__init__(data_processor, task_loader) self.load(model_ID) + self._set_num_mixture_components() + + def _set_num_mixture_components(self): + """ + Set the number of mixture components for the model based on the likelihood. + """ + if self.config["likelihood"] in ["spikes-beta"]: + self.N_mixture_components = 3 + elif self.config["likelihood"] in ["bernoulli-gamma"]: + self.N_mixture_components = 2 + else: + self.N_mixture_components = 1 def save(self, model_ID: str): """ @@ -450,7 +464,7 @@ def variance(self, task: Task): return self.variance(dist) @dispatch - def stddev(self, dist: AbstractMultiOutputDistribution): + def std(self, dist: AbstractMultiOutputDistribution): """ ... @@ -468,7 +482,7 @@ def stddev(self, dist: AbstractMultiOutputDistribution): return np.sqrt(variance) @dispatch - def stddev(self, task: Task): + def std(self, task: Task): """ ... @@ -480,7 +494,119 @@ def stddev(self, task: Task): ...: ... """ dist = self(task) - return self.stddev(dist) + return self.std(dist) + + @dispatch + def alpha( + self, dist: AbstractMultiOutputDistribution + ) -> Union[np.ndarray, List[np.ndarray]]: + if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]: + raise NotImplementedError( + f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. " + f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'." + ) + alpha = dist.slab.alpha + if isinstance(alpha, backend.nps.Aggregate): + return [B.to_numpy(m)[0, 0] for m in alpha] + else: + return B.to_numpy(alpha)[0, 0] + + @dispatch + def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: + """ + Alpha values of model's distribution at target locations in task. + + .. note:: + This method only works for models that return a distribution with + a ``dist.slab.alpha`` attribute, e.g. models with a Beta or + Bernoulli-Gamma likelihood. + + Args: + task (:class:`~.data.task.Task`): + The task containing the context and target data. + + Returns: + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Alpha values. + """ + dist = self(task) + return self.alpha(dist) + + @dispatch + def beta(self, dist: AbstractMultiOutputDistribution): + if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]: + raise NotImplementedError( + f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. " + f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'." + ) + beta = dist.slab.beta + if isinstance(beta, backend.nps.Aggregate): + return [B.to_numpy(m)[0, 0] for m in beta] + else: + return B.to_numpy(beta)[0, 0] + + @dispatch + def beta(self, task: Task): + """ + Beta values of model's distribution at target locations in task. + + .. note:: + This method only works for models that return a distribution with + a ``dist.slab.beta`` attribute, e.g. models with a Beta or + Bernoulli-Gamma likelihood. + + Args: + task (:class:`~.data.task.Task`): + The task containing the context and target data. + + Returns: + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Beta values. + """ + dist = self(task) + return self.beta(dist) + + @dispatch + def mixture_probs(self, dist: AbstractMultiOutputDistribution): + """ + Probabilities of the components of a mixture distribution. + + Shape (N_components, N_features, N_targets) + + Args: + dist (neuralprocesses.dist.AbstractMultiOutputDistribution): + ... + + Returns: + ...: ... + """ + if self.N_mixture_components == 1: + raise NotImplementedError( + f"mixture_probs not supported if model attribute N_mixture_components == 1. " + f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta'." + ) + mixture_probs = dist.logprobs + if isinstance(mixture_probs, backend.nps.Aggregate): + return [ + np.moveaxis(np.exp(B.to_numpy(m)[0, 0]), -1, 0) for m in mixture_probs + ] + else: + return np.moveaxis(np.exp(B.to_numpy(mixture_probs)[0, 0]), -1, 0) + + @dispatch + def mixture_probs(self, task: Task): + """ + ... + + Args: + task (:class:`~.data.task.Task`): + ... + + Returns: + ...: ... + """ + dist = self(task) + return self.mixture_probs(dist) @dispatch def covariance(self, dist: AbstractMultiOutputDistribution): @@ -516,7 +642,6 @@ def sample( self, dist: AbstractMultiOutputDistribution, n_samples: int = 1, - noiseless: bool = True, ): """ Create samples from a ConvNP distribution. @@ -527,15 +652,12 @@ def sample( n_samples (int, optional): The number of samples to draw from the distribution, by default 1. - noiseless (bool, optional): - Whether to sample from the noiseless distribution, by default - True. Returns: :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: The samples as an array or list of arrays. """ - if noiseless: + if self.config["likelihood"] in ["gnp", "lowrank"]: samples = dist.noiseless.sample(n_samples) else: samples = dist.sample(n_samples) @@ -546,7 +668,7 @@ def sample( return B.to_numpy(samples)[:, 0, 0] @dispatch - def sample(self, task: Task, n_samples: int = 1, noiseless: bool = True): + def sample(self, task: Task, n_samples: int = 1): """ Create samples from a ConvNP distribution. @@ -556,16 +678,13 @@ def sample(self, task: Task, n_samples: int = 1, noiseless: bool = True): n_samples (int, optional): The number of samples to draw from the distribution, by default 1. - noiseless (bool, optional): - Whether to sample from the noiseless distribution, by default - True. Returns: :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: The samples as an array or list of arrays. """ dist = self(task) - return self.sample(dist, n_samples, noiseless) + return self.sample(dist, n_samples) @dispatch def slice_diag(self, task: Task): @@ -580,12 +699,15 @@ def slice_diag(self, task: Task): ...: ... """ dist = self(task) - dist_diag = backend.nps.MultiOutputNormal( - dist._mean, - B.zeros(dist._var), - Diagonal(B.diag(dist._noise + dist._var)), - dist.shape, - ) + if self.config["likelihood"] in ["spikes-beta"]: + dist_diag = dist + else: + dist_diag = backend.nps.MultiOutputNormal( + dist._mean, + B.zeros(dist._var), + Diagonal(B.diag(dist._noise + dist._var)), + dist.shape, + ) return dist_diag @dispatch @@ -600,12 +722,15 @@ def slice_diag(self, dist: AbstractMultiOutputDistribution): Returns: ...: ... """ - dist_diag = backend.nps.MultiOutputNormal( - dist._mean, - B.zeros(dist._var), - Diagonal(B.diag(dist._noise + dist._var)), - dist.shape, - ) + if self.config["likelihood"]: + dist_diag = dist + else: + dist_diag = backend.nps.MultiOutputNormal( + dist._mean, + B.zeros(dist._var), + Diagonal(B.diag(dist._noise + dist._var)), + dist.shape, + ) return dist_diag @dispatch @@ -669,8 +794,11 @@ def joint_entropy(self, task: Task): @dispatch def logpdf(self, dist: AbstractMultiOutputDistribution, task: Task): """ - Model outputs joint distribution over all targets: Concat targets along - observation dimension. + Joint logpdf over all target sets. + + .. note:: + If the model has multiple target sets, the returned logpdf is the + mean logpdf over all target sets. Args: dist (neuralprocesses.dist.AbstractMultiOutputDistribution): @@ -681,14 +809,20 @@ def logpdf(self, dist: AbstractMultiOutputDistribution, task: Task): Returns: float: The logpdf. """ - Y_t = B.concat(*task["Y_t"], axis=-1) + # Need to ensure `Y_t` is a tensor and, if multiple target sets, + # an nps.Aggregate object + task = ConvNP.modify_task(task) + _, _, Y_t, _ = convert_task_to_nps_args(task) return B.to_numpy(dist.logpdf(Y_t)).mean() @dispatch def logpdf(self, task: Task): """ - Model outputs joint distribution over all targets: Concat targets along - observation dimension. + Joint logpdf over all target sets. + + .. note:: + If the model has multiple target sets, the returned logpdf is the + mean logpdf over all target sets. Args: task (:class:`~.data.task.Task`): diff --git a/deepsensor/model/model.py b/deepsensor/model/model.py index a8a23522..315d8fe0 100644 --- a/deepsensor/model/model.py +++ b/deepsensor/model/model.py @@ -23,7 +23,6 @@ import xarray as xr import lab as B - # For dispatching with TF and PyTorch model types when they have not yet been loaded. # See https://beartype.github.io/plum/types.html#moduletype @@ -71,7 +70,7 @@ def variance(self, task: Task, *args, **kwargs): """ raise NotImplementedError() - def stddev(self, task: Task): + def std(self, task: Task): """ Model marginal standard deviation over target points given context points. Shape (N,). @@ -86,6 +85,9 @@ def stddev(self, task: Task): var = self.variance(task) return var**0.5 + def stddev(self, *args, **kwargs): + return self.std(*args, **kwargs) + def covariance(self, task: Task, *args, **kwargs): """ Computes the model covariance matrix over target points based on given @@ -215,6 +217,8 @@ class DeepSensorModel(ProbabilisticModel): TaskLoader object, used to determine target variables for unnormalising. """ + N_mixture_components = 1 # Number of mixture components for mixture likelihoods + def __init__( self, data_processor: Optional[DataProcessor] = None, @@ -239,6 +243,7 @@ def predict( aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None, aux_at_targets_override_is_normalised: bool = False, resolution_factor: int = 1, + pred_params: tuple[str] = ("mean", "std"), n_samples: int = 0, ar_sample: bool = False, ar_subsample_factor: int = 1, @@ -269,6 +274,10 @@ def predict( Whether the `aux_at_targets_override` coords are normalised. If False, the DataProcessor will normalise the coords before passing to model. Default False. + pred_params (tuple[str]): + Tuple of prediction parameters to return. The strings refer to methods + of the model class which will be called and stored in the Prediction object. + Default ("mean", "std"). resolution_factor (float): Optional factor to increase the resolution of the target grid by. E.g. 2 will double the target resolution, 0.5 will halve @@ -420,13 +429,32 @@ def predict( elif mode == "off-grid": X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T + if isinstance(X_t_arr, tuple): + target_shape = (len(X_t_arr[0]), len(X_t_arr[1])) + else: + target_shape = (X_t_arr.shape[1],) + if not unnormalise: X_t = X_t_normalised + if "mixture_probs" in pred_params: + # Store each mixture component separately w/o overriding pred_params + pred_params_to_store = copy.deepcopy(pred_params) + pred_params_to_store.remove("mixture_probs") + for component_i in range(self.N_mixture_components): + pred_params_to_store.append(f"mixture_probs_{component_i}") + else: + pred_params_to_store = pred_params + # Dict to store predictions for each target variable - # Make this a subclass of dict like Task? And way to initialise cleanly with target_var_IDs? pred = Prediction( - target_var_IDs, dates, X_t, X_t_mask, coord_names, n_samples=n_samples + target_var_IDs, + pred_params_to_store, + dates, + X_t, + X_t_mask, + coord_names, + n_samples=n_samples, ) def unnormalise_pred_array(arr, **kwargs): @@ -436,7 +464,9 @@ def unnormalise_pred_array(arr, **kwargs): for var_IDs in self.task_loader.target_var_IDs for var_ID in var_IDs ] - assert arr.shape[0] == len(var_IDs_flattened) + assert arr.shape[0] == len( + var_IDs_flattened + ), f"{arr.shape[0]} != {len(var_IDs_flattened)}" for i, var_ID in enumerate(var_IDs_flattened): arr[i] = self.data_processor.map_array( arr[i], @@ -474,66 +504,115 @@ def unnormalise_pred_array(arr, **kwargs): X_t_arr, aux_at_targets_sliced ) + prediction_arrs = {} + prediction_methods = {} + for param in pred_params: + try: + method = getattr(self, param) + prediction_methods[param] = method + except AttributeError: + raise AttributeError( + f"Prediction method {param} not found in model class." + ) + if n_samples >= 1: + B.set_random_seed(seed) + np.random.seed(seed) + if ar_sample: + sample_method = getattr(self, "ar_sample") + sample_args = { + "n_samples": n_samples, + "ar_subsample_factor": ar_subsample_factor, + } + else: + sample_method = getattr(self, "sample") + sample_args = {"n_samples": n_samples} + # If `DeepSensor` model child has been sub-classed with a `__call__` method, # we assume this is a distribution-like object that can be used to compute # mean, std and samples. Otherwise, run the model with `Task` for each prediction type. if hasattr(self, "__call__"): # Run model forwards once to generate output distribution, which we re-use dist = self(task, n_samples=n_samples) - mean_arr = self.mean(dist) - std_arr = self.stddev(dist) - if n_samples >= 1: - B.set_random_seed(seed) - np.random.seed(seed) - if ar_sample: - samples_arr = self.ar_sample( - task, - n_samples=n_samples, - ar_subsample_factor=ar_subsample_factor, - ) - samples_arr = samples_arr.reshape((n_samples, *mean_arr.shape)) - else: - samples_arr = self.sample(dist, n_samples=n_samples) - # Repeated code not ideal here... + for param, method in prediction_methods.items(): + prediction_arrs[param] = method(dist) + if n_samples >= 1 and not ar_sample: + samples_arr = sample_method(dist, **sample_args) + # samples_arr = samples_arr.reshape((n_samples, len(target_var_IDs), *target_shape)) + prediction_arrs["samples"] = samples_arr + elif n_samples >= 1 and ar_sample: + # Can't draw AR samples from distribution object, need to re-run with task + samples_arr = sample_method(task, **sample_args) + samples_arr = samples_arr.reshape( + (n_samples, len(target_var_IDs), *target_shape) + ) + prediction_arrs["samples"] = samples_arr else: # Re-run model for each prediction type - mean_arr = self.mean(task) - std_arr = self.stddev(task) + for param, method in prediction_methods.items(): + prediction_arrs[param] = method(task) if n_samples >= 1: - B.set_random_seed(seed) - np.random.seed(seed) + samples_arr = sample_method(task, **sample_args) if ar_sample: - samples_arr = self.ar_sample( - task, - n_samples=n_samples, - ar_subsample_factor=ar_subsample_factor, + samples_arr = samples_arr.reshape( + (n_samples, len(target_var_IDs), *target_shape) ) - samples_arr = samples_arr.reshape((n_samples, *mean_arr.shape)) - else: - samples_arr = self.sample(task, n_samples=n_samples) + prediction_arrs["samples"] = samples_arr # Concatenate multi-target predictions - if isinstance(mean_arr, (list, tuple)): - mean_arr = np.concatenate(mean_arr, axis=0) - std_arr = np.concatenate(std_arr, axis=0) - if n_samples >= 1: - # Axis 0 is sample dim, axis 1 is variable dim - samples_arr = np.concatenate(samples_arr, axis=1) + for param, arr in prediction_arrs.items(): + if isinstance(arr, (list, tuple)): + if param != "samples": + concat_axis = 0 + elif param == "samples": + # Axis 0 is sample dim, axis 1 is variable dim + concat_axis = 1 + prediction_arrs[param] = np.concatenate(arr, axis=concat_axis) # Unnormalise predictions - if unnormalise: - mean_arr = unnormalise_pred_array(mean_arr) - std_arr = unnormalise_pred_array(std_arr, add_offset=False) - if n_samples >= 1: - for sample_i in range(n_samples): - samples_arr[sample_i] = unnormalise_pred_array( - samples_arr[sample_i] + for param, arr in prediction_arrs.items(): + # TODO make class attributes? + scale_and_offset_params = ["mean"] + scale_only_params = ["std"] + scale_squared_only_params = ["variance"] + if unnormalise: + if param == "samples": + for sample_i in range(n_samples): + prediction_arrs["samples"][ + sample_i + ] = unnormalise_pred_array( + prediction_arrs["samples"][sample_i] + ) + elif param in scale_and_offset_params: + prediction_arrs[param] = unnormalise_pred_array(arr) + elif param in scale_only_params: + prediction_arrs[param] = unnormalise_pred_array( + arr, add_offset=False ) - - pred.assign("mean", task["time"], mean_arr) - pred.assign("std", task["time"], std_arr) - if n_samples >= 1: - pred.assign("samples", task["time"], samples_arr) + elif param in scale_squared_only_params: + # This is a horrible hack to repeat the scaling operation of the linear + # transform twice s.t. new_var = scale ^ 2 * var + prediction_arrs[param] = unnormalise_pred_array( + arr, add_offset=False + ) + prediction_arrs[param] = unnormalise_pred_array( + prediction_arrs[param], add_offset=False + ) + else: + # Assume prediction parameters not captured above are dimensionless + # quantities like probabilities and should not be unnormalised + pass + + # Assign predictions to Prediction object + for param, arr in prediction_arrs.items(): + if param != "mixture_probs": + pred.assign(param, task["time"], arr) + elif param == "mixture_probs": + assert arr.shape[0] == self.N_mixture_components, ( + f"Number of mixture components ({arr.shape[0]}) does not match " + f"model attribute N_mixture_components ({self.N_mixture_components})." + ) + for component_i, probs in enumerate(arr): + pred.assign(f"{param}_{component_i}", task["time"], probs) if verbose: dur = time.time() - tic diff --git a/deepsensor/model/nps.py b/deepsensor/model/nps.py index 44d5b5f4..a85b9987 100644 --- a/deepsensor/model/nps.py +++ b/deepsensor/model/nps.py @@ -237,6 +237,8 @@ def construct_neural_process( likelihood = "lowrank" elif likelihood == "cnp-spikes-beta": likelihood = "spikes-beta" + elif likelihood == "cnp-bernoulli-gamma": + likelihood = "bernoulli-gamma" # Log the call signature for `construct_convgnp` config = dict(locals()) diff --git a/deepsensor/model/pred.py b/deepsensor/model/pred.py index 53169cd1..00d49f81 100644 --- a/deepsensor/model/pred.py +++ b/deepsensor/model/pred.py @@ -37,6 +37,7 @@ class Prediction(dict): def __init__( self, target_var_IDs: List[str], + pred_params: List[str], dates: List[Union[str, pd.Timestamp]], X_t: Union[ xr.Dataset, @@ -59,10 +60,12 @@ def __init__( self.mode = infer_prediction_modality_from_X_t(X_t) - # TODO don't assume Gaussian distribution - self.pred_parameters = ["mean", "std"] + self.pred_params = pred_params if n_samples >= 1: - self.pred_parameters.extend([f"sample_{i}" for i in range(n_samples)]) + self.pred_params = [ + *pred_params, + *[f"sample_{i}" for i in range(n_samples)], + ] if self.mode == "on-grid": for var_ID in self.target_var_IDs: @@ -70,7 +73,7 @@ def __init__( self[var_ID] = create_empty_spatiotemporal_xarray( X_t, dates, - data_vars=self.pred_parameters, + data_vars=self.pred_params, coord_names=coord_names, ) if self.X_t_mask is None: @@ -86,7 +89,7 @@ def __init__( idxs = [(date, *idxs) for date in dates for idxs in X_t.index] index = pd.MultiIndex.from_tuples(idxs, names=["time", *X_t.index.names]) for var_ID in self.target_var_IDs: - self[var_ID] = pd.DataFrame(index=index, columns=self.pred_parameters) + self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params) def __getitem__(self, key): # Support self[i] syntax @@ -95,7 +98,7 @@ def __getitem__(self, key): return super().__getitem__(key) def __str__(self): - dict_repr = {var_ID: self.pred_parameters for var_ID in self.target_var_IDs} + dict_repr = {var_ID: self.pred_params for var_ID in self.target_var_IDs} return f"Prediction({dict_repr}), mode={self.mode}" def assign( diff --git a/deepsensor/plot.py b/deepsensor/plot.py index 4541e698..53e1eb77 100644 --- a/deepsensor/plot.py +++ b/deepsensor/plot.py @@ -309,7 +309,7 @@ def offgrid_context( context coordinates if provided. Args: - axes (:class:`numpy:numpy.ndarray` | List[:class:`matplotlib:matplotlib.axes.Axes`] | Tuple[:class:`matplotlib:matplotlib.axes.Axes`]: + axes (:class:`numpy:numpy.ndarray` | List[:class:`matplotlib:matplotlib.axes.Axes`] | Tuple[:class:`matplotlib:matplotlib.axes.Axes`]): Axes to plot on. task (:class:`~.data.task.Task`): Task containing the context set to plot. diff --git a/docs/_config.yml b/docs/_config.yml index e9e2005e..36ddd3b0 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -69,6 +69,7 @@ sphinx: autodoc_typehints: "none" autoclass_content: "class" bibtex_reference_style: author_year + napoleon_use_rtype: False intersphinx_mapping: python: - https://docs.python.org/3 diff --git a/tests/test_model.py b/tests/test_model.py index e6465276..822b5cae 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -69,8 +69,7 @@ def _gen_task_loader_call_args(self, n_context, n_target): ]: yield [sampling_method] * n_context, [sampling_method] * n_target - # TEMP only 1D because non-overlapping target sets are not yet supported - @parameterized.expand(range(1, 2)) + @parameterized.expand(range(1, 4)) def test_model_call(self, n_context_and_target): """Check ``ConvNP`` runs with all possible combinations of context/target sampling methods.""" @@ -120,68 +119,116 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): context_sampling = 10 - model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False) + likelihoods = ["cnp", "gnp", "cnp-spikes-beta"] - for target_sampling, expected_obs_shape in ( - (10, (10,)), # expected shape is (10,) when target_sampling is 10 - ( - "all", - self.da.shape[-2:], - ), # expected shape is da.shape[-2:] when target_sampling is "all" - ): - task = tl("2020-01-01", context_sampling, target_sampling) - - n_targets = np.product(expected_obs_shape) - - # Tensors - mean = model.mean(task) - # TODO avoid repeated code - if isinstance(mean, (list, tuple)): - for m, dim_y in zip(mean, tl.target_dims): - assert_shape(m, (dim_y, *expected_obs_shape)) - else: - assert_shape(mean, (n_target_sets, *expected_obs_shape)) - - variance = model.variance(task) - if isinstance(variance, (list, tuple)): - for v, dim_y in zip(variance, tl.target_dims): - assert_shape(v, (dim_y, *expected_obs_shape)) - else: - assert_shape(variance, (n_target_sets, *expected_obs_shape)) - - stddev = model.stddev(task) - if isinstance(stddev, (list, tuple)): - for s, dim_y in zip(stddev, tl.target_dims): - assert_shape(s, (dim_y, *expected_obs_shape)) - else: - assert_shape(stddev, (n_target_sets, *expected_obs_shape)) - - n_samples = 5 - samples = model.sample(task, n_samples) - if isinstance(samples, (list, tuple)): - for s, dim_y in zip(samples, tl.target_dims): - assert_shape(s, (n_samples, dim_y, *expected_obs_shape)) - else: - assert_shape(samples, (n_samples, n_target_sets, *expected_obs_shape)) - - n_target_dims = np.product(tl.target_dims) - assert_shape( - model.covariance(task), - ( - n_targets * n_target_sets * n_target_dims, - n_targets * n_target_sets * n_target_dims, - ), + for likelihood in likelihoods: + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + likelihood=likelihood, + verbose=False, ) - # Scalars - x = model.logpdf(task) - assert x.size == 1 and x.shape == () - x = model.joint_entropy(task) - assert x.size == 1 and x.shape == () - x = model.mean_marginal_entropy(task) - assert x.size == 1 and x.shape == () - x = B.to_numpy(model.loss_fn(task)) - assert x.size == 1 and x.shape == () + for target_sampling, expected_obs_shape in ( + # expected shape is (10,) when target_sampling is 10 + (10, (10,)), + # expected shape is da.shape[-2:] when target_sampling is "all" + ("all", self.da.shape[-2:]), + ): + task = tl("2020-01-01", context_sampling, target_sampling) + + n_targets = np.product(expected_obs_shape) + + # Tensors + mean = model.mean(task) + # TODO avoid repeated code by looping over methods + if isinstance(mean, (list, tuple)): + for m, dim_y in zip(mean, tl.target_dims): + assert_shape(m, (dim_y, *expected_obs_shape)) + else: + assert_shape(mean, (n_target_sets, *expected_obs_shape)) + + variance = model.variance(task) + if isinstance(variance, (list, tuple)): + for v, dim_y in zip(variance, tl.target_dims): + assert_shape(v, (dim_y, *expected_obs_shape)) + else: + assert_shape(variance, (n_target_sets, *expected_obs_shape)) + + stddev = model.stddev(task) + if isinstance(stddev, (list, tuple)): + for s, dim_y in zip(stddev, tl.target_dims): + assert_shape(s, (dim_y, *expected_obs_shape)) + else: + assert_shape(stddev, (n_target_sets, *expected_obs_shape)) + + n_samples = 5 + samples = model.sample(task, n_samples) + if isinstance(samples, (list, tuple)): + for s, dim_y in zip(samples, tl.target_dims): + assert_shape(s, (n_samples, dim_y, *expected_obs_shape)) + else: + assert_shape( + samples, (n_samples, n_target_sets, *expected_obs_shape) + ) + + if likelihood in ["cnp", "gnp"]: + n_target_dims = np.product(tl.target_dims) + assert_shape( + model.covariance(task), + ( + n_targets * n_target_sets * n_target_dims, + n_targets * n_target_sets * n_target_dims, + ), + ) + if likelihood in ["cnp-spikes-beta"]: + mixture_probs = model.mixture_probs(task) + if isinstance(mixture_probs, (list, tuple)): + for p, dim_y in zip(mixture_probs, tl.target_dims): + assert_shape( + p, + ( + model.N_mixture_components, + dim_y, + *expected_obs_shape, + ), + ) + else: + assert_shape( + mixture_probs, + ( + model.N_mixture_components, + n_target_sets, + *expected_obs_shape, + ), + ) + + x = model.alpha(task) + if isinstance(x, (list, tuple)): + for p, dim_y in zip(x, tl.target_dims): + assert_shape(p, (dim_y, *expected_obs_shape)) + else: + assert_shape(x, (n_target_sets, *expected_obs_shape)) + + x = model.beta(task) + if isinstance(x, (list, tuple)): + for p, dim_y in zip(x, tl.target_dims): + assert_shape(p, (dim_y, *expected_obs_shape)) + else: + assert_shape(x, (n_target_sets, *expected_obs_shape)) + + # Scalars + if likelihood in ["cnp", "gnp"]: + # Methods for Gaussian likelihoods only + x = model.joint_entropy(task) + assert x.size == 1 and x.shape == () + x = model.mean_marginal_entropy(task) + assert x.size == 1 and x.shape == () + x = model.logpdf(task) + assert x.size == 1 and x.shape == () + x = B.to_numpy(model.loss_fn(task)) + assert x.size == 1 and x.shape == () @parameterized.expand(range(1, 4)) def test_prediction_shapes_highlevel(self, target_dim): @@ -372,9 +419,7 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self): dp = DataProcessor( x1_name="latitude", - x1_map=lat_lims, x2_name="longitude", - x2_map=lon_lims, ) df = dp(df_raw) @@ -392,6 +437,81 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self): df_raw.reset_index()["longitude"], ) + def test_highlevel_predict_with_pred_params_pandas(self): + """ + Test that passing ``pred_params`` to ``.predict`` works with + a spikes-beta likelihood for prediction to pandas. + """ + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + verbose=False, + likelihood="cnp-spikes-beta", + ) + task = tl("2020-01-01", context_sampling=10, target_sampling=10) + + # Off-grid prediction + X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]) + + # Check that nothing breaks and the correct parameters are returned + pred_params = ["mean", "std", "variance", "alpha", "beta"] + pred = model.predict(task, X_t=X_t, pred_params=pred_params) + for pred_param in pred_params: + assert pred_param in pred["var"] + + # Test mixture probs special case + pred_params = ["mixture_probs"] + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for component in range(model.N_mixture_components): + pred_param = f"mixture_probs_{component}" + assert pred_param in pred["var"] + + def test_highlevel_predict_with_pred_params_xarray(self): + """ + Test that passing ``pred_params`` to ``.predict`` works with + a spikes-beta likelihood for prediction to xarray. + """ + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + verbose=False, + likelihood="cnp-spikes-beta", + ) + task = tl("2020-01-01", context_sampling=10, target_sampling=10) + + # Check that nothing breaks and the correct parameters are returned + pred_params = ["mean", "std", "variance", "alpha", "beta"] + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for pred_param in pred_params: + assert pred_param in pred["var"] + + # Test mixture probs special case + pred_params = ["mixture_probs"] + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for component in range(model.N_mixture_components): + pred_param = f"mixture_probs_{component}" + assert pred_param in pred["var"] + + def test_highlevel_predict_with_invalid_pred_params(self): + """Test that passing ``pred_params`` to ``.predict`` works.""" + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + verbose=False, + likelihood="cnp-spikes-beta", + ) + task = tl("2020-01-01", context_sampling=10, target_sampling=10) + + # Check that passing an invalid parameter raises an AttributeError + with self.assertRaises(AttributeError): + model.predict(task, X_t=self.da, pred_params=["invalid_param"]) + def test_saving_and_loading(self): """Test saving and loading of model""" with tempfile.TemporaryDirectory() as folder: