Skip to content

Commit

Permalink
Fix multiple N-D target sets case in ConvNP; Update ConvNP API
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Dec 10, 2023
1 parent 642f5ab commit 4241544
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 127 deletions.
222 changes: 113 additions & 109 deletions deepsensor/model/convnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,22 @@ class ConvNP(DeepSensorModel):
from the ``TaskLoader``.
The ``ConvNP`` can optionally be instantiated with:
- a ``DataProcessor`` object to auto-unnormalise the data at inference time with the ``.predict`` method.
- a ``TaskLoader`` object to infer sensible default model parameters from the data.
- a ``DataProcessor`` object to auto-unnormalise the data at inference
time with the ``.predict`` method.
- a ``TaskLoader`` object to infer sensible default model parameters
from the data.
Many of the ``ConvNP`` class methods utilise multiple dispatch so that they
can either be run with a ``Task`` object or a ``neuralprocesses`` distribution
object. This allows for re-using the model's forward prediction object.
Dimension shapes are expressed in method docstrings in terms of:
- ``N_features``: number of features/dimensions in the target set.
- ``N_targets``: number of target points (1D for off-grid targets, 2D for gridded targets).
- ``N_components``: number of mixture components in the likelihood (for mixture likelihoods only).
- ``N_samples``: number of samples drawn from the distribution.
If the model has multiple target sets and the ``Task`` object
has different target locations for each set, a list of arrays is returned
for each target set. Otherwise, a single array is returned.
Examples:
Instantiate a ``ConvNP`` with all hyperparameters set to their default values:
Expand Down Expand Up @@ -397,84 +408,102 @@ def __call__(self, task, n_samples=10, requires_grad=False):
dist = run_nps_model(self.model, task, n_samples, requires_grad)
return dist

@dispatch
def mean(self, dist: AbstractMultiOutputDistribution):
def _cast_numpy_and_squeeze(
self,
x: Union[B.Numeric, List[B.Numeric]],
squeeze_axes: List[int] = (0, 1),
):
"""TODO docstring"""
if isinstance(x, backend.nps.Aggregate):
return [np.squeeze(B.to_numpy(xi), axis=squeeze_axes) for xi in x]
else:
return np.squeeze(B.to_numpy(x), axis=squeeze_axes)

