Skip to content

Commit

Permalink
Specialize ExperimentAxisQuery.{obs,var} to pa.IntegerArray (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams authored Nov 5, 2024
1 parent 1f9317d commit 6243608
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
8 changes: 4 additions & 4 deletions python-spec/src/somacore/query/_fast_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

def read_csr(
matrix: scd.SparseNDArray,
obs_joinids: pa.Array,
var_joinids: pa.Array,
obs_joinids: pa.IntegerArray,
var_joinids: pa.IntegerArray,
index_factory: types.IndexFactory,
) -> "AccumulatedCSR":
if not isinstance(matrix, scd.SparseNDArray) or matrix.ndim != 2:
Expand Down Expand Up @@ -82,8 +82,8 @@ class _CSRAccumulator:

def __init__(
self,
obs_joinids: pa.Array,
var_joinids: pa.Array,
obs_joinids: pa.IntegerArray,
var_joinids: pa.IntegerArray,
pool: futures.Executor,
index_factory: types.IndexFactory,
):
Expand Down
18 changes: 9 additions & 9 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ def var(
platform_config=platform_config,
)

def obs_joinids(self) -> pa.Array:
def obs_joinids(self) -> pa.IntegerArray:
"""Returns ``obs`` ``soma_joinids`` as an Arrow array.
Lifecycle: maturing
"""
return self._joinids.obs

def var_joinids(self) -> pa.Array:
def var_joinids(self) -> pa.IntegerArray:
"""Returns ``var`` ``soma_joinids`` as an Arrow array.
Lifecycle: maturing
Expand Down Expand Up @@ -750,8 +750,8 @@ class _JoinIDCache:

owner: ExperimentAxisQuery

_cached_obs: Optional[pa.Array] = None
_cached_var: Optional[pa.Array] = None
_cached_obs: Optional[pa.IntegerArray] = None
_cached_var: Optional[pa.IntegerArray] = None

def _is_cached(self, axis: _Axis) -> bool:
field = "_cached_" + axis.value
Expand All @@ -767,7 +767,7 @@ def preload(self, pool: futures.ThreadPoolExecutor) -> None:
var_ft.result()

@property
def obs(self) -> pa.Array:
def obs(self) -> pa.IntegerArray:
"""Join IDs for the obs axis. Will load and cache if not already."""
if not self._cached_obs:
self._cached_obs = _load_joinids(
Expand All @@ -776,11 +776,11 @@ def obs(self) -> pa.Array:
return self._cached_obs

@obs.setter
def obs(self, val: pa.Array) -> None:
def obs(self, val: pa.IntegerArray) -> None:
self._cached_obs = val

@property
def var(self) -> pa.Array:
def var(self) -> pa.IntegerArray:
"""Join IDs for the var axis. Will load and cache if not already."""
if not self._cached_var:
self._cached_var = _load_joinids(
Expand All @@ -789,11 +789,11 @@ def var(self) -> pa.Array:
return self._cached_var

@var.setter
def var(self, val: pa.Array) -> None:
def var(self, val: pa.IntegerArray) -> None:
self._cached_var = val


def _load_joinids(df: data.DataFrame, axq: axis.AxisQuery) -> pa.Array:
def _load_joinids(df: data.DataFrame, axq: axis.AxisQuery) -> pa.IntegerArray:
tbl = df.read(
axq.coords,
value_filter=axq.value_filter,
Expand Down
6 changes: 3 additions & 3 deletions python-spec/src/somacore/query/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow as pa
from typing_extensions import Protocol

_Array = Union[npt.NDArray[np.int64], pa.Array]
_IntegerArray = Union[npt.NDArray[np.int64], pa.IntegerArray]


class IndexLike(Protocol):
Expand All @@ -19,11 +19,11 @@ class IndexLike(Protocol):
not as a full specification of the types and behavior of ``get_indexer``.
"""

def get_indexer(self, target: _Array) -> Any:
def get_indexer(self, target: _IntegerArray) -> Any:
"""Something compatible with Pandas' Index.get_indexer method."""


IndexFactory = Callable[[_Array], "IndexLike"]
IndexFactory = Callable[[_IntegerArray], "IndexLike"]
"""Function that builds an index over the given NDArray.
This interface is implemented by the callable ``pandas.Index``.
Expand Down

0 comments on commit 6243608

Please sign in to comment.