diff --git a/deepsensor/model/convnp.py b/deepsensor/model/convnp.py index 1b3649d1..78c7f178 100644 --- a/deepsensor/model/convnp.py +++ b/deepsensor/model/convnp.py @@ -47,11 +47,22 @@ class ConvNP(DeepSensorModel): from the ``TaskLoader``. The ``ConvNP`` can optionally be instantiated with: + - a ``DataProcessor`` object to auto-unnormalise the data at inference time with the ``.predict`` method. + - a ``TaskLoader`` object to infer sensible default model parameters from the data. - - a ``DataProcessor`` object to auto-unnormalise the data at inference - time with the ``.predict`` method. - - a ``TaskLoader`` object to infer sensible default model parameters - from the data. + Many of the ``ConvNP`` class methods utilise multiple dispatch so that they + can either be run with a ``Task`` object or a ``neuralprocesses`` distribution + object. This allows for re-using the model's forward prediction object. + + Dimension shapes are expressed in method docstrings in terms of: + - ``N_features``: number of features/dimensions in the target set. + - ``N_targets``: number of target points (1D for off-grid targets, 2D for gridded targets). + - ``N_components``: number of mixture components in the likelihood (for mixture likelihoods only). + - ``N_samples``: number of samples drawn from the distribution. + + If the model has multiple target sets and the ``Task`` object + has different target locations for each set, a list of arrays is returned + for each target set. Otherwise, a single array is returned. Examples: Instantiate a ``ConvNP`` with all hyperparameters set to their default values: @@ -397,84 +408,102 @@ def __call__(self, task, n_samples=10, requires_grad=False): dist = run_nps_model(self.model, task, n_samples, requires_grad) return dist - @dispatch - def mean(self, dist: AbstractMultiOutputDistribution): + def _cast_numpy_and_squeeze( + self, + x: Union[B.Numeric, List[B.Numeric]], + squeeze_axes: List[int] = (0, 1), + ): + """TODO docstring""" + if isinstance(x, backend.nps.Aggregate): + return [np.squeeze(B.to_numpy(xi), axis=squeeze_axes) for xi in x] + else: + return np.squeeze(B.to_numpy(x), axis=squeeze_axes) + + def _maybe_concat_multi_targets( + self, + x: Union[np.ndarray, List[np.ndarray]], + concat_axis: int = 0, + ) -> Union[np.ndarray, List[np.ndarray]]: """ - ... + Concatenate multiple target sets into a single tensor along feature dimension + and remove size-1 dimensions. Args: - dist (neuralprocesses.dist.AbstractMultiOutputDistribution): - ... + x (:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]): + List of target sets. + squeeze_axes (List[int], optional): + Axes to squeeze out of the concatenated target sets. Defaults to (0, 1). + concat_axis (int, optional): + Axis to concatenate along (*after* squeezing arrays) when + merging multiple target sets. Defaults to 0. Returns: - ...: ... - """ - mean = dist.mean - if isinstance(mean, backend.nps.Aggregate): - return [B.to_numpy(m)[0, 0] for m in mean] + (:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]): + Concatenated target sets. + """ + if isinstance(x, (list, tuple)): + new_list = [] + pos = 0 + for dim in self.task_loader.target_dims: + new_list.append(x[pos : pos + dim]) + pos += dim + return [ + B.concat(*[xi for xi in sub_list], axis=concat_axis) + for sub_list in new_list + ] else: - return B.to_numpy(mean)[0, 0] + return x + + @dispatch + def mean(self, dist: AbstractMultiOutputDistribution): + mean = dist.mean + mean = self._cast_numpy_and_squeeze(mean) + return self._maybe_concat_multi_targets(mean) @dispatch def mean(self, task: Task): """ - ... + Mean values of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_features, *N_targets)``. Args: task (:class:`~.data.task.Task`): - ... + The task containing the context and target data. Returns: - ...: ... + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Mean values. """ dist = self(task) return self.mean(dist) @dispatch def variance(self, dist: AbstractMultiOutputDistribution): - """ - ... - - Args: - dist (neuralprocesses.dist.AbstractMultiOutputDistribution): - ... - - Returns: - ...: ... - """ variance = dist.var - if isinstance(variance, backend.nps.Aggregate): - return [B.to_numpy(v)[0, 0] for v in variance] - else: - return B.to_numpy(variance)[0, 0] + variance = self._cast_numpy_and_squeeze(variance) + return self._maybe_concat_multi_targets(variance) @dispatch def variance(self, task: Task): """ - ... + Variance values of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_features, *N_targets)``. Args: task (:class:`~.data.task.Task`): - ... + The task containing the context and target data. Returns: - ...: ... + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Variance values. """ dist = self(task) return self.variance(dist) @dispatch def std(self, dist: AbstractMultiOutputDistribution): - """ - ... - - Args: - dist (neuralprocesses.dist.AbstractMultiOutputDistribution): - ... - - Returns: - ...: ... - """ variance = self.variance(dist) if isinstance(variance, (list, tuple)): return [np.sqrt(v) for v in variance] @@ -484,14 +513,17 @@ def std(self, dist: AbstractMultiOutputDistribution): @dispatch def std(self, task: Task): """ - ... + Standard deviation values of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_features, *N_targets)``. Args: task (:class:`~.data.task.Task`): - ... + The task containing the context and target data. Returns: - ...: ... + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Standard deviation values. """ dist = self(task) return self.std(dist) @@ -506,20 +538,21 @@ def alpha( f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'." ) alpha = dist.slab.alpha - if isinstance(alpha, backend.nps.Aggregate): - return [B.to_numpy(m)[0, 0] for m in alpha] - else: - return B.to_numpy(alpha)[0, 0] + alpha = self._cast_numpy_and_squeeze(alpha) + return self._maybe_concat_multi_targets(alpha) @dispatch def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: """ - Alpha values of model's distribution at target locations in task. + Alpha parameter values of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_features, *N_targets)``. .. note:: This method only works for models that return a distribution with a ``dist.slab.alpha`` attribute, e.g. models with a Beta or - Bernoulli-Gamma likelihood. + Bernoulli-Gamma likelihood, where it returns the alpha values of + the slab component of the mixture model. Args: task (:class:`~.data.task.Task`): @@ -533,23 +566,25 @@ def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: return self.alpha(dist) @dispatch - def beta(self, dist: AbstractMultiOutputDistribution): + def beta( + self, dist: AbstractMultiOutputDistribution + ) -> Union[np.ndarray, List[np.ndarray]]: if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]: raise NotImplementedError( f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. " f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'." ) beta = dist.slab.beta - if isinstance(beta, backend.nps.Aggregate): - return [B.to_numpy(m)[0, 0] for m in beta] - else: - return B.to_numpy(beta)[0, 0] + beta = self._cast_numpy_and_squeeze(beta) + return self._maybe_concat_multi_targets(beta) @dispatch - def beta(self, task: Task): + def beta(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: """ Beta values of model's distribution at target locations in task. + Returned numpy arrays have shape ``(N_features, *N_targets)``. + .. note:: This method only works for models that return a distribution with a ``dist.slab.beta`` attribute, e.g. models with a Beta or @@ -568,58 +603,39 @@ def beta(self, task: Task): @dispatch def mixture_probs(self, dist: AbstractMultiOutputDistribution): - """ - Probabilities of the components of a mixture distribution. - - Shape (N_components, N_features, N_targets) - - Args: - dist (neuralprocesses.dist.AbstractMultiOutputDistribution): - ... - - Returns: - ...: ... - """ if self.N_mixture_components == 1: raise NotImplementedError( f"mixture_probs not supported if model attribute N_mixture_components == 1. " f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta'." ) mixture_probs = dist.logprobs - if isinstance(mixture_probs, backend.nps.Aggregate): - return [ - np.moveaxis(np.exp(B.to_numpy(m)[0, 0]), -1, 0) for m in mixture_probs - ] + mixture_probs = self._cast_numpy_and_squeeze(mixture_probs) + mixture_probs = self._maybe_concat_multi_targets(mixture_probs) + if isinstance(mixture_probs, (list, tuple)): + return [np.moveaxis(np.exp(m), -1, 0) for m in mixture_probs] else: - return np.moveaxis(np.exp(B.to_numpy(mixture_probs)[0, 0]), -1, 0) + return np.moveaxis(np.exp(mixture_probs), -1, 0) @dispatch def mixture_probs(self, task: Task): """ - ... + Mixture probabilities of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_components, N_features, *N_targets)``. Args: task (:class:`~.data.task.Task`): - ... + The task containing the context and target data. Returns: - ...: ... + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Mixture probabilities. """ dist = self(task) return self.mixture_probs(dist) @dispatch def covariance(self, dist: AbstractMultiOutputDistribution): - """ - ... - - Args: - dist (neuralprocesses.dist.AbstractMultiOutputDistribution): - ... - - Returns: - ...: ... - """ return B.to_numpy(B.dense(dist.vectorised_normal.var))[0, 0] @dispatch @@ -643,35 +659,21 @@ def sample( dist: AbstractMultiOutputDistribution, n_samples: int = 1, ): - """ - Create samples from a ConvNP distribution. - - Args: - dist (neuralprocesses.dist.AbstractMultiOutputDistribution): - The distribution to sample from. - n_samples (int, optional): - The number of samples to draw from the distribution, by - default 1. - - Returns: - :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: - The samples as an array or list of arrays. - """ if self.config["likelihood"] in ["gnp", "lowrank"]: samples = dist.noiseless.sample(n_samples) else: samples = dist.sample(n_samples) - - if isinstance(samples, backend.nps.Aggregate): - return [B.to_numpy(s)[:, 0, 0] for s in samples] - else: - return B.to_numpy(samples)[:, 0, 0] + # Be careful to keep sample dimension in position 0 + samples = self._cast_numpy_and_squeeze(samples, squeeze_axes=(1, 2)) + return self._maybe_concat_multi_targets(samples, concat_axis=1) @dispatch def sample(self, task: Task, n_samples: int = 1): """ Create samples from a ConvNP distribution. + Returned numpy arrays have shape ``(N_samples, N_features, *N_targets)``, + Args: dist (neuralprocesses.dist.AbstractMultiOutputDistribution): The distribution to sample from. @@ -894,6 +896,8 @@ def ar_sample( subset of the target set and then infill the rest of the sample with the model mean or joint sample conditioned on the AR samples. + Returned numpy arrays have shape ``(N_samples, N_features, *N_targets)``, + .. note:: AR sampling only works for 0th context/target set, and only for models with a single target set. diff --git a/deepsensor/model/nps.py b/deepsensor/model/nps.py index a85b9987..5af4e5a7 100644 --- a/deepsensor/model/nps.py +++ b/deepsensor/model/nps.py @@ -40,10 +40,21 @@ def convert_task_to_nps_args(task: Task): yt = task["Y_t"][0] elif len(task["X_t"]) > 1 and len(task["Y_t"]) > 1: # Multiple target sets, different target locations - xt = backend.nps.AggregateInput(*[(xt, i) for i, xt in enumerate(task["X_t"])]) - yt = backend.nps.Aggregate(*task["Y_t"]) + assert len(task["X_t"]) == len(task["Y_t"]) + xts = [] + yts = [] + target_dims = [yt.shape[1] for yt in task["Y_t"]] + # Map from ND target sets to 1D target sets + dim_counter = 0 + for i, (xt, yt) in enumerate(zip(task["X_t"], task["Y_t"])): + # Repeat target locations for each target dimension in target set + xts.extend([(xt, dim_counter + j) for j in range(target_dims[i])]) + yts.extend([yt[:, j : j + 1] for j in range(target_dims[i])]) + dim_counter += target_dims[i] + xt = backend.nps.AggregateInput(*xts) + yt = backend.nps.Aggregate(*yts) elif len(task["X_t"]) == 1 and len(task["Y_t"]) > 1: - # Multiple target sets, same target locations + # Multiple target sets, same target locations; `Y_t`s along feature dim xt = task["X_t"][0] yt = B.concat(*task["Y_t"], axis=1) else: diff --git a/tests/test_model.py b/tests/test_model.py index 822b5cae..c3d675bc 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -110,17 +110,26 @@ def set_list_to_data(set_list): @parameterized.expand(range(1, 4)) def test_prediction_shapes_lowlevel(self, n_target_sets): - """Test low-level model prediction interface over a range of number of - target sets.""" + """ + Test low-level model prediction interface over a range of number of + target sets. + """ + # Make dataset 5D for non-trivial target dimensions + ndim = 5 + ds = xr.Dataset( + {f"var{i}": self.da for i in range(ndim)}, + ) tl = TaskLoader( - context=self.da, - target=[self.da] * n_target_sets, + context=ds, + target=[ds] * n_target_sets, ) context_sampling = 10 likelihoods = ["cnp", "gnp", "cnp-spikes-beta"] + dim_y_combined = sum(tl.target_dims) + for likelihood in likelihoods: model = ConvNP( self.dp, @@ -129,11 +138,12 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): likelihood=likelihood, verbose=False, ) + assert dim_y_combined == model.config["dim_yt"] for target_sampling, expected_obs_shape in ( - # expected shape is (10,) when target_sampling is 10 + # expected obs shape is (10,) when target_sampling is 10 (10, (10,)), - # expected shape is da.shape[-2:] when target_sampling is "all" + # expected obs shape is da.shape[-2:] when target_sampling is "all" ("all", self.da.shape[-2:]), ): task = tl("2020-01-01", context_sampling, target_sampling) @@ -147,21 +157,21 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): for m, dim_y in zip(mean, tl.target_dims): assert_shape(m, (dim_y, *expected_obs_shape)) else: - assert_shape(mean, (n_target_sets, *expected_obs_shape)) + assert_shape(mean, (dim_y_combined, *expected_obs_shape)) variance = model.variance(task) if isinstance(variance, (list, tuple)): for v, dim_y in zip(variance, tl.target_dims): assert_shape(v, (dim_y, *expected_obs_shape)) else: - assert_shape(variance, (n_target_sets, *expected_obs_shape)) + assert_shape(variance, (dim_y_combined, *expected_obs_shape)) stddev = model.stddev(task) if isinstance(stddev, (list, tuple)): for s, dim_y in zip(stddev, tl.target_dims): assert_shape(s, (dim_y, *expected_obs_shape)) else: - assert_shape(stddev, (n_target_sets, *expected_obs_shape)) + assert_shape(stddev, (dim_y_combined, *expected_obs_shape)) n_samples = 5 samples = model.sample(task, n_samples) @@ -170,7 +180,7 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): assert_shape(s, (n_samples, dim_y, *expected_obs_shape)) else: assert_shape( - samples, (n_samples, n_target_sets, *expected_obs_shape) + samples, (n_samples, dim_y_combined, *expected_obs_shape) ) if likelihood in ["cnp", "gnp"]: @@ -178,8 +188,8 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): assert_shape( model.covariance(task), ( - n_targets * n_target_sets * n_target_dims, - n_targets * n_target_sets * n_target_dims, + n_targets * dim_y_combined * n_target_dims, + n_targets * dim_y_combined * n_target_dims, ), ) if likelihood in ["cnp-spikes-beta"]: @@ -199,7 +209,7 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): mixture_probs, ( model.N_mixture_components, - n_target_sets, + dim_y_combined, *expected_obs_shape, ), ) @@ -209,14 +219,14 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): for p, dim_y in zip(x, tl.target_dims): assert_shape(p, (dim_y, *expected_obs_shape)) else: - assert_shape(x, (n_target_sets, *expected_obs_shape)) + assert_shape(x, (dim_y_combined, *expected_obs_shape)) x = model.beta(task) if isinstance(x, (list, tuple)): for p, dim_y in zip(x, tl.target_dims): assert_shape(p, (dim_y, *expected_obs_shape)) else: - assert_shape(x, (n_target_sets, *expected_obs_shape)) + assert_shape(x, (dim_y_combined, *expected_obs_shape)) # Scalars if likelihood in ["cnp", "gnp"]: