Skip to content

Commit

Permalink
Add ying yang grid
Browse files Browse the repository at this point in the history
This PR adds a new projection based grid, the ying yang grid. It
restructures some of the lambert conformal logic a bit, so Simon should
take a look.
  • Loading branch information
nbren12 committed Dec 13, 2024
1 parent c9cb58d commit 81b8a47
Show file tree
Hide file tree
Showing 11 changed files with 403 additions and 75 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ test_grid_visualize.png
*.png
*.jpg
*.jpeg
*.gif
public/

a.out
*.o

# editor backup files
# helix
\#*\#
# emacs
*~
10 changes: 9 additions & 1 deletion earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
from typing import Dict, Sequence

import einops
import netCDF4 as nc
import torch

try:
import netCDF4 as nc
except ImportError:
nc = None

from scipy import spatial

from earth2grid.spatial import ang2vec, haversine_distance
Expand Down Expand Up @@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
class TempestRegridder(torch.nn.Module):
def __init__(self, file_path):
super().__init__()
if nc is None:
raise ImportError("netCDF4 not imported. Please install for this feature.")

dataset = nc.Dataset(file_path)
self.lat = dataset["latc_b"][:]
self.lon = dataset["lonc_b"][:]
Expand Down
22 changes: 17 additions & 5 deletions earth2grid/latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@


class LatLonGrid(base.Grid):
def __init__(self, lat: list[float], lon: list[float]):
def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
"""
Args:
lat: center of lat cells
lon: center of lon cells
cylinder: if true, then lon is considered a periodic coordinate
on cylinder so that interpolation wraps around the edge.
Otherwise, it is assumed to be a finite plane.
"""
self._lat = lat
self._lon = lon
self.cylinder = cylinder

@property
def lat(self):
Expand All @@ -48,7 +52,7 @@ def shape(self):

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""
return _RegridFromLatLon(self, lat, lon)
return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)

def _lonb(self):
edges = (self.lon[1:] + self.lon[:-1]) / 2
Expand Down Expand Up @@ -78,15 +82,22 @@ def to_pyvista(self):
class _RegridFromLatLon(torch.nn.Module):
"""Regrid from lat-lon to unstructured grid with bilinear interpolation"""

def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
"""
Args:
cylinder: if True than lon is assumed to be periodic
"""
super().__init__()
self.cylinder = cylinder

lat, lon = np.broadcast_arrays(lat, lon)
self.shape = lat.shape

# TODO add device switching logic (maybe use torch registers for this
# info)
long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
long = src.lon.ravel()
if self.cylinder:
long = np.concatenate([long, [360]], axis=-1)
long_t = torch.from_numpy(long)

# flip the order latg since bilinear only works with increasing coordinate values
Expand All @@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# pad z in lon direction
# only works for a global grid
# TODO generalize this to local grids and add options for padding
x = torch.cat([x, x[..., 0:1]], axis=-1)
if self.cylinder:
x = torch.cat([x, x[..., 0:1]], axis=-1)
out = self._bilinear(x)
return out.view(out.shape[:-1] + self.shape)

Expand Down
74 changes: 6 additions & 68 deletions earth2grid/lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator
from earth2grid import projections

try:
import pyvista as pv
Expand All @@ -31,7 +29,10 @@
]


class LambertConformalConicProjection:
LambertConformalConicGrid = projections.Grid


class LambertConformalConicProjection(projections.Projection):
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
"""
Expand Down Expand Up @@ -108,69 +109,6 @@ def inverse_project(self, x, y):
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)


class LambertConformalConicGrid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: LambertConformalConicProjection, x, y):
"""
Args:
projection: LambertConformalConicProjection object
x: range of x values
y: range of y values
"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.y), len(self.x))

def __getitem__(self, idxs):
yidxs, xidxs = idxs
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid


def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
# coordinates of point in top-left corner
lat0 = 21.138123
Expand All @@ -183,7 +121,7 @@ def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]

return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
return projections.Grid(HRRR_CONUS_PROJECTION, x, y)


# Grid used by HRRR CONUS (Continental US) data
Expand Down
100 changes: 100 additions & 0 deletions earth2grid/projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc

import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator

try:
import pyvista as pv
except ImportError:
pv = None


class Projection(abc.ABC):
@abc.abstractmethod
def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the projected x,y from lat,lon.
"""
pass

@abc.abstractmethod
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the lat,lon from the projected x,y.
"""
pass


class Grid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: Projection, x, y):
"""
Args:
x: range of x values
y: range of y values
"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y, indexing='ij')
return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.x), len(self.y))

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid
7 changes: 7 additions & 0 deletions earth2grid/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ def ang2vec(lon, lat):
y = torch.cos(lat) * torch.sin(lon)
z = torch.sin(lat)
return (x, y, z)


def vec2ang(x, y, z):
"""convert lon,lat in radians to cartesian coordinates"""
lat = torch.asin(z)
lon = torch.atan2(y, x)
return lon, lat
Loading

0 comments on commit 81b8a47

Please sign in to comment.