Skip to content

Commit

Permalink
Merge pull request #160 from zmoon/lon-norm
Browse files Browse the repository at this point in the history
Longitude normalization
  • Loading branch information
zmoon authored Nov 21, 2024
2 parents 987a1fe + 1dc5a8b commit 92f9959
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 86 deletions.
157 changes: 78 additions & 79 deletions monet/monet_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def _monet_to_latlon(da):
return dset


def _dataset_to_monet(dset, lat_name="latitude", lon_name="longitude", latlon2d=False):
def _dataset_to_monet(
dset,
lat_name="latitude",
lon_name="longitude",
latlon2d=None,
lon180=None,
):
"""Rename xarray DataArray or Dataset coordinate variables for use with monet functions,
returning a new xarray object.
Expand All @@ -74,73 +80,68 @@ def _dataset_to_monet(dset, lat_name="latitude", lon_name="longitude", latlon2d=
Name of the latitude array.
lon_name : str
Name of the longitude array.
latlon2d : bool
latlon2d : bool, optional
Whether the latitude and longitude data is two-dimensional.
If unset (``None``), guess based on dim count.
lon180 : bool, optional
Whether the longitude values are in the range [-180, 180) already.
If true, longitude wrapping/normalization,
which can introduce small floating point errors, will be skipped.
If unset (``None``), compute min/max to determine.
"""
if "grid_xt" in dset.dims:
# GFS v16 file
try:
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")
elif isinstance(dset, xr.Dataset):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")
else:
raise ValueError
except ValueError:
print("dset must be an xarray.DataArray or xarray.Dataset")
if not isinstance(dset, (xr.DataArray, xr.Dataset)):
raise TypeError("dset must be an xarray.DataArray or xarray.Dataset")

if "grid_xt" in dset.dims: # GFS v16 file
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")
elif isinstance(dset, xr.Dataset):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")

if "south_north" in dset.dims: # WRF WPS file
dset = dset.rename(dict(south_north="y", west_east="x"))
try:
if isinstance(dset, xr.Dataset):
if "XLAT_M" in dset.data_vars:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
dset = dset.set_coords(["XLAT_M", "XLONG_M"])
elif "XLAT" in dset.data_vars:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()
dset = dset.set_coords(["XLAT", "XLONG"])
elif isinstance(dset, xr.DataArray):
if "XLAT_M" in dset.coords:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
elif "XLAT" in dset.coords:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()
else:
raise ValueError
except ValueError:
print("dset must be an Xarray.DataArray or Xarray.Dataset")
if isinstance(dset, xr.Dataset):
if "XLAT_M" in dset.data_vars:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
dset = dset.set_coords(["XLAT_M", "XLONG_M"])
elif "XLAT" in dset.data_vars:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()
dset = dset.set_coords(["XLAT", "XLONG"])
elif isinstance(dset, xr.DataArray):
if "XLAT_M" in dset.coords:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
elif "XLAT" in dset.coords:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()

# Rename lat/lon coordinates to 'latitude'/'longitude'
dset = _rename_to_monet_latlon(dset) # common cases
if (isinstance(dset, xr.Dataset) and not {"latitude", "longitude"} <= set(dset.variables)) or (
isinstance(dset, xr.DataArray) and not {"latitude", "longitude"} <= set(dset.coords)
):
dset = dset.rename({lat_name: "latitude", lon_name: "longitude"})

# Unstructured Grid
# lat & lon are not coordinate variables in unstructured grid
if dset.attrs.get("mio_has_unstructured_grid", False):
# only call rename and wrap_longitudes
dset = _rename_to_monet_latlon(dset)
# Maybe wrap longitudes
if lon180 is None:
lon180 = dset["longitude"].min() >= -180 and dset["longitude"].max() < 180
if not lon180:
dset["longitude"] = wrap_longitudes(dset["longitude"])

else:
dset = _rename_to_monet_latlon(dset)
latlon2d = True
# print(len(dset[lat_name].shape))
# print(dset)
if len(dset[lat_name].shape) < 2:
# print(dset[lat_name].shape)
latlon2d = False
if latlon2d is False:
try:
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name=lat_name, lon_name=lon_name)
elif isinstance(dset, xr.Dataset):
dset = _coards_to_netcdf(dset, lat_name=lat_name, lon_name=lon_name)
else:
raise ValueError
except ValueError:
print("dset must be an Xarray.DataArray or Xarray.Dataset")
else:
dset = _rename_to_monet_latlon(dset)
dset["longitude"] = wrap_longitudes(dset["longitude"])
# lat & lon are not coordinate variables in unstructured grid, so we're done
if dset.attrs.get("mio_has_unstructured_grid", False):
return dset

# Maybe convert 1-D lat/lon coords to 2-D
if latlon2d is None:
latlon2d = dset["latitude"].ndim >= 2
if not latlon2d:
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name="latitude", lon_name="longitude")
elif isinstance(dset, xr.Dataset):
dset = _coards_to_netcdf(dset, lat_name="latitude", lon_name="longitude")

return dset

Expand Down Expand Up @@ -171,7 +172,7 @@ def _rename_to_monet_latlon(ds):
elif "XLAT" in check_list:
return ds.rename({"XLAT": "latitude", "XLONG": "longitude"})
else:
return ds
return ds.copy()


def _coards_to_netcdf(dset, lat_name="lat", lon_name="lon"):
Expand All @@ -189,7 +190,7 @@ def _coards_to_netcdf(dset, lat_name="lat", lon_name="lon"):
"""
from numpy import arange, meshgrid

lon = wrap_longitudes(dset[lon_name])
lon = dset[lon_name]
lat = dset[lat_name]
lons, lats = meshgrid(lon, lat)
x = arange(len(lon))
Expand Down Expand Up @@ -218,7 +219,7 @@ def _dataarray_coards_to_netcdf(dset, lat_name="lat", lon_name="lon"):
"""
from numpy import arange, meshgrid

