Skip to content

Commit

Permalink
KNN SPMD python interfaces (#1208)
Browse files Browse the repository at this point in the history
* initial draft of python knn SPMD interfaces

* conditionally adding responses to regression

* forgot this

* trying moving get_queue

* reverting last and restoring cpp to test

* trying just infer

* resolved issue and examples run

* removing commented lines

* ci update

* reverting intel-dpcpp-cpp-compiler version change

* reverting

* reverting several commits...

* addressing comments

* flake8

* addressing comments

* flake8

* specifying onedal version for spmd neighbors setup

* temporary commit for debugging external CI fails

* printing import error

* adding conditional to neighbors.cp

* adding version import

* removing all logging and re-adding onedal version conditional in setup

* addressing last comments

* removing comm size warnings and cleaning up examples

* example rename

* update run_examples.py after rebase

* flake8

* addressing non gpu spmd call, better multiline string
  • Loading branch information
ethanglaser authored and ahuber21 committed Mar 23, 2023
1 parent cfc214c commit 6ef0a5d
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 10 deletions.
65 changes: 65 additions & 0 deletions examples/sklearnex/knn_bf_classification_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

import numpy as np
from sklearn.metrics import accuracy_score
from warnings import warn
from mpi4py import MPI
import dpctl
from sklearnex.spmd.neighbors import KNeighborsClassifier


def generate_X_y(par, seed):
ns, nf = par['ns'], par['nf']

drng = np.random.default_rng(seed)
data = drng.uniform(-1, 1, size=(ns, nf))
resp = (data > 0) @ (2 ** np.arange(nf))

return data, resp


comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

if dpctl.has_gpu_devices:
q = dpctl.SyclQueue("gpu")
else:
raise RuntimeError('GPU devices unavailable. Currently, '
'SPMD execution mode is implemented only for this device type.')

params_train = {'ns': 100000, 'nf': 8}
params_test = {'ns': 100, 'nf': 8}

X_train, y_train = generate_X_y(params_train, rank)
X_test, y_test = generate_X_y(params_test, rank + 99)

model_spmd = KNeighborsClassifier(algorithm='brute',
n_neighbors=20,
weights='uniform',
p=2,
metric='minkowski')
model_spmd.fit(X_train, y_train, queue=q)

y_predict = model_spmd.predict(X_test, queue=q)

print("Brute Force Distributed kNN classification results:")
print("Ground truth (first 5 observations on rank {}):\n{}".format(rank, y_test[:5]))
print("Classification results (first 5 observations on rank {}):\n{}"
.format(rank, y_predict[:5]))
print("Accuracy for entire rank {} (256 classes): {}\n"
.format(rank, accuracy_score(y_test, y_predict)))
71 changes: 71 additions & 0 deletions examples/sklearnex/knn_bf_regression_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

import numpy as np
from sklearn.metrics import mean_squared_error
from warnings import warn
from mpi4py import MPI
import dpctl
from numpy.testing import assert_allclose
from sklearnex.spmd.neighbors import KNeighborsRegressor


def generate_X_y(par, coef_seed, data_seed):
ns, nf = par['ns'], par['nf']

crng = np.random.default_rng(coef_seed)
coef = crng.uniform(-10, 10, size=(nf,))

drng = np.random.default_rng(data_seed)
data = drng.uniform(-100, 100, size=(ns, nf))
resp = data @ coef

return data, resp, coef


comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

if dpctl.has_gpu_devices:
q = dpctl.SyclQueue("gpu")
else:
raise RuntimeError('GPU devices unavailable. Currently, '
'SPMD execution mode is implemented only for this device type.')

params_train = {'ns': 1000000, 'nf': 3}
params_test = {'ns': 100, 'nf': 3}

X_train, y_train, coef_train = generate_X_y(params_train, 10, rank)
X_test, y_test, coef_test = generate_X_y(params_test, 10, rank + 99)

assert_allclose(coef_train, coef_test)

model_spmd = KNeighborsRegressor(algorithm='brute',
n_neighbors=5,
weights='uniform',
p=2,
metric='minkowski')
model_spmd.fit(X_train, y_train, queue=q)

y_predict = model_spmd.predict(X_test, queue=q)

print("Brute Force Distributed kNN regression results:")
print("Ground truth (first 5 observations on rank {}):\n{}".format(rank, y_test[:5]))
print("Regression results (first 5 observations on rank {}):\n{}"
.format(rank, y_predict[:5]))
print("RMSE for entire rank {}: {}\n"
.format(rank, mean_squared_error(y_test, y_predict, squared=False)))
3 changes: 2 additions & 1 deletion onedal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@
__all__ += ['basic_statistics', 'linear_model']

if _is_dpc_backend:
__all__ += ['spmd.basic_statistics', 'spmd.decomposition', 'spmd.linear_model',]
__all__ += ['spmd.basic_statistics', 'spmd.decomposition',
'spmd.linear_model', 'spmd.neighbors']
6 changes: 6 additions & 0 deletions onedal/neighbors/neighbors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "oneapi/dal/algo/knn.hpp"

#include "onedal/common.hpp"
#include "onedal/version.hpp"
#include "onedal/primitives/pairwise_distances.hpp"
#include <regex>

Expand Down Expand Up @@ -313,8 +314,13 @@ ONEDAL_PY_INIT_MODULE(neighbors) {
using task_list = types<task::classification, task::regression, task::search>;
auto sub = m.def_submodule("neighbors");

#if defined(ONEDAL_DATA_PARALLEL_SPMD) && defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100
ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list_spmd, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list_spmd, task_list);
#else // defined(ONEDAL_DATA_PARALLEL_SPMD) && defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100
ONEDAL_PY_INSTANTIATE(init_train_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_ops, sub, policy_list, task_list);
#endif // defined(ONEDAL_DATA_PARALLEL_SPMD) && defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100

ONEDAL_PY_INSTANTIATE(init_model, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list);
Expand Down
19 changes: 13 additions & 6 deletions onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@


class NeighborsCommonBase(metaclass=ABCMeta):
def _get_policy(self, queue, *data):
return _get_policy(queue, *data)

def _parse_auto_method(self, method, n_samples, n_features):
result_method = method

Expand Down Expand Up @@ -402,7 +405,7 @@ def _onedal_fit(self, X, y, queue):

return train_alg(**params).compute(X, y).model

policy = _get_policy(queue, X, y)
policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)
params = self._get_onedal_params(X, y)
train_alg = _backend.neighbors.classification.train(policy, params,
Expand All @@ -421,7 +424,7 @@ def _onedal_predict(self, model, X, params, queue):

return predict_alg(**params).compute(X, model)

policy = _get_policy(queue, X)
policy = self._get_policy(queue, X)
X = _convert_to_supported(policy, X)
if hasattr(self, '_onedal_model'):
model = self._onedal_model
Expand Down Expand Up @@ -549,7 +552,8 @@ def _onedal_fit(self, X, y, queue):

return train_alg(**params).compute(X, y).model

policy = _get_policy(queue, X, y)
policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)
params = self._get_onedal_params(X, y)
train_alg_regr = _backend.neighbors.regression.train
train_alg_srch = _backend.neighbors.search.train
Expand All @@ -568,7 +572,8 @@ def _onedal_predict(self, model, X, params, queue):

