Skip to content

Commit

Permalink
Merge branch 'main' into test-uv-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 authored Dec 6, 2024
2 parents 8b2c816 + 6962073 commit e02e25a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

autosummary_generate = True
autodoc_member_order = "bysource"
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft"]
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft", "cuvs"]
default_role = "literal"
napoleon_google_docstring = False
napoleon_numpy_docstring = True
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/0.10.12.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```{rubric} Features
```
* use `cuvs` over `raft` for `pp.neighbors` for `rapids>=24.12`{pr}`304` {smaller}`S Dicks`
```{rubric} Performance
```
```{rubric} Bug fixes
Expand Down
101 changes: 70 additions & 31 deletions src/rapids_singlecell/preprocessing/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import cupy as cp
import numpy as np
import pylibraft
from cuml.manifold.simpl_set import fuzzy_simplicial_set
from cupyx.scipy import sparse as cp_sparse
from packaging.version import parse as parse_version
from pylibraft.common import DeviceResources
from scipy import sparse as sc_sparse

Expand Down Expand Up @@ -59,6 +61,10 @@
_Metrics = _MetricsDense | _MetricsSparse


def _cuvs_switch():
return parse_version(pylibraft.__version__) > parse_version("24.10")


def _brute_knn(
X: cp_sparse.spmatrix | cp.ndarray,
Y: cp_sparse.spmatrix | cp.ndarray,
Expand All @@ -83,81 +89,114 @@ def _brute_knn(
def _cagra_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
) -> tuple[cp.ndarray, cp.ndarray]:
try:
from pylibraft.neighbors import cagra
except ImportError:
raise ImportError(
"The 'cagra' module is not available in your current RAFT installation. "
"Please update RAFT to a version that supports 'cagra'."
)
if not _cuvs_switch():
try:
from pylibraft.neighbors import cagra
except ImportError:
raise ImportError(
"The 'cagra' module is not available in your current RAFT installation. "
"Please update RAFT to a version that supports 'cagra'."
)
resources = DeviceResources()
build_kwargs = {"handle": resources}
search_kwargs = {"handle": resources}
else:
from cuvs.neighbors import cagra

resources = None
build_kwargs = {}
search_kwargs = {}

handle = DeviceResources()
build_params = cagra.IndexParams(metric="sqeuclidean", build_algo="nn_descent")
index = cagra.build(build_params, X, handle=handle)
index = cagra.build(build_params, X, **build_kwargs)

n_samples = Y.shape[0]
all_neighbors = cp.zeros((n_samples, k), dtype=cp.int32)
all_distances = cp.zeros((n_samples, k), dtype=cp.float32)

batchsize = 65000
n_batches = math.ceil(Y.shape[0] / batchsize)
n_batches = math.ceil(n_samples / batchsize)
for batch in range(n_batches):
start_idx = batch * batchsize
stop_idx = min(batch * batchsize + batchsize, Y.shape[0])
stop_idx = min((batch + 1) * batchsize, n_samples)
batch_Y = Y[start_idx:stop_idx, :]

search_params = cagra.SearchParams()
distances, neighbors = cagra.search(
search_params, index, batch_Y, k, handle=handle
search_params, index, batch_Y, k, **search_kwargs
)
all_neighbors[start_idx:stop_idx, :] = cp.asarray(neighbors)
all_distances[start_idx:stop_idx, :] = cp.asarray(distances)
handle.sync()

if resources is not None:
resources.sync()

if metric == "euclidean":
all_distances = cp.sqrt(all_distances)

return all_neighbors, all_distances


def _ivf_flat_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
) -> tuple[cp.ndarray, cp.ndarray]:
from pylibraft.neighbors import ivf_flat
if not _cuvs_switch():
from pylibraft.neighbors import ivf_flat

handle = DeviceResources()
if X.shape[0] < 2048:
n_lists = X.shape[0] // 2
resources = DeviceResources()
build_kwargs = {"handle": resources} # pylibraft uses 'handle'
search_kwargs = {"handle": resources}
else:
n_lists = 1024
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_flat.build(index_params, X, handle=handle)
from cuvs.neighbors import ivf_flat

resources = None
build_kwargs = {} # cuvs does not need handle/resources
search_kwargs = {}

n_lists = int(math.sqrt(X.shape[0]))
index_params = ivf_flat.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_flat.build(index_params, X, **build_kwargs)
distances, neighbors = ivf_flat.search(
ivf_flat.SearchParams(), index, Y, k, handle=handle
ivf_flat.SearchParams(), index, Y, k, **search_kwargs
)

if resources is not None:
resources.sync()

distances = cp.asarray(distances)
neighbors = cp.asarray(neighbors)
handle.sync()

return neighbors, distances


def _ivf_pq_knn(
X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping
) -> tuple[cp.ndarray, cp.ndarray]:
from pylibraft.neighbors import ivf_pq
if not _cuvs_switch():
from pylibraft.neighbors import ivf_pq

handle = DeviceResources()
if X.shape[0] < 2048:
n_lists = X.shape[0] // 2
resources = DeviceResources()
build_kwargs = {"handle": resources}
search_kwargs = {"handle": resources}
else:
n_lists = 1024
index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_pq.build(index_params, X, handle=handle)
from cuvs.neighbors import ivf_pq

resources = None
build_kwargs = {}
search_kwargs = {}

n_lists = int(math.sqrt(X.shape[0]))
index_params = ivf_pq.IndexParams(n_lists=n_lists, metric=metric)
index = ivf_pq.build(index_params, X, **build_kwargs)
distances, neighbors = ivf_pq.search(
ivf_pq.SearchParams(), index, Y, k, handle=handle
ivf_pq.SearchParams(), index, Y, k, **search_kwargs
)
if resources is not None:
resources.sync()

distances = cp.asarray(distances)
neighbors = cp.asarray(neighbors)
handle.sync()

return neighbors, distances


Expand Down

0 comments on commit e02e25a

Please sign in to comment.