Skip to content

Commit

Permalink
remove dask.delayed
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Oct 16, 2024
1 parent 17ca2ef commit 06ce8e5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 63 deletions.
103 changes: 41 additions & 62 deletions src/rapids_singlecell/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,43 +128,33 @@ def _first_pass_qc(

block = (32,)
grid = (int(math.ceil(X.shape[0] / block[0])),)
sparse_qc_csr = _sparse_qc_csr(X.data.dtype)
sparse_qc_csr(
grid,
block,
(
X.indptr,
X.indices,
X.data,
sums_cells,
sums_genes,
cell_ex,
gene_ex,
X.shape[0],
),
)
call_shape = X.shape[0]
sparse_qc_kernel = _sparse_qc_csr(X.data.dtype)

elif sparse.isspmatrix_csc(X):
from ._kernels._qc_kernels import _sparse_qc_csc

block = (32,)
grid = (int(math.ceil(X.shape[1] / block[0])),)
sparse_qc_csc = _sparse_qc_csc(X.data.dtype)
sparse_qc_csc(
grid,
block,
(
X.indptr,
X.indices,
X.data,
sums_cells,
sums_genes,
cell_ex,
gene_ex,
X.shape[1],
),
)
call_shape = X.shape[1]
sparse_qc_kernel = _sparse_qc_csc(X.data.dtype)

else:
raise ValueError("Please use a csr or csc matrix")
sparse_qc_kernel(
grid,
block,
(
X.indptr,
X.indices,
X.data,
sums_cells,
sums_genes,
cell_ex,
gene_ex,
call_shape,
),
)
else:
from ._kernels._qc_kernels import _sparse_qc_dense

Expand All @@ -189,7 +179,6 @@ def _first_pass_qc_dask(
X: DaskArray,
) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]:
import dask
import dask.array as da

if isinstance(X._meta, sparse.csr_matrix):
from ._kernels._qc_kernels_dask import (
Expand All @@ -200,7 +189,6 @@ def _first_pass_qc_dask(
sparse_qc_csr_cells = _sparse_qc_csr_dask_cells(X.dtype)
sparse_qc_csr_cells.compile()

@dask.delayed
def __qc_calc_1(X_part):
sums_cells = cp.zeros(X_part.shape[0], dtype=X_part.dtype)
cell_ex = cp.zeros(X_part.shape[0], dtype=cp.int32)
Expand All @@ -219,12 +207,11 @@ def __qc_calc_1(X_part):
X_part.shape[0],
),
)
return cp.vstack([sums_cells, cell_ex.astype(X_part.dtype)])
return cp.stack([sums_cells, cell_ex.astype(X_part.dtype)], axis=1)

sparse_qc_csr_genes = _sparse_qc_csr_dask_genes(X.dtype)
sparse_qc_csr_genes.compile()

@dask.delayed
def __qc_calc_2(X_part):
sums_genes = cp.zeros(X_part.shape[1], dtype=X_part.dtype)
gene_ex = cp.zeros(X_part.shape[1], dtype=cp.int32)
Expand All @@ -241,7 +228,7 @@ def __qc_calc_2(X_part):
X_part.nnz,
),
)
return cp.vstack([sums_genes, gene_ex.astype(X_part.dtype)])
return cp.vstack([sums_genes, gene_ex.astype(X_part.dtype)])[None, ...]

elif isinstance(X._meta, cp.ndarray):
from ._kernels._qc_kernels_dask import (
Expand All @@ -252,7 +239,6 @@ def __qc_calc_2(X_part):
sparse_qc_dense_cells = _sparse_qc_dense_cells(X.dtype)
sparse_qc_dense_cells.compile()

@dask.delayed
def __qc_calc_1(X_part):
sums_cells = cp.zeros(X_part.shape[0], dtype=X_part.dtype)
cell_ex = cp.zeros(X_part.shape[0], dtype=cp.int32)
Expand All @@ -274,12 +260,11 @@ def __qc_calc_1(X_part):
X_part.shape[1],
),
)
return cp.vstack([sums_cells, cell_ex.astype(X_part.dtype)])
return cp.stack([sums_cells, cell_ex.astype(X_part.dtype)], axis=1)

sparse_qc_dense_genes = _sparse_qc_dense_genes(X.dtype)
sparse_qc_dense_genes.compile()

@dask.delayed
def __qc_calc_2(X_part):
sums_genes = cp.zeros((X_part.shape[1]), dtype=X_part.dtype)
gene_ex = cp.zeros((X_part.shape[1]), dtype=cp.int32)
Expand All @@ -301,35 +286,29 @@ def __qc_calc_2(X_part):
X_part.shape[1],
),
)
return cp.vstack([sums_genes, gene_ex.astype(X_part.dtype)])
return cp.vstack([sums_genes, gene_ex.astype(X_part.dtype)])[None, ...]
else:
raise ValueError(
"Please use a cupy csr_matrix or cp.ndarray. csc_matrix are not supported with dask."
)

blocks = X.to_delayed().ravel()
cell_blocks = [
da.from_delayed(
__qc_calc_1(block),
shape=(2, X.chunks[0][ind]),
dtype=X.dtype,
meta=cp.array([]),
)
for ind, block in enumerate(blocks)
]

blocks = X.to_delayed().ravel()
gene_blocks = [
da.from_delayed(
__qc_calc_2(block),
shape=(2, X.shape[1]),
dtype=X.dtype,
meta=cp.array([]),
)
for ind, block in enumerate(blocks)
]
sums_cells, cell_ex = da.hstack(cell_blocks)
sums_genes, gene_ex = da.stack(gene_blocks, axis=1).sum(axis=1)
cell_results = X.map_blocks(
__qc_calc_1,
chunks=(X.chunks[0], (2,)),
dtype=X.dtype,
meta=cp.empty((0, 2), dtype=X.dtype),
)
sums_cells = cell_results[:, 0]
cell_ex = cell_results[:, 1]

n_blocks = X.blocks.size
sums_genes, gene_ex = X.map_blocks(
__qc_calc_2,
new_axis=(1,),
chunks=((1,) * n_blocks, (2,), (X.shape[1],)),
dtype=X.dtype,
meta=cp.array([]),
).sum(axis=0)

sums_cells, cell_ex, sums_genes, gene_ex = dask.compute(
sums_cells, cell_ex, sums_genes, gene_ex
Expand Down
2 changes: 1 addition & 1 deletion tests/dask/test_qc_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
def test_qc_metrics_sparse(client, data_kind):
def test_qc_metrics(client, data_kind):
adata = pbmc3k()
adata.var["mt"] = adata.var_names.str.startswith("MT-")
dask_data = adata.copy()
Expand Down

0 comments on commit 06ce8e5

Please sign in to comment.