From 64551aaf27d79741aede841d40029a8264d32c4a Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 17 Dec 2024 17:35:32 +0100 Subject: [PATCH] fix updating var with hvg multibatch --- src/rapids_singlecell/preprocessing/_hvg.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rapids_singlecell/preprocessing/_hvg.py b/src/rapids_singlecell/preprocessing/_hvg.py index 54d6785d..87a6bcd7 100644 --- a/src/rapids_singlecell/preprocessing/_hvg.py +++ b/src/rapids_singlecell/preprocessing/_hvg.py @@ -14,7 +14,7 @@ from rapids_singlecell._compat import DaskArray, _meta_dense, _meta_sparse -from ._qc import calculate_qc_metrics +from ._qc import _basic_qc from ._utils import _check_gpu_X, _check_nonnegative_integers, _get_mean_var if TYPE_CHECKING: @@ -434,8 +434,10 @@ def _highly_variable_genes_batched( for batch in batches: adata_subset = adata[adata.obs[batch_key] == batch] - calculate_qc_metrics(adata_subset, layer=layer) - filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0 + X = _get_obs_rep(adata_subset, layer=layer) + _check_gpu_X(X, allow_dask=True) + _, _, _, n_cells_per_gene = _basic_qc(X=X) + filt = (n_cells_per_gene > 0).get() adata_subset = adata_subset[:, filt] hvg = _highly_variable_genes_single_batch(