Skip to content

Commit

Permalink
Make creation of aligned mapping lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Sep 4, 2023
1 parent 22f33bb commit a9a10af
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 171 deletions.
61 changes: 47 additions & 14 deletions anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
from scipy.sparse import spmatrix

from ..utils import deprecated, ensure_df_homogeneous, dim_len
from ..utils import deprecated, ensure_df_homogeneous, dim_len, convert_to_dict
from . import raw, anndata
from .views import as_view, view_update
from .access import ElementRef
Expand Down Expand Up @@ -108,7 +108,7 @@ def parent(self) -> Union["anndata.AnnData", "raw.Raw"]:
return self._parent

def copy(self):
d = self._actual_class(self.parent, self._axis)
d = self._actual_class(self.parent, self._axis, {})
for k, v in self.items():
if isinstance(v, AwkArray):
# Shallow copy since awkward array buffers are immutable
Expand Down Expand Up @@ -262,6 +262,7 @@ def dim_names(self) -> pd.Index:
return (self.parent.obs_names, self.parent.var_names)[self._axis]


# TODO: vals can't be None
class AxisArrays(AlignedActualMixin, AxisArraysBase):
def __init__(
self,
Expand All @@ -273,9 +274,9 @@ def __init__(
if axis not in (0, 1):
raise ValueError()
self._axis = axis
self._data = dict()
if vals is not None:
self.update(vals)
for k, v in vals.items():
vals[k] = self._validate_value(v, k)
self._data = vals


class AxisArraysView(AlignedViewMixin, AxisArraysBase):
Expand Down Expand Up @@ -307,18 +308,21 @@ class LayersBase(AlignedMapping):

# TODO: I thought I had a more elegant solution to overriding this...
def copy(self) -> "Layers":
d = self._actual_class(self.parent)
d = self._actual_class(self.parent, vals={})
for k, v in self.items():
d[k] = v.copy()
return d


class Layers(AlignedActualMixin, LayersBase):
def __init__(self, parent: "anndata.AnnData", vals: Optional[Mapping] = None):
def __init__(
self, parent: "anndata.AnnData", axis=(0, 1), vals: Optional[Mapping] = None
):
assert axis == (0, 1), axis
self._parent = parent
self._data = dict()
if vals is not None:
self.update(vals)
for k, v in vals.items():
vals[k] = self._validate_value(v, k)
self._data = vals


class LayersView(AlignedViewMixin, LayersBase):
Expand Down Expand Up @@ -372,9 +376,9 @@ def __init__(
if axis not in (0, 1):
raise ValueError()
self._axis = axis
self._data = dict()
if vals is not None:
self.update(vals)
for k, v in vals.items():
vals[k] = self._validate_value(v, k)
self._data = vals


class PairwiseArraysView(AlignedViewMixin, PairwiseArraysBase):
Expand All @@ -386,9 +390,38 @@ def __init__(
):
self.parent_mapping = parent_mapping
self._parent = parent_view
self.subset_idx = (subset_idx, subset_idx)
self.subset_idx = subset_idx
self._axis = parent_mapping._axis


PairwiseArraysBase._view_class = PairwiseArraysView
PairwiseArraysBase._actual_class = PairwiseArrays


class AlignedMappingProperty:
def __init__(self, name, cls, axis):
self.name = name
self.axis = axis
self.cls = cls

def __get__(self, obj, objtype=None):
if obj.is_view:
parent_anndata = obj._adata_ref
idxs = (obj._oidx, obj._vidx)
parent_aligned_mapping = getattr(parent_anndata, self.name)
return parent_aligned_mapping._view(
obj, tuple(idxs[ax] for ax in parent_aligned_mapping.axes)
)
# return self.cls._view_class()
else:
return self.cls(obj, self.axis, getattr(obj, "_" + self.name))

def __set__(self, obj, value):
value = convert_to_dict(value)
_ = self.cls(obj, self.axis, value) # Validate
if obj.is_view:
obj._init_as_actual(obj.copy())
setattr(obj, "_" + self.name, value)

def __delete__(self, obj):
setattr(obj, self.name, dict())
160 changes: 13 additions & 147 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@
from .access import ElementRef
from .aligned_mapping import (
AxisArrays,
AxisArraysView,
PairwiseArrays,
PairwiseArraysView,
Layers,
LayersView,
AlignedMappingProperty,
)
from .views import (
ArrayView,
Expand All @@ -47,7 +45,7 @@
)
from .sparse_dataset import SparseDataset
from .. import utils
from ..utils import convert_to_dict, ensure_df_homogeneous, dim_len
from ..utils import ensure_df_homogeneous, dim_len
from ..logging import anndata_logger as logger
from ..compat import (
ZarrArray,
Expand Down Expand Up @@ -343,11 +341,6 @@ def _init_as_view(self, adata_ref: "AnnData", oidx: Index, vidx: Index):
# views on attributes of adata_ref
obs_sub = adata_ref.obs.iloc[oidx]
var_sub = adata_ref.var.iloc[vidx]
self._obsm = adata_ref.obsm._view(self, (oidx,))
self._varm = adata_ref.varm._view(self, (vidx,))
self._layers = adata_ref.layers._view(self, (oidx, vidx))
self._obsp = adata_ref.obsp._view(self, oidx)
self._varp = adata_ref.varp._view(self, vidx)
# fix categories
uns = copy(adata_ref._uns)
self._remove_unused_categories(adata_ref.obs, obs_sub, uns)
Expand Down Expand Up @@ -506,12 +499,11 @@ def _init_as_actual(
# unstructured annotations
self.uns = uns or OrderedDict()

# TODO: Think about consequences of making obsm a group in hdf
self._obsm = AxisArrays(self, 0, vals=convert_to_dict(obsm))
self._varm = AxisArrays(self, 1, vals=convert_to_dict(varm))
self.obsm = obsm
self.varm = varm

self._obsp = PairwiseArrays(self, 0, vals=convert_to_dict(obsp))
self._varp = PairwiseArrays(self, 1, vals=convert_to_dict(varp))
self.obsp = obsp
self.varp = varp

# Backwards compat for connectivities matrices in uns["neighbors"]
_move_adj_mtx({"uns": self._uns, "obsp": self._obsp})
Expand All @@ -536,7 +528,7 @@ def _init_as_actual(
self._clean_up_old_format(uns)

# layers
self._layers = Layers(self, layers)
self.layers = layers

def __sizeof__(self, show_stratified=None) -> int:
def get_size(X):
Expand Down Expand Up @@ -696,45 +688,11 @@ def X(self, value: Optional[Union[np.ndarray, sparse.spmatrix]]):
def X(self):
self.X = None

@property
def layers(self) -> Union[Layers, LayersView]:
"""\
Dictionary-like object with values of the same dimensions as :attr:`X`.
Layers in AnnData are inspired by loompy’s :ref:`loomlayers`.
Return the layer named `"unspliced"`::
adata.layers["unspliced"]
Create or replace the `"spliced"` layer::
adata.layers["spliced"] = ...
Assign the 10th column of layer `"spliced"` to the variable a::
a = adata.layers["spliced"][:, 10]
Delete the `"spliced"` layer::
del adata.layers["spliced"]
Return layers’ names::
adata.layers.keys()
"""
return self._layers

@layers.setter
def layers(self, value):
layers = Layers(self, vals=convert_to_dict(value))
if self.is_view:
self._init_as_actual(self.copy())
self._layers = layers

@layers.deleter
def layers(self):
self.layers = dict()
obsm = AlignedMappingProperty("obsm", AxisArrays, 0)
varm = AlignedMappingProperty("varm", AxisArrays, 1)
layers = AlignedMappingProperty("layers", Layers, (0, 1))
obsp = AlignedMappingProperty("obsp", PairwiseArrays, 0)
varp = AlignedMappingProperty("varp", PairwiseArrays, 1)

@property
def raw(self) -> Raw:
Expand Down Expand Up @@ -845,7 +803,7 @@ def _set_dim_index(self, value: pd.Index, attr: str):
if self.is_view:
self._init_as_actual(self.copy())
getattr(self, attr).index = value
for v in getattr(self, f"{attr}m").values():
for v in getattr(self, f"_{attr}m").values():
if isinstance(v, pd.DataFrame):
v.index = value

Expand Down Expand Up @@ -919,98 +877,6 @@ def uns(self, value: MutableMapping):
def uns(self):
self.uns = OrderedDict()

@property
def obsm(self) -> Union[AxisArrays, AxisArraysView]:
"""\
Multi-dimensional annotation of observations
(mutable structured :class:`~numpy.ndarray`).
Stores for each key a two or higher-dimensional :class:`~numpy.ndarray`
of length `n_obs`.
Is sliced with `data` and `obs` but behaves otherwise like a :term:`mapping`.
"""
return self._obsm

@obsm.setter
def obsm(self, value):
obsm = AxisArrays(self, 0, vals=convert_to_dict(value))
if self.is_view:
self._init_as_actual(self.copy())
self._obsm = obsm

@obsm.deleter
def obsm(self):
self.obsm = dict()

@property
def varm(self) -> Union[AxisArrays, AxisArraysView]:
"""\
Multi-dimensional annotation of variables/features
(mutable structured :class:`~numpy.ndarray`).
Stores for each key a two or higher-dimensional :class:`~numpy.ndarray`
of length `n_vars`.
Is sliced with `data` and `var` but behaves otherwise like a :term:`mapping`.
"""
return self._varm

@varm.setter
def varm(self, value):
varm = AxisArrays(self, 1, vals=convert_to_dict(value))
if self.is_view:
self._init_as_actual(self.copy())
self._varm = varm

@varm.deleter
def varm(self):
self.varm = dict()

@property
def obsp(self) -> Union[PairwiseArrays, PairwiseArraysView]:
"""\
Pairwise annotation of observations,
a mutable mapping with array-like values.
Stores for each key a two or higher-dimensional :class:`~numpy.ndarray`
whose first two dimensions are of length `n_obs`.
Is sliced with `data` and `obs` but behaves otherwise like a :term:`mapping`.
"""
return self._obsp

@obsp.setter
def obsp(self, value):
obsp = PairwiseArrays(self, 0, vals=convert_to_dict(value))
if self.is_view:
self._init_as_actual(self.copy())
self._obsp = obsp

@obsp.deleter
def obsp(self):
self.obsp = dict()

@property
def varp(self) -> Union[PairwiseArrays, PairwiseArraysView]:
"""\
Pairwise annotation of variables/features,
a mutable mapping with array-like values.
Stores for each key a two or higher-dimensional :class:`~numpy.ndarray`
whose first two dimensions are of length `n_var`.
Is sliced with `data` and `var` but behaves otherwise like a :term:`mapping`.
"""
return self._varp

@varp.setter
def varp(self, value):
varp = PairwiseArrays(self, 1, vals=convert_to_dict(value))
if self.is_view:
self._init_as_actual(self.copy())
self._varp = varp

@varp.deleter
def varp(self):
self.varp = dict()

def obs_keys(self) -> List[str]:
"""List keys of observation annotation :attr:`obs`."""
return self._obs.keys().tolist()
Expand Down
20 changes: 12 additions & 8 deletions anndata/_core/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from . import anndata
from .index import _normalize_index, _subset, unpack_index, get_vector
from .aligned_mapping import AxisArrays
from .aligned_mapping import AxisArrays, AlignedMappingProperty
from .sparse_dataset import SparseDataset

from ..compat import CupyArray, CupySparseMatrix
Expand All @@ -35,15 +35,17 @@ def __init__(
else:
self._X = X
self._var = _gen_dataframe(var, self.X.shape[1], ["var_names"])
self._varm = AxisArrays(self, 1, varm)
self.varm = varm
# self._varm = AxisArrays(self, 1, varm)
elif X is None: # construct from adata
# Move from GPU to CPU since it's large and not always used
if isinstance(adata.X, (CupyArray, CupySparseMatrix)):
self._X = adata.X.get()
else:
self._X = adata.X.copy()
self._var = adata.var.copy()
self._varm = AxisArrays(self, 1, adata.varm.copy())
self.varm = adata.varm.copy()
# self._varm = AxisArrays(self, 1, adata.varm.copy())
elif adata.isbacked:
raise ValueError("Cannot specify X if adata is backed")

Expand Down Expand Up @@ -95,9 +97,7 @@ def n_vars(self):
def n_obs(self):
return self._n_obs

@property
def varm(self):
return self._varm
varm = AlignedMappingProperty("varm", AxisArrays, 1)

@property
def var_names(self):
Expand All @@ -123,11 +123,15 @@ def __getitem__(self, index):

var = self._var.iloc[vidx]
new = Raw(self._adata, X=X, var=var)
if self._varm is not None:
if self.varm is not None:
# Since there is no view of raws
new._varm = self._varm._view(_RawViewHack(self, vidx), (vidx,)).copy()
new.varm = self.varm._view(_RawViewHack(self, vidx), (vidx,)).copy()
return new

@property
def is_view(self):
return False

def __str__(self):
descr = f"Raw AnnData with n_obs × n_vars = {self.n_obs} × {self.n_vars}"
for attr in ["var", "varm"]:
Expand Down
Loading

0 comments on commit a9a10af

Please sign in to comment.