def _maybe_concat_multi_targets(
self,
x: Union[np.ndarray, List[np.ndarray]],
concat_axis: int = 0,
) -> Union[np.ndarray, List[np.ndarray]]:
"""
...
Concatenate multiple target sets into a single tensor along feature dimension
and remove size-1 dimensions.
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
...
x (:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]):
List of target sets.
squeeze_axes (List[int], optional):
Axes to squeeze out of the concatenated target sets. Defaults to (0, 1).
concat_axis (int, optional):
Axis to concatenate along (*after* squeezing arrays) when
merging multiple target sets. Defaults to 0.
Returns:
...: ...
"""
mean = dist.mean
if isinstance(mean, backend.nps.Aggregate):
return [B.to_numpy(m)[0, 0] for m in mean]
(:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]):
Concatenated target sets.
"""
if isinstance(x, (list, tuple)):
new_list = []
pos = 0
for dim in self.task_loader.target_dims:
new_list.append(x[pos : pos + dim])
pos += dim
return [
B.concat(*[xi for xi in sub_list], axis=concat_axis)
for sub_list in new_list
]
else:
return B.to_numpy(mean)[0, 0]
return x

@dispatch
def mean(self, dist: AbstractMultiOutputDistribution):
mean = dist.mean
mean = self._cast_numpy_and_squeeze(mean)
return self._maybe_concat_multi_targets(mean)

@dispatch
def mean(self, task: Task):
"""
...
Mean values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
Args:
task (:class:`~.data.task.Task`):
...
The task containing the context and target data.
Returns:
...: ...
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Mean values.
"""
dist = self(task)
return self.mean(dist)

@dispatch
def variance(self, dist: AbstractMultiOutputDistribution):
"""
...
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
...
Returns:
...: ...
"""
variance = dist.var
if isinstance(variance, backend.nps.Aggregate):
return [B.to_numpy(v)[0, 0] for v in variance]
else:
return B.to_numpy(variance)[0, 0]
variance = self._cast_numpy_and_squeeze(variance)
return self._maybe_concat_multi_targets(variance)

@dispatch
def variance(self, task: Task):
"""
...
Variance values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
Args:
task (:class:`~.data.task.Task`):
...
The task containing the context and target data.
Returns:
...: ...
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Variance values.
"""
dist = self(task)
return self.variance(dist)

@dispatch
def std(self, dist: AbstractMultiOutputDistribution):
"""
...
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
...
Returns:
...: ...
"""
variance = self.variance(dist)
if isinstance(variance, (list, tuple)):
return [np.sqrt(v) for v in variance]
Expand All @@ -484,14 +513,17 @@ def std(self, dist: AbstractMultiOutputDistribution):
@dispatch
def std(self, task: Task):
"""
...
Standard deviation values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
Args:
task (:class:`~.data.task.Task`):
...
The task containing the context and target data.
Returns:
...: ...
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Standard deviation values.
"""
dist = self(task)
return self.std(dist)
Expand All @@ -506,20 +538,21 @@ def alpha(
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
)
alpha = dist.slab.alpha
if isinstance(alpha, backend.nps.Aggregate):
return [B.to_numpy(m)[0, 0] for m in alpha]
else:
return B.to_numpy(alpha)[0, 0]
alpha = self._cast_numpy_and_squeeze(alpha)
return self._maybe_concat_multi_targets(alpha)

@dispatch
def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
"""
Alpha values of model's distribution at target locations in task.
Alpha parameter values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
.. note::
This method only works for models that return a distribution with
a ``dist.slab.alpha`` attribute, e.g. models with a Beta or
Bernoulli-Gamma likelihood.
Bernoulli-Gamma likelihood, where it returns the alpha values of
the slab component of the mixture model.
Args:
task (:class:`~.data.task.Task`):
Expand All @@ -533,23 +566,25 @@ def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
return self.alpha(dist)

@dispatch
def beta(self, dist: AbstractMultiOutputDistribution):
def beta(
self, dist: AbstractMultiOutputDistribution
) -> Union[np.ndarray, List[np.ndarray]]:
if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
raise NotImplementedError(
f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
)
beta = dist.slab.beta
if isinstance(beta, backend.nps.Aggregate):
return [B.to_numpy(m)[0, 0] for m in beta]
else:
return B.to_numpy(beta)[0, 0]
beta = self._cast_numpy_and_squeeze(beta)
return self._maybe_concat_multi_targets(beta)

