Skip to content

Commit

Permalink
Sklearn 1.6 support (#2221)
Browse files Browse the repository at this point in the history
* Apply scipy array API support

* Deselect tests for unsupported skl1.6 features

* Add sklearn 1.6 to CI matrix

* Fix pairwise_distances dispatching

* Fix forbidden usage of sklearn_check_version

* Fix for pairwise_distances params validation

* Fix for pairwise_distances params validation

* Fix for pairwise_distances params validation

* Add SCIPY_ARRAY_API to test_estimators

* Update input validation in AdaBoost and GBT d4p estimators

* Pin sklearn 1.5 for py3.9

* Fix knn bf regr spmd example

* Update python-sklearn CI matrix

* Apply comments for AdaBoost and GBT estimators

* Add sklearn 1.6 to README badge

* Linting

* Update metric in knn bf regr spmd example

* Update CI matrix
  • Loading branch information
Alexsandruss authored Dec 12, 2024
1 parent ab143a2 commit 624f1cf
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 57 deletions.
38 changes: 16 additions & 22 deletions .ci/pipeline/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,18 @@ jobs:
Python3.9_Sklearn1.0:
PYTHON_VERSION: '3.9'
SKLEARN_VERSION: '1.0'
Python3.9_Sklearn1.1:
PYTHON_VERSION: '3.9'
SKLEARN_VERSION: '1.1'
Python3.10_Sklearn1.2:
Python3.10_Sklearn1.3:
PYTHON_VERSION: '3.10'
SKLEARN_VERSION: '1.2'
Python3.11_Sklearn1.3:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.3'
Python3.12_Sklearn1.4:
PYTHON_VERSION: '3.12'
Python3.11_Sklearn1.4:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.4'
Python3.13_Sklearn1.5:
PYTHON_VERSION: '3.13'
Python3.12_Sklearn1.5:
PYTHON_VERSION: '3.12'
SKLEARN_VERSION: '1.5'
Python3.13_Sklearn1.6:
PYTHON_VERSION: '3.13'
SKLEARN_VERSION: '1.6'
pool:
vmImage: 'ubuntu-22.04'
steps:
Expand All @@ -146,21 +143,18 @@ jobs:
Python3.9_Sklearn1.0:
PYTHON_VERSION: '3.9'
SKLEARN_VERSION: '1.0'
Python3.9_Sklearn1.1:
PYTHON_VERSION: '3.9'
SKLEARN_VERSION: '1.1'
Python3.10_Sklearn1.2:
Python3.10_Sklearn1.3:
PYTHON_VERSION: '3.10'
SKLEARN_VERSION: '1.2'
Python3.11_Sklearn1.3:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.3'
Python3.12_Sklearn1.4:
PYTHON_VERSION: '3.12'
Python3.11_Sklearn1.4:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.4'
Python3.13_Sklearn1.5:
PYTHON_VERSION: '3.13'
Python3.12_Sklearn1.5:
PYTHON_VERSION: '3.12'
SKLEARN_VERSION: '1.5'
Python3.13_Sklearn1.6:
PYTHON_VERSION: '3.13'
SKLEARN_VERSION: '1.6'
pool:
vmImage: 'windows-2022'
steps:
Expand Down
5 changes: 5 additions & 0 deletions .ci/scripts/run_sklearn_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import pytest
import sklearn

from daal4py.sklearn._utils import sklearn_check_version

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -43,6 +45,9 @@
if os.environ["SELECTED_TESTS"] == "all":
os.environ["SELECTED_TESTS"] = ""

if sklearn_check_version("1.6"):
os.environ["SCIPY_ARRAY_API"] = "1"

pytest_args = (
"--verbose --durations=100 --durations-min=0.01 "
f"--rootdir={sklearn_file_dir} "
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
[![PyPI Version](https://img.shields.io/pypi/v/scikit-learn-intelex)](https://pypi.org/project/scikit-learn-intelex/)
[![Conda Version](https://img.shields.io/conda/vn/conda-forge/scikit-learn-intelex)](https://anaconda.org/conda-forge/scikit-learn-intelex)
[![python version](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)
[![scikit-learn supported versions](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5-blue)](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5-blue)
[![scikit-learn supported versions](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5%20%7C%201.6-blue)](https://img.shields.io/badge/sklearn-1.0%20%7C%201.2%20%7C%201.3%20%7C%201.4%20%7C%201.5%20%7C%201.6-blue)

---
</h3>
Expand Down
14 changes: 9 additions & 5 deletions daal4py/sklearn/ensemble/AdaBoostClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

import daal4py as d4p
from daal4py.sklearn._utils import sklearn_check_version

from .._n_jobs_support import control_n_jobs
from .._utils import getFPType

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
else:
validate_data = BaseEstimator._validate_data


@control_n_jobs(decorated_methods=["fit", "predict"])
class AdaBoostClassifier(BaseEstimator, ClassifierMixin):
class AdaBoostClassifier(ClassifierMixin, BaseEstimator):
def __init__(
self,
split_criterion="gini",
Expand Down Expand Up @@ -89,7 +95,7 @@ def fit(self, X, y):
)

# Check that X and y have correct shape
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double])
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32])

check_classification_targets(y)

Expand Down Expand Up @@ -151,9 +157,7 @@ def predict(self, X):
check_is_fitted(self)

# Input validation
X = check_array(X, dtype=[np.single, np.double])
if X.shape[1] != self.n_features_in_:
raise ValueError("Shape of input is different from what was seen in `fit`")
X = validate_data(self, X, dtype=[np.float64, np.float32], reset=False)

# Trivial case
if self.n_classes_ == 1:
Expand Down
51 changes: 35 additions & 16 deletions daal4py/sklearn/ensemble/GBTDAAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

import daal4py as d4p
from daal4py.sklearn._utils import sklearn_check_version

from .._n_jobs_support import control_n_jobs
from .._utils import getFPType

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
else:
validate_data = BaseEstimator._validate_data


class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel):
def __init__(
Expand Down Expand Up @@ -128,15 +134,22 @@ def _check_params(self):
def _more_tags(self):
return {"allow_nan": self.allow_nan_}

if sklearn_check_version("1.6"):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = self.allow_nan_
return tags


@control_n_jobs(decorated_methods=["fit", "predict"])
class GBTDAALClassifier(GBTDAALBase, ClassifierMixin):
class GBTDAALClassifier(ClassifierMixin, GBTDAALBase):
def fit(self, X, y):
# Check the algorithm parameters
self._check_params()

# Check that X and y have correct shape
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double])
X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32])

