Skip to content

Commit

Permalink
add update
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Dec 16, 2024
1 parent 19e3602 commit e6f3c0d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/rapids_singlecell/preprocessing/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _normalize_total(X: ArrayTypesDask, target_sum: int):
return _normalize_total_csr(X, target_sum)
elif isinstance(X, DaskArray):
return _normalize_total_dask(X, target_sum)
else:
elif isinstance(X, cp.ndarray):
from ._kernels._norm_kernel import _mul_dense

if not X.flags.c_contiguous:
Expand All @@ -100,6 +100,8 @@ def _normalize_total(X: ArrayTypesDask, target_sum: int):
(X, X.shape[0], X.shape[1], int(target_sum)),
)
return X
else:
raise ValueError(f"Cannot normalize {type(X)}")


def _normalize_total_csr(X: sparse.csr_matrix, target_sum: int) -> sparse.csr_matrix:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file will be removed in Q3 2025 when in favor of the CUML implementation

from __future__ import annotations

import math
Expand Down
6 changes: 6 additions & 0 deletions src/rapids_singlecell/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def _get_mean_var(X, axis=0):
major = X.shape[1]
minor = X.shape[0]
mean, var = _mean_var_minor(X, major, minor)
else:
raise ValueError("axis must be either 0 or 1")
elif isinstance(X, DaskArray):
if isspmatrix_csr(X._meta):
if axis == 0:
Expand All @@ -243,6 +245,10 @@ def _get_mean_var(X, axis=0):
mean, var = _mean_var_major_dask(X, major, minor)
elif isinstance(X._meta, cp.ndarray):
mean, var = _mean_var_dense_dask(X, axis)
else:
raise ValueError(
"Type not supported. Please provide a CuPy ndarray or a CuPy sparse matrix. Or a Dask array with a CuPy ndarray or a CuPy sparse matrix as meta."
)
else:
mean, var = _mean_var_dense(X, axis)
return mean, var
Expand Down

0 comments on commit e6f3c0d

Please sign in to comment.