Skip to content

Commit

Permalink
Unit tests for non-Gaussian likelihoods
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Dec 5, 2023
1 parent 9f6b921 commit ba55b55
Showing 1 changed file with 125 additions and 63 deletions.
188 changes: 125 additions & 63 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def _gen_task_loader_call_args(self, n_context, n_target):
]:
yield [sampling_method] * n_context, [sampling_method] * n_target

# TEMP only 1D because non-overlapping target sets are not yet supported
@parameterized.expand(range(1, 2))
@parameterized.expand(range(1, 4))
def test_model_call(self, n_context_and_target):
"""Check ``ConvNP`` runs with all possible combinations of context/target
sampling methods."""
Expand Down Expand Up @@ -120,68 +119,117 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):

context_sampling = 10

model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
likelihoods = ["cnp", "gnp", "cnp-spikes-beta"]

for target_sampling, expected_obs_shape in (
(10, (10,)), # expected shape is (10,) when target_sampling is 10
(
"all",
self.da.shape[-2:],
), # expected shape is da.shape[-2:] when target_sampling is "all"
):
task = tl("2020-01-01", context_sampling, target_sampling)

n_targets = np.product(expected_obs_shape)

# Tensors
mean = model.mean(task)
# TODO avoid repeated code
if isinstance(mean, (list, tuple)):
for m, dim_y in zip(mean, tl.target_dims):
assert_shape(m, (dim_y, *expected_obs_shape))
else:
assert_shape(mean, (n_target_sets, *expected_obs_shape))

variance = model.variance(task)
if isinstance(variance, (list, tuple)):
for v, dim_y in zip(variance, tl.target_dims):
assert_shape(v, (dim_y, *expected_obs_shape))
else:
assert_shape(variance, (n_target_sets, *expected_obs_shape))

stddev = model.stddev(task)
if isinstance(stddev, (list, tuple)):
for s, dim_y in zip(stddev, tl.target_dims):
assert_shape(s, (dim_y, *expected_obs_shape))
else:
assert_shape(stddev, (n_target_sets, *expected_obs_shape))

n_samples = 5
samples = model.sample(task, n_samples)
if isinstance(samples, (list, tuple)):
for s, dim_y in zip(samples, tl.target_dims):
assert_shape(s, (n_samples, dim_y, *expected_obs_shape))
else:
assert_shape(samples, (n_samples, n_target_sets, *expected_obs_shape))

n_target_dims = np.product(tl.target_dims)
assert_shape(
model.covariance(task),
(
n_targets * n_target_sets * n_target_dims,
n_targets * n_target_sets * n_target_dims,
),
for likelihood in likelihoods:
model = ConvNP(
self.dp,
tl,
unet_channels=(5, 5, 5),
likelihood=likelihood,
verbose=False,
)

# Scalars
x = model.logpdf(task)
assert x.size == 1 and x.shape == ()
x = model.joint_entropy(task)
assert x.size == 1 and x.shape == ()
x = model.mean_marginal_entropy(task)
assert x.size == 1 and x.shape == ()
x = B.to_numpy(model.loss_fn(task))
assert x.size == 1 and x.shape == ()
for target_sampling, expected_obs_shape in (
(10, (10,)), # expected shape is (10,) when target_sampling is 10
(
"all",
self.da.shape[-2:],
), # expected shape is da.shape[-2:] when target_sampling is "all"
):
task = tl("2020-01-01", context_sampling, target_sampling)

n_targets = np.product(expected_obs_shape)

# Tensors
mean = model.mean(task)
# TODO avoid repeated code
if isinstance(mean, (list, tuple)):
for m, dim_y in zip(mean, tl.target_dims):
assert_shape(m, (dim_y, *expected_obs_shape))
else:
assert_shape(mean, (n_target_sets, *expected_obs_shape))

variance = model.variance(task)
if isinstance(variance, (list, tuple)):
for v, dim_y in zip(variance, tl.target_dims):
assert_shape(v, (dim_y, *expected_obs_shape))
else:
assert_shape(variance, (n_target_sets, *expected_obs_shape))

stddev = model.stddev(task)
if isinstance(stddev, (list, tuple)):
for s, dim_y in zip(stddev, tl.target_dims):
assert_shape(s, (dim_y, *expected_obs_shape))
else:
assert_shape(stddev, (n_target_sets, *expected_obs_shape))

n_samples = 5
samples = model.sample(task, n_samples)
if isinstance(samples, (list, tuple)):
for s, dim_y in zip(samples, tl.target_dims):
assert_shape(s, (n_samples, dim_y, *expected_obs_shape))
else:
assert_shape(
samples, (n_samples, n_target_sets, *expected_obs_shape)
)

if likelihood in ["cnp", "gnp"]:
n_target_dims = np.product(tl.target_dims)
assert_shape(
model.covariance(task),
(
n_targets * n_target_sets * n_target_dims,
n_targets * n_target_sets * n_target_dims,
),
)
if likelihood in ["cnp-spikes-beta"]:
mixture_probs = model.mixture_probs(task)
if isinstance(mixture_probs, (list, tuple)):
for p, dim_y in zip(mixture_probs, tl.target_dims):
assert_shape(
p,
(
model.N_mixture_components,
dim_y,
*expected_obs_shape,
),
)
else:
assert_shape(
mixture_probs,
(
model.N_mixture_components,
n_target_sets,
*expected_obs_shape,
),
)

x = model.beta_dist_alpha(task)
if isinstance(x, (list, tuple)):
for p, dim_y in zip(x, tl.target_dims):
assert_shape(p, (dim_y, *expected_obs_shape))
else:
assert_shape(x, (n_target_sets, *expected_obs_shape))

x = model.beta_dist_beta(task)
if isinstance(x, (list, tuple)):
for p, dim_y in zip(x, tl.target_dims):
assert_shape(p, (dim_y, *expected_obs_shape))
else:
assert_shape(x, (n_target_sets, *expected_obs_shape))

# Scalars
if likelihood in ["cnp", "gnp"]:
# Methods for Gaussian likelihoods only
x = model.logpdf(task)
assert x.size == 1 and x.shape == ()
x = model.joint_entropy(task)
assert x.size == 1 and x.shape == ()
x = model.mean_marginal_entropy(task)
assert x.size == 1 and x.shape == ()
x = B.to_numpy(model.loss_fn(task))
assert x.size == 1 and x.shape == ()

@parameterized.expand(range(1, 4))
def test_prediction_shapes_highlevel(self, target_dim):
Expand Down Expand Up @@ -372,9 +420,7 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):

dp = DataProcessor(
x1_name="latitude",
x1_map=lat_lims,
x2_name="longitude",
x2_map=lon_lims,
)
df = dp(df_raw)

Expand All @@ -392,6 +438,22 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):
df_raw.reset_index()["longitude"],
)

def test_highlevel_predict_with_pred_params(self):
"""Test that passing ``pred_params`` to ``.predict`` works."""
tl = TaskLoader(context=self.da, target=self.da)
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
task = tl("2020-01-01", context_sampling=10, target_sampling=10)

# Check that nothing breaks and the correct parameters are returned
pred_params = ["mean", "std", "variance"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["air"]

# Check that passing an invalid parameter raises a ValueError
with self.assertRaises(ValueError):
model.predict(task, X_t=self.da, pred_params=["invalid_param"])

def test_saving_and_loading(self):
"""Test saving and loading of model"""
with tempfile.TemporaryDirectory() as folder:
Expand Down

0 comments on commit ba55b55

Please sign in to comment.