diff --git a/virtualizarr/backend.py b/virtualizarr/backend.py index a8e3b66a..b743f202 100644 --- a/virtualizarr/backend.py +++ b/virtualizarr/backend.py @@ -1,13 +1,23 @@ +import os import warnings from collections.abc import Iterable, Mapping from enum import Enum, auto +from functools import partial from pathlib import Path from typing import ( + TYPE_CHECKING, Any, + Callable, + Literal, Optional, + Sequence, + cast, ) -from xarray import Dataset, Index +from xarray import DataArray, Dataset, Index, combine_by_coords +from xarray.backends.api import _multi_file_closer +from xarray.backends.common import _find_absolute_paths +from xarray.core.combine import _infer_concat_order_from_positions, _nested_combine from virtualizarr.manifests import ManifestArray from virtualizarr.readers import ( @@ -22,6 +32,15 @@ from virtualizarr.readers.common import VirtualBackend from virtualizarr.utils import _FsspecFSFromFilepath, check_for_collisions +if TYPE_CHECKING: + from xarray.core.types import ( + CombineAttrsOptions, + CompatOptions, + JoinOptions, + NestedSequence, + ) + + # TODO add entrypoint to allow external libraries to add to this mapping VIRTUAL_BACKENDS = { "kerchunk": KerchunkVirtualBackend, @@ -209,3 +228,186 @@ def open_virtual_dataset( ) return vds + + +def open_virtual_mfdataset( + paths: str | Sequence[str | os.PathLike] | NestedSequence[str | os.PathLike], + concat_dim: ( + str + | DataArray + | Index + | Sequence[str] + | Sequence[DataArray] + | Sequence[Index] + | None + ) = None, + compat: CompatOptions = "no_conflicts", + preprocess: Callable[[Dataset], Dataset] | None = None, + data_vars: Literal["all", "minimal", "different"] | list[str] = "all", + coords="different", + combine: Literal["by_coords", "nested"] = "by_coords", + parallel: Literal["lithops", "dask", False] = False, + join: JoinOptions = "outer", + attrs_file: str | os.PathLike | None = None, + combine_attrs: CombineAttrsOptions = "override", + **kwargs, +) -> Dataset: + """Open multiple files as a single virtual dataset + + If combine='by_coords' then the function ``combine_by_coords`` is used to combine + the datasets into one before returning the result, and if combine='nested' then + ``combine_nested`` is used. The filepaths must be structured according to which + combining function is used, the details of which are given in the documentation for + ``combine_by_coords`` and ``combine_nested``. By default ``combine='by_coords'`` + will be used. Global attributes from the ``attrs_file`` are used + for the combined dataset. + + Parameters + ---------- + paths + Same as in xarray.open_mfdataset + concat_dim + Same as in xarray.open_mfdataset + compat + Same as in xarray.open_mfdataset + preprocess + Same as in xarray.open_mfdataset + data_vars + Same as in xarray.open_mfdataset + coords + Same as in xarray.open_mfdataset + combine + Same as in xarray.open_mfdataset + parallel : 'dask', 'lithops', or False + Specify whether the open and preprocess steps of this function will be + performed in parallel using ``dask.delayed``, in parallel using ``lithops.map``, or in serial. + Default is False. + join + Same as in xarray.open_mfdataset + attrs_file + Same as in xarray.open_mfdataset + combine_attrs + Same as in xarray.open_mfdataset + **kwargs : optional + Additional arguments passed on to :py:func:`virtualizarr.open_virtual_dataset`. For an + overview of some of the possible options, see the documentation of + :py:func:`virtualizarr.open_virtual_dataset`. + + Returns + ------- + xarray.Dataset + + Notes + ----- + The results of opening each virtual dataset in parallel are sent back to the client process, so must not be too large. + """ + + # TODO this is practically all just copied from xarray.open_mfdataset - an argument for writing a virtualizarr engine for xarray? + + # TODO add options passed to open_virtual_dataset explicitly? + + paths = _find_absolute_paths(paths) + + if not paths: + raise OSError("no files to open") + + paths1d: list[str] + if combine == "nested": + if isinstance(concat_dim, str | DataArray) or concat_dim is None: + concat_dim = [concat_dim] # type: ignore[assignment] + + # This creates a flat list which is easier to iterate over, whilst + # encoding the originally-supplied structure as "ids". + # The "ids" are not used at all if combine='by_coords`. + combined_ids_paths = _infer_concat_order_from_positions(paths) + ids, paths1d = ( + list(combined_ids_paths.keys()), + list(combined_ids_paths.values()), + ) + elif concat_dim is not None: + raise ValueError( + "When combine='by_coords', passing a value for `concat_dim` has no " + "effect. To manually combine along a specific dimension you should " + "instead specify combine='nested' along with a value for `concat_dim`.", + ) + else: + paths1d = paths # type: ignore[assignment] + + if parallel == "dask": + import dask + + # wrap the open_dataset, getattr, and preprocess with delayed + open_ = dask.delayed(open_virtual_dataset) + getattr_ = dask.delayed(getattr) + if preprocess is not None: + preprocess = dask.delayed(preprocess) + elif parallel == "lithops": + raise NotImplementedError() + elif parallel is not False: + raise ValueError( + f"{parallel} is an invalid option for the keyword argument ``parallel``" + ) + else: + open_ = open_virtual_dataset + getattr_ = getattr + + datasets = [open_(p, **kwargs) for p in paths1d] + closers = [getattr_(ds, "_close") for ds in datasets] + if preprocess is not None: + datasets = [preprocess(ds) for ds in datasets] + + if parallel == "dask": + # calling compute here will return the datasets/file_objs lists, + # the underlying datasets will still be stored as dask arrays + datasets, closers = dask.compute(datasets, closers) + elif parallel == "lithops": + raise NotImplementedError() + + # Combine all datasets, closing them in case of a ValueError + try: + if combine == "nested": + # Combined nested list by successive concat and merge operations + # along each dimension, using structure given by "ids" + combined = _nested_combine( + datasets, + concat_dims=concat_dim, + compat=compat, + data_vars=data_vars, + coords=coords, + ids=ids, + join=join, + combine_attrs=combine_attrs, + ) + elif combine == "by_coords": + # Redo ordering from coordinates, ignoring how they were ordered + # previously + combined = combine_by_coords( + datasets, + compat=compat, + data_vars=data_vars, + coords=coords, + join=join, + combine_attrs=combine_attrs, + ) + else: + raise ValueError( + f"{combine} is an invalid option for the keyword argument" + " ``combine``" + ) + except ValueError: + for ds in datasets: + ds.close() + raise + + combined.set_close(partial(_multi_file_closer, closers)) + + # read global attributes from the attrs_file or from the first dataset + if attrs_file is not None: + if isinstance(attrs_file, os.PathLike): + attrs_file = cast(str, os.fspath(attrs_file)) + combined.attrs = datasets[paths1d.index(attrs_file)].attrs + + # TODO should we just immediately close everything? + # TODO We should have already read everything we're ever going to read into memory at this point + + return combined