Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KNN on max series seems slower than cuda-based implementation on comparable devices ? #1441

Open
fcharras opened this issue Sep 6, 2023 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@fcharras
Copy link

fcharras commented Sep 6, 2023

Initial report contained an error, please follow through the first comment for a better explanation.

import numpy as np
from sklearn.neighbors import NearestNeighbors
import sklearn

device = "
# device = "gpu:0"
from sklearnex import patch_sklearn
patch_sklearn()
sklearn.set_config(target_offload=f"{device}")

seed = 123
rng = np.random.default_rng(seed)

n_samples = 10_000_000
dim = 100
n_queries = 10_000
k = 100

data = rng.random((n_samples, dim), dtype=np.float32)
query = rng.random((n_queries, dim), dtype=np.float32)

knn = NearestNeighbors(n_neighbors=k, algorithm="brute")
knn.fit(data)
%time knn.kneighbors(X=query)

show following results:

  • if device=cpu:
CPU times: user 25min 40s, sys: 18 s, total: 25min 58s
Wall time: 14.1 s
  • if device=gpu (Max Series on intel beta cloud):
CPU times: user 25min 42s, sys: 21.7 s, total: 26min 4s
Wall time: 14.1 s

but one could expect a significant speedup on GPU.

Comparing on A100 with cuml implementation (in fact inherited from OSS implementation from FAISS):

import numpy as np
from cuml.neighbors import NearestNeighbors
import cupy

seed = 123
rng = np.random.default_rng(seed)

n_samples = 10_000_000
dim = 100
n_queries = 10_000
k = 100

data = rng.random((n_samples, dim), dtype=np.float32)
query = rng.random((n_queries, dim), dtype=np.float32)

data = cupy.asarray(data)
query = cupy.asarray(query)

knn = NearestNeighbors(n_neighbors=k, algorithm="brute")
knn.fit(data)
%time knn.kneighbors(X=query)

it's about 3sc:

CPU times: user 2.71 s, sys: 8.49 ms, total: 2.72 s
Wall time: 2.73 s
Also, looking at total total cpu times with scikit-learn-intelex it's unexpected that I see 25mins+ for both cpu and gpu runs despite the walltime being <15sc, it suggests cpu is also under heavy load for the gpu call snippet, is this possibility really dismissed by https://github.com//issues/1416 ?

Environment:

sklearn-intelex + dpcpp_cpp_rt install with conda with max series gpu on intel beta cloud.

@fcharras fcharras added the bug Something isn't working label Sep 6, 2023
@fcharras
Copy link
Author

fcharras commented Sep 6, 2023

There is actually an error in my initial snippet, in that it imports NearestNeighbors estimators before calling patch_sklearn, it should read:

import numpy as np
import sklearn

device = "cpu"
# device = "gpu:0"
from sklearnex import patch_sklearn, config_context
patch_sklearn()
from sklearn.neighbors import NearestNeighbors

seed = 123
rng = np.random.default_rng(seed)

n_samples = 10_000_000
dim = 100
n_queries = 10_000
k = 100

data = rng.random((n_samples, dim), dtype=np.float32)
query = rng.random((n_queries, dim), dtype=np.float32)

with config_context(target_offload=f"{device}"):
    knn = NearestNeighbors(n_neighbors=k, algorithm="brute")
    knn.fit(data)
    %time knn.kneighbors(X=query)

it significantly improves the walltime on cpu:

CPU times: user 6min 21s, sys: 4.6 s, total: 6min 26s
Wall time: 3.53 s