check_classification_targets(y)

Expand Down Expand Up @@ -196,15 +209,18 @@ def fit(self, X, y):
def _predict(
self, X, resultsToEvaluate, pred_contribs=False, pred_interactions=False
):
# Input validation
if not self.allow_nan_:
X = check_array(X, dtype=[np.single, np.double])
else:
X = check_array(X, dtype=[np.single, np.double], force_all_finite="allow-nan")

# Check is fit had been called
check_is_fitted(self, ["n_features_in_", "n_classes_"])

# Input validation
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
force_all_finite="allow-nan" if self.allow_nan_ else True,
reset=False,
)

# Trivial case
if self.n_classes_ == 1:
return np.full(X.shape[0], self.classes_[0])
Expand Down Expand Up @@ -251,13 +267,13 @@ def convert_model(model):


@control_n_jobs(decorated_methods=["fit", "predict"])
class GBTDAALRegressor(GBTDAALBase, RegressorMixin):
class GBTDAALRegressor(RegressorMixin, GBTDAALBase):
def fit(self, X, y):
# Check the algorithm parameters
self._check_params()

# Check that X and y have correct shape
X, y = check_X_y(X, y, y_numeric=True, dtype=[np.single, np.double])
X, y = check_X_y(X, y, y_numeric=True, dtype=[np.float64, np.float32])

# Convert to 2d array
y_ = y.reshape((-1, 1))
Expand Down Expand Up @@ -297,15 +313,18 @@ def fit(self, X, y):
return self

def predict(self, X, pred_contribs=False, pred_interactions=False):
# Input validation
if not self.allow_nan_:
X = check_array(X, dtype=[np.single, np.double])
else:
X = check_array(X, dtype=[np.single, np.double], force_all_finite="allow-nan")

