Skip to content

Commit

Permalink
copy implementation from xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNicholas committed Dec 15, 2024
1 parent 1dbd119 commit a48e8a4
Showing 1 changed file with 203 additions and 1 deletion.
204 changes: 203 additions & 1 deletion virtualizarr/backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import os
import warnings
from collections.abc import Iterable, Mapping
from enum import Enum, auto
from functools import partial
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Sequence,
cast,
)

from xarray import Dataset, Index
from xarray import DataArray, Dataset, Index, combine_by_coords
from xarray.backends.api import _multi_file_closer
from xarray.backends.common import _find_absolute_paths
from xarray.core.combine import _infer_concat_order_from_positions, _nested_combine

from virtualizarr.manifests import ManifestArray
from virtualizarr.readers import (
Expand All @@ -22,6 +32,15 @@
from virtualizarr.readers.common import VirtualBackend
from virtualizarr.utils import _FsspecFSFromFilepath, check_for_collisions

if TYPE_CHECKING:
from xarray.core.types import (
CombineAttrsOptions,
CompatOptions,
JoinOptions,
NestedSequence,
)


# TODO add entrypoint to allow external libraries to add to this mapping
VIRTUAL_BACKENDS = {
"kerchunk": KerchunkVirtualBackend,
Expand Down Expand Up @@ -209,3 +228,186 @@ def open_virtual_dataset(
)

return vds


def open_virtual_mfdataset(
paths: str | Sequence[str | os.PathLike] | NestedSequence[str | os.PathLike],
concat_dim: (
str
| DataArray
| Index
| Sequence[str]
| Sequence[DataArray]
| Sequence[Index]
| None
) = None,
compat: CompatOptions = "no_conflicts",
preprocess: Callable[[Dataset], Dataset] | None = None,
data_vars: Literal["all", "minimal", "different"] | list[str] = "all",
coords="different",
combine: Literal["by_coords", "nested"] = "by_coords",
parallel: Literal["lithops", "dask", False] = False,
join: JoinOptions = "outer",
attrs_file: str | os.PathLike | None = None,
combine_attrs: CombineAttrsOptions = "override",
**kwargs,
) -> Dataset:
"""Open multiple files as a single virtual dataset
If combine='by_coords' then the function ``combine_by_coords`` is used to combine
the datasets into one before returning the result, and if combine='nested' then
``combine_nested`` is used. The filepaths must be structured according to which
combining function is used, the details of which are given in the documentation for
``combine_by_coords`` and ``combine_nested``. By default ``combine='by_coords'``
will be used. Global attributes from the ``attrs_file`` are used
for the combined dataset.
Parameters
----------
paths
Same as in xarray.open_mfdataset
concat_dim
Same as in xarray.open_mfdataset
compat
Same as in xarray.open_mfdataset
preprocess
Same as in xarray.open_mfdataset
data_vars
Same as in xarray.open_mfdataset
coords
Same as in xarray.open_mfdataset
combine
Same as in xarray.open_mfdataset
parallel : 'dask', 'lithops', or False
Specify whether the open and preprocess steps of this function will be
performed in parallel using ``dask.delayed``, in parallel using ``lithops.map``, or in serial.
Default is False.
join
Same as in xarray.open_mfdataset
attrs_file
Same as in xarray.open_mfdataset
combine_attrs
Same as in xarray.open_mfdataset
**kwargs : optional
Additional arguments passed on to :py:func:`virtualizarr.open_virtual_dataset`. For an
overview of some of the possible options, see the documentation of
:py:func:`virtualizarr.open_virtual_dataset`.
Returns
-------
xarray.Dataset
Notes
-----
The results of opening each virtual dataset in parallel are sent back to the client process, so must not be too large.
"""

# TODO this is practically all just copied from xarray.open_mfdataset - an argument for writing a virtualizarr engine for xarray?

# TODO add options passed to open_virtual_dataset explicitly?

paths = _find_absolute_paths(paths)

if not paths:
raise OSError("no files to open")

paths1d: list[str]
if combine == "nested":
if isinstance(concat_dim, str | DataArray) or concat_dim is None:
concat_dim = [concat_dim] # type: ignore[assignment]

# This creates a flat list which is easier to iterate over, whilst
# encoding the originally-supplied structure as "ids".
# The "ids" are not used at all if combine='by_coords`.
combined_ids_paths = _infer_concat_order_from_positions(paths)
ids, paths1d = (
list(combined_ids_paths.keys()),
list(combined_ids_paths.values()),
)
elif concat_dim is not None:
raise ValueError(
"When combine='by_coords', passing a value for `concat_dim` has no "
"effect. To manually combine along a specific dimension you should "
"instead specify combine='nested' along with a value for `concat_dim`.",
)
else:
paths1d = paths # type: ignore[assignment]

if parallel == "dask":
import dask

# wrap the open_dataset, getattr, and preprocess with delayed
open_ = dask.delayed(open_virtual_dataset)
getattr_ = dask.delayed(getattr)
if preprocess is not None:
preprocess = dask.delayed(preprocess)
elif parallel == "lithops":
raise NotImplementedError()
elif parallel is not False:
raise ValueError(
f"{parallel} is an invalid option for the keyword argument ``parallel``"
)
else:
open_ = open_virtual_dataset
getattr_ = getattr

datasets = [open_(p, **kwargs) for p in paths1d]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

if parallel == "dask":
# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
datasets, closers = dask.compute(datasets, closers)
elif parallel == "lithops":
raise NotImplementedError()

# Combine all datasets, closing them in case of a ValueError
try:
if combine == "nested":
# Combined nested list by successive concat and merge operations
# along each dimension, using structure given by "ids"
combined = _nested_combine(
datasets,
concat_dims=concat_dim,
compat=compat,
data_vars=data_vars,
coords=coords,
ids=ids,
join=join,
combine_attrs=combine_attrs,
)
elif combine == "by_coords":
# Redo ordering from coordinates, ignoring how they were ordered
# previously
combined = combine_by_coords(
datasets,
compat=compat,
data_vars=data_vars,
coords=coords,
join=join,
combine_attrs=combine_attrs,
)
else:
raise ValueError(
f"{combine} is an invalid option for the keyword argument"
" ``combine``"
)
except ValueError:
for ds in datasets:
ds.close()
raise

combined.set_close(partial(_multi_file_closer, closers))

# read global attributes from the attrs_file or from the first dataset
if attrs_file is not None:
if isinstance(attrs_file, os.PathLike):
attrs_file = cast(str, os.fspath(attrs_file))
combined.attrs = datasets[paths1d.index(attrs_file)].attrs

# TODO should we just immediately close everything?
# TODO We should have already read everything we're ever going to read into memory at this point

return combined

0 comments on commit a48e8a4

Please sign in to comment.