(NB: the CPU it runs on provides 254 cores, that's a lot of cores, users usually have easier access to middle-end gpus than workstation CPUs with 64cores+)

But still no luck running it on GPU, now I have the following error:

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
INFO:sklearnex: sklearn.utils.validation._assert_all_finite: running accelerated version on CPU
INFO:sklearnex: sklearn.neighbors.NearestNeighbors.fit: running accelerated version on CPU
INFO:sklearnex: sklearn.utils.validation._assert_all_finite: running accelerated version on CPU
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[1], line 27
     25 with config_context(target_offload=f"{device}"):
     26     knn = NearestNeighbors(n_neighbors=k, algorithm="brute")
---> 27     knn.fit(data)
     28     get_ipython().run_line_magic('time', 'knn.kneighbors(X=query)')

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/sklearnex/neighbors/knn_unsupervised.py:91, in NearestNeighbors.fit(self, X, y)
     89 def fit(self, X, y=None):
     90     self._fit_validation(X, y)
---> 91     dispatch(self, 'fit', {
     92         'onedal': self.__class__._onedal_fit,
     93         'sklearn': sklearn_NearestNeighbors.fit,
     94     }, X, None)
     95     return self

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/sklearnex/_device_offload.py:161, in dispatch(obj, method_name, branches, *args, **kwargs)
    158 backend, q, cpu_fallback = _get_backend(obj, q, method_name, *hostargs)
    160 if backend == 'onedal':
--> 161     return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
    162 if backend == 'sklearn':
    163     return branches[backend](obj, *hostargs, **hostkwargs)

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/sklearnex/neighbors/knn_unsupervised.py:144, in NearestNeighbors._onedal_fit(self, X, y, queue)
    142 self._onedal_estimator.effective_metric_ = self.effective_metric_
    143 self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
--> 144 self._onedal_estimator.fit(X, y, queue=queue)
    146 self._save_attributes()

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/onedal/neighbors/neighbors.py:722, in NearestNeighbors.fit(self, X, y, queue)
    721 def fit(self, X, y, queue=None):
--> 722     return super()._fit(X, y, queue=queue)

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/onedal/neighbors/neighbors.py:248, in NeighborsBase._fit(self, X, y, queue)
    246 if _is_classifier(self) or (_is_regressor(self) and gpu_device):
    247     _fit_y = self._validate_targets(self._y, X.dtype).reshape((-1, 1))
--> 248 result = self._onedal_fit(X, _fit_y, queue)
    250 if y is not None and _is_regressor(self):
    251     self._y = y if self._shape is None else y.reshape(self._shape)

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/onedal/neighbors/neighbors.py:690, in NearestNeighbors._onedal_fit(self, X, y, queue)
    686         train_alg = kdtree_knn_classification_training
    688     return train_alg(**params).compute(X, y).model
--> 690 policy = self._get_policy(queue, X, y)
    691 X, y = _convert_to_supported(policy, X, y)
    692 params = self._get_onedal_params(X, y)

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/onedal/neighbors/neighbors.py:48, in NeighborsCommonBase._get_policy(self, queue, *data)
     47 def _get_policy(self, queue, *data):
---> 48     return _get_policy(queue, *data)

File ~/mambaforge/envs/sklex/lib/python3.10/site-packages/onedal/common/_policy.py:33, in _get_policy(queue, *data)
     31         return _DataParallelInteropPolicy(data_queue)
     32     return _DataParallelInteropPolicy(queue)
---> 33 assert data_queue is None and queue is None
     34 return _HostInteropPolicy()

AssertionError: 

I thought about converting the data to on-device usm_ndarray beforehand:

import numpy as np
import sklearn
import dpctl.tensor as dpt

# device = "cpu"
device = "gpu"
from sklearnex import patch_sklearn, config_context
patch_sklearn()
from sklearn.neighbors import NearestNeighbors

seed = 123
rng = np.random.default_rng(seed)

n_samples = 10_000_000
dim = 100
n_queries = 10_000
k = 100

data = rng.random((n_samples, dim), dtype=np.float32)
query = rng.random((n_queries, dim), dtype=np.float32)

data = dpt.asarray(data)
query = dpt.asarray(query)

with config_context(target_offload=f"{device}"):
    knn = NearestNeighbors(n_neighbors=k, algorithm="brute")
    knn.fit(data)
    %time knn.kneighbors(X=query)

but then the compute will just hang and output nothing.

@fcharras fcharras changed the title Same KNN performance on CPU and GPU ? KNN on GPU errors out with AssertionError Sep 6, 2023
@fcharras
Copy link
Author

fcharras commented Sep 6, 2023

So I found out I had a version mismatch in the conda dependency tree if I don't install everything with the -c intel channel. It does not change the performance I got on CPU:

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
CPU times: user 6min 19s, sys: 4.03 s, total: 6min 23s
Wall time: 3.5 s

and now here's on GPU Max Series:

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)
CPU times: user 10.4 s, sys: 4.01 s, total: 14.4 s
Wall time: 14.5 s

this time it seems to work and to be properly dispatched to GPU. There's about a 5 times slowdown compared to the cuml backend on nvidia A100 (see report in the OP). The performance cap one can reach on intel Max Series is unknown but the gap still feel larger than it should be, judging by the respective GPU specs.

@fcharras fcharras changed the title KNN on GPU errors out with AssertionError KNN on max series seems slower than cuda-based implementation on comparable devices ? Sep 6, 2023
@samir-nasibli
Copy link
Contributor

@fcharras thank you for the report. Let me reproduce and investigate the issue.

@ethanglaser
Copy link
Contributor

Hi @fcharras, thank you for providing these results. We have reproduced the experiments and will create an internal feature request to identify ways to speed up this computation for more comparable results.

@ethanglaser ethanglaser added enhancement New feature or request and removed bug Something isn't working labels Feb 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants