From fc536331c582b20fc351d20231f768198fb87621 Mon Sep 17 00:00:00 2001 From: ayushnag <35325113+ayushnag@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:26:37 -0700 Subject: [PATCH] add 2d coords test --- conftest.py | 9 +++++++++ virtualizarr/readers/common.py | 3 ++- virtualizarr/tests/test_backend.py | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index 810fd833..55c07823 100644 --- a/conftest.py +++ b/conftest.py @@ -35,6 +35,15 @@ def netcdf4_file(tmpdir): return filepath +@pytest.fixture +def netcdf4_file_with_2d_coords(tmpdir): + ds = xr.tutorial.open_dataset("ROMS_example") + filepath = f"{tmpdir}/ROMS_example.nc" + ds.to_netcdf(filepath, format="NETCDF4") + ds.close() + return filepath + + @pytest.fixture def netcdf4_virtual_dataset(netcdf4_file): from virtualizarr import open_virtual_dataset diff --git a/virtualizarr/readers/common.py b/virtualizarr/readers/common.py index ff616a04..0d03cf97 100644 --- a/virtualizarr/readers/common.py +++ b/virtualizarr/readers/common.py @@ -153,9 +153,10 @@ def separate_coords( ] = {} found_coord_names: set[str] = set() # Search through variable attributes for coordinate names - for name, var in vars.items(): + for var in vars.values(): if "coordinates" in var.attrs: found_coord_names.update(var.attrs["coordinates"].split(" ")) + for name, var in vars.items(): if name in coord_names or var.dims == (name,) or name in found_coord_names: # use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263 if len(var.dims) == 1: diff --git a/virtualizarr/tests/test_backend.py b/virtualizarr/tests/test_backend.py index 43a6bbd8..61f03bab 100644 --- a/virtualizarr/tests/test_backend.py +++ b/virtualizarr/tests/test_backend.py @@ -156,6 +156,23 @@ def test_coordinate_variable_attrs_preserved(self, netcdf4_file): } +class TestDetermineCoords: + def test_determine_all_coords(self, netcdf4_file_with_2d_coords): + vds = open_virtual_dataset(netcdf4_file_with_2d_coords, indexes={}) + + expected_dimension_coords = ["ocean_time", "s_rho"] + expected_2d_coords = ["lon_rho", "lat_rho", "h"] + expected_1d_non_dimension_coords = ["Cs_r"] + expected_scalar_coords = ["hc", "Vtransform"] + expected_coords = ( + expected_dimension_coords + + expected_2d_coords + + expected_1d_non_dimension_coords + + expected_scalar_coords + ) + assert set(vds.coords) == set(expected_coords) + + @network @requires_s3fs class TestReadFromS3: