diff --git a/src/rapids_singlecell/preprocessing/_kernels/_bbknn.py b/src/rapids_singlecell/preprocessing/_kernels/_bbknn.py index d6bb2b07..e8607efc 100644 --- a/src/rapids_singlecell/preprocessing/_kernels/_bbknn.py +++ b/src/rapids_singlecell/preprocessing/_kernels/_bbknn.py @@ -64,6 +64,7 @@ extern "C" __global__ void cut_smaller( const int *indptr, + const int * index, float *data, float* vals, int n_rows) { @@ -74,13 +75,15 @@ int start_idx = indptr[row_id]; int stop_idx = indptr[row_id+1]; - float cut = vals[row_id]; + float cut_row = vals[row_id]; for(int i = start_idx+threadIdx.x; i < stop_idx; i+= blockDim.x){ + float cut = max(vals[index[i]], cut_row); if(data[i] cp_sparse.csr_matrix: (cnts.data, cnts.indptr, cnts.shape[0], trim, vals_gpu), shared_mem=shared_mem_size, ) - - for _ in range(2): - cut_smaller_func( - (cnts.shape[0],), - (64,), - (cnts.indptr, cnts.data, vals_gpu, cnts.shape[0]), - ) - cnts.eliminate_zeros() - cnts = cnts.T.tocsr() + cut_smaller_func( + (cnts.shape[0],), + (64,), + (cnts.indptr, cnts.indices, cnts.data, vals_gpu, cnts.shape[0]), + ) + cnts.eliminate_zeros() return cnts