Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nearest neighbor interpolator #12

Merged
merged 14 commits into from
Sep 16, 2024
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Changelog

## latest

- `earth2grid.latlon.BilinearInterpolator` moved to `earth2grid.BilinearInterpolator`

## 2024.8.1

Expand Down
30 changes: 28 additions & 2 deletions earth2grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from earth2grid import base, healpix, latlon
from earth2grid._regrid import get_regridder
from earth2grid._regrid import BilinearInterpolator, Identity, Regridder, S2NearestNeighborInterpolator

__all__ = [
"base",
"healpix",
"latlon",
"get_regridder",
"BilinearInterpolator",
"S2NearestNeighborInterpolator",
"Regridder",
]


def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
"""Get a regridder from `src` to `dest`"""
if src == dest:
return Identity()
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, latlon.LatLonGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(dest, healpix.Grid):
return src.get_healpix_regridder(dest) # type: ignore

__all__ = ["base", "healpix", "latlon", "get_regridder"]
raise ValueError(src, dest, "not supported.")
197 changes: 179 additions & 18 deletions earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Sequence

import einops
import netCDF4 as nc
import torch
from scipy import spatial

from earth2grid.spatial import ang2vec, haversine_distance


class Regridder(torch.nn.Module):
"""Regridder to n points, with p nonzero maps weights

Forward:
(*, m) -> (*,) + shape
"""

def __init__(self, shape: Sequence[int], p: int):
super().__init__()
self.register_buffer("index", torch.empty(*shape, p, dtype=torch.long))
self.register_buffer("weight", torch.ones(*shape, p))

def forward(self, z):
*shape, x = z.shape
zrs = z.view(-1, x).T

*output_shape, p = self.index.shape
index = self.index.view(-1, p)
weight = self.weight.view(-1, p)

from earth2grid import base, healpix
from earth2grid.latlon import LatLonGrid
# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum')
output = output.T.view(*shape, -1)
return output.reshape(list(shape) + output_shape)

@staticmethod
nbren12 marked this conversation as resolved.
Show resolved Hide resolved
def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
n, p = d["index"].shape
regridder = Regridder((n,), p)
regridder.load_state_dict(d)
return regridder


class TempestRegridder(torch.nn.Module):
Expand Down Expand Up @@ -48,22 +84,147 @@ def forward(self, x):
return y


class BilinearInterpolator(torch.nn.Module):
"""Bilinear interpolation for a non-uniform grid"""

def __init__(
self,
x_coords: torch.Tensor,
y_coords: torch.Tensor,
x_query: torch.Tensor,
y_query: torch.Tensor,
fill_value=math.nan,
) -> None:
"""

Args:
x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order.
y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order.
x_query (Tensor): X-coordinates for query points, shape [N].
y_query (Tensor): Y-coordinates for query points, shape [N].
"""
super().__init__()
self.fill_value = fill_value

# Ensure input coordinates are float for interpolation
x_coords, y_coords = x_coords.double(), y_coords.double()
x_query = x_query.double()
y_query = y_query.double()

if torch.any(x_coords[1:] < x_coords[:-1]):
raise ValueError("x_coords must be in non-decreasing order.")

if torch.any(y_coords[1:] < y_coords[:-1]):
raise ValueError("y_coords must be in non-decreasing order.")

# Find indices for the closest lower and upper bounds in x and y directions
x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1
x_u_idx = x_l_idx + 1
y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1
y_u_idx = y_l_idx + 1

# fill in nan outside mask
def isin(x, a, b):
return (x <= b) & (x >= a)

mask = (
isin(x_l_idx, 0, x_coords.size(0) - 2)
& isin(x_u_idx, 1, x_coords.size(0) - 1)
& isin(y_l_idx, 0, y_coords.size(0) - 2)
& isin(y_u_idx, 1, y_coords.size(0) - 1)
)
x_u_idx = x_u_idx[mask]
x_l_idx = x_l_idx[mask]
y_u_idx = y_u_idx[mask]
y_l_idx = y_l_idx[mask]
x_query = x_query[mask]
y_query = y_query[mask]

