Skip to content

Commit

Permalink
Added Support for Reading Unfiltered Cell Ranger Feature Barcode Matr…
Browse files Browse the repository at this point in the history
…ices (MEX Format) (#123)

* Added functionality for valid barcode filtering on CrDirReader.
_get_valid_barcodes() -> Returns a list of valid barcodes (indices in barcode.tsv.gz) after filtering out background barcodes for a given threshold. Iteratively read the matrix in chunks and perform the following operations:
1. Create a collection of 'batch_size' entries
2. Sum the count data and apply threshold
3. Append to valid barcodes
Returns the list of valid barcodes.

consume() -> Yields chunks of data from the MTX file. Iteratively read the matrix in chunks and perform the following operations:
1. Keep adding chunk to collection until the unique barcode length reaches 'batch_size'
2. Rename collection barcodes (0 to batch_size-1)
3. Convert data to sparse and yield

* comments and clean up

* Consume Function Updated with Polars Functionality

* Updated the head reader method to include  instead of . Updated other polar readers as well. Added test for empty barcodes. Added small dataset for empty barcode test
  • Loading branch information
Gautam8387 authored Sep 30, 2024
1 parent 6ed28f7 commit dc08e06
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ docs/jupyter_execute
.pytest_cache
interrogate_badge.svg
scarf/tests/datasets/*
test_data
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ loguru
setuptools
packaging
importlib_metadata
polars
201 changes: 175 additions & 26 deletions scarf/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from abc import ABC, abstractmethod
from typing import Generator, Dict, List, Optional, Tuple
from typing import IO

import math
import h5py
import numpy as np
import pandas as pd
import polars as pl
from scipy.sparse import coo_matrix

from .utils import logger, tqdmbar
Expand Down Expand Up @@ -245,7 +246,7 @@ def _get_valid_barcodes(
indptr = self.grp["indptr"][:]
for s in tqdmbar(
range(0, len(indptr) - 1, batch_size),
desc=f"Filtering out background barcodes",
desc=f"Filtering out background barcodes", # noqa: F541
):
idx = indptr[s : s + batch_size + 1]
data = self.grp["data"][idx[0] : idx[-1]]
Expand Down Expand Up @@ -315,12 +316,21 @@ def __init__(
loc,
mtx_separator: str = " ",
index_offset: int = -1,
is_filtered: bool = True,
filtering_cutoff: int = 500,
):
self.loc: str = loc.rstrip("/") + "/"
self.matFn = None
self.sep = mtx_separator
self.indexOffset = index_offset
self.validBarcodeIdx = None
super().__init__(self._handle_version())
if is_filtered:
self.validBarcodeIdx = np.array(range(self.nCells))
self.validBarcodeIdx -= self.indexOffset
else:
self.validBarcodeIdx = self._get_valid_barcodes(filtering_cutoff)
self.nCells = len(self.validBarcodeIdx)

def _handle_version(self):
show_error = False
Expand Down Expand Up @@ -382,6 +392,97 @@ def _read_dataset(self, key: Optional[str] = None):
vals = None
return vals

def read_header(self) -> pl.DataFrame:
header = pl.read_csv(
self.matFn,
comment_prefix = '%',
separator=self.sep,
has_header=False,
n_rows=1,
new_columns=["nFeatures", "nCells", "nCounts"],
)
if header['nCells'][0] == 0 and self.nCells > 0:
raise ValueError("ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file")
if header['nCells'][0] > 0 and self.nCells == 0:
raise ValueError("ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file")
if header['nCells'][0] == 0 and self.nCells == 0:
raise ValueError("ERROR: Barcode count in MTX header and barcodes file is 0. No data to read")
return header

def process_batch(self, dfs: pl.DataFrame, filtering_cutoff: int) -> List:
"""Returns a list of valid barcodes after filtering out background barcodes for a given batch.
Args:
dfs: A Polar DataFrame containing a chunk of data from the MTX file.
filtering_cutoff: The cutoff value for filtering out background barcodes
"""
dfs_ = dfs.group_by('barcode').agg(pl.sum('count'))
dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff)
return np.sort(dfs_['barcode'])

def _get_valid_barcodes(
self, filtering_cutoff: int,
batch_size: int = int(10e4),
lines_in_mem: int = int(10e6)
) -> np.ndarray:
"""Returns a list of valid barcodes after filtering out background barcodes.
Args:
filtering_cutoff: The cutoff value for filtering out background barcodes.
batch_size: The number of barcodes to process in each batch.
lines_in_mem: The number of lines to read into memory
"""
test_counter = 0
matrixIO = pl.scan_csv(
self.matFn,
comment_prefix='%',
# skip_rows=3,
skip_rows_after_header=1,
separator=self.sep,
has_header=False,
)
assert len(matrixIO.collect_schema().names()) == 3
matrixIO = matrixIO.rename({'column_1': 'gene', 'column_2': 'barcode', 'column_3': 'count'})
header = self.read_header()
nChunks = math.ceil(header["nCounts"][0] / lines_in_mem)
test_counter = 0
valid_idx = []
start = 1
dfs = pl.DataFrame()
for i in tqdmbar(
range(nChunks), desc="Filtering out background barcodes"
):
chunk = matrixIO.slice(i*lines_in_mem, lines_in_mem).collect()
# Check if we've reached or exceeded the current batch boundary
if (chunk[-1]['barcode'][0] - start) >= batch_size: # If the last "cell id" is greater than the start + batch size
# Filter rows in the current chunk that belong to the current batch
idx = np.array(chunk['barcode'] < (batch_size + start)) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included.
# If no rows belong to the current batch, move to the next batch.
if idx.sum() == 0:
dfs = pl.concat([dfs, chunk])
start += batch_size
test_counter += len(chunk)
continue
# Process the rows belonging to the current batch
mask_pos = np.where(idx)[0]
mask_neg = np.where(~idx)[0]
dfs = pl.concat([dfs, chunk[mask_pos]])
valid_idx.append(self.process_batch(dfs, filtering_cutoff))
# Prepare for the next batch
del dfs
dfs = chunk[mask_neg]
start += batch_size
else:
# If we haven't reached the batch boundary, accumulate the chunk
dfs = pl.concat([dfs, chunk])
test_counter += len(chunk)
# Process any remaining data after the main loop
if len(dfs) > 0:
valid_idx.append(self.process_batch(dfs, filtering_cutoff))
# Verify that all rows were processed
assert test_counter == header["nCounts"][0]
return np.sort(np.unique(np.hstack(valid_idx)))

def to_sparse(self, a: np.ndarray, dtype) -> coo_matrix:
"""Returns the input data as a sparse (COO) matrix.
Expand All @@ -402,28 +503,77 @@ def to_sparse(self, a: np.ndarray, dtype) -> coo_matrix:
dtype=dtype,
)

def cell_names(self) -> List[str]:
"""Returns a list of names of the cells in the dataset."""
vals = np.array(self._read_dataset("cell_names"))
if self.validBarcodeIdx is not None:
vals = vals[(self.validBarcodeIdx + self.indexOffset)]
return list(vals)

def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List:
df = pl.concat(collect)
barcodes = np.array(df['barcode'])
count_hash = {}
for i, x in enumerate(np.unique(barcodes)):
count_hash[x] = i
cell_idx = np.array([count_hash[x] for x in barcodes])
df = df.with_columns([pl.Series("barcode", cell_idx)])
return np.array(df)

# noinspection DuplicatedCode
def consume(
self,
batch_size: int,
lines_in_mem: int = int(1e5),
dtype=np.uint32,
) -> Generator[coo_matrix, None, None]:
stream = pd.read_csv(
self.matFn, skiprows=3, sep=self.sep, header=None, chunksize=lines_in_mem
"""Yields chunks of data from the MTX file.
Args:
batch_size: The number of barcodes to process in each batch.
lines_in_mem: The number of lines to read into memory.
dtype: The data type of the matrix.
"""
matrixIO = pl.read_csv_batched(
self.matFn,
has_header=False,
separator=self.sep,
comment_prefix="%",
skip_rows_after_header=1,
new_columns=['gene', 'barcode', 'count'],
schema_overrides={'gene': pl.Int64, 'barcode': pl.Int64, 'count': pl.Int64},
batch_size=lines_in_mem
)
start = 1
dfs = []
for df in stream:
if (df.iloc[-1, 1] - start) >= batch_size:
idx = df[1] < (batch_size + start)
dfs.append(df[idx])
yield self.to_sparse(np.vstack(dfs), dtype=dtype)
dfs = [df[~idx]]
start += batch_size
unique_list = []
collect = []
while True:
chunk = matrixIO.next_batches(1)
if chunk is None:
break
chunk = chunk[0]
chunk = chunk.filter(pl.col('barcode').is_in(self.validBarcodeIdx))
in_uniques = np.unique(chunk['barcode'])
unique_list.extend(in_uniques)
unique_list = list(set(unique_list))
if len(unique_list) > batch_size:
diff = batch_size - (len(unique_list) - len(in_uniques))
mask_pos = in_uniques[:diff]
mask_neg = in_uniques[diff:]
extra = chunk.filter(pl.col('barcode').is_in(mask_pos))
collect.append(extra)
collect = self.rename_batches(collect, batch_size)
mtx = self.to_sparse(np.array(collect), dtype=dtype)
yield mtx
left_out = chunk.filter(pl.col('barcode').is_in(mask_neg))
collect = []
unique_list = list(mask_neg)
collect.append(left_out)
else:
dfs.append(df)
yield self.to_sparse(np.vstack(dfs), dtype=dtype)
collect.append(chunk)
if len(collect) > 0:
collect = self.rename_batches(collect, batch_size)
mtx = self.to_sparse(np.array(collect), dtype=dtype)
yield mtx


class H5adReader:
Expand Down Expand Up @@ -498,9 +648,9 @@ def _validate_group(self, group: str) -> int:
if group not in self.h5:
logger.warning(f"`{group}` group not found in the H5ad file")
ret_val = 0
elif type(self.h5[group]) == h5py.Dataset:
elif isinstance(self.h5[group], h5py.Dataset):
ret_val = 1
elif type(self.h5[group]) == h5py.Group:
elif isinstance(self.h5[group], h5py.Group):
ret_val = 2
else:
logger.warning(
Expand All @@ -518,7 +668,7 @@ def _validate_group(self, group: str) -> int:
[
self.h5[group][x].shape[0]
for x in self.h5[group].keys()
if type(self.h5[group][x]) == h5py.Dataset
if isinstance(self.h5[group][x], h5py.Dataset)
]
)
)
Expand Down Expand Up @@ -575,7 +725,7 @@ def _get_n(self, group: str) -> int:
return self.h5[group].shape[0]
else:
for i in self.h5[group].keys():
if type(self.h5[group][i]) == h5py.Dataset:
if isinstance(self.h5[group][i], h5py.Dataset):
return self.h5[group][i].shape[0]
raise KeyError(
f"ERROR: `{group}` key doesn't contain any child node of Dataset type."
Expand Down Expand Up @@ -625,7 +775,7 @@ def _replace_category_values(self, v: np.ndarray, key: str, group: str):
if self.catNamesKey is not None:
if self._check_exists(group, self.catNamesKey):
cat_g = self.h5[group][self.catNamesKey]
if type(cat_g) == h5py.Group:
if isinstance(cat_g, h5py.Group):
if key in cat_g:
c = cat_g[key][:]
try:
Expand Down Expand Up @@ -658,7 +808,7 @@ def _get_col_data(
):
if i in ignore_keys:
continue
if type(self.h5[group][i]) == h5py.Dataset:
if isinstance(self.h5[group][i], h5py.Dataset):
yield i, self._replace_category_values(
self.h5[group][i][:], i, group
)
Expand All @@ -677,12 +827,12 @@ def _get_obsm_data(
f" Will not save this specific slot into Zarr."
)
continue
if type(g) == h5py.Dataset:
if isinstance(g, h5py.Dataset):
for j in range(g.shape[1]):
yield f"{i}{j+1}", g[:, j]
else:
logger.warning(
f"Reading of obsm failed because it either does not exist or is not in expected format"
f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541
)

def get_cell_columns(self) -> Generator[Tuple[str, np.ndarray], None, None]:
Expand Down Expand Up @@ -787,8 +937,7 @@ def consume(self, batch_size: int = 100) -> Generator[np.ndarray, None, None]:
a[v["idx"]] = v["val"]
batch.append(a)
if len(batch) >= batch_size:
batch = np.array(batch)
yield batch
yield np.array(batch)
batch = []
if len(batch) > 0:
yield np.array(batch)
Expand Down Expand Up @@ -985,7 +1134,7 @@ def __init__(
if pandas_kwargs is None:
pandas_kwargs = {}
else:
if type(pandas_kwargs) != dict:
if not isinstance(pandas_kwargs, dict):
logger.error("")
if has_header is False:
has_header = None
Expand Down
Binary file added scarf/tests/datasets/toy_cr_dir_empty.tar.gz
Binary file not shown.
13 changes: 13 additions & 0 deletions scarf/tests/fixtures_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ def toy_crdir_reader():
yield reader
remove(out_fn)

@pytest.fixture(scope="session")
def toy_crdir_empty():
from ..readers import CrDirReader
import tarfile

fn = full_path("toy_cr_dir_empty.tar.gz")
out_fn = fn.replace(".tar.gz", "")
remove(out_fn)
tar = tarfile.open(fn, "r:gz")
tar.extractall(out_fn)
reader = CrDirReader(out_fn)
yield reader
remove(out_fn)

@pytest.fixture(scope="session")
def crh5_reader():
Expand Down
20 changes: 20 additions & 0 deletions scarf/tests/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ def test_toy_crdir_reader_cells_feats(toy_crdir_reader):
"a3",
]

def test_toy_crdir_empty(toy_crdir_empty):
assert toy_crdir_empty.nCells == 0
assert toy_crdir_empty.nFeatures == 4
assert toy_crdir_empty.feature_names() == [
"g1",
"a1",
"a2",
"g2",
]
assert toy_crdir_empty.feature_ids() == [
"g1",
"a1",
"a2",
"g2",
]
# check for raise ValueError
try:
toy_crdir_empty.read_header()
except ValueError:
pass

def test_crh5reader(crh5_reader):
assert crh5_reader.nCells == 892
Expand Down

0 comments on commit dc08e06

Please sign in to comment.