Skip to content

Commit

Permalink
Backport PR scverse#1079: Dask Distributed Write Fix For Zarr
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored and meeseeksmachine committed Aug 25, 2023
1 parent 2c0903a commit 121df54
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
29 changes: 28 additions & 1 deletion anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@
H5File = h5py.File


####################
# Dask utils #
####################

try:
from dask.utils import SerializableLock as Lock
except ImportError:
from threading import Lock

# to fix https://github.com/dask/distributed/issues/780
GLOBAL_LOCK = Lock()

####################
# Dispatch methods #
####################
Expand Down Expand Up @@ -301,9 +313,24 @@ def write_basic(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):


@_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0"))
def write_basic_dask_zarr(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
import dask.array as da

g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
da.store(elem, g, lock=GLOBAL_LOCK)


# Adding this seperately because h5py isn't serializable
# https://github.com/pydata/xarray/issues/4242
@_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0"))
def write_basic_dask(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
def write_basic_dask_h5(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
import dask.array as da
import dask.config as dc

if dc.get("scheduler", None) == "dask.distributed":
raise ValueError(
"Cannot write dask arrays to hdf5 when using distributed scheduler"
)

g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
da.store(elem, g)
Expand Down
40 changes: 40 additions & 0 deletions anndata/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
gen_adata,
assert_equal,
)
from anndata.experimental import write_elem, read_elem
from anndata.experimental.merge import as_group
from anndata.compat import DaskArray

pytest.importorskip("dask.array")
Expand Down Expand Up @@ -94,6 +96,44 @@ def test_dask_write(adata, tmp_path, diskfmt):
assert isinstance(orig.varm["a"], DaskArray)


def test_dask_distributed_write(adata, tmp_path, diskfmt):
import dask.array as da
import dask.distributed as dd
import numpy as np

pth = tmp_path / f"test_write.{diskfmt}"
g = as_group(pth, mode="w")

with dd.LocalCluster(n_workers=1, threads_per_worker=1, processes=False) as cluster:
with dd.Client(cluster):
M, N = adata.X.shape
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
adata.varm["a"] = da.random.random((N, 10))
orig = adata
if diskfmt == "h5ad":
with pytest.raises(
ValueError, match="Cannot write dask arrays to hdf5"
):
write_elem(g, "", orig)
return
write_elem(g, "", orig)
curr = read_elem(g)

with pytest.raises(Exception):
assert_equal(curr.obsm["a"], curr.obsm["b"])

assert_equal(curr.varm["a"], orig.varm["a"])
assert_equal(curr.obsm["a"], orig.obsm["a"])

assert isinstance(curr.X, np.ndarray)
assert isinstance(curr.obsm["a"], np.ndarray)
assert isinstance(curr.varm["a"], np.ndarray)
assert isinstance(orig.X, DaskArray)
assert isinstance(orig.obsm["a"], DaskArray)
assert isinstance(orig.varm["a"], DaskArray)


def test_dask_to_memory_check_array_types(adata, tmp_path, diskfmt):
import dask.array as da
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ test = [
"joblib",
"boltons",
"scanpy",
"dask[array]",
"dask[array,distributed]",
"awkward>=2.3",
"pytest_memray",
]
Expand Down

0 comments on commit 121df54

Please sign in to comment.