Skip to content

Commit

Permalink
Extending CSR Accumulator to reuse query context instead of creating …
Browse files Browse the repository at this point in the history
…its own
  • Loading branch information
beroy committed Dec 5, 2023
1 parent 1d71cc1 commit 4b18ba8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
9 changes: 6 additions & 3 deletions python-spec/src/somacore/query/_fast_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt
import pyarrow as pa
from scipy import sparse

from typing import Optional, Any
from .. import data as scd
from . import _eager_iter
from . import types
Expand All @@ -19,6 +19,7 @@ def read_csr(
obs_joinids: pa.Array,
var_joinids: pa.Array,
index_factory: types.IndexFactory,
context: Optional[Any],
) -> "AccumulatedCSR":
if not isinstance(matrix, scd.SparseNDArray) or matrix.ndim != 2:
raise TypeError("Can only read from a 2D SparseNDArray")
Expand All @@ -30,6 +31,7 @@ def read_csr(
var_joinids=var_joinids,
pool=pool,
index_factory=index_factory,
context=context,
)
for tbl in _eager_iter.EagerIterator(
matrix.read((obs_joinids, var_joinids)).tables(),
Expand Down Expand Up @@ -86,14 +88,15 @@ def __init__(
var_joinids: npt.NDArray[np.int64],
pool: futures.Executor,
index_factory: types.IndexFactory,
context: Optional[Any],
):
self.obs_joinids = obs_joinids
self.var_joinids = var_joinids
self.pool = pool

self.shape: Tuple[int, int] = (len(self.obs_joinids), len(self.var_joinids))
self.obs_indexer = index_factory(self.obs_joinids)
self.var_indexer = index_factory(self.var_joinids)
self.obs_indexer = index_factory(self.obs_joinids, context)
self.var_indexer = index_factory(self.var_joinids, context)
self.row_length: npt.NDArray[np.int64] = np.zeros(
(self.shape[0],), dtype=_select_dtype(self.shape[1])
)
Expand Down
1 change: 1 addition & 0 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def _read(
self.obs_joinids(),
self.var_joinids(),
index_factory=self._index_factory,
context=self.experiment.context,
).to_scipy()
for _xname in all_x_arrays
}
Expand Down
6 changes: 3 additions & 3 deletions python-spec/src/somacore/query/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Common types used across SOMA query modules."""

from typing import Any, Callable

import numpy as np
import numpy.typing as npt

from typing import Optional, Any, Callable
from typing_extensions import Protocol


Expand All @@ -26,7 +26,7 @@ def get_indexer(
"""Something compatible with Pandas' Index.get_indexer method."""


IndexFactory = Callable[[npt.NDArray[np.int64]], "IndexLike"]
IndexFactory = Callable[[npt.NDArray[np.int64], Optional[Any]], "IndexLike"]
"""Function that builds an index over the given NDArray.
This interface is implemented by the callable ``pandas.Index``.
Expand Down

0 comments on commit 4b18ba8

Please sign in to comment.