return predict_alg(**params).compute(X, model)

policy = _get_policy(queue, X)
policy = self._get_policy(queue, X)
X = _convert_to_supported(policy, X)
backend = _backend.neighbors.regression if gpu_device \
else _backend.neighbors.search

Expand Down Expand Up @@ -678,7 +683,8 @@ def _onedal_fit(self, X, y, queue):

return train_alg(**params).compute(X, y).model

policy = _get_policy(queue, X, y)
policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)
params = self._get_onedal_params(X, y)
train_alg = _backend.neighbors.search.train(policy, params,
to_table(X))
Expand All @@ -696,7 +702,8 @@ def _onedal_predict(self, model, X, params, queue):

return predict_alg(**params).compute(X, model)

policy = _get_policy(queue, X)
policy = self._get_policy(queue, X)
X = _convert_to_supported(policy, X)
if hasattr(self, '_onedal_model'):
model = self._onedal_model
else:
Expand Down
2 changes: 1 addition & 1 deletion onedal/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#===============================================================================

__all__ = ['basic_statistics', 'decomposition', 'linear_model']
__all__ = ['basic_statistics', 'decomposition', 'linear_model', 'neighbors']
19 changes: 19 additions & 0 deletions onedal/spmd/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors

__all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'NearestNeighbors']
52 changes: 52 additions & 0 deletions onedal/spmd/neighbors/neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from abc import ABC
from ...common._spmd_policy import _get_spmd_policy
from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch


class NeighborsCommonBaseSPMD(ABC):
def _get_policy(self, queue, *data):
return _get_spmd_policy(queue)


class KNeighborsClassifier(NeighborsCommonBaseSPMD, KNeighborsClassifier_Batch):
def predict_proba(self, X, queue=None):
raise NotImplementedError("predict_proba not supported in distributed mode.")


class KNeighborsRegressor(NeighborsCommonBaseSPMD, KNeighborsRegressor_Batch):
def fit(self, X, y, queue=None):
if queue is not None and queue.sycl_device.is_gpu:
return super()._fit(X, y, queue=queue)
else:
raise ValueError('SPMD version of kNN is not implemented for '
'CPU. Consider running on it on GPU.')

def predict(self, X, queue=None):
return self._predict_gpu(X, queue=queue)

def _get_onedal_params(self, X, y=None):
params = super()._get_onedal_params(X, y)
if 'responses' not in params['result_option']:
params['result_option'] += '|responses'
return params


class NearestNeighbors(NeighborsCommonBaseSPMD):
pass
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def run(self):
'onedal.spmd.basic_statistics',
'onedal.spmd.decomposition',
'onedal.spmd.linear_model'
] if build_distribute else [])),
] + (['onedal.spmd.neighbors']
if ONEDAL_VERSION >= 20230100 else [])
if build_distribute else [])),
package_data={
'daal4py.oneapi': [
'liboneapi_backend.so',
Expand Down
2 changes: 1 addition & 1 deletion sklearnex/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#===============================================================================

__all__ = ['basic_statistics', 'decomposition', 'linear_model']
__all__ = ['basic_statistics', 'decomposition', 'linear_model', 'neighbors']
19 changes: 19 additions & 0 deletions sklearnex/spmd/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors

__all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'NearestNeighbors']
25 changes: 25 additions & 0 deletions sklearnex/spmd/neighbors/neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from onedal.spmd.neighbors import (
KNeighborsClassifier,
KNeighborsRegressor,
NearestNeighbors
)

# TODO:
# Currently it uses `onedal` module interface.
# Add sklearnex dispatching.
Loading

0 comments on commit 6ef0a5d

Please sign in to comment.