diff --git a/src/rapids_singlecell/tools/_umap.py b/src/rapids_singlecell/tools/_umap.py index e075dc8c..7b2dffbf 100644 --- a/src/rapids_singlecell/tools/_umap.py +++ b/src/rapids_singlecell/tools/_umap.py @@ -2,8 +2,13 @@ from typing import TYPE_CHECKING, Literal +import cuml +import cupy as cp +from cuml.manifold.simpl_set import simplicial_set_embedding from cuml.manifold.umap import UMAP from cuml.manifold.umap_utils import find_ab_params +from cupyx.scipy import sparse +from packaging.version import parse as parse_version from scanpy._utils import NeighborsView from sklearn.utils import check_random_state @@ -149,37 +154,61 @@ def umap( ) # 0 is not a valid value for rapids, unlike original umap n_obs = adata.shape[0] - n_neighbors = neigh_params["n_neighbors"] - if neigh_params.get("method") == "rapids": - knn_dist = neighbors["distances"].data.reshape(n_obs, n_neighbors) - knn_indices = neighbors["distances"].indices.reshape(n_obs, n_neighbors) - pre_knn = (knn_indices, knn_dist) + if parse_version(cuml.__version__) < parse_version("24.10"): + n_neighbors = neigh_params["n_neighbors"] + if neigh_params.get("method") == "rapids": + knn_dist = neighbors["distances"].data.reshape(n_obs, n_neighbors) + knn_indices = neighbors["distances"].indices.reshape(n_obs, n_neighbors) + pre_knn = (knn_indices, knn_dist) + else: + pre_knn = None + + if init_pos == "auto": + init_pos = "spectral" if n_obs < 1000000 else "random" + + umap = UMAP( + n_neighbors=n_neighbors, + n_components=n_components, + metric=neigh_params.get("metric", "euclidean"), + metric_kwds=neigh_params.get("metric_kwds", None), + n_epochs=n_epochs, + learning_rate=alpha, + init=init_pos, + min_dist=min_dist, + spread=spread, + negative_sample_rate=negative_sample_rate, + a=a, + b=b, + random_state=random_state, + output_type="numpy", + precomputed_knn=pre_knn, + ) + + X_umap = umap.fit_transform(X) else: - pre_knn = None - - if init_pos == "auto": - init_pos = "spectral" if n_obs < 1000000 else "random" - - umap = UMAP( - n_neighbors=n_neighbors, - n_components=n_components, - metric=neigh_params.get("metric", "euclidean"), - metric_kwds=neigh_params.get("metric_kwds", None), - n_epochs=n_epochs, - learning_rate=alpha, - init=init_pos, - min_dist=min_dist, - spread=spread, - negative_sample_rate=negative_sample_rate, - a=a, - b=b, - random_state=random_state, - output_type="numpy", - precomputed_knn=pre_knn, - ) + pre_knn = neighbors["connectivities"] + + if init_pos == "auto": + init_pos = "spectral" if n_obs < 1000000 else "random" + + X_umap = simplicial_set_embedding( + data=cp.array(X), + graph=sparse.coo_matrix(pre_knn), + n_components=n_components, + initial_alpha=alpha, + a=a, + b=b, + negative_sample_rate=negative_sample_rate, + n_epochs=n_epochs, + init=init_pos, + random_state=random_state, + metric=neigh_params.get("metric", "euclidean"), + metric_kwds=neigh_params.get("metric_kwds", None), + ) + X_umap = cp.asarray(X_umap).get() key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 - adata.obsm[key_obsm] = umap.fit_transform(X) + adata.obsm[key_obsm] = X_umap adata.uns[key_uns] = {"params": stored_params} return adata if copy else None