Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add as_sklearn and from_sklearn APIs to serialize to CPU sklearn-estimators for supported models #6102

Open
wants to merge 12 commits into
base: branch-25.02
Choose a base branch
from
40 changes: 39 additions & 1 deletion python/cuml/cuml/experimental/accel/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import click
import code
import joblib
import pickle
import os
import runpy
import sys
Expand All @@ -31,14 +33,50 @@
default=False,
help="Turn strict mode for hyperparameters on.",
)
@click.option(
"--convert_to_sklearn",
type=click.Path(exists=True),
required=False,
help="Path to a pickled accelerated estimator to convert to a sklearn estimator.",
)
@click.option(
"--format",
"format",
type=click.Choice(["pickle", "joblib"], case_sensitive=False),
default="pickle",
help="Format to save the converted sklearn estimator.",
)
@click.option(
"--output",
type=click.Path(writable=True),
default="converted_sklearn_model.pkl",
help="Output path for the converted sklearn estimator file.",
)
@click.argument("args", nargs=-1)
def main(module, strict, args):
def main(module, strict, convert_to_sklearn, format, output, args):

if strict:
os.environ["CUML_ACCEL_STRICT_MODE"] = "ON"

install()

# If the user requested a conversion, handle it and exit
if convert_to_sklearn:

with open(convert_to_sklearn, "rb") as f:
if format == "pickle":
serializer = pickle
elif format == "joblib":
serializer = joblib
accelerated_estimator = serializer.load(f)

sklearn_estimator = accelerated_estimator.as_sklearn()

with open(output, "wb") as f:
serializer.dump(sklearn_estimator, f)

sys.exit()

if module:
(module,) = module
# run the module passing the remaining arguments
Expand Down
80 changes: 79 additions & 1 deletion python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# distutils: language = c++

import copy
import os
import inspect
import numbers
Expand All @@ -24,7 +25,7 @@ from cuml.internals.device_support import GPU_ENABLED
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import_from,
null_decorator
null_decorator,
)
np = cpu_only_import('numpy')
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)
Expand Down Expand Up @@ -851,3 +852,80 @@ class UniversalBase(Base):
raise ex

raise ex

def as_sklearn(self, deepcopy=False):
"""
Convert the current GPU-accelerated estimator into a scikit-learn estimator.
dantegd marked this conversation as resolved.
Show resolved Hide resolved

This method imports and builds an equivalent CPU-backed scikit-learn model,
transferring all necessary parameters from the GPU representation to the
CPU model. After this conversion, the returned object should be a fully
compatible scikit-learn estimator, allowing you to use it in standard
scikit-learn pipelines and workflows.

Parameters
----------
deepcopy : boolean (default=False)
Whether to return a deepcopy of the internal scikit-learn estimator of
the cuML models. cuML models internally have CPU based estimators that
could be updated. If you intend to use both the cuML and the scikit-learn
estimators after using the method in parallel, it is recommended to set
this to True to avoid one overwriting data of the other.

Returns
-------
sklearn.base.BaseEstimator
A scikit-learn compatible estimator instance that mirrors the trained
state of the current GPU-accelerated estimator.

"""
self.import_cpu_model()
self.build_cpu_model()
self.gpu_to_cpu()
if deepcopy:
return copy.deepcopy(self._cpu_model)
else:
return self._cpu_model

@classmethod
def from_sklearn(cls, model):
Copy link
Contributor

@viclafargue viclafargue Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have a global conversion table, so that we don't need to provide the class as a parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a class method, so we get the class from that, it's not something the user passes (like self in non class methods)

A global conversion table will be useful for a follow up to add cuml.from_sklearn library type of functionality though