lon = wrap_longitudes(dset[lon_name])
lon = dset[lon_name]
lat = dset[lat_name]
lons, lats = meshgrid(lon, lat)
x = arange(len(lon))
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def _get_CoordinateDefinition(self, data=None):
g = geo.CoordinateDefinition(lats=self._obj.latitude, lons=self._obj.longitude)
return g

def remap_nearest(self, data, **kwargs):
def remap_nearest(self, data, radius_of_influence=1e6, **kwargs):
"""Remap `data` from another grid to the current self grid using pyresample
nearest-neighbor interpolation.
Expand All @@ -1213,16 +1214,20 @@ def remap_nearest(self, data, **kwargs):

# from .grids import get_generic_projection_from_proj4
# check to see if grid is supplied

source_data = _dataset_to_monet(data)
target_data = _dataset_to_monet(self._obj)
source = self._get_CoordinateDefinition(data=source_data)
target = self._get_CoordinateDefinition(data=target_data)
r = kd_tree.XArrayResamplerNN(source, target, **kwargs)
source = self._get_CoordinateDefinition(source_data)
target = self._get_CoordinateDefinition(target_data)
r = kd_tree.XArrayResamplerNN(
source, target, radius_of_influence=radius_of_influence, **kwargs
)
r.get_neighbour_info()
if isinstance(source_data, xr.DataArray):
result = r.get_sample_from_neighbour_info(source_data)
result.name = source_data.name
result["latitude"] = target_data.latitude
result["longitude"] = target_data.longitude

elif isinstance(source_data, xr.Dataset):
results = {}
Expand Down Expand Up @@ -1504,7 +1509,7 @@ def _get_CoordinateDefinition(self, data=None):
g = geo.CoordinateDefinition(lats=self._obj.latitude, lons=self._obj.longitude)
return g

def remap_nearest(self, data, radius_of_influence=1e6):
def remap_nearest(self, data, radius_of_influence=1e6, **kwargs):
"""Remap `data` from another grid to the current self grid using pyresample
nearest-neighbor interpolation.
Expand All @@ -1525,26 +1530,20 @@ def remap_nearest(self, data, radius_of_influence=1e6):

# from .grids import get_generic_projection_from_proj4
# check to see if grid is supplied
try:
check_error = False
if isinstance(data, xr.DataArray) or isinstance(data, xr.Dataset):
check_error = False
else:
check_error = True
if check_error:
raise TypeError
except TypeError:
print("data must be either an Xarray.DataArray or Xarray.Dataset")

source_data = _dataset_to_monet(data)
target_data = _dataset_to_monet(self._obj)
source = self._get_CoordinateDefinition(source_data)
target = self._get_CoordinateDefinition(target_data)
r = kd_tree.XArrayResamplerNN(source, target, radius_of_influence=radius_of_influence)
r = kd_tree.XArrayResamplerNN(
source, target, radius_of_influence=radius_of_influence, **kwargs
)
r.get_neighbour_info()
if isinstance(source_data, xr.DataArray):
result = r.get_sample_from_neighbour_info(source_data)
result.name = source_data.name
result["latitude"] = target_data.latitude
result["longitude"] = target_data.longitude

elif isinstance(source_data, xr.Dataset):
results = {}
Expand Down
11 changes: 6 additions & 5 deletions monet/util/combinetool.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def combine_da_to_da(source, target, *, merge=True, interp_time=False, **kwargs)
----------
source : xarray.DataArray or xarray.Dataset
Gridded data.
target : xarray.DataArray
target : xarray.DataArray or xarray.Dataset
Point observations.
merge : bool
If false, only return the interpolated source data.
Expand All @@ -87,13 +87,14 @@ def combine_da_to_da(source, target, *, merge=True, interp_time=False, **kwargs)
"""
from ..monet_accessor import _dataset_to_monet

target_fixed = _dataset_to_monet(target)
source_fixed = _dataset_to_monet(source)
output = target_fixed.monet.remap_nearest(source_fixed, **kwargs)
output = target.monet.remap_nearest(source, **kwargs)

if interp_time:
output = output.interp(time=target.time)

if merge:
output = xr.merge([target_fixed, output])
output = xr.merge([_dataset_to_monet(target), output])

return output


Expand Down
11 changes: 9 additions & 2 deletions tests/test_remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def test_combine_da_da():
},
)

# Longitude normalization introduces floating point error
x_ = (x + 180) % 360 - 180
assert not (x_ == x).any()
assert np.abs(x_ - x).max() < 5e-14

# Combine (find closest model grid cell to each obs point)
# NOTE: to use `merge`, must have matching `level` dims
new = combine_da_to_da(model, obs, merge=False, interp_time=False)
Expand All @@ -100,8 +105,10 @@ def test_combine_da_da():
assert float(new.longitude.max()) == pytest.approx(0.9)
assert float(new.latitude.min()) == pytest.approx(0.1)
assert float(new.latitude.max()) == pytest.approx(0.9)
assert (new.latitude.isel(x=0).values == obs.latitude.values).all()
assert np.allclose(new.longitude.isel(y=0).values, obs.longitude.values)

assert (obs.longitude.values == x).all(), "preserved"
assert (new.latitude.isel(x=0).values == obs.latitude.values).all(), "same as target"
assert (new.longitude.isel(y=0).values == obs.longitude.values).all(), "same as target"

# Use orthogonal selection to get track
a = new.data.values[:, new.y, new.x]
Expand Down

0 comments on commit 92f9959

Please sign in to comment.