diff --git a/conftest.py b/conftest.py index 810fd833..55c07823 100644 --- a/conftest.py +++ b/conftest.py @@ -35,6 +35,15 @@ def netcdf4_file(tmpdir): return filepath +@pytest.fixture +def netcdf4_file_with_2d_coords(tmpdir): + ds = xr.tutorial.open_dataset("ROMS_example") + filepath = f"{tmpdir}/ROMS_example.nc" + ds.to_netcdf(filepath, format="NETCDF4") + ds.close() + return filepath + + @pytest.fixture def netcdf4_virtual_dataset(netcdf4_file): from virtualizarr import open_virtual_dataset diff --git a/virtualizarr/readers/common.py b/virtualizarr/readers/common.py index f6f5dff4..646d26ca 100644 --- a/virtualizarr/readers/common.py +++ b/virtualizarr/readers/common.py @@ -144,8 +144,13 @@ def separate_coords( coord_vars: dict[ str, tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]] | Variable ] = {} + found_coord_names: set[str] = set() + # Search through variable attributes for coordinate names + for var in vars.values(): + if "coordinates" in var.attrs: + found_coord_names.update(var.attrs["coordinates"].split(" ")) for name, var in vars.items(): - if name in coord_names or var.dims == (name,): + if name in coord_names or var.dims == (name,) or name in found_coord_names: # use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263 if len(var.dims) == 1: dim1d, *_ = var.dims diff --git a/virtualizarr/tests/test_backend.py b/virtualizarr/tests/test_backend.py index 43a6bbd8..e9b60814 100644 --- a/virtualizarr/tests/test_backend.py +++ b/virtualizarr/tests/test_backend.py @@ -156,6 +156,28 @@ def test_coordinate_variable_attrs_preserved(self, netcdf4_file): } +@requires_kerchunk +class TestDetermineCoords: + def test_infer_one_dimensional_coords(self, netcdf4_file): + vds = open_virtual_dataset(netcdf4_file, indexes={}) + assert set(vds.coords) == {"time", "lat", "lon"} + + def test_var_attr_coords(self, netcdf4_file_with_2d_coords): + vds = open_virtual_dataset(netcdf4_file_with_2d_coords, indexes={}) + + expected_dimension_coords = ["ocean_time", "s_rho"] + expected_2d_coords = ["lon_rho", "lat_rho", "h"] + expected_1d_non_dimension_coords = ["Cs_r"] + expected_scalar_coords = ["hc", "Vtransform"] + expected_coords = ( + expected_dimension_coords + + expected_2d_coords + + expected_1d_non_dimension_coords + + expected_scalar_coords + ) + assert set(vds.coords) == set(expected_coords) + + @network @requires_s3fs class TestReadFromS3: