diff --git a/.gitignore b/.gitignore index e1f3e35..bb66842 100644 --- a/.gitignore +++ b/.gitignore @@ -116,7 +116,14 @@ test_grid_visualize.png *.png *.jpg *.jpeg +*.gif public/ a.out *.o + +# editor backup files +# helix +\#*\# +# emacs +*~ diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index daba8bb..48c6cd6 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -16,8 +16,13 @@ from typing import Dict, Sequence import einops -import netCDF4 as nc import torch + +try: + import netCDF4 as nc +except ImportError: + nc = None + from scipy import spatial from earth2grid.spatial import ang2vec, haversine_distance @@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder": class TempestRegridder(torch.nn.Module): def __init__(self, file_path): super().__init__() + if nc is None: + raise ImportError("netCDF4 not imported. Please install for this feature.") + dataset = nc.Dataset(file_path) self.lat = dataset["latc_b"][:] self.lon = dataset["lonc_b"][:] diff --git a/earth2grid/latlon.py b/earth2grid/latlon.py index 4b52669..c6abcbe 100644 --- a/earth2grid/latlon.py +++ b/earth2grid/latlon.py @@ -25,14 +25,18 @@ class LatLonGrid(base.Grid): - def __init__(self, lat: list[float], lon: list[float]): + def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True): """ Args: lat: center of lat cells lon: center of lon cells + cylinder: if true, then lon is considered a periodic coordinate + on cylinder so that interpolation wraps around the edge. + Otherwise, it is assumed to be a finite plane. """ self._lat = lat self._lon = lon + self.cylinder = cylinder @property def lat(self): @@ -48,7 +52,7 @@ def shape(self): def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): """Get regridder to the specified lat and lon points""" - return _RegridFromLatLon(self, lat, lon) + return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder) def _lonb(self): edges = (self.lon[1:] + self.lon[:-1]) / 2 @@ -78,15 +82,22 @@ def to_pyvista(self): class _RegridFromLatLon(torch.nn.Module): """Regrid from lat-lon to unstructured grid with bilinear interpolation""" - def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray): + def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True): + """ + Args: + cylinder: if True than lon is assumed to be periodic + """ super().__init__() + self.cylinder = cylinder lat, lon = np.broadcast_arrays(lat, lon) self.shape = lat.shape # TODO add device switching logic (maybe use torch registers for this # info) - long = np.concatenate([src.lon.ravel(), [360]], axis=-1) + long = src.lon.ravel() + if self.cylinder: + long = np.concatenate([long, [360]], axis=-1) long_t = torch.from_numpy(long) # flip the order latg since bilinear only works with increasing coordinate values @@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # pad z in lon direction # only works for a global grid # TODO generalize this to local grids and add options for padding - x = torch.cat([x, x[..., 0:1]], axis=-1) + if self.cylinder: + x = torch.cat([x, x[..., 0:1]], axis=-1) out = self._bilinear(x) return out.view(out.shape[:-1] + self.shape) diff --git a/earth2grid/lcc.py b/earth2grid/lcc.py index 55146d0..7db6368 100644 --- a/earth2grid/lcc.py +++ b/earth2grid/lcc.py @@ -13,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -import torch -from earth2grid import base -from earth2grid._regrid import BilinearInterpolator +from earth2grid import projections try: import pyvista as pv @@ -31,7 +29,10 @@ ] -class LambertConformalConicProjection: +LambertConformalConicGrid = projections.Grid + + +class LambertConformalConicProjection(projections.Projection): def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float): """ @@ -108,69 +109,6 @@ def inverse_project(self, x, y): HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0) -class LambertConformalConicGrid(base.Grid): - # nothing here is specific to the projection, so could be shared by any projected rectilinear grid - def __init__(self, projection: LambertConformalConicProjection, x, y): - """ - Args: - projection: LambertConformalConicProjection object - x: range of x values - y: range of y values - - """ - self.projection = projection - - self.x = np.array(x) - self.y = np.array(y) - - @property - def lat_lon(self): - mesh_x, mesh_y = np.meshgrid(self.x, self.y) - return self.projection.inverse_project(mesh_x, mesh_y) - - @property - def lat(self): - return self.lat_lon[0] - - @property - def lon(self): - return self.lat_lon[1] - - @property - def shape(self): - return (len(self.y), len(self.x)) - - def __getitem__(self, idxs): - yidxs, xidxs = idxs - return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs]) - - def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): - """Get regridder to the specified lat and lon points""" - - x, y = self.projection.project(lat, lon) - - return BilinearInterpolator( - x_coords=torch.from_numpy(self.x), - y_coords=torch.from_numpy(self.y), - x_query=torch.from_numpy(x), - y_query=torch.from_numpy(y), - ) - - def visualize(self, data): - raise NotImplementedError() - - def to_pyvista(self): - if pv is None: - raise ImportError("Need to install pyvista") - - lat, lon = self.lat_lon - y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon)) - x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon)) - z = np.sin(np.deg2rad(lat)) - grid = pv.StructuredGrid(x, y, z) - return grid - - def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059): # coordinates of point in top-left corner lat0 = 21.138123 @@ -183,7 +121,7 @@ def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059): x = [x0 + i * scale for i in range(ix0, ix0 + nx)] y = [y0 + i * scale for i in range(iy0, iy0 + ny)] - return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y) + return projections.Grid(HRRR_CONUS_PROJECTION, x, y) # Grid used by HRRR CONUS (Continental US) data diff --git a/earth2grid/projections.py b/earth2grid/projections.py new file mode 100644 index 0000000..4a71dfc --- /dev/null +++ b/earth2grid/projections.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc + +import numpy as np +import torch + +from earth2grid import base +from earth2grid._regrid import BilinearInterpolator + +try: + import pyvista as pv +except ImportError: + pv = None + + +class Projection(abc.ABC): + @abc.abstractmethod + def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the projected x,y from lat,lon. + """ + pass + + @abc.abstractmethod + def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the lat,lon from the projected x,y. + """ + pass + + +class Grid(base.Grid): + # nothing here is specific to the projection, so could be shared by any projected rectilinear grid + def __init__(self, projection: Projection, x, y): + """ + Args: + x: range of x values + y: range of y values + + """ + self.projection = projection + + self.x = np.array(x) + self.y = np.array(y) + + @property + def lat_lon(self): + mesh_x, mesh_y = np.meshgrid(self.x, self.y, indexing='ij') + return self.projection.inverse_project(mesh_x, mesh_y) + + @property + def lat(self): + return self.lat_lon[0] + + @property + def lon(self): + return self.lat_lon[1] + + @property + def shape(self): + return (len(self.x), len(self.y)) + + def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): + """Get regridder to the specified lat and lon points""" + + x, y = self.projection.project(lat, lon) + + return BilinearInterpolator( + x_coords=torch.from_numpy(self.x), + y_coords=torch.from_numpy(self.y), + x_query=torch.from_numpy(x), + y_query=torch.from_numpy(y), + ) + + def visualize(self, data): + raise NotImplementedError() + + def to_pyvista(self): + if pv is None: + raise ImportError("Need to install pyvista") + + lat, lon = self.lat_lon + y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon)) + x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon)) + z = np.sin(np.deg2rad(lat)) + grid = pv.StructuredGrid(x, y, z) + return grid diff --git a/earth2grid/spatial.py b/earth2grid/spatial.py index 974108e..87a161b 100644 --- a/earth2grid/spatial.py +++ b/earth2grid/spatial.py @@ -44,3 +44,10 @@ def ang2vec(lon, lat): y = torch.cos(lat) * torch.sin(lon) z = torch.sin(lat) return (x, y, z) + + +def vec2ang(x, y, z): + """convert lon,lat in radians to cartesian coordinates""" + lat = torch.asin(z) + lon = torch.atan2(y, x) + return lon, lat diff --git a/earth2grid/yinyang.py b/earth2grid/yinyang.py new file mode 100644 index 0000000..258d13a --- /dev/null +++ b/earth2grid/yinyang.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Yin Yang + +the ying yang grid is an overset grid for the sphere containing two faces +- Yin: a normal lat lon grid for 2/3 of lon, and 2/3 of lat +- Yang: Yin but with pole along x + + +Key facts + +ying +lon: [-3 pi /4 - delta, 3 pi / 4 + delta ] +lat: [-pi / 4 - delta, pi / 4 + delta] + +ying to yang transformation: alpha = 0, beta = 90, gamma = 180 + +(x, y, z) - > (-x, z, y) + +""" +import math + +import numpy as np +import torch + +from earth2grid import latlon, projections, spatial + + +def Ying(nlat: int, nlon: int, delta: int): + """The ying grid + + nlat, and nlon are as in the latlon.equiangular_latlon_grid and + refer to full sphere. + + ``nlat`` includes the poles [90, -90], and ``nlon`` is [0, 2 pi). + + ``delta`` is the amount of overlap in terms of number of grid points. + + """ + # TODO test that min(lat) = -max(lat), and for lon too + + dlat = 180 / (nlat - 1) + dlon = 360 / nlon + + n = math.ceil(3 * nlon / 8) + lon = np.arange(-n - delta, n + delta + 1) * dlon + lat = np.arange(-(nlat - 1) // 4 - delta, (nlat + 1) // 4 + delta + 1) * dlat + + return latlon.LatLonGrid(lat.tolist(), lon.tolist(), cylinder=False) + + +class YangProjection(projections.Projection): + def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the projected x,y from lat,lon. + """ + lat = torch.from_numpy(lat) + lon = torch.from_numpy(lon) + + lat = torch.deg2rad(lat) + lon = torch.deg2rad(lon) + + x, y, z = spatial.ang2vec(lat=lat, lon=lon) + x, y, z = -x, z, y + lon, lat = spatial.vec2ang(x, y, z) + + lat = torch.rad2deg(lat) + lon = torch.rad2deg(lon) + + return lat.numpy(), lon.numpy() + + def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Compute the lat,lon from the projected x,y. + """ + # ying-yang is its own inverse + return self.project(x, y) + + +def Yang(nlat, nlon, delta): + ying = Ying(nlat, nlon, delta) + return projections.Grid(YangProjection(), ying.lat, ying.lon) diff --git a/examples/yinyang.py b/examples/yinyang.py new file mode 100644 index 0000000..a03500d --- /dev/null +++ b/examples/yinyang.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.from earth2grid.yinyang import Ying, Yang, YangProjection +import matplotlib.pyplot as plt +import numpy as np +import pyvista as pv +import torch + +from earth2grid.yinyang import Yang, Ying + +nlat = 721 +nlon = 1440 +delta = 64 + +nlat = 37 +nlon = 72 +delta = 0 + +ying = Ying(nlat, nlon, delta) +yang = Yang(nlat, nlon, delta) + + +def structured_grid(lon, lat): + y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon)) + x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon)) + z = np.sin(np.deg2rad(lat)) + grid = pv.StructuredGrid(x, y, z) + return grid + + +lon, lat = np.meshgrid(ying.lon, ying.lat) +ying_g = structured_grid(lon, lat) +yang_g = structured_grid(yang.lon, yang.lat) + +pl = pv.Plotter() +pl.add_mesh(ying_g, show_edges=True) +# scale slightly so yang is on top +pl.add_mesh(yang_g.scale(1.002), show_edges=True, color="red", opacity=0.5) +pl.show() + + +y2y = ying.get_bilinear_regridder_to(yang.lat, yang.lon) +y2y.float() + +x = torch.ones(ying.shape) +y = y2y(x) +y = y.reshape(yang.shape) +print("mask", torch.isnan(y).sum() / y.numel()) + +plt.figure() +# TODO fix yang.shape, it is the opposite it should be +plt.imshow(y.reshape(*ying.shape)) +plt.colorbar() +plt.show() diff --git a/pyproject.toml b/pyproject.toml index 0b54520..0b4892a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ dependencies = [ "einops>=0.7.0", - "netCDF4>=1.6.5", "numpy>=1.23.3", "torch>=2.0.1", "scipy" @@ -35,6 +34,9 @@ dependencies = [ [project.optional-dependencies] +all = [ + "netCDF4>=1.6.5", +] viz = [ "pyvista>=0.43.2", "matplotlib", diff --git a/tests/test_spatial.py b/tests/test_spatial.py new file mode 100644 index 0000000..ffb772f --- /dev/null +++ b/tests/test_spatial.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import pytest +import torch + +from earth2grid import spatial + + +def test_vec2ang2vec(): + vec = torch.randn(3) + vec /= torch.norm(vec) + x, y, z = vec + + lon, lat = spatial.vec2ang(x, y, z) + x1, y1, z1 = spatial.ang2vec(lon, lat) + assert torch.allclose(torch.stack([x1, y1, z1]), torch.stack([x, y, z])) + + +def test_vec2ang(): + lon, lat = spatial.vec2ang(torch.tensor(0), torch.tensor(0), torch.tensor(1)) + assert lat == pytest.approx(math.pi / 2) + + lon, _ = spatial.vec2ang(torch.tensor(1), torch.tensor(0), torch.tensor(0)) + assert lon == pytest.approx(0) + + lon, _ = spatial.vec2ang(torch.tensor(0), torch.tensor(1), torch.tensor(0)) + assert lon == pytest.approx(math.pi / 2) diff --git a/tests/test_yingyang.py b/tests/test_yingyang.py new file mode 100644 index 0000000..cdae764 --- /dev/null +++ b/tests/test_yingyang.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.from earth2grid.yinyang import Ying, Yang, YangProjection +import numpy as np +import pytest +import torch + +from earth2grid.yinyang import Yang, Ying + + +def test_yingyang(): + nlat = 721 + nlon = 1440 + delta = 64 + + nlat = 37 + nlon = 72 + delta = 0 + + ying = Ying(nlat, nlon, delta) + yang = Yang(nlat, nlon, delta) + + assert ying.lat.min() == pytest.approx(-45) + assert ying.lat.max() == pytest.approx(45) + assert ying.lat.min() == -ying.lat.max() + assert ying.lon.min() == -ying.lon.max() + y2y = ying.get_bilinear_regridder_to(yang.lat, yang.lon) + y2y.float() + + x = torch.ones(ying.shape) + y = y2y(x) + mask = ~torch.isnan(y) + # this is a regression check. will need to verify and change for different res + fraction_missing = 1 - mask.sum().item() / mask.numel() + assert fraction_missing == pytest.approx(0.8038, abs=0.01) + assert torch.allclose(y[mask], torch.tensor(1).float()) + + # more complex check + lat, lon = np.meshgrid(ying.lat, ying.lon, indexing='ij') + x = torch.as_tensor(lat, dtype=torch.float).deg2rad().cos() + y = y2y(x) + expected = torch.as_tensor(yang.lat).deg2rad().cos().float() + assert torch.allclose(y[mask], expected[mask], atol=0.01)