"""
Create a GPU-accelerated estimator from a scikit-learn estimator.
dantegd marked this conversation as resolved.
Show resolved Hide resolved

This class method takes an existing scikit-learn estimator and converts it
into the corresponding GPU-backed estimator. It imports any required CPU
model definitions, stores the given scikit-learn model internally, and then
transfers the model parameters and state onto the GPU.
dantegd marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
model : sklearn.base.BaseEstimator
A fitted scikit-learn estimator from which to create the GPU-accelerated
version.
dantegd marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
cls
A new instance of the GPU-accelerated estimator class that mirrors the
state of the input scikit-learn estimator.

Notes
dantegd marked this conversation as resolved.
Show resolved Hide resolved
-----
- `output_type` of the estimator is set to "numpy"
by default, as these cannot be inferred from training arguments. If
something different is required, then please use cuML's output_type
configuration utilities.
"""
estimator = cls()
estimator.import_cpu_model()
estimator._cpu_model = model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be interesting to add an optional parameter to this function to allow a deepcopy of the sklearn model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol I asked the same thing before reading this suggestion :)

estimator.cpu_to_gpu()

# we need to set an output type here since
# we cannot infer from training args.
# Setting to numpy seems like a reasonable default for matching the
# deserialized class by default.
estimator.output_type = "numpy"
estimator.output_mem_type = MemoryType.host

return estimator
213 changes: 213 additions & 0 deletions python/cuml/cuml/tests/test_sklearn_import_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
import numpy as np

from cuml.cluster import KMeans, DBSCAN
from cuml.decomposition import PCA, TruncatedSVD
from cuml.linear_model import (
LinearRegression,
LogisticRegression,
ElasticNet,
Ridge,
Lasso,
)
from cuml.manifold import TSNE
from cuml.neighbors import NearestNeighbors

from cuml.testing.utils import array_equal

from numpy.testing import assert_allclose

from sklearn.datasets import make_blobs, make_classification, make_regression
from sklearn.utils.validation import check_is_fitted
from sklearn.cluster import KMeans as SkKMeans, DBSCAN as SkDBSCAN
from sklearn.decomposition import PCA as SkPCA, TruncatedSVD as SkTruncatedSVD
from sklearn.linear_model import (
LinearRegression as SkLinearRegression,
LogisticRegression as SkLogisticRegression,
ElasticNet as SkElasticNet,
Ridge as SkRidge,
Lasso as SkLasso,
)
from sklearn.manifold import TSNE as SkTSNE
from sklearn.neighbors import NearestNeighbors as SkNearestNeighbors

###############################################################################
# Helper functions #
###############################################################################


@pytest.fixture
def random_state():
return 42
Comment on lines +54 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this instead of using 42 in the tests directly?

We could have a global version of this that allows us to run the tests with several seeds, but maybe something to tackle in the future/new PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need it at all



def assert_estimator_roundtrip(
cuml_model, sklearn_class, X, y=None, transform=False
):
"""
Generic assertion helper to test round-trip conversion:
fit original custom model
convert to sklearn
convert back to custom model
compare predictions or transform outputs
"""
# Fit original model
if y is not None:
cuml_model.fit(X, y)
else:
cuml_model.fit(X)

# Convert to sklearn model
sklearn_model = cuml_model.as_sklearn()
check_is_fitted(sklearn_model)

assert isinstance(sklearn_model, sklearn_class)

# Convert back
roundtrip_model = type(cuml_model).from_sklearn(sklearn_model)

# Ensure roundtrip model is fitted
check_is_fitted(roundtrip_model)

# Compare predictions or transforms
if transform:
original_output = cuml_model.transform(X)
roundtrip_output = roundtrip_model.transform(X)
array_equal(original_output, roundtrip_output)
else:
# For predict methods
if hasattr(cuml_model, "predict"):
original_pred = cuml_model.predict(X)
roundtrip_pred = roundtrip_model.predict(X)
array_equal(original_pred, roundtrip_pred)
# For models that only produce labels_ or similar attributes (e.g., clustering)
elif hasattr(cuml_model, "labels_"):
array_equal(cuml_model.labels_, roundtrip_model.labels_)
else:
# If we get here, need a custom handling for that type
raise NotImplementedError(
"No known method to compare outputs of this model."
)