# Check is fit had been called
check_is_fitted(self, ["n_features_in_"])

# Input validation
X = validate_data(
self,
X,
dtype=[np.float64, np.float32],
force_all_finite="allow-nan" if self.allow_nan_ else True,
reset=False,
)

fptype = getFPType(X)
return self._predict_regression(X, fptype, pred_contribs, pred_interactions)

Expand Down
12 changes: 12 additions & 0 deletions daal4py/sklearn/linear_model/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
# limitations under the License.
# ==============================================================================


from os import environ

from daal4py.sklearn._utils import sklearn_check_version

# sklearn requires manual enabling of Scipy array API support
# if `array-api-compat` package is present in environment
# TODO: create generic approach to handle this for all tests
if sklearn_check_version("1.6"):
environ["SCIPY_ARRAY_API"] = "1"


import numpy as np
import pytest
from sklearn.datasets import make_regression
Expand Down
101 changes: 91 additions & 10 deletions daal4py/sklearn/metrics/_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def _precompute_metric_params(*args, **kwrds):
from .._utils import PatchingConditionsChain, getFPType, sklearn_check_version

if sklearn_check_version("1.3"):
from sklearn.utils._param_validation import Integral, StrOptions, validate_params
from sklearn.utils._param_validation import (
Hidden,
Integral,
StrOptions,
validate_params,
)


def _daal4py_cosine_distance_dense(X):
Expand All @@ -65,7 +70,7 @@ def _daal4py_correlation_distance_dense(X):
return res.correlationDistance


def pairwise_distances(
def _pairwise_distances(
X, Y=None, metric="euclidean", *, n_jobs=None, force_all_finite=True, **kwds
):
if metric not in _VALID_METRICS and not callable(metric) and metric != "precomputed":
Expand Down Expand Up @@ -140,16 +145,92 @@ def pairwise_distances(
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)


# logic to deprecate `force_all_finite` from sklearn:
# it was renamed to `ensure_all_finite` since 1.6 and will be removed in 1.8
if sklearn_check_version("1.3"):
pairwise_distances_parameters = {
"X": ["array-like", "sparse matrix"],
"Y": ["array-like", "sparse matrix", None],
"metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
"n_jobs": [Integral, None],
"force_all_finite": [
"boolean",
StrOptions({"allow-nan"}),
Hidden(StrOptions({"deprecated"})),
],
"ensure_all_finite": [
"boolean",
StrOptions({"allow-nan"}),
Hidden(None),
],
}
if sklearn_check_version("1.6"):
if sklearn_check_version("1.8"):
del pairwise_distances_parameters["force_all_finite"]

def pairwise_distances(
X,
Y=None,
metric="euclidean",
*,
n_jobs=None,
ensure_all_finite=None,
**kwds,
):
return _pairwise_distances(
X,
Y,
metric,
n_jobs=n_jobs,
force_all_finite=ensure_all_finite,
**kwds,
)

else:
from sklearn.utils.deprecation import _deprecate_force_all_finite

def pairwise_distances(
X,
Y=None,
metric="euclidean",
*,
n_jobs=None,
force_all_finite="deprecated",
ensure_all_finite=None,
**kwds,
):
force_all_finite = _deprecate_force_all_finite(
force_all_finite, ensure_all_finite
)
return _pairwise_distances(
X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds
)

else:
del pairwise_distances_parameters["ensure_all_finite"]

def pairwise_distances(
X,
Y=None,
metric="euclidean",
*,
n_jobs=None,
force_all_finite=True,
**kwds,
):
return _pairwise_distances(
X,
Y,
metric,
n_jobs=n_jobs,
force_all_finite=force_all_finite,
**kwds,
)

pairwise_distances = validate_params(
{
"X": ["array-like", "sparse matrix"],
"Y": ["array-like", "sparse matrix", None],
"metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
"n_jobs": [Integral, None],
"force_all_finite": ["boolean", StrOptions({"allow-nan"})],
},
pairwise_distances_parameters,
prefer_skip_nested_validation=True,
)(pairwise_distances)

else:
pairwise_distances = _pairwise_distances
pairwise_distances.__doc__ = pairwise_distances_original.__doc__
Loading

0 comments on commit 624f1cf

Please sign in to comment.