Skip to content

Commit

Permalink
add 2d coords test
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushnag committed Oct 30, 2024
1 parent 245bce0 commit fc53633
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def netcdf4_file(tmpdir):
return filepath


@pytest.fixture
def netcdf4_file_with_2d_coords(tmpdir):
ds = xr.tutorial.open_dataset("ROMS_example")
filepath = f"{tmpdir}/ROMS_example.nc"
ds.to_netcdf(filepath, format="NETCDF4")
ds.close()
return filepath


@pytest.fixture
def netcdf4_virtual_dataset(netcdf4_file):
from virtualizarr import open_virtual_dataset
Expand Down
3 changes: 2 additions & 1 deletion virtualizarr/readers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ def separate_coords(
] = {}
found_coord_names: set[str] = set()
# Search through variable attributes for coordinate names
for name, var in vars.items():
for var in vars.values():
if "coordinates" in var.attrs:
found_coord_names.update(var.attrs["coordinates"].split(" "))
for name, var in vars.items():
if name in coord_names or var.dims == (name,) or name in found_coord_names:
# use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263
if len(var.dims) == 1:
Expand Down
17 changes: 17 additions & 0 deletions virtualizarr/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ def test_coordinate_variable_attrs_preserved(self, netcdf4_file):
}


class TestDetermineCoords:
def test_determine_all_coords(self, netcdf4_file_with_2d_coords):
vds = open_virtual_dataset(netcdf4_file_with_2d_coords, indexes={})

expected_dimension_coords = ["ocean_time", "s_rho"]
expected_2d_coords = ["lon_rho", "lat_rho", "h"]
expected_1d_non_dimension_coords = ["Cs_r"]
expected_scalar_coords = ["hc", "Vtransform"]
expected_coords = (
expected_dimension_coords
+ expected_2d_coords
+ expected_1d_non_dimension_coords
+ expected_scalar_coords
)
assert set(vds.coords) == set(expected_coords)


@network
@requires_s3fs
class TestReadFromS3:
Expand Down

0 comments on commit fc53633

Please sign in to comment.