Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix group kwarg #338

Merged
merged 21 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ def pytest_runtest_setup(item):
)


@pytest.fixture
def empty_netcdf4_file(tmpdir):
# Set up example xarray dataset
ds = xr.Dataset() # Save it to disk as netCDF (in temporary directory)
filepath = f"{tmpdir}/empty.nc"
ds.to_netcdf(filepath, format="NETCDF4")
ds.close()

return filepath


@pytest.fixture
def netcdf4_file(tmpdir):
# Set up example xarray dataset
Expand All @@ -37,6 +48,18 @@ def netcdf4_file(tmpdir):
return filepath


@pytest.fixture
def netcdf4_file_with_data_in_multiple_groups(tmpdir):
filepath = str(tmpdir / "test.nc")

ds1 = xr.DataArray([1, 2, 3], name="foo").to_dataset()
ds1.to_netcdf(filepath)
ds2 = xr.DataArray([4, 5], name="bar").to_dataset()
ds2.to_netcdf(filepath, group="subgroup", mode="a")

return filepath


@pytest.fixture
def netcdf4_files_factory(tmpdir) -> callable:
def create_netcdf4_files(
Expand Down
4 changes: 3 additions & 1 deletion virtualizarr/readers/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def open_virtual_dataset(
filepath, inline_threshold=0, **reader_options
).translate()

refs = extract_group(refs, group)
# both group=None and group='' mean to read root group
if group:
refs = extract_group(refs, group)

virtual_vars, attrs, coord_names = virtual_vars_and_metadata_from_kerchunk_refs(
refs,
Expand Down
41 changes: 38 additions & 3 deletions virtualizarr/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import xarray as xr
import xarray.testing as xrt
from xarray import open_dataset
from xarray import Dataset, open_dataset
from xarray.core.indexes import Index

from virtualizarr import open_virtual_dataset
Expand Down Expand Up @@ -309,6 +309,43 @@ def test_virtualizarr_vs_local_nisar(self, hdf_backend):
xrt.assert_equal(dsXR, dsV)


@requires_kerchunk
def test_open_empty_group(empty_netcdf4_file):
vds = open_virtual_dataset(empty_netcdf4_file, indexes={})
assert isinstance(vds, xr.Dataset)
expected = Dataset()
xrt.assert_identical(vds, expected)


@requires_kerchunk
class TestOpenVirtualDatasetHDFGroup:
def test_open_subgroup(self, netcdf4_file_with_data_in_multiple_groups):
vds = open_virtual_dataset(
netcdf4_file_with_data_in_multiple_groups, group="subgroup", indexes={}
)
assert list(vds.variables) == ["bar"]
assert isinstance(vds["bar"].data, ManifestArray)
assert vds["bar"].shape == (2,)

def test_open_root_group_manually(self, netcdf4_file_with_data_in_multiple_groups):
vds = open_virtual_dataset(
netcdf4_file_with_data_in_multiple_groups, group="", indexes={}
)
assert list(vds.variables) == ["foo"]
assert isinstance(vds["foo"].data, ManifestArray)
assert vds["foo"].shape == (3,)

def test_open_root_group_by_default(
self, netcdf4_file_with_data_in_multiple_groups
):
vds = open_virtual_dataset(
netcdf4_file_with_data_in_multiple_groups, indexes={}
)
assert list(vds.variables) == ["foo"]
assert isinstance(vds["foo"].data, ManifestArray)
assert vds["foo"].shape == (3,)


@requires_kerchunk
class TestLoadVirtualDataset:
@pytest.mark.parametrize("hdf_backend", [HDF5VirtualBackend, HDFVirtualBackend])
Expand Down Expand Up @@ -356,8 +393,6 @@ def test_group_kwarg(self, hdf5_groups_file, hdf_backend):
hdf5_groups_file, group="doesnt_exist", backend=hdf_backend
)
if hdf_backend == HDF5VirtualBackend:
with pytest.raises(ValueError, match="Multiple HDF Groups found"):
open_virtual_dataset(hdf5_groups_file)
with pytest.raises(ValueError, match="not found in"):
open_virtual_dataset(hdf5_groups_file, group="doesnt_exist")

Expand Down
7 changes: 0 additions & 7 deletions virtualizarr/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from virtualizarr.tests import requires_kerchunk
from virtualizarr.translators.kerchunk import (
dataset_from_kerchunk_refs,
find_var_names,
)
from virtualizarr.zarr import ZArray

Expand Down Expand Up @@ -49,12 +48,6 @@ def test_kerchunk_roundtrip_in_memory_no_concat():
xrt.assert_equal(roundtrip, ds)


def test_no_duplicates_find_var_names():
"""Verify that we get a deduplicated list of var names"""
ref_dict = {"refs": {"x/something": {}, "x/otherthing": {}}}
assert len(find_var_names(ref_dict)) == 1


@requires_kerchunk
@pytest.mark.parametrize(
"inline_threshold, vars_to_inline",
Expand Down
72 changes: 39 additions & 33 deletions virtualizarr/translators/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,42 +43,40 @@ def virtual_vars_and_metadata_from_kerchunk_refs(
return virtual_vars, ds_attrs, coord_names


def extract_group(vds_refs: KerchunkStoreRefs, group: str | None) -> KerchunkStoreRefs:
"""Extract only the part of the kerchunk reference dict that is relevant to a single HDF group"""
def extract_group(vds_refs: KerchunkStoreRefs, group: str) -> KerchunkStoreRefs:
"""
Extract only the part of the kerchunk reference dict that is relevant to a single HDF group.

group : str
Should be a non-empty string
"""
hdf_groups = [
k.removesuffix(".zgroup") for k in vds_refs["refs"].keys() if ".zgroup" in k
]
if len(hdf_groups) == 1:
return vds_refs
else:
if group is None:
raise ValueError(
f"Multiple HDF Groups found. Must specify group= keyword to select one of {hdf_groups}"
)
else:
# Ensure supplied group kwarg is consistent with kerchunk keys
if not group.endswith("/"):
group += "/"
if group.startswith("/"):
group = group.removeprefix("/")

if group not in hdf_groups:
raise ValueError(f'Group "{group}" not found in {hdf_groups}')

# Filter by group prefix and remove prefix from all keys
groupdict = {
k.removeprefix(group): v
for k, v in vds_refs["refs"].items()
if k.startswith(group)
}
# Also remove group prefix from _ARRAY_DIMENSIONS
for k, v in groupdict.items():
if isinstance(v, str):
groupdict[k] = v.replace("\\/", "/").replace(group, "")

vds_refs["refs"] = groupdict
# Ensure supplied group kwarg is consistent with kerchunk keys
if not group.endswith("/"):
group += "/"
if group.startswith("/"):
group = group.removeprefix("/")

return KerchunkStoreRefs(vds_refs)
if group not in hdf_groups:
raise ValueError(f'Group "{group}" not found in {hdf_groups}')

# Filter by group prefix and remove prefix from all keys
groupdict = {
k.removeprefix(group): v
for k, v in vds_refs["refs"].items()
if k.startswith(group)
}
# Also remove group prefix from _ARRAY_DIMENSIONS
for k, v in groupdict.items():
if isinstance(v, str):
groupdict[k] = v.replace("\\/", "/").replace(group, "")

vds_refs["refs"] = groupdict

return KerchunkStoreRefs(vds_refs)


def virtual_vars_from_kerchunk_refs(
Expand Down Expand Up @@ -222,9 +220,17 @@ def find_var_names(ds_reference_dict: KerchunkStoreRefs) -> list[str]:
"""Find the names of zarr variables in this store/group."""

refs = ds_reference_dict["refs"]
found_var_names = {key.split("/")[0] for key in refs.keys() if "/" in key}

return list(found_var_names)
found_var_names = []
for key in refs.keys():
# has to capture "foo/.zarray", but ignore ".zgroup", ".zattrs", and "subgroup/bar/.zarray"
# TODO this might be a sign that we should introduce a KerchunkGroupRefs type and cut down the references before getting to this point...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be cleaner, but is a refactor that can be done afterwards.

if key not in (".zgroup", ".zattrs", ".zmetadata"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@norlandrhagen the only reason I needed the ".zmetadata" key in the check here is to make one kerchunk parquet test pass. Is it really the case that kerchunk parquet references use different keys than kerchunk json references? Or is that just a mistake in the fake kerchunk parquet data that we create and use in

with open(tmp_path / "refs" / ".zmetadata") as f:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think we were relying on the parquet directory to have a .zmetadata file so that we could identify the directory containing parquet's as parquet. #278 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's merge then.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I have another bug to fix first...

first_part, second_part, *_ = key.split("/")
if second_part == ".zarray":
found_var_names.append(first_part)

return found_var_names


def extract_array_refs(
Expand Down
Loading