diff --git a/conftest.py b/conftest.py index 0be0a89c..bc1e6d47 100644 --- a/conftest.py +++ b/conftest.py @@ -48,6 +48,19 @@ def netcdf4_file(tmpdir): return filepath +@pytest.fixture +def chunked_netcdf4_file(tmpdir): + # Set up example xarray dataset + ds = xr.tutorial.open_dataset("air_temperature") + + # Save it to disk as netCDF (in temporary directory) + filepath = f"{tmpdir}/air.nc" + ds.chunk(time=1460).to_netcdf(filepath, format="NETCDF4") + ds.close() + + return filepath + + @pytest.fixture def netcdf4_file_with_data_in_multiple_groups(tmpdir): filepath = str(tmpdir / "test.nc") diff --git a/virtualizarr/backend.py b/virtualizarr/backend.py index cefa21ea..d6fc337b 100644 --- a/virtualizarr/backend.py +++ b/virtualizarr/backend.py @@ -255,7 +255,7 @@ def open_virtual_mfdataset( combine_attrs: "CombineAttrsOptions" = "override", **kwargs, ) -> Dataset: - """Open multiple files as a single virtual dataset + """Open multiple files as a single virtual dataset. If combine='by_coords' then the function ``combine_by_coords`` is used to combine the datasets into one before returning the result, and if combine='nested' then @@ -307,7 +307,7 @@ def open_virtual_mfdataset( # TODO this is practically all just copied from xarray.open_mfdataset - an argument for writing a virtualizarr engine for xarray? - # TODO add options passed to open_virtual_dataset explicitly? + # TODO list kwargs passed to open_virtual_dataset explicitly? paths = _find_absolute_paths(paths) diff --git a/virtualizarr/tests/test_backend.py b/virtualizarr/tests/test_backend.py index 0c054a3b..ee1832ce 100644 --- a/virtualizarr/tests/test_backend.py +++ b/virtualizarr/tests/test_backend.py @@ -8,7 +8,7 @@ from xarray import Dataset, open_dataset from xarray.core.indexes import Index -from virtualizarr import open_virtual_dataset +from virtualizarr import open_virtual_dataset, open_virtual_mfdataset from virtualizarr.backend import FileType, automatically_determine_filetype from virtualizarr.manifests import ManifestArray from virtualizarr.readers import HDF5VirtualBackend @@ -440,6 +440,41 @@ def test_open_dataset_with_scalar(self, hdf5_scalar, tmpdir, hdf_backend): assert vds.scalar.attrs == {"scalar": "true"} +class TestOpenVirtualMFDataset: + def test_serial(self, netcdf4_files_factory, chunked_netcdf4_file): + filepath1, filepath2 = netcdf4_files_factory() + + combined_vds = open_virtual_mfdataset( + [filepath1, filepath2], + combine="nested", + concat_dim="time", + coords="minimal", + compat="override", + indexes={}, + ) + expected_vds = open_virtual_dataset(chunked_netcdf4_file, indexes={}) + print(combined_vds["air"].data) + print(expected_vds["air"].data) + xrt.assert_identical(combined_vds, expected_vds) + + combined_vds = open_virtual_mfdataset( + [filepath1, filepath2], combine="by_coords" + ) + expected_vds = open_virtual_dataset(chunked_netcdf4_file) + xrt.assert_identical(combined_vds, expected_vds) + + file_glob = filepath1.parent.with_suffix("air*.nc") + combined_vds = open_virtual_mfdataset(file_glob, combine="by_coords") + expected_vds = open_virtual_dataset(chunked_netcdf4_file) + xrt.assert_identical(combined_vds, expected_vds) + + # @requires_dask + def test_dask(self, netcdf4_files_factory): ... + + # @requires_lithops + def test_lithops(self, netcdf4_files_factory): ... + + @requires_kerchunk @pytest.mark.parametrize( "reference_format",