Skip to content

Commit

Permalink
fix gp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Oct 11, 2024
1 parent 16d80f1 commit 6a4ca01
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/models/test_gptorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

@pytest.fixture
def sample_data_y1d():
X, y = make_regression(n_samples=10, n_features=5, n_targets=1, random_state=0)
X, y = make_regression(n_samples=20, n_features=5, n_targets=1, random_state=0)
return X, y


@pytest.fixture
def sample_data_y2d():
X, y = make_regression(n_samples=10, n_features=5, n_targets=2, random_state=0)
X, y = make_regression(n_samples=20, n_features=5, n_targets=2, random_state=0)
return X, y


Expand All @@ -28,7 +28,7 @@ def test_multi_output_gpmt(sample_data_y2d):
X, y = sample_data_y2d
gp = GaussianProcessMT(random_state=42)
gp.fit(X, y)
assert gp.predict(X).shape == (10, 2)
assert gp.predict(X).shape == (20, 2)


def test_predict_with_uncertainty_gpmt(sample_data_y1d):
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_multioutput_gp(sample_data_y2d):
X, y = sample_data_y2d
gp = GaussianProcess(random_state=42)
gp.fit(X, y)
assert gp.predict(X).shape == (10, 2)
assert gp.predict(X).shape == (20, 2)


def test_predict_with_uncertainty_gp(sample_data_y1d):
Expand Down

0 comments on commit 6a4ca01

Please sign in to comment.