Skip to content

Commit

Permalink
apply linter
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Sep 30, 2024
1 parent bcf8642 commit 643cedc
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 43 deletions.
6 changes: 3 additions & 3 deletions earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def forward(self, z):
weight = self.weight.view(-1, p)

# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum')
output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode="sum")
output = output.T.view(*shape, -1)
return output.reshape(list(shape) + output_shape)

Expand Down Expand Up @@ -173,11 +173,11 @@ def forward(self, z: torch.Tensor):
*shape, y, x = z.shape
zrs = z.view(-1, y * x).T
# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode="sum")
interpolated = torch.full(
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
)
interpolated.masked_scatter_(self.mask.view(-1,1), output)
interpolated.masked_scatter_(self.mask.view(-1, 1), output)
interpolated = interpolated.T.view(*shape, *self.mask.shape)
return interpolated

Expand Down
46 changes: 21 additions & 25 deletions earth2grid/lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class LambertConformalConicProjection:
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
"""
Args:
lat0: latitude of origin (degrees)
lon0: longitude of origin (degrees)
Expand All @@ -50,16 +50,15 @@ def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: f
self.lat2 = lat2
self.radius = radius


c1 = np.cos(np.deg2rad(lat1))
c2 = np.cos(np.deg2rad(lat2))
t1 = np.tan(np.pi / 4 + np.deg2rad(lat1) / 2)
t2 = np.tan(np.pi / 4 + np.deg2rad(lat2) / 2)

if np.abs(lat1 - lat2) < 1e-8:
self.n = np.sin(np.deg2rad(lat1))
else:
self.n = np.log(c1/c2) / np.log(t2/t1)
else:
self.n = np.log(c1 / c2) / np.log(t2 / t1)

self.RF = radius * c1 * np.power(t1, self.n) / self.n
self.rho0 = self._rho(lat0)
Expand All @@ -78,17 +77,16 @@ def _theta(self, lon):
"""
# center about reference longitude
delta_lon = lon - self.lon0
delta_lon = delta_lon - np.round(delta_lon/360) * 360 # convert to [-180, 180]
delta_lon = delta_lon - np.round(delta_lon / 360) * 360 # convert to [-180, 180]
return self.n * np.deg2rad(delta_lon)


def project(self, lat, lon):
"""
Compute the projected x,y from lat,lon.
"""
rho = self._rho(lat)
theta = self._theta(lon)

x = rho * np.sin(theta)
y = self.rho0 - rho * np.cos(theta)
return x, y
Expand All @@ -99,26 +97,21 @@ def inverse_project(self, x, y):
"""
rho = np.hypot(x, self.rho0 - y)
theta = np.arctan2(x, self.rho0 - y)
lat = np.rad2deg(2 * np.arctan(np.power(self.RF/rho, 1/self.n))) - 90

lat = np.rad2deg(2 * np.arctan(np.power(self.RF / rho, 1 / self.n))) - 90
lon = self.lon0 + np.rad2deg(theta / self.n)
return lat, lon


# Projection used by HRRR CONUS (Continental US) data
# https://rapidrefresh.noaa.gov/hrrr/HRRR_conus.domain.txt
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(
lon0 = -97.5,
lat0 = 38.5,
lat1 = 38.5,
lat2 = 38.5,
radius = 6371229.0
)
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)


class LambertConformalConicGrid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: LambertConformalConicProjection, x, y):
"""
"""
Args:
projection: LambertConformalConicProjection object
x: range of x values
Expand Down Expand Up @@ -155,12 +148,13 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords = torch.from_numpy(self.x),
y_coords = torch.from_numpy(self.y),
x_query = torch.from_numpy(x),
y_query = torch.from_numpy(y))
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()
Expand All @@ -176,7 +170,8 @@ def to_pyvista(self):
grid = pv.StructuredGrid(x, y, z)
return grid

def hrrr_conus_grid(ix0 = 0, iy0 = 0, nx = 1799, ny = 1059):

def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
# coordinates of point in top-left corner
lat0 = 21.138123
lon0 = 237.280472
Expand All @@ -185,10 +180,11 @@ def hrrr_conus_grid(ix0 = 0, iy0 = 0, nx = 1799, ny = 1059):
# coordinates on projected space
x0, y0 = HRRR_CONUS_PROJECTION.project(lat0, lon0)

x = [x0 + i * scale for i in range(ix0, ix0+nx)]
y = [y0 + i * scale for i in range(iy0, iy0+ny)]
x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]

return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)


# Grid used by HRRR CONUS (Continental US) data
HRRR_CONUS_GRID = hrrr_conus_grid()
42 changes: 27 additions & 15 deletions tests/test_lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#%%
from earth2grid.lcc import HRRR_CONUS_GRID
# %%
import numpy as np
import torch
import pytest
import torch

from earth2grid.lcc import HRRR_CONUS_GRID


def test_grid_shape():
assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape
assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape
assert HRRR_CONUS_GRID.lon.shape == HRRR_CONUS_GRID.shape

lats = np.array([

lats = np.array(
[
[21.138123, 21.801926, 22.393631, 22.911015],
[23.636763, 24.328228, 24.944668, 25.48374 ],
[23.636763, 24.328228, 24.944668, 25.48374],
[26.155672, 26.875362, 27.517046, 28.078257],
[28.69017 , 29.438608, 30.106009, 30.68978 ]])
[28.69017, 29.438608, 30.106009, 30.68978],
]
)

lons = np.array([
[-122.71953 , -120.03195 , -117.304596, -114.54146 ],
[-123.491356, -120.72898 , -117.92319 , -115.07828 ],
[-124.310524, -121.469505, -118.58098 , -115.649574],
[-125.181404, -122.25762 , -119.28173 , -116.25871 ]])
lons = np.array(
[
[-122.71953, -120.03195, -117.304596, -114.54146],
[-123.491356, -120.72898, -117.92319, -115.07828],
[-124.310524, -121.469505, -118.58098, -115.649574],
[-125.181404, -122.25762, -119.28173, -116.25871],
]
)


def test_grid_vals():
assert HRRR_CONUS_GRID.lat[0:400:100,0:400:100] == pytest.approx(lats)
assert HRRR_CONUS_GRID.lon[0:400:100,0:400:100] == pytest.approx(lons)
assert HRRR_CONUS_GRID.lat[0:400:100, 0:400:100] == pytest.approx(lats)
assert HRRR_CONUS_GRID.lon[0:400:100, 0:400:100] == pytest.approx(lons)


def test_grid_slice():
slice_grid = HRRR_CONUS_GRID[0:400:100,0:400:100]
slice_grid = HRRR_CONUS_GRID[0:400:100, 0:400:100]
assert slice_grid.lat == pytest.approx(lats)
assert slice_grid.lon == pytest.approx(lons)


def test_regrid_1d():
src = HRRR_CONUS_GRID
dest_lat = np.linspace(25.0, 33.0, 10)
Expand All @@ -55,6 +66,7 @@ def test_regrid_1d():

assert torch.allclose(out_lat, torch.tensor(dest_lat))


def test_regrid_2d():
src = HRRR_CONUS_GRID
dest_lat, dest_lon = np.meshgrid(np.linspace(25.0, 33.0, 10), np.linspace(-123, -98, 12))
Expand Down

0 comments on commit 643cedc

Please sign in to comment.