Skip to content

Commit

Permalink
add log1p wraper
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Dec 17, 2024
1 parent 14268c5 commit 0704cac
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/rapids_singlecell/preprocessing/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import cupy as cp
from cupyx.scipy import sparse
from cupyx.scipy.sparse import csr_matrix
from scanpy.get import _get_obs_rep, _set_obs_rep

from rapids_singlecell._compat import (
Expand All @@ -18,7 +19,7 @@

if TYPE_CHECKING:
from anndata import AnnData
from cupyx.scipy.sparse import csr_matrix, spmatrix
from cupyx.scipy.sparse import spmatrix

from rapids_singlecell._utils import ArrayTypesDask

Expand Down Expand Up @@ -211,6 +212,19 @@ def __sum(X_part):
return target_sum


def _calc_log1p(X: ArrayTypesDask) -> ArrayTypesDask:
if isinstance(X, DaskArray):
meta = _meta_sparse if isinstance(X._meta, csr_matrix) else _meta_dense
X = X.map_blocks(_calc_log1p, meta=meta(X.dtype))
else:
X = X.copy()
if sparse.issparse(X):
X = X.log1p()
else:
X = cp.log1p(X)
return X


def log1p(
adata: AnnData,
*,
Expand Down Expand Up @@ -253,8 +267,8 @@ def log1p(

if not inplace:
X = X.copy()

if isinstance(X, cp.ndarray):
"""
if isinstance(X, cp.ndarray):
X = cp.log1p(X)
elif sparse.issparse(X):
X = X.log1p()
Expand All @@ -263,6 +277,8 @@ def log1p(
X = X.map_blocks(lambda x: cp.log1p(x), meta=_meta_dense(X.dtype))
elif isinstance(X._meta, sparse.csr_matrix):
X = X.map_blocks(lambda x: x.log1p(), meta=_meta_sparse(X.dtype))
"""
X = _calc_log1p(X)
adata.uns["log1p"] = {"base": None}
if inplace:
_set_obs_rep(adata, X, layer=layer, obsm=obsm)
Expand Down

0 comments on commit 0704cac

Please sign in to comment.