From 6ef0a5d9af7233f9a77d3996e75650171724dd53 Mon Sep 17 00:00:00 2001 From: "glasere@purdue.edu" <42726565+ethanglaser@users.noreply.github.com> Date: Thu, 23 Mar 2023 08:39:19 -0400 Subject: [PATCH] KNN SPMD python interfaces (#1208) * initial draft of python knn SPMD interfaces * conditionally adding responses to regression * forgot this * trying moving get_queue * reverting last and restoring cpp to test * trying just infer * resolved issue and examples run * removing commented lines * ci update * reverting intel-dpcpp-cpp-compiler version change * reverting * reverting several commits... * addressing comments * flake8 * addressing comments * flake8 * specifying onedal version for spmd neighbors setup * temporary commit for debugging external CI fails * printing import error * adding conditional to neighbors.cp * adding version import * removing all logging and re-adding onedal version conditional in setup * addressing last comments * removing comm size warnings and cleaning up examples * example rename * update run_examples.py after rebase * flake8 * addressing non gpu spmd call, better multiline string --- .../sklearnex/knn_bf_classification_spmd.py | 65 +++++++++++++++++ examples/sklearnex/knn_bf_regression_spmd.py | 71 +++++++++++++++++++ onedal/__init__.py | 3 +- onedal/neighbors/neighbors.cpp | 6 ++ onedal/neighbors/neighbors.py | 19 +++-- onedal/spmd/__init__.py | 2 +- onedal/spmd/neighbors/__init__.py | 19 +++++ onedal/spmd/neighbors/neighbors.py | 52 ++++++++++++++ setup.py | 4 +- sklearnex/spmd/__init__.py | 2 +- sklearnex/spmd/neighbors/__init__.py | 19 +++++ sklearnex/spmd/neighbors/neighbors.py | 25 +++++++ tests/run_examples.py | 6 ++ 13 files changed, 283 insertions(+), 10 deletions(-) create mode 100644 examples/sklearnex/knn_bf_classification_spmd.py create mode 100644 examples/sklearnex/knn_bf_regression_spmd.py create mode 100644 onedal/spmd/neighbors/__init__.py create mode 100644 onedal/spmd/neighbors/neighbors.py create mode 100644 sklearnex/spmd/neighbors/__init__.py create mode 100644 sklearnex/spmd/neighbors/neighbors.py diff --git a/examples/sklearnex/knn_bf_classification_spmd.py b/examples/sklearnex/knn_bf_classification_spmd.py new file mode 100644 index 0000000000..b2a1e40e5d --- /dev/null +++ b/examples/sklearnex/knn_bf_classification_spmd.py @@ -0,0 +1,65 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +import numpy as np +from sklearn.metrics import accuracy_score +from warnings import warn +from mpi4py import MPI +import dpctl +from sklearnex.spmd.neighbors import KNeighborsClassifier + + +def generate_X_y(par, seed): + ns, nf = par['ns'], par['nf'] + + drng = np.random.default_rng(seed) + data = drng.uniform(-1, 1, size=(ns, nf)) + resp = (data > 0) @ (2 ** np.arange(nf)) + + return data, resp + + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +if dpctl.has_gpu_devices: + q = dpctl.SyclQueue("gpu") +else: + raise RuntimeError('GPU devices unavailable. Currently, ' + 'SPMD execution mode is implemented only for this device type.') + +params_train = {'ns': 100000, 'nf': 8} +params_test = {'ns': 100, 'nf': 8} + +X_train, y_train = generate_X_y(params_train, rank) +X_test, y_test = generate_X_y(params_test, rank + 99) + +model_spmd = KNeighborsClassifier(algorithm='brute', + n_neighbors=20, + weights='uniform', + p=2, + metric='minkowski') +model_spmd.fit(X_train, y_train, queue=q) + +y_predict = model_spmd.predict(X_test, queue=q) + +print("Brute Force Distributed kNN classification results:") +print("Ground truth (first 5 observations on rank {}):\n{}".format(rank, y_test[:5])) +print("Classification results (first 5 observations on rank {}):\n{}" + .format(rank, y_predict[:5])) +print("Accuracy for entire rank {} (256 classes): {}\n" + .format(rank, accuracy_score(y_test, y_predict))) diff --git a/examples/sklearnex/knn_bf_regression_spmd.py b/examples/sklearnex/knn_bf_regression_spmd.py new file mode 100644 index 0000000000..223cd910dd --- /dev/null +++ b/examples/sklearnex/knn_bf_regression_spmd.py @@ -0,0 +1,71 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +import numpy as np +from sklearn.metrics import mean_squared_error +from warnings import warn +from mpi4py import MPI +import dpctl +from numpy.testing import assert_allclose +from sklearnex.spmd.neighbors import KNeighborsRegressor + + +def generate_X_y(par, coef_seed, data_seed): + ns, nf = par['ns'], par['nf'] + + crng = np.random.default_rng(coef_seed) + coef = crng.uniform(-10, 10, size=(nf,)) + + drng = np.random.default_rng(data_seed) + data = drng.uniform(-100, 100, size=(ns, nf)) + resp = data @ coef + + return data, resp, coef + + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +if dpctl.has_gpu_devices: + q = dpctl.SyclQueue("gpu") +else: + raise RuntimeError('GPU devices unavailable. Currently, ' + 'SPMD execution mode is implemented only for this device type.') + +params_train = {'ns': 1000000, 'nf': 3} +params_test = {'ns': 100, 'nf': 3} + +X_train, y_train, coef_train = generate_X_y(params_train, 10, rank) +X_test, y_test, coef_test = generate_X_y(params_test, 10, rank + 99) + +assert_allclose(coef_train, coef_test) + +model_spmd = KNeighborsRegressor(algorithm='brute', + n_neighbors=5, + weights='uniform', + p=2, + metric='minkowski') +model_spmd.fit(X_train, y_train, queue=q) + +y_predict = model_spmd.predict(X_test, queue=q) + +print("Brute Force Distributed kNN regression results:") +print("Ground truth (first 5 observations on rank {}):\n{}".format(rank, y_test[:5])) +print("Regression results (first 5 observations on rank {}):\n{}" + .format(rank, y_predict[:5])) +print("RMSE for entire rank {}: {}\n" + .format(rank, mean_squared_error(y_test, y_predict, squared=False))) diff --git a/onedal/__init__.py b/onedal/__init__.py index db30bf0334..8924e2e99c 100644 --- a/onedal/__init__.py +++ b/onedal/__init__.py @@ -49,4 +49,5 @@ __all__ += ['basic_statistics', 'linear_model'] if _is_dpc_backend: - __all__ += ['spmd.basic_statistics', 'spmd.decomposition', 'spmd.linear_model',] + __all__ += ['spmd.basic_statistics', 'spmd.decomposition', + 'spmd.linear_model', 'spmd.neighbors'] diff --git a/onedal/neighbors/neighbors.cpp b/onedal/neighbors/neighbors.cpp index 8e315a8cd4..5af3ecc8b0 100644 --- a/onedal/neighbors/neighbors.cpp +++ b/onedal/neighbors/neighbors.cpp @@ -17,6 +17,7 @@ #include "oneapi/dal/algo/knn.hpp" #include "onedal/common.hpp" +#include "onedal/version.hpp" #include "onedal/primitives/pairwise_distances.hpp" #include @@ -313,8 +314,13 @@ ONEDAL_PY_INIT_MODULE(neighbors) { using task_list = types; auto sub = m.def_submodule("neighbors"); +#if defined(ONEDAL_DATA_PARALLEL_SPMD) && defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100 + ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list_spmd, task_list); + ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list_spmd, task_list); +#else // defined(ONEDAL_DATA_PARALLEL_SPMD) && defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100 ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list); ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list, task_list); +#endif // defined(ONEDAL_DATA_PARALLEL_SPMD) && defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100 ONEDAL_PY_INSTANTIATE(init_model, sub, task_list); ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list); diff --git a/onedal/neighbors/neighbors.py b/onedal/neighbors/neighbors.py index dded3e3c29..459c761c28 100755 --- a/onedal/neighbors/neighbors.py +++ b/onedal/neighbors/neighbors.py @@ -44,6 +44,9 @@ class NeighborsCommonBase(metaclass=ABCMeta): + def _get_policy(self, queue, *data): + return _get_policy(queue, *data) + def _parse_auto_method(self, method, n_samples, n_features): result_method = method @@ -402,7 +405,7 @@ def _onedal_fit(self, X, y, queue): return train_alg(**params).compute(X, y).model - policy = _get_policy(queue, X, y) + policy = self._get_policy(queue, X, y) X, y = _convert_to_supported(policy, X, y) params = self._get_onedal_params(X, y) train_alg = _backend.neighbors.classification.train(policy, params, @@ -421,7 +424,7 @@ def _onedal_predict(self, model, X, params, queue): return predict_alg(**params).compute(X, model) - policy = _get_policy(queue, X) + policy = self._get_policy(queue, X) X = _convert_to_supported(policy, X) if hasattr(self, '_onedal_model'): model = self._onedal_model @@ -549,7 +552,8 @@ def _onedal_fit(self, X, y, queue): return train_alg(**params).compute(X, y).model - policy = _get_policy(queue, X, y) + policy = self._get_policy(queue, X, y) + X, y = _convert_to_supported(policy, X, y) params = self._get_onedal_params(X, y) train_alg_regr = _backend.neighbors.regression.train train_alg_srch = _backend.neighbors.search.train @@ -568,7 +572,8 @@ def _onedal_predict(self, model, X, params, queue): return predict_alg(**params).compute(X, model) - policy = _get_policy(queue, X) + policy = self._get_policy(queue, X) + X = _convert_to_supported(policy, X) backend = _backend.neighbors.regression if gpu_device \ else _backend.neighbors.search @@ -678,7 +683,8 @@ def _onedal_fit(self, X, y, queue): return train_alg(**params).compute(X, y).model - policy = _get_policy(queue, X, y) + policy = self._get_policy(queue, X, y) + X, y = _convert_to_supported(policy, X, y) params = self._get_onedal_params(X, y) train_alg = _backend.neighbors.search.train(policy, params, to_table(X)) @@ -696,7 +702,8 @@ def _onedal_predict(self, model, X, params, queue): return predict_alg(**params).compute(X, model) - policy = _get_policy(queue, X) + policy = self._get_policy(queue, X) + X = _convert_to_supported(policy, X) if hasattr(self, '_onedal_model'): model = self._onedal_model else: diff --git a/onedal/spmd/__init__.py b/onedal/spmd/__init__.py index 9ac25b4370..f71ffd55df 100644 --- a/onedal/spmd/__init__.py +++ b/onedal/spmd/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. #=============================================================================== -__all__ = ['basic_statistics', 'decomposition', 'linear_model'] +__all__ = ['basic_statistics', 'decomposition', 'linear_model', 'neighbors'] diff --git a/onedal/spmd/neighbors/__init__.py b/onedal/spmd/neighbors/__init__.py new file mode 100644 index 0000000000..99099fa51c --- /dev/null +++ b/onedal/spmd/neighbors/__init__.py @@ -0,0 +1,19 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors + +__all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'NearestNeighbors'] diff --git a/onedal/spmd/neighbors/neighbors.py b/onedal/spmd/neighbors/neighbors.py new file mode 100644 index 0000000000..067bac6d3c --- /dev/null +++ b/onedal/spmd/neighbors/neighbors.py @@ -0,0 +1,52 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from abc import ABC +from ...common._spmd_policy import _get_spmd_policy +from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch +from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch + + +class NeighborsCommonBaseSPMD(ABC): + def _get_policy(self, queue, *data): + return _get_spmd_policy(queue) + + +class KNeighborsClassifier(NeighborsCommonBaseSPMD, KNeighborsClassifier_Batch): + def predict_proba(self, X, queue=None): + raise NotImplementedError("predict_proba not supported in distributed mode.") + + +class KNeighborsRegressor(NeighborsCommonBaseSPMD, KNeighborsRegressor_Batch): + def fit(self, X, y, queue=None): + if queue is not None and queue.sycl_device.is_gpu: + return super()._fit(X, y, queue=queue) + else: + raise ValueError('SPMD version of kNN is not implemented for ' + 'CPU. Consider running on it on GPU.') + + def predict(self, X, queue=None): + return self._predict_gpu(X, queue=queue) + + def _get_onedal_params(self, X, y=None): + params = super()._get_onedal_params(X, y) + if 'responses' not in params['result_option']: + params['result_option'] += '|responses' + return params + + +class NearestNeighbors(NeighborsCommonBaseSPMD): + pass diff --git a/setup.py b/setup.py index dff6ffe3a3..1e912ee547 100644 --- a/setup.py +++ b/setup.py @@ -478,7 +478,9 @@ def run(self): 'onedal.spmd.basic_statistics', 'onedal.spmd.decomposition', 'onedal.spmd.linear_model' - ] if build_distribute else [])), + ] + (['onedal.spmd.neighbors'] + if ONEDAL_VERSION >= 20230100 else []) + if build_distribute else [])), package_data={ 'daal4py.oneapi': [ 'liboneapi_backend.so', diff --git a/sklearnex/spmd/__init__.py b/sklearnex/spmd/__init__.py index 9ac25b4370..f71ffd55df 100644 --- a/sklearnex/spmd/__init__.py +++ b/sklearnex/spmd/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. #=============================================================================== -__all__ = ['basic_statistics', 'decomposition', 'linear_model'] +__all__ = ['basic_statistics', 'decomposition', 'linear_model', 'neighbors'] diff --git a/sklearnex/spmd/neighbors/__init__.py b/sklearnex/spmd/neighbors/__init__.py new file mode 100644 index 0000000000..99099fa51c --- /dev/null +++ b/sklearnex/spmd/neighbors/__init__.py @@ -0,0 +1,19 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors + +__all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'NearestNeighbors'] diff --git a/sklearnex/spmd/neighbors/neighbors.py b/sklearnex/spmd/neighbors/neighbors.py new file mode 100644 index 0000000000..7eaa5e9f62 --- /dev/null +++ b/sklearnex/spmd/neighbors/neighbors.py @@ -0,0 +1,25 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +from onedal.spmd.neighbors import ( + KNeighborsClassifier, + KNeighborsRegressor, + NearestNeighbors +) + +# TODO: +# Currently it uses `onedal` module interface. +# Add sklearnex dispatching. diff --git a/tests/run_examples.py b/tests/run_examples.py index a5f43b7f07..8822fa478b 100755 --- a/tests/run_examples.py +++ b/tests/run_examples.py @@ -139,10 +139,14 @@ def check_library(rule): req_version['decision_forest_classification_traverse_batch.py'] = (2023, 'P', 1) req_version['decision_forest_regression_hist_batch.py'] = (2021, 'P', 200) req_version['basic_statistics_spmd.py'] = (2023, 'P', 1) +req_version['knn_bf_classification_spmd.py'] = (2023, 'P', 1) +req_version['knn_bf_regression_spmd.py'] = (2023, 'P', 1) req_version['linear_regression_spmd.py'] = (2023, 'P', 1) req_device = defaultdict(lambda: []) req_device['basic_statistics_spmd.py'] = ["gpu"] +req_device['knn_bf_classification_spmd.py'] = ["gpu"] +req_device['knn_bf_regression_spmd.py'] = ["gpu"] req_device['linear_regression_spmd.py'] = ["gpu"] req_device['pca_spmd.py'] = ["gpu"] req_device['sycl/gradient_boosted_regression_batch.py'] = ["gpu"] @@ -152,6 +156,8 @@ def check_library(rule): req_library['gbt_cls_model_create_from_lightgbm_batch.py'] = ['lightgbm'] req_library['gbt_cls_model_create_from_xgboost_batch.py'] = ['xgboost'] req_library['gbt_cls_model_create_from_catboost_batch.py'] = ['catboost'] +req_library['knn_bf_classification_spmd.py'] = ['dpctl', 'mpi4py'] +req_library['knn_bf_regression_spmd.py'] = ['dpctl', 'mpi4py'] req_library['linear_regression_spmd.py'] = ['dpctl', 'mpi4py'] req_library['pca_spmd.py'] = ['dpctl', 'mpi4py']