@dispatch
def beta(self, task: Task):
def beta(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
"""
Beta values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
.. note::
This method only works for models that return a distribution with
a ``dist.slab.beta`` attribute, e.g. models with a Beta or
Expand All @@ -568,58 +603,39 @@ def beta(self, task: Task):

@dispatch
def mixture_probs(self, dist: AbstractMultiOutputDistribution):
"""
Probabilities of the components of a mixture distribution.
Shape (N_components, N_features, N_targets)
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
...
Returns:
...: ...
"""
if self.N_mixture_components == 1:
raise NotImplementedError(
f"mixture_probs not supported if model attribute N_mixture_components == 1. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta'."
)
mixture_probs = dist.logprobs
if isinstance(mixture_probs, backend.nps.Aggregate):
return [
np.moveaxis(np.exp(B.to_numpy(m)[0, 0]), -1, 0) for m in mixture_probs
]
mixture_probs = self._cast_numpy_and_squeeze(mixture_probs)
mixture_probs = self._maybe_concat_multi_targets(mixture_probs)
if isinstance(mixture_probs, (list, tuple)):
return [np.moveaxis(np.exp(m), -1, 0) for m in mixture_probs]
else:
return np.moveaxis(np.exp(B.to_numpy(mixture_probs)[0, 0]), -1, 0)
return np.moveaxis(np.exp(mixture_probs), -1, 0)

@dispatch
def mixture_probs(self, task: Task):
"""
...
Mixture probabilities of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_components, N_features, *N_targets)``.
Args:
task (:class:`~.data.task.Task`):
...
The task containing the context and target data.
Returns:
...: ...
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Mixture probabilities.
"""
dist = self(task)
return self.mixture_probs(dist)

@dispatch
def covariance(self, dist: AbstractMultiOutputDistribution):
"""
...
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
...
Returns:
...: ...
"""
return B.to_numpy(B.dense(dist.vectorised_normal.var))[0, 0]

@dispatch
Expand All @@ -643,35 +659,21 @@ def sample(
dist: AbstractMultiOutputDistribution,
n_samples: int = 1,
):
"""
Create samples from a ConvNP distribution.
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
The distribution to sample from.
n_samples (int, optional):
The number of samples to draw from the distribution, by
default 1.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
The samples as an array or list of arrays.
"""
if self.config["likelihood"] in ["gnp", "lowrank"]:
samples = dist.noiseless.sample(n_samples)
else:
samples = dist.sample(n_samples)

if isinstance(samples, backend.nps.Aggregate):
return [B.to_numpy(s)[:, 0, 0] for s in samples]
else:
return B.to_numpy(samples)[:, 0, 0]
# Be careful to keep sample dimension in position 0
samples = self._cast_numpy_and_squeeze(samples, squeeze_axes=(1, 2))
return self._maybe_concat_multi_targets(samples, concat_axis=1)

@dispatch
def sample(self, task: Task, n_samples: int = 1):
"""
Create samples from a ConvNP distribution.
Returned numpy arrays have shape ``(N_samples, N_features, *N_targets)``,
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
The distribution to sample from.
Expand Down Expand Up @@ -894,6 +896,8 @@ def ar_sample(
subset of the target set and then infill the rest of the sample with
the model mean or joint sample conditioned on the AR samples.
Returned numpy arrays have shape ``(N_samples, N_features, *N_targets)``,
.. note::
AR sampling only works for 0th context/target set, and only for models with
a single target set.
Expand Down
17 changes: 14 additions & 3 deletions deepsensor/model/nps.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,21 @@ def convert_task_to_nps_args(task: Task):
yt = task["Y_t"][0]
elif len(task["X_t"]) > 1 and len(task["Y_t"]) > 1:
# Multiple target sets, different target locations
xt = backend.nps.AggregateInput(*[(xt, i) for i, xt in enumerate(task["X_t"])])
yt = backend.nps.Aggregate(*task["Y_t"])
assert len(task["X_t"]) == len(task["Y_t"])
xts = []
yts = []
target_dims = [yt.shape[1] for yt in task["Y_t"]]
# Map from ND target sets to 1D target sets
dim_counter = 0
for i, (xt, yt) in enumerate(zip(task["X_t"], task["Y_t"])):
# Repeat target locations for each target dimension in target set
xts.extend([(xt, dim_counter + j) for j in range(target_dims[i])])
yts.extend([yt[:, j : j + 1] for j in range(target_dims[i])])
dim_counter += target_dims[i]
xt = backend.nps.AggregateInput(*xts)
yt = backend.nps.Aggregate(*yts)
elif len(task["X_t"]) == 1 and len(task["Y_t"]) > 1:
# Multiple target sets, same target locations
# Multiple target sets, same target locations; `Y_t`s along feature dim
xt = task["X_t"][0]
yt = B.concat(*task["Y_t"], axis=1)
else:
Expand Down
Loading

0 comments on commit 4241544

Please sign in to comment.