diff --git a/tests/models/test_gptorch.py b/tests/models/test_gptorch.py index 066b13c7..c435d881 100644 --- a/tests/models/test_gptorch.py +++ b/tests/models/test_gptorch.py @@ -13,13 +13,13 @@ @pytest.fixture def sample_data_y1d(): - X, y = make_regression(n_samples=10, n_features=5, n_targets=1, random_state=0) + X, y = make_regression(n_samples=20, n_features=5, n_targets=1, random_state=0) return X, y @pytest.fixture def sample_data_y2d(): - X, y = make_regression(n_samples=10, n_features=5, n_targets=2, random_state=0) + X, y = make_regression(n_samples=20, n_features=5, n_targets=2, random_state=0) return X, y @@ -28,7 +28,7 @@ def test_multi_output_gpmt(sample_data_y2d): X, y = sample_data_y2d gp = GaussianProcessMT(random_state=42) gp.fit(X, y) - assert gp.predict(X).shape == (10, 2) + assert gp.predict(X).shape == (20, 2) def test_predict_with_uncertainty_gpmt(sample_data_y1d): @@ -62,7 +62,7 @@ def test_multioutput_gp(sample_data_y2d): X, y = sample_data_y2d gp = GaussianProcess(random_state=42) gp.fit(X, y) - assert gp.predict(X).shape == (10, 2) + assert gp.predict(X).shape == (20, 2) def test_predict_with_uncertainty_gp(sample_data_y1d):