###############################################################################
# Tests #
###############################################################################


def test_kmeans(random_state):
# Using sklearn directly for demonstration
X, _ = make_blobs(
n_samples=50, n_features=2, centers=3, random_state=random_state
)
original = KMeans(n_clusters=3, random_state=random_state)
assert_estimator_roundtrip(original, SkKMeans, X)


def test_dbscan(random_state):
X, _ = make_blobs(
n_samples=50, n_features=2, centers=3, random_state=random_state
)
original = DBSCAN(eps=0.5, min_samples=5)
# DBSCAN assigns labels_ after fit
original.fit(X)
sklearn_model = original.as_sklearn()
roundtrip_model = DBSCAN.from_sklearn(sklearn_model)
array_equal(original.labels_, roundtrip_model.labels_)


def test_pca(random_state):
X = np.random.RandomState(random_state).rand(50, 5)
original = PCA(n_components=2, random_state=random_state)
assert_estimator_roundtrip(original, SkPCA, X, transform=True)


def test_truncated_svd(random_state):
X = np.random.RandomState(random_state).rand(50, 5)
original = TruncatedSVD(n_components=2, random_state=random_state)
assert_estimator_roundtrip(original, SkTruncatedSVD, X, transform=True)


def test_linear_regression(random_state):
X, y = make_regression(
n_samples=50, n_features=5, noise=0.1, random_state=random_state
)
original = LinearRegression()
assert_estimator_roundtrip(original, SkLinearRegression, X, y)


def test_logistic_regression(random_state):
X, y = make_classification(
n_samples=50, n_features=5, n_informative=3, random_state=random_state
)
original = LogisticRegression(random_state=random_state, max_iter=500)
assert_estimator_roundtrip(original, SkLogisticRegression, X, y)


def test_elasticnet(random_state):
X, y = make_regression(
n_samples=50, n_features=5, noise=0.1, random_state=random_state
)
original = ElasticNet(random_state=random_state)
assert_estimator_roundtrip(original, SkElasticNet, X, y)


def test_ridge(random_state):
X, y = make_regression(
n_samples=50, n_features=5, noise=0.1, random_state=random_state
)
original = Ridge(alpha=1.0, random_state=random_state)
assert_estimator_roundtrip(original, SkRidge, X, y)


def test_lasso(random_state):
X, y = make_regression(
n_samples=50, n_features=5, noise=0.1, random_state=random_state
)
original = Lasso(alpha=0.1, random_state=random_state)
assert_estimator_roundtrip(original, SkLasso, X, y)


def test_tsne(random_state):
# TSNE is a bit tricky as it is non-deterministic. For test simplicity:
X = np.random.RandomState(random_state).rand(50, 5)
original = TSNE(n_components=2, random_state=random_state)
original.fit(X)
sklearn_model = original.as_sklearn()
roundtrip_model = TSNE.from_sklearn(sklearn_model)
# Since TSNE is non-deterministic, exact match is unlikely.
# We can at least check output dimensions are the same.
original_embedding = original.embedding_
sklearn_embedding = sklearn_model.embedding_
roundtrip_embedding = roundtrip_model.embedding_

array_equal(original_embedding, sklearn_embedding)
array_equal(original_embedding, roundtrip_embedding)


def test_nearest_neighbors(random_state):
X = np.random.RandomState(random_state).rand(50, 5)
original = NearestNeighbors(n_neighbors=5)
original.fit(X)
sklearn_model = original.as_sklearn()
roundtrip_model = NearestNeighbors.from_sklearn(sklearn_model)
# Check that the kneighbors results are the same
dist_original, ind_original = original.kneighbors(X)
dist_roundtrip, ind_roundtrip = roundtrip_model.kneighbors(X)
assert_allclose(dist_original, dist_roundtrip)
assert_allclose(ind_original, ind_roundtrip)
Loading