Skip to content

Commit

Permalink
include random_state to make cv equal across models and cv checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Nov 13, 2024
1 parent 3daa3ed commit 421b56e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
7 changes: 5 additions & 2 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from autoemulate.save import ModelSerialiser
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
from autoemulate.sensitivity_analysis import sensitivity_analysis
from autoemulate.utils import _check_cv
from autoemulate.utils import _ensure_2d
from autoemulate.utils import _get_full_model_name
from autoemulate.utils import _redirect_warnings
Expand Down Expand Up @@ -54,7 +55,9 @@ def setup(
scaler=StandardScaler(),
reduce_dim=False,
dim_reducer=PCA(),
cross_validator=KFold(n_splits=5, shuffle=True),
cross_validator=KFold(
n_splits=5, shuffle=True, random_state=np.random.randint(1e5)
),
n_jobs=None,
models=None,
verbose=0,
Expand Down Expand Up @@ -121,7 +124,7 @@ def setup(
dim_reducer=dim_reducer,
)
self.metrics = self._get_metrics(METRIC_REGISTRY)
self.cross_validator = cross_validator
self.cross_validator = _check_cv(cross_validator)
self.param_search = param_search
self.search_type = param_search_type
self.param_search_iters = param_search_iters
Expand Down
16 changes: 16 additions & 0 deletions autoemulate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from sklearn.base import RegressorMixin
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import KFold
from sklearn.multioutput import MultiOutputRegressor
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -370,3 +371,18 @@ def _ensure_2d(arr):
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
return arr


# checkers for scikit-learn objects --------------------------------------------


def _check_cv(cv):
"""Ensure that cross-validation method is valid"""
if cv is None:
raise ValueError("cross_validator cannot be None")
if not isinstance(cv, KFold):
raise ValueError(
"cross_validator should be an instance of KFold cross-validation. We do not "
"currently support other cross-validation methods."
)
return cv
2 changes: 1 addition & 1 deletion tests/test_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_cross_validators():
X = np.random.rand(100, 5)
y = np.random.rand(100, 1)

cross_validators = [KFold(n_splits=5), TimeSeriesSplit(n_splits=5)]
cross_validators = [KFold(n_splits=5)]

for cross_validator in cross_validators:
ae = AutoEmulate()
Expand Down
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import KFold
from sklearn.model_selection import LeaveOneOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
Expand All @@ -12,6 +14,7 @@
from autoemulate.utils import _add_prefix_to_param_space
from autoemulate.utils import _add_prefix_to_single_grid
from autoemulate.utils import _adjust_param_space
from autoemulate.utils import _check_cv
from autoemulate.utils import _denormalise_y
from autoemulate.utils import _ensure_2d
from autoemulate.utils import _get_full_model_name
Expand Down Expand Up @@ -340,3 +343,14 @@ def test_ensure_2d_2d():
y = np.array([[1, 2], [3, 4], [5, 6]])
y_2d = _ensure_2d(y)
assert y_2d.ndim == 2


# test checkers for scikit-learn objects --------------------------------------
def test_check_cv():
cv = KFold(n_splits=5, shuffle=True, random_state=np.random.randint(1e5))
_check_cv(cv)


def test_check_cv_error():
with pytest.raises(ValueError):
_check_cv(LeaveOneOut())

0 comments on commit 421b56e

Please sign in to comment.