-
Notifications
You must be signed in to change notification settings - Fork 179
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial draft of python knn SPMD interfaces * conditionally adding responses to regression * forgot this * trying moving get_queue * reverting last and restoring cpp to test * trying just infer * resolved issue and examples run * removing commented lines * ci update * reverting intel-dpcpp-cpp-compiler version change * reverting * reverting several commits... * addressing comments * flake8 * addressing comments * flake8 * specifying onedal version for spmd neighbors setup * temporary commit for debugging external CI fails * printing import error * adding conditional to neighbors.cp * adding version import * removing all logging and re-adding onedal version conditional in setup * addressing last comments * removing comm size warnings and cleaning up examples * example rename * update run_examples.py after rebase * flake8 * addressing non gpu spmd call, better multiline string
- Loading branch information
1 parent
cfc214c
commit 6ef0a5d
Showing
13 changed files
with
283 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#=============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#=============================================================================== | ||
|
||
import numpy as np | ||
from sklearn.metrics import accuracy_score | ||
from warnings import warn | ||
from mpi4py import MPI | ||
import dpctl | ||
from sklearnex.spmd.neighbors import KNeighborsClassifier | ||
|
||
|
||
def generate_X_y(par, seed): | ||
ns, nf = par['ns'], par['nf'] | ||
|
||
drng = np.random.default_rng(seed) | ||
data = drng.uniform(-1, 1, size=(ns, nf)) | ||
resp = (data > 0) @ (2 ** np.arange(nf)) | ||
|
||
return data, resp | ||
|
||
|
||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
|
||
if dpctl.has_gpu_devices: | ||
q = dpctl.SyclQueue("gpu") | ||
else: | ||
raise RuntimeError('GPU devices unavailable. Currently, ' | ||
'SPMD execution mode is implemented only for this device type.') | ||
|
||
params_train = {'ns': 100000, 'nf': 8} | ||
params_test = {'ns': 100, 'nf': 8} | ||
|
||
X_train, y_train = generate_X_y(params_train, rank) | ||
X_test, y_test = generate_X_y(params_test, rank + 99) | ||
|
||
model_spmd = KNeighborsClassifier(algorithm='brute', | ||
n_neighbors=20, | ||
weights='uniform', | ||
p=2, | ||
metric='minkowski') | ||
model_spmd.fit(X_train, y_train, queue=q) | ||
|
||
y_predict = model_spmd.predict(X_test, queue=q) | ||
|
||
print("Brute Force Distributed kNN classification results:") | ||
print("Ground truth (first 5 observations on rank {}):\n{}".format(rank, y_test[:5])) | ||
print("Classification results (first 5 observations on rank {}):\n{}" | ||
.format(rank, y_predict[:5])) | ||
print("Accuracy for entire rank {} (256 classes): {}\n" | ||
.format(rank, accuracy_score(y_test, y_predict))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#=============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#=============================================================================== | ||
|
||
import numpy as np | ||
from sklearn.metrics import mean_squared_error | ||
from warnings import warn | ||
from mpi4py import MPI | ||
import dpctl | ||
from numpy.testing import assert_allclose | ||
from sklearnex.spmd.neighbors import KNeighborsRegressor | ||
|
||
|
||
def generate_X_y(par, coef_seed, data_seed): | ||
ns, nf = par['ns'], par['nf'] | ||
|
||
crng = np.random.default_rng(coef_seed) | ||
coef = crng.uniform(-10, 10, size=(nf,)) | ||
|
||
drng = np.random.default_rng(data_seed) | ||
data = drng.uniform(-100, 100, size=(ns, nf)) | ||
resp = data @ coef | ||
|
||
return data, resp, coef | ||
|
||
|
||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
|
||
if dpctl.has_gpu_devices: | ||
q = dpctl.SyclQueue("gpu") | ||
else: | ||
raise RuntimeError('GPU devices unavailable. Currently, ' | ||
'SPMD execution mode is implemented only for this device type.') | ||
|
||
params_train = {'ns': 1000000, 'nf': 3} | ||
params_test = {'ns': 100, 'nf': 3} | ||
|
||
X_train, y_train, coef_train = generate_X_y(params_train, 10, rank) | ||
X_test, y_test, coef_test = generate_X_y(params_test, 10, rank + 99) | ||
|
||
assert_allclose(coef_train, coef_test) | ||
|
||
model_spmd = KNeighborsRegressor(algorithm='brute', | ||
n_neighbors=5, | ||
weights='uniform', | ||
p=2, | ||
metric='minkowski') | ||
model_spmd.fit(X_train, y_train, queue=q) | ||
|
||
y_predict = model_spmd.predict(X_test, queue=q) | ||
|
||
print("Brute Force Distributed kNN regression results:") | ||
print("Ground truth (first 5 observations on rank {}):\n{}".format(rank, y_test[:5])) | ||
print("Regression results (first 5 observations on rank {}):\n{}" | ||
.format(rank, y_predict[:5])) | ||
print("RMSE for entire rank {}: {}\n" | ||
.format(rank, mean_squared_error(y_test, y_predict, squared=False))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#=============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#=============================================================================== | ||
|
||
from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors | ||
|
||
__all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'NearestNeighbors'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#=============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#=============================================================================== | ||
|
||
from abc import ABC | ||
from ...common._spmd_policy import _get_spmd_policy | ||
from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch | ||
from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch | ||
|
||
|
||
class NeighborsCommonBaseSPMD(ABC): | ||
def _get_policy(self, queue, *data): | ||
return _get_spmd_policy(queue) | ||
|
||
|
||
class KNeighborsClassifier(NeighborsCommonBaseSPMD, KNeighborsClassifier_Batch): | ||
def predict_proba(self, X, queue=None): | ||
raise NotImplementedError("predict_proba not supported in distributed mode.") | ||
|
||
|
||
class KNeighborsRegressor(NeighborsCommonBaseSPMD, KNeighborsRegressor_Batch): | ||
def fit(self, X, y, queue=None): | ||
if queue is not None and queue.sycl_device.is_gpu: | ||
return super()._fit(X, y, queue=queue) | ||
else: | ||
raise ValueError('SPMD version of kNN is not implemented for ' | ||
'CPU. Consider running on it on GPU.') | ||
|
||
def predict(self, X, queue=None): | ||
return self._predict_gpu(X, queue=queue) | ||
|
||
def _get_onedal_params(self, X, y=None): | ||
params = super()._get_onedal_params(X, y) | ||
if 'responses' not in params['result_option']: | ||
params['result_option'] += '|responses' | ||
return params | ||
|
||
|
||
class NearestNeighbors(NeighborsCommonBaseSPMD): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#=============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#=============================================================================== | ||
|
||
from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors | ||
|
||
__all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'NearestNeighbors'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#=============================================================================== | ||
# Copyright 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#=============================================================================== | ||
|
||
from onedal.spmd.neighbors import ( | ||
KNeighborsClassifier, | ||
KNeighborsRegressor, | ||
NearestNeighbors | ||
) | ||
|
||
# TODO: | ||
# Currently it uses `onedal` module interface. | ||
# Add sklearnex dispatching. |
Oops, something went wrong.