Skip to content

Commit

Permalink
update to use cuvs for next rapids
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Dec 6, 2024
1 parent 7d10feb commit 7e62240
Showing 1 changed file with 68 additions and 31 deletions.
99 changes: 68 additions & 31 deletions src/rapids_singlecell/preprocessing/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@

import cupy as cp
import numpy as np
import pylibraft
from cuml.manifold.simpl_set import fuzzy_simplicial_set
from cupyx.scipy import sparse as cp_sparse
from packaging.version import parse as parse_version
from pylibraft.common import DeviceResources
from scipy import sparse as sc_sparse

from rapids_singlecell.tools._utils import _choose_representation

CUVS_SWITCH = parse_version(pylibraft.__version__) > parse_version("24.10")

if TYPE_CHECKING:
from collections.abc import Mapping

Expand Down Expand Up @@ -83,81 +87,114 @@ def _brute_knn(
def _cagra_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
) -> tuple[cp.ndarray, cp.ndarray]:
try:
from pylibraft.neighbors import cagra
except ImportError:
raise ImportError(
"The 'cagra' module is not available in your current RAFT installation. "
"Please update RAFT to a version that supports 'cagra'."
)
if not CUVS_SWITCH:
try:
from pylibraft.neighbors import cagra
except ImportError:
raise ImportError(
"The 'cagra' module is not available in your current RAFT installation. "
"Please update RAFT to a version that supports 'cagra'."
)
resources = DeviceResources()
build_kwargs = {"handle": resources}
search_kwargs = {"handle": resources}
else:
from cuvs.neighbors import cagra

resources = None
build_kwargs = {}
search_kwargs = {}

handle = DeviceResources()
build_params = cagra.IndexParams(metric="sqeuclidean", build_algo="nn_descent")
index = cagra.build(build_params, X, handle=handle)
index = cagra.build(build_params, X, **build_kwargs)

n_samples = Y.shape[0]
all_neighbors = cp.zeros((n_samples, k), dtype=cp.int32)
all_distances = cp.zeros((n_samples, k), dtype=cp.float32)

batchsize = 65000
n_batches = math.ceil(Y.shape[0] / batchsize)
n_batches = math.ceil(n_samples / batchsize)
for batch in range(n_batches):
start_idx = batch * batchsize
stop_idx = min(batch * batchsize + batchsize, Y.shape[0])
stop_idx = min((batch + 1) * batchsize, n_samples)
batch_Y = Y[start_idx:stop_idx, :]

search_params = cagra.SearchParams()
distances, neighbors = cagra.search(
search_params, index, batch_Y, k, handle=handle
search_params, index, batch_Y, k, **search_kwargs
)
all_neighbors[start_idx:stop_idx, :] = cp.asarray(neighbors)
all_distances[start_idx:stop_idx, :] = cp.asarray(distances)
handle.sync()

if resources is not None:
resources.sync()

if metric == "euclidean":
all_distances = cp.sqrt(all_distances)

return all_neighbors, all_distances


def _ivf_flat_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
) -> tuple[cp.ndarray, cp.ndarray]:
from pylibraft.neighbors import ivf_flat
if not CUVS_SWITCH:
from pylibraft.neighbors import ivf_flat

handle = DeviceResources()
if X.shape[0] < 2048:
n_lists = X.shape[0] // 2
resources = DeviceResources()
build_kwargs = {"handle": resources} # pylibraft uses 'handle'
search_kwargs = {"handle": resources}
else:
n_lists = 1024
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_flat.build(index_params, X, handle=handle)
from cuvs.neighbors import ivf_flat

resources = None
build_kwargs = {} # cuvs does not need handle/resources
search_kwargs = {}

n_lists = int(math.sqrt(X.shape[0]))
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_flat.build(index_params, X, **build_kwargs)
distances, neighbors = ivf_flat.search(
ivf_flat.SearchParams(), index, Y, k, handle=handle
ivf_flat.SearchParams(), index, Y, k, **search_kwargs
)

if resources is not None:
resources.sync()

distances = cp.asarray(distances)
neighbors = cp.asarray(neighbors)
handle.sync()

return neighbors, distances


def _ivf_pq_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
) -> tuple[cp.ndarray, cp.ndarray]:
from pylibraft.neighbors import ivf_pq
if not CUVS_SWITCH:
from pylibraft.neighbors import ivf_pq

handle = DeviceResources()
if X.shape[0] < 2048:
n_lists = X.shape[0] // 2
resources = DeviceResources()
build_kwargs = {"handle": resources}
search_kwargs = {"handle": resources}
else:
n_lists = 1024
index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_pq.build(index_params, X, handle=handle)
from cuvs.neighbors import ivf_pq

resources = None
build_kwargs = {}
search_kwargs = {}

n_lists = int(math.sqrt(X.shape[0]))
index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_pq.build(index_params, X, **build_kwargs)
distances, neighbors = ivf_pq.search(
ivf_pq.SearchParams(), index, Y, k, handle=handle
ivf_pq.SearchParams(), index, Y, k, **search_kwargs
)
if resources is not None:
resources.sync()

distances = cp.asarray(distances)
neighbors = cp.asarray(neighbors)
handle.sync()

return neighbors, distances


Expand Down

0 comments on commit 7e62240

Please sign in to comment.