diff --git a/tests/test_model.py b/tests/test_model.py index e6465276..38a626a2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -69,8 +69,7 @@ def _gen_task_loader_call_args(self, n_context, n_target): ]: yield [sampling_method] * n_context, [sampling_method] * n_target - # TEMP only 1D because non-overlapping target sets are not yet supported - @parameterized.expand(range(1, 2)) + @parameterized.expand(range(1, 4)) def test_model_call(self, n_context_and_target): """Check ``ConvNP`` runs with all possible combinations of context/target sampling methods.""" @@ -120,68 +119,117 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): context_sampling = 10 - model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False) + likelihoods = ["cnp", "gnp", "cnp-spikes-beta"] - for target_sampling, expected_obs_shape in ( - (10, (10,)), # expected shape is (10,) when target_sampling is 10 - ( - "all", - self.da.shape[-2:], - ), # expected shape is da.shape[-2:] when target_sampling is "all" - ): - task = tl("2020-01-01", context_sampling, target_sampling) - - n_targets = np.product(expected_obs_shape) - - # Tensors - mean = model.mean(task) - # TODO avoid repeated code - if isinstance(mean, (list, tuple)): - for m, dim_y in zip(mean, tl.target_dims): - assert_shape(m, (dim_y, *expected_obs_shape)) - else: - assert_shape(mean, (n_target_sets, *expected_obs_shape)) - - variance = model.variance(task) - if isinstance(variance, (list, tuple)): - for v, dim_y in zip(variance, tl.target_dims): - assert_shape(v, (dim_y, *expected_obs_shape)) - else: - assert_shape(variance, (n_target_sets, *expected_obs_shape)) - - stddev = model.stddev(task) - if isinstance(stddev, (list, tuple)): - for s, dim_y in zip(stddev, tl.target_dims): - assert_shape(s, (dim_y, *expected_obs_shape)) - else: - assert_shape(stddev, (n_target_sets, *expected_obs_shape)) - - n_samples = 5 - samples = model.sample(task, n_samples) - if isinstance(samples, (list, tuple)): - for s, dim_y in zip(samples, tl.target_dims): - assert_shape(s, (n_samples, dim_y, *expected_obs_shape)) - else: - assert_shape(samples, (n_samples, n_target_sets, *expected_obs_shape)) - - n_target_dims = np.product(tl.target_dims) - assert_shape( - model.covariance(task), - ( - n_targets * n_target_sets * n_target_dims, - n_targets * n_target_sets * n_target_dims, - ), + for likelihood in likelihoods: + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + likelihood=likelihood, + verbose=False, ) - # Scalars - x = model.logpdf(task) - assert x.size == 1 and x.shape == () - x = model.joint_entropy(task) - assert x.size == 1 and x.shape == () - x = model.mean_marginal_entropy(task) - assert x.size == 1 and x.shape == () - x = B.to_numpy(model.loss_fn(task)) - assert x.size == 1 and x.shape == () + for target_sampling, expected_obs_shape in ( + (10, (10,)), # expected shape is (10,) when target_sampling is 10 + ( + "all", + self.da.shape[-2:], + ), # expected shape is da.shape[-2:] when target_sampling is "all" + ): + task = tl("2020-01-01", context_sampling, target_sampling) + + n_targets = np.product(expected_obs_shape) + + # Tensors + mean = model.mean(task) + # TODO avoid repeated code + if isinstance(mean, (list, tuple)): + for m, dim_y in zip(mean, tl.target_dims): + assert_shape(m, (dim_y, *expected_obs_shape)) + else: + assert_shape(mean, (n_target_sets, *expected_obs_shape)) + + variance = model.variance(task) + if isinstance(variance, (list, tuple)): + for v, dim_y in zip(variance, tl.target_dims): + assert_shape(v, (dim_y, *expected_obs_shape)) + else: + assert_shape(variance, (n_target_sets, *expected_obs_shape)) + + stddev = model.stddev(task) + if isinstance(stddev, (list, tuple)): + for s, dim_y in zip(stddev, tl.target_dims): + assert_shape(s, (dim_y, *expected_obs_shape)) + else: + assert_shape(stddev, (n_target_sets, *expected_obs_shape)) + + n_samples = 5 + samples = model.sample(task, n_samples) + if isinstance(samples, (list, tuple)): + for s, dim_y in zip(samples, tl.target_dims): + assert_shape(s, (n_samples, dim_y, *expected_obs_shape)) + else: + assert_shape( + samples, (n_samples, n_target_sets, *expected_obs_shape) + ) + + if likelihood in ["cnp", "gnp"]: + n_target_dims = np.product(tl.target_dims) + assert_shape( + model.covariance(task), + ( + n_targets * n_target_sets * n_target_dims, + n_targets * n_target_sets * n_target_dims, + ), + ) + if likelihood in ["cnp-spikes-beta"]: + mixture_probs = model.mixture_probs(task) + if isinstance(mixture_probs, (list, tuple)): + for p, dim_y in zip(mixture_probs, tl.target_dims): + assert_shape( + p, + ( + model.N_mixture_components, + dim_y, + *expected_obs_shape, + ), + ) + else: + assert_shape( + mixture_probs, + ( + model.N_mixture_components, + n_target_sets, + *expected_obs_shape, + ), + ) + + x = model.beta_dist_alpha(task) + if isinstance(x, (list, tuple)): + for p, dim_y in zip(x, tl.target_dims): + assert_shape(p, (dim_y, *expected_obs_shape)) + else: + assert_shape(x, (n_target_sets, *expected_obs_shape)) + + x = model.beta_dist_beta(task) + if isinstance(x, (list, tuple)): + for p, dim_y in zip(x, tl.target_dims): + assert_shape(p, (dim_y, *expected_obs_shape)) + else: + assert_shape(x, (n_target_sets, *expected_obs_shape)) + + # Scalars + if likelihood in ["cnp", "gnp"]: + # Methods for Gaussian likelihoods only + x = model.logpdf(task) + assert x.size == 1 and x.shape == () + x = model.joint_entropy(task) + assert x.size == 1 and x.shape == () + x = model.mean_marginal_entropy(task) + assert x.size == 1 and x.shape == () + x = B.to_numpy(model.loss_fn(task)) + assert x.size == 1 and x.shape == () @parameterized.expand(range(1, 4)) def test_prediction_shapes_highlevel(self, target_dim): @@ -372,9 +420,7 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self): dp = DataProcessor( x1_name="latitude", - x1_map=lat_lims, x2_name="longitude", - x2_map=lon_lims, ) df = dp(df_raw) @@ -392,6 +438,22 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self): df_raw.reset_index()["longitude"], ) + def test_highlevel_predict_with_pred_params(self): + """Test that passing ``pred_params`` to ``.predict`` works.""" + tl = TaskLoader(context=self.da, target=self.da) + model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False) + task = tl("2020-01-01", context_sampling=10, target_sampling=10) + + # Check that nothing breaks and the correct parameters are returned + pred_params = ["mean", "std", "variance"] + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for pred_param in pred_params: + assert pred_param in pred["air"] + + # Check that passing an invalid parameter raises a ValueError + with self.assertRaises(ValueError): + model.predict(task, X_t=self.da, pred_params=["invalid_param"]) + def test_saving_and_loading(self): """Test saving and loading of model""" with tempfile.TemporaryDirectory() as folder: