Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ying yang grid #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,14 @@ test_grid_visualize.png
*.png
*.jpg
*.jpeg
*.gif
public/

a.out
*.o

# editor backup files
# helix
\#*\#
# emacs
*~
10 changes: 9 additions & 1 deletion earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
from typing import Dict, Sequence

import einops
import netCDF4 as nc
import torch

try:
import netCDF4 as nc
except ImportError:
nc = None

from scipy import spatial

from earth2grid.spatial import ang2vec, haversine_distance
Expand Down Expand Up @@ -59,6 +64,9 @@ def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
class TempestRegridder(torch.nn.Module):
def __init__(self, file_path):
super().__init__()
if nc is None:
raise ImportError("netCDF4 not imported. Please install for this feature.")

dataset = nc.Dataset(file_path)
self.lat = dataset["latc_b"][:]
self.lon = dataset["lonc_b"][:]
Expand Down
22 changes: 17 additions & 5 deletions earth2grid/latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@


class LatLonGrid(base.Grid):
def __init__(self, lat: list[float], lon: list[float]):
def __init__(self, lat: list[float], lon: list[float], cylinder: bool = True):
"""
Args:
lat: center of lat cells
lon: center of lon cells
cylinder: if true, then lon is considered a periodic coordinate
on cylinder so that interpolation wraps around the edge.
Otherwise, it is assumed to be a finite plane.
"""
self._lat = lat
self._lon = lon
self.cylinder = cylinder

@property
def lat(self):
Expand All @@ -48,7 +52,7 @@ def shape(self):

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""
return _RegridFromLatLon(self, lat, lon)
return _RegridFromLatLon(self, lat, lon, cylinder=self.cylinder)

def _lonb(self):
edges = (self.lon[1:] + self.lon[:-1]) / 2
Expand Down Expand Up @@ -78,15 +82,22 @@ def to_pyvista(self):
class _RegridFromLatLon(torch.nn.Module):
"""Regrid from lat-lon to unstructured grid with bilinear interpolation"""

def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray):
def __init__(self, src: LatLonGrid, lat: np.ndarray, lon: np.ndarray, cylinder: bool = True):
"""
Args:
cylinder: if True than lon is assumed to be periodic
"""
super().__init__()
self.cylinder = cylinder

lat, lon = np.broadcast_arrays(lat, lon)
self.shape = lat.shape

# TODO add device switching logic (maybe use torch registers for this
# info)
long = np.concatenate([src.lon.ravel(), [360]], axis=-1)
long = src.lon.ravel()
if self.cylinder:
long = np.concatenate([long, [360]], axis=-1)
long_t = torch.from_numpy(long)

# flip the order latg since bilinear only works with increasing coordinate values
Expand All @@ -104,7 +115,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# pad z in lon direction
# only works for a global grid
# TODO generalize this to local grids and add options for padding
x = torch.cat([x, x[..., 0:1]], axis=-1)
if self.cylinder:
x = torch.cat([x, x[..., 0:1]], axis=-1)
out = self._bilinear(x)
return out.view(out.shape[:-1] + self.shape)

Expand Down
74 changes: 6 additions & 68 deletions earth2grid/lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator
from earth2grid import projections

try:
import pyvista as pv
Expand All @@ -31,7 +29,10 @@
]


class LambertConformalConicProjection:
LambertConformalConicGrid = projections.Grid


class LambertConformalConicProjection(projections.Projection):
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
"""

Expand Down Expand Up @@ -108,69 +109,6 @@ def inverse_project(self, x, y):
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)


class LambertConformalConicGrid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: LambertConformalConicProjection, x, y):
"""
Args:
projection: LambertConformalConicProjection object
x: range of x values
y: range of y values

"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.y), len(self.x))

def __getitem__(self, idxs):
yidxs, xidxs = idxs
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid


def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
# coordinates of point in top-left corner
lat0 = 21.138123
Expand All @@ -183,7 +121,7 @@ def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]

return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)
return projections.Grid(HRRR_CONUS_PROJECTION, x, y)


# Grid used by HRRR CONUS (Continental US) data
Expand Down
100 changes: 100 additions & 0 deletions earth2grid/projections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc

import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator

try:
import pyvista as pv
except ImportError:
pv = None


class Projection(abc.ABC):
@abc.abstractmethod
def project(self, lat: np.ndarray, lon: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the projected x,y from lat,lon.
"""
pass

@abc.abstractmethod
def inverse_project(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the lat,lon from the projected x,y.
"""
pass


class Grid(base.Grid):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just renamed your LCC grid class @simonbyrne . Intended to be used like earth2grid.projections.Grid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you want to give it a more specific name (i.e. it assumes the underlying grid is rectilinear, not an unstructured mesh).

Copy link
Collaborator Author

@nbren12 nbren12 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "projection" implies rectangle. This is just a style preference for projection.Grid over projection.ProjectionGrid. Same style as healpix.Grid.

# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: Projection, x, y):
"""
Args:
x: range of x values
y: range of y values

"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y, indexing='ij')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i needed to change this to use indexing=ij here versus the original which was indexing="xy", since meshgrid has unusual behavior. @simonbyrne please review. Specifically, do we want the data to have the shaped [nx, ny] or [ny,nx]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really have a preference, I think I just copied LatLongGrid which uses Lat (Y) first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to update the LCC tests?

return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.x), len(self.y))

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid
7 changes: 7 additions & 0 deletions earth2grid/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ def ang2vec(lon, lat):
y = torch.cos(lat) * torch.sin(lon)
z = torch.sin(lat)
return (x, y, z)


def vec2ang(x, y, z):
"""convert lon,lat in radians to cartesian coordinates"""
lat = torch.asin(z)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes your inputs are already normalized. If you want it to work with non-normalized inputs, you would need

Suggested change
lat = torch.asin(z)
lat = torch.atan2(z, torch.hypot(y, x))

lon = torch.atan2(y, x)
return lon, lat
Loading
Loading