diff --git a/src/rapids_singlecell/preprocessing/_neighbors.py b/src/rapids_singlecell/preprocessing/_neighbors.py index db17f2dc..54b10084 100644 --- a/src/rapids_singlecell/preprocessing/_neighbors.py +++ b/src/rapids_singlecell/preprocessing/_neighbors.py @@ -6,13 +6,17 @@ import cupy as cp import numpy as np +import pylibraft from cuml.manifold.simpl_set import fuzzy_simplicial_set from cupyx.scipy import sparse as cp_sparse +from packaging.version import parse as parse_version from pylibraft.common import DeviceResources from scipy import sparse as sc_sparse from rapids_singlecell.tools._utils import _choose_representation +CUVS_SWITCH = parse_version(pylibraft.__version__) > parse_version("24.10") + if TYPE_CHECKING: from collections.abc import Mapping @@ -83,81 +87,114 @@ def _brute_knn( def _cagra_knn( X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping ) -> tuple[cp.ndarray, cp.ndarray]: - try: - from pylibraft.neighbors import cagra - except ImportError: - raise ImportError( - "The 'cagra' module is not available in your current RAFT installation. " - "Please update RAFT to a version that supports 'cagra'." - ) + if not CUVS_SWITCH: + try: + from pylibraft.neighbors import cagra + except ImportError: + raise ImportError( + "The 'cagra' module is not available in your current RAFT installation. " + "Please update RAFT to a version that supports 'cagra'." + ) + resources = DeviceResources() + build_kwargs = {"handle": resources} + search_kwargs = {"handle": resources} + else: + from cuvs.neighbors import cagra + + resources = None + build_kwargs = {} + search_kwargs = {} - handle = DeviceResources() build_params = cagra.IndexParams(metric="sqeuclidean", build_algo="nn_descent") - index = cagra.build(build_params, X, handle=handle) + index = cagra.build(build_params, X, **build_kwargs) n_samples = Y.shape[0] all_neighbors = cp.zeros((n_samples, k), dtype=cp.int32) all_distances = cp.zeros((n_samples, k), dtype=cp.float32) batchsize = 65000 - n_batches = math.ceil(Y.shape[0] / batchsize) + n_batches = math.ceil(n_samples / batchsize) for batch in range(n_batches): start_idx = batch * batchsize - stop_idx = min(batch * batchsize + batchsize, Y.shape[0]) + stop_idx = min((batch + 1) * batchsize, n_samples) batch_Y = Y[start_idx:stop_idx, :] + search_params = cagra.SearchParams() distances, neighbors = cagra.search( - search_params, index, batch_Y, k, handle=handle + search_params, index, batch_Y, k, **search_kwargs ) all_neighbors[start_idx:stop_idx, :] = cp.asarray(neighbors) all_distances[start_idx:stop_idx, :] = cp.asarray(distances) - handle.sync() + + if resources is not None: + resources.sync() + if metric == "euclidean": all_distances = cp.sqrt(all_distances) + return all_neighbors, all_distances def _ivf_flat_knn( X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping ) -> tuple[cp.ndarray, cp.ndarray]: - from pylibraft.neighbors import ivf_flat + if not CUVS_SWITCH: + from pylibraft.neighbors import ivf_flat - handle = DeviceResources() - if X.shape[0] < 2048: - n_lists = X.shape[0] // 2 + resources = DeviceResources() + build_kwargs = {"handle": resources} # pylibraft uses 'handle' + search_kwargs = {"handle": resources} else: - n_lists = 1024 - index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric) - index = ivf_flat.build(index_params, X, handle=handle) + from cuvs.neighbors import ivf_flat + resources = None + build_kwargs = {} # cuvs does not need handle/resources + search_kwargs = {} + + n_lists = int(math.sqrt(X.shape[0])) + index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric) + index = ivf_flat.build(index_params, X, **build_kwargs) distances, neighbors = ivf_flat.search( - ivf_flat.SearchParams(), index, Y, k, handle=handle + ivf_flat.SearchParams(), index, Y, k, **search_kwargs ) + + if resources is not None: + resources.sync() + distances = cp.asarray(distances) neighbors = cp.asarray(neighbors) - handle.sync() + return neighbors, distances def _ivf_pq_knn( X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping ) -> tuple[cp.ndarray, cp.ndarray]: - from pylibraft.neighbors import ivf_pq + if not CUVS_SWITCH: + from pylibraft.neighbors import ivf_pq - handle = DeviceResources() - if X.shape[0] < 2048: - n_lists = X.shape[0] // 2 + resources = DeviceResources() + build_kwargs = {"handle": resources} + search_kwargs = {"handle": resources} else: - n_lists = 1024 - index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric) - index = ivf_pq.build(index_params, X, handle=handle) + from cuvs.neighbors import ivf_pq + + resources = None + build_kwargs = {} + search_kwargs = {} + n_lists = int(math.sqrt(X.shape[0])) + index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric) + index = ivf_pq.build(index_params, X, **build_kwargs) distances, neighbors = ivf_pq.search( - ivf_pq.SearchParams(), index, Y, k, handle=handle + ivf_pq.SearchParams(), index, Y, k, **search_kwargs ) + if resources is not None: + resources.sync() + distances = cp.asarray(distances) neighbors = cp.asarray(neighbors) - handle.sync() + return neighbors, distances