Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Dec 5, 2023
1 parent ba55b55 commit ae7dc86
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions deepsensor/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def unnormalise_pred_array(arr, **kwargs):
try:
method = getattr(self, param)
prediction_methods[param] = method
except ValueError:
raise ValueError(
except AttributeError:
raise AttributeError(
f"Prediction method {param} not found in model class."
)
if n_samples >= 1:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,10 @@ def test_highlevel_predict_with_pred_params(self):
pred_params = ["mean", "std", "variance"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["air"]
assert pred_param in pred["var"]

# Check that passing an invalid parameter raises a ValueError
with self.assertRaises(ValueError):
# Check that passing an invalid parameter raises an AttributeError
with self.assertRaises(AttributeError):
model.predict(task, X_t=self.da, pred_params=["invalid_param"])

def test_saving_and_loading(self):
Expand Down

0 comments on commit ae7dc86

Please sign in to comment.