# Compute weights
x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx])
x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx])
y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx])
y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx])
weights = torch.stack(
[x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1
)

stride = x_coords.size(-1)
index = torch.stack(
[
x_l_idx + stride * y_l_idx,
x_u_idx + stride * y_l_idx,
x_l_idx + stride * y_u_idx,
x_u_idx + stride * y_u_idx,
],
dim=-1,
)
self.register_buffer("weights", weights)
self.register_buffer("mask", mask)
self.register_buffer("index", index)

def forward(self, z: torch.Tensor):
"""
Interpolate the field

Args:
z: shape [Y, X]
"""
*shape, y, x = z.shape
zrs = z.view(-1, y * x).T
# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
interpolated = torch.full(
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
)
interpolated.masked_scatter_(self.mask.unsqueeze(-1), output)
interpolated = interpolated.T.view(*shape, self.mask.numel())
return interpolated


def S2NearestNeighborInterpolator(
nbren12 marked this conversation as resolved.
Show resolved Hide resolved
src_lon: torch.Tensor,
src_lat: torch.Tensor,
dest_lon: torch.Tensor,
nbren12 marked this conversation as resolved.
Show resolved Hide resolved
dest_lat: torch.Tensor,
k: int = 1,
eps=1e-7,
) -> Regridder:
"""K-nearest neighbor interpolator with inverse distance weighting

Args:
src_lon: (m,) source longitude in degrees E
src_lat: (m,) source latitude in degrees N
dest_lon: (n,) output longitude in degrees E
dest_lat: (n,) output latitude in degrees N
k: number of neighbors, default: 1
eps: regularization factor for inverse distance weighting. Only used if
k > 1.

"""
src_lon = torch.deg2rad(src_lon.cpu())
src_lat = torch.deg2rad(src_lat.cpu())

dest_lon = torch.deg2rad(dest_lon.cpu())
dest_lat = torch.deg2rad(dest_lat.cpu())

vec = torch.stack(ang2vec(src_lon, src_lat), -1)

# havesign distance and euclidean are monotone for points on S2 so can use 3d lookups.
tree = spatial.KDTree(vec)
vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1)
_, neighbors = tree.query(vec, k=k)
regridder = Regridder(dest_lon.shape, k)
regridder.index.copy_(torch.as_tensor(neighbors).view(-1, k))
if k > 1:
d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors])
lam = 1 / (d + eps)
lam = lam / lam.sum(-1, keepdim=True)
regridder.weight.copy_(lam)

return regridder


class Identity(torch.nn.Module):
def forward(self, x):
return x


def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
"""Get a regridder from `src` to `dest`"""
if src == dest:
return Identity()
elif isinstance(src, LatLonGrid) and isinstance(dest, LatLonGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(dest, healpix.Grid):
return src.get_healpix_regridder(dest) # type: ignore

raise ValueError(src, dest, "not supported.")
32 changes: 11 additions & 21 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import torch

from earth2grid import healpix_bare
from earth2grid._regrid import Regridder

try:
import pyvista as pv
Expand Down Expand Up @@ -230,26 +231,6 @@ def _convert_xyindex(nside: int, src: XY, dest: XY, i):
return i


class ApplyWeights(torch.nn.Module):
def __init__(self, pix: torch.Tensor, weight: torch.Tensor):
super().__init__()

# the first dim is the 4 point stencil
n, *self.shape = pix.shape

pix = pix.view(n, -1).T
weight = weight.view(n, -1).T

self.register_buffer("index", pix)
self.register_buffer("weight", weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
*shape, npix = x.shape
x = x.view(-1, npix).T
interpolated = torch.nn.functional.embedding_bag(self.index, x, per_sample_weights=self.weight, mode="sum").T
return interpolated.view(shape + self.shape)


@dataclass
class Grid(base.Grid):
"""A Healpix Grid
Expand Down Expand Up @@ -345,7 +326,16 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat))
i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel())
i_me = self._nest2me(i_nest).reshape(i_ring.shape)
return ApplyWeights(i_me, weights)

# reshape to (*, p)
weights = weights.movedim(0, -1)
index = i_me.movedim(0, -1)

regridder = Regridder(weights.shape[:-1], p=weights.shape[-1])
regridder.to(weights)
regridder.index.copy_(index)
regridder.weight.copy_(weights)
return regridder

def approximate_grid_length_meters(self):
return approx_grid_length_meters(self._nside())
Expand Down
Loading
Loading