From 421b56e2b05da923066779543d4d729d5ea8d67f Mon Sep 17 00:00:00 2001 From: mastoffel Date: Wed, 13 Nov 2024 13:06:18 +0000 Subject: [PATCH] include random_state to make cv equal across models and cv checks --- autoemulate/compare.py | 7 +++++-- autoemulate/utils.py | 16 ++++++++++++++++ tests/test_ui.py | 2 +- tests/test_utils.py | 14 ++++++++++++++ 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/autoemulate/compare.py b/autoemulate/compare.py index 986ec21e..3ecbc42a 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -22,6 +22,7 @@ from autoemulate.save import ModelSerialiser from autoemulate.sensitivity_analysis import plot_sensitivity_analysis from autoemulate.sensitivity_analysis import sensitivity_analysis +from autoemulate.utils import _check_cv from autoemulate.utils import _ensure_2d from autoemulate.utils import _get_full_model_name from autoemulate.utils import _redirect_warnings @@ -54,7 +55,9 @@ def setup( scaler=StandardScaler(), reduce_dim=False, dim_reducer=PCA(), - cross_validator=KFold(n_splits=5, shuffle=True), + cross_validator=KFold( + n_splits=5, shuffle=True, random_state=np.random.randint(1e5) + ), n_jobs=None, models=None, verbose=0, @@ -121,7 +124,7 @@ def setup( dim_reducer=dim_reducer, ) self.metrics = self._get_metrics(METRIC_REGISTRY) - self.cross_validator = cross_validator + self.cross_validator = _check_cv(cross_validator) self.param_search = param_search self.search_type = param_search_type self.param_search_iters = param_search_iters diff --git a/autoemulate/utils.py b/autoemulate/utils.py index 3d1dfe78..09eec082 100644 --- a/autoemulate/utils.py +++ b/autoemulate/utils.py @@ -9,6 +9,7 @@ import torch from sklearn.base import RegressorMixin from sklearn.exceptions import ConvergenceWarning +from sklearn.model_selection import KFold from sklearn.multioutput import MultiOutputRegressor from sklearn.pipeline import Pipeline @@ -370,3 +371,18 @@ def _ensure_2d(arr): if arr.ndim == 1: arr = arr.reshape(-1, 1) return arr + + +# checkers for scikit-learn objects -------------------------------------------- + + +def _check_cv(cv): + """Ensure that cross-validation method is valid""" + if cv is None: + raise ValueError("cross_validator cannot be None") + if not isinstance(cv, KFold): + raise ValueError( + "cross_validator should be an instance of KFold cross-validation. We do not " + "currently support other cross-validation methods." + ) + return cv diff --git a/tests/test_ui.py b/tests/test_ui.py index 30223456..a022c7f8 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -44,7 +44,7 @@ def test_cross_validators(): X = np.random.rand(100, 5) y = np.random.rand(100, 1) - cross_validators = [KFold(n_splits=5), TimeSeriesSplit(n_splits=5)] + cross_validators = [KFold(n_splits=5)] for cross_validator in cross_validators: ae = AutoEmulate() diff --git a/tests/test_utils.py b/tests/test_utils.py index 2d7b9da6..6fc5fe73 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,8 @@ import pytest from sklearn.ensemble import GradientBoostingRegressor from sklearn.ensemble import RandomForestRegressor +from sklearn.model_selection import KFold +from sklearn.model_selection import LeaveOneOut from sklearn.multioutput import MultiOutputRegressor from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler @@ -12,6 +14,7 @@ from autoemulate.utils import _add_prefix_to_param_space from autoemulate.utils import _add_prefix_to_single_grid from autoemulate.utils import _adjust_param_space +from autoemulate.utils import _check_cv from autoemulate.utils import _denormalise_y from autoemulate.utils import _ensure_2d from autoemulate.utils import _get_full_model_name @@ -340,3 +343,14 @@ def test_ensure_2d_2d(): y = np.array([[1, 2], [3, 4], [5, 6]]) y_2d = _ensure_2d(y) assert y_2d.ndim == 2 + + +# test checkers for scikit-learn objects -------------------------------------- +def test_check_cv(): + cv = KFold(n_splits=5, shuffle=True, random_state=np.random.randint(1e5)) + _check_cv(cv) + + +def test_check_cv_error(): + with pytest.raises(ValueError): + _check_cv(LeaveOneOut())