diff --git a/src/rapids_singlecell/preprocessing/_normalize.py b/src/rapids_singlecell/preprocessing/_normalize.py index e4f92912..0cc7df53 100644 --- a/src/rapids_singlecell/preprocessing/_normalize.py +++ b/src/rapids_singlecell/preprocessing/_normalize.py @@ -6,6 +6,7 @@ import cupy as cp from cupyx.scipy import sparse +from cupyx.scipy.sparse import csr_matrix from scanpy.get import _get_obs_rep, _set_obs_rep from rapids_singlecell._compat import ( @@ -18,7 +19,7 @@ if TYPE_CHECKING: from anndata import AnnData - from cupyx.scipy.sparse import csr_matrix, spmatrix + from cupyx.scipy.sparse import spmatrix from rapids_singlecell._utils import ArrayTypesDask @@ -211,6 +212,19 @@ def __sum(X_part): return target_sum +def _calc_log1p(X: ArrayTypesDask) -> ArrayTypesDask: + if isinstance(X, DaskArray): + meta = _meta_sparse if isinstance(X._meta, csr_matrix) else _meta_dense + X = X.map_blocks(_calc_log1p, meta=meta(X.dtype)) + else: + X = X.copy() + if sparse.issparse(X): + X = X.log1p() + else: + X = cp.log1p(X) + return X + + def log1p( adata: AnnData, *, @@ -253,8 +267,8 @@ def log1p( if not inplace: X = X.copy() - - if isinstance(X, cp.ndarray): + """ + if isinstance(X, cp.ndarray): X = cp.log1p(X) elif sparse.issparse(X): X = X.log1p() @@ -263,6 +277,8 @@ def log1p( X = X.map_blocks(lambda x: cp.log1p(x), meta=_meta_dense(X.dtype)) elif isinstance(X._meta, sparse.csr_matrix): X = X.map_blocks(lambda x: x.log1p(), meta=_meta_sparse(X.dtype)) + """ + X = _calc_log1p(X) adata.uns["log1p"] = {"base": None} if inplace: _set_obs_rep(adata, X, layer=layer, obsm=obsm)