Skip to content

Commit

Permalink
Support non-Gaussian ConvNP likelihoods in low-level and high-level…
Browse files Browse the repository at this point in the history
… prediction interfaces (#97)

This new `deepsensor` feature enables non-Gaussian features in high-level xarray/pandas predictions:

* Rename `model.stddev` to `model.std` (but keep `.stddev` for back-comp)

* Remove `noiseless` `ConvNP.sample` kwarg, only noiseless for GNP

* Add `N_mixture_component` model class arg

* Add `"spikes-beta"` likelihood methods in `ConvNP`

* Anticipate `neuralprocesses` Bernoulli-Gamma likelihood

* Add `pred_params` kwarg to `model.predict` dictating which model methods are called

* Unit tests for non-Gaussian likelihoods

* Fix tests

* Fix `model.logpdf` for `"spikes-beta"` likelihood

* Rename `model.beta_dist_alpha` to `model.alpha` for generality (and same for `.beta_dist_beta`)

* Fix missing bracket in API

* Fix return type

* Test non-Gaussian pred_params in high-level predict

* Add off-grid non-Gaussian unit test
  • Loading branch information
tom-andersson authored Dec 10, 2023
1 parent 16b8c04 commit 642f5ab
Show file tree
Hide file tree
Showing 7 changed files with 488 additions and 149 deletions.
194 changes: 164 additions & 30 deletions deepsensor/model/convnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(
kwargs["decoder_scale"] = decoder_scale

self.model, self.config = construct_neural_process(*args, **kwargs)
self._set_num_mixture_components()

@dispatch
def __init__(
Expand Down Expand Up @@ -254,6 +255,7 @@ def __init__(self, model_ID: str):
super().__init__()

self.load(model_ID)
self._set_num_mixture_components()

@dispatch
def __init__(
Expand All @@ -277,6 +279,18 @@ def __init__(
super().__init__(data_processor, task_loader)

self.load(model_ID)
self._set_num_mixture_components()

def _set_num_mixture_components(self):
"""
Set the number of mixture components for the model based on the likelihood.
"""
if self.config["likelihood"] in ["spikes-beta"]:
self.N_mixture_components = 3
elif self.config["likelihood"] in ["bernoulli-gamma"]:
self.N_mixture_components = 2
else:
self.N_mixture_components = 1

def save(self, model_ID: str):
"""
Expand Down Expand Up @@ -450,7 +464,7 @@ def variance(self, task: Task):
return self.variance(dist)

@dispatch
def stddev(self, dist: AbstractMultiOutputDistribution):
def std(self, dist: AbstractMultiOutputDistribution):
"""
...
Expand All @@ -468,7 +482,7 @@ def stddev(self, dist: AbstractMultiOutputDistribution):
return np.sqrt(variance)

@dispatch
def stddev(self, task: Task):
def std(self, task: Task):
"""
...
Expand All @@ -480,7 +494,119 @@ def stddev(self, task: Task):
...: ...
"""
dist = self(task)
return self.stddev(dist)
return self.std(dist)

@dispatch
def alpha(
self, dist: AbstractMultiOutputDistribution
) -> Union[np.ndarray, List[np.ndarray]]:
if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
raise NotImplementedError(
f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
)
alpha = dist.slab.alpha
if isinstance(alpha, backend.nps.Aggregate):
return [B.to_numpy(m)[0, 0] for m in alpha]
else:
return B.to_numpy(alpha)[0, 0]

@dispatch
def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
"""
Alpha values of model's distribution at target locations in task.
.. note::
This method only works for models that return a distribution with
a ``dist.slab.alpha`` attribute, e.g. models with a Beta or
Bernoulli-Gamma likelihood.
Args:
task (:class:`~.data.task.Task`):
The task containing the context and target data.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Alpha values.
"""
dist = self(task)
return self.alpha(dist)

@dispatch
def beta(self, dist: AbstractMultiOutputDistribution):
if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
raise NotImplementedError(
f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
)
beta = dist.slab.beta
if isinstance(beta, backend.nps.Aggregate):
return [B.to_numpy(m)[0, 0] for m in beta]
else:
return B.to_numpy(beta)[0, 0]

@dispatch
def beta(self, task: Task):
"""
Beta values of model's distribution at target locations in task.
.. note::
This method only works for models that return a distribution with
a ``dist.slab.beta`` attribute, e.g. models with a Beta or
Bernoulli-Gamma likelihood.
Args:
task (:class:`~.data.task.Task`):
The task containing the context and target data.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Beta values.
"""
dist = self(task)
return self.beta(dist)

@dispatch
def mixture_probs(self, dist: AbstractMultiOutputDistribution):
"""
Probabilities of the components of a mixture distribution.
Shape (N_components, N_features, N_targets)
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
...
Returns:
...: ...
"""
if self.N_mixture_components == 1:
raise NotImplementedError(
f"mixture_probs not supported if model attribute N_mixture_components == 1. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta'."
)
mixture_probs = dist.logprobs
if isinstance(mixture_probs, backend.nps.Aggregate):
return [
np.moveaxis(np.exp(B.to_numpy(m)[0, 0]), -1, 0) for m in mixture_probs
]
else:
return np.moveaxis(np.exp(B.to_numpy(mixture_probs)[0, 0]), -1, 0)

@dispatch
def mixture_probs(self, task: Task):
"""
...
Args:
task (:class:`~.data.task.Task`):
...
Returns:
...: ...
"""
dist = self(task)
return self.mixture_probs(dist)

@dispatch
def covariance(self, dist: AbstractMultiOutputDistribution):
Expand Down Expand Up @@ -516,7 +642,6 @@ def sample(
self,
dist: AbstractMultiOutputDistribution,
n_samples: int = 1,
noiseless: bool = True,
):
"""
Create samples from a ConvNP distribution.
Expand All @@ -527,15 +652,12 @@ def sample(
n_samples (int, optional):
The number of samples to draw from the distribution, by
default 1.
noiseless (bool, optional):
Whether to sample from the noiseless distribution, by default
True.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
The samples as an array or list of arrays.
"""
if noiseless:
if self.config["likelihood"] in ["gnp", "lowrank"]:
samples = dist.noiseless.sample(n_samples)
else:
samples = dist.sample(n_samples)
Expand All @@ -546,7 +668,7 @@ def sample(
return B.to_numpy(samples)[:, 0, 0]

@dispatch
def sample(self, task: Task, n_samples: int = 1, noiseless: bool = True):
def sample(self, task: Task, n_samples: int = 1):
"""
Create samples from a ConvNP distribution.
Expand All @@ -556,16 +678,13 @@ def sample(self, task: Task, n_samples: int = 1, noiseless: bool = True):
n_samples (int, optional):
The number of samples to draw from the distribution, by
default 1.
noiseless (bool, optional):
Whether to sample from the noiseless distribution, by default
True.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
The samples as an array or list of arrays.
"""
dist = self(task)
return self.sample(dist, n_samples, noiseless)
return self.sample(dist, n_samples)

@dispatch
def slice_diag(self, task: Task):
Expand All @@ -580,12 +699,15 @@ def slice_diag(self, task: Task):
...: ...
"""
dist = self(task)
dist_diag = backend.nps.MultiOutputNormal(
dist._mean,
B.zeros(dist._var),
Diagonal(B.diag(dist._noise + dist._var)),
dist.shape,
)
if self.config["likelihood"] in ["spikes-beta"]:
dist_diag = dist
else:
dist_diag = backend.nps.MultiOutputNormal(
dist._mean,
B.zeros(dist._var),
Diagonal(B.diag(dist._noise + dist._var)),
dist.shape,
)
return dist_diag

@dispatch
Expand All @@ -600,12 +722,15 @@ def slice_diag(self, dist: AbstractMultiOutputDistribution):
Returns:
...: ...
"""
dist_diag = backend.nps.MultiOutputNormal(
dist._mean,
B.zeros(dist._var),
Diagonal(B.diag(dist._noise + dist._var)),
dist.shape,
)
if self.config["likelihood"]:
dist_diag = dist
else:
dist_diag = backend.nps.MultiOutputNormal(
dist._mean,
B.zeros(dist._var),
Diagonal(B.diag(dist._noise + dist._var)),
dist.shape,
)
return dist_diag

@dispatch
Expand Down Expand Up @@ -669,8 +794,11 @@ def joint_entropy(self, task: Task):
@dispatch
def logpdf(self, dist: AbstractMultiOutputDistribution, task: Task):
"""
Model outputs joint distribution over all targets: Concat targets along
observation dimension.
Joint logpdf over all target sets.
.. note::
If the model has multiple target sets, the returned logpdf is the
mean logpdf over all target sets.
Args:
dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
Expand All @@ -681,14 +809,20 @@ def logpdf(self, dist: AbstractMultiOutputDistribution, task: Task):
Returns:
float: The logpdf.
"""
Y_t = B.concat(*task["Y_t"], axis=-1)
# Need to ensure `Y_t` is a tensor and, if multiple target sets,
# an nps.Aggregate object
task = ConvNP.modify_task(task)
_, _, Y_t, _ = convert_task_to_nps_args(task)
return B.to_numpy(dist.logpdf(Y_t)).mean()

@dispatch
def logpdf(self, task: Task):
"""
Model outputs joint distribution over all targets: Concat targets along
observation dimension.
Joint logpdf over all target sets.
.. note::
If the model has multiple target sets, the returned logpdf is the
mean logpdf over all target sets.
Args:
task (:class:`~.data.task.Task`):
Expand Down
Loading

0 comments on commit 642f5ab

Please sign in to comment.