diff --git a/CHANGELOG.md b/CHANGELOG.md index a82817f..b52e7ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # Changelog ## latest - +- `earth2grid.latlon.BilinearInterpolator` moved to `earth2grid.BilinearInterpolator` ## 2024.8.1 diff --git a/docs/api.rst b/docs/api.rst index a272fc2..d57f5e8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,7 +19,12 @@ Regridding .. autofunction:: earth2grid.get_regridder +.. autofunction:: earth2grid.KNNS2Interpolator + +.. autofunction:: earth2grid.BilinearInterpolator + Other utilities --------------- +.. autofunction:: earth2grid.healpix.reorder .. autofunction:: earth2grid.healpix.pad diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 3143b70..6d52f27 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -12,7 +12,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + from earth2grid import base, healpix, latlon -from earth2grid._regrid import get_regridder +from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder + +__all__ = [ + "base", + "healpix", + "latlon", + "get_regridder", + "BilinearInterpolator", + "KNNS2Interpolator", + "Regridder", +] + + +def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: + """Get a regridder from `src` to `dest`""" + if src == dest: + return Identity() + elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, latlon.LatLonGrid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(src, healpix.Grid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(dest, healpix.Grid): + return src.get_healpix_regridder(dest) # type: ignore -__all__ = ["base", "healpix", "latlon", "get_regridder"] + raise ValueError(src, dest, "not supported.") diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 4230026..5e9e5e8 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -12,12 +12,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math +from typing import Dict, Sequence + import einops import netCDF4 as nc import torch +from scipy import spatial + +from earth2grid.spatial import ang2vec, haversine_distance + + +class Regridder(torch.nn.Module): + """Regridder to n points, with p nonzero maps weights + + Forward: + (*, m) -> (*,) + shape + """ + + def __init__(self, shape: Sequence[int], p: int): + super().__init__() + self.register_buffer("index", torch.empty(*shape, p, dtype=torch.long)) + self.register_buffer("weight", torch.ones(*shape, p)) + + def forward(self, z): + *shape, x = z.shape + zrs = z.view(-1, x).T + + *output_shape, p = self.index.shape + index = self.index.view(-1, p) + weight = self.weight.view(-1, p) -from earth2grid import base, healpix -from earth2grid.latlon import LatLonGrid + # using embedding bag is 2x faster on cpu and 4x on gpu. + output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum') + output = output.T.view(*shape, -1) + return output.reshape(list(shape) + output_shape) + + @staticmethod + def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder": + n, p = d["index"].shape + regridder = Regridder((n,), p) + regridder.load_state_dict(d) + return regridder class TempestRegridder(torch.nn.Module): @@ -48,22 +84,150 @@ def forward(self, x): return y +class BilinearInterpolator(torch.nn.Module): + """Bilinear interpolation for a non-uniform grid""" + + def __init__( + self, + x_coords: torch.Tensor, + y_coords: torch.Tensor, + x_query: torch.Tensor, + y_query: torch.Tensor, + fill_value=math.nan, + ) -> None: + """ + + Args: + x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order. + y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order. + x_query (Tensor): X-coordinates for query points, shape [N]. + y_query (Tensor): Y-coordinates for query points, shape [N]. + """ + super().__init__() + self.fill_value = fill_value + + # Ensure input coordinates are float for interpolation + x_coords, y_coords = x_coords.double(), y_coords.double() + x_query = x_query.double() + y_query = y_query.double() + + if torch.any(x_coords[1:] < x_coords[:-1]): + raise ValueError("x_coords must be in non-decreasing order.") + + if torch.any(y_coords[1:] < y_coords[:-1]): + raise ValueError("y_coords must be in non-decreasing order.") + + # Find indices for the closest lower and upper bounds in x and y directions + x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1 + x_u_idx = x_l_idx + 1 + y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1 + y_u_idx = y_l_idx + 1 + + # fill in nan outside mask + def isin(x, a, b): + return (x <= b) & (x >= a) + + mask = ( + isin(x_l_idx, 0, x_coords.size(0) - 2) + & isin(x_u_idx, 1, x_coords.size(0) - 1) + & isin(y_l_idx, 0, y_coords.size(0) - 2) + & isin(y_u_idx, 1, y_coords.size(0) - 1) + ) + x_u_idx = x_u_idx[mask] + x_l_idx = x_l_idx[mask] + y_u_idx = y_u_idx[mask] + y_l_idx = y_l_idx[mask] + x_query = x_query[mask] + y_query = y_query[mask] + + # Compute weights + x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx]) + x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx]) + y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx]) + y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx]) + weights = torch.stack( + [x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1 + ) + + stride = x_coords.size(-1) + index = torch.stack( + [ + x_l_idx + stride * y_l_idx, + x_u_idx + stride * y_l_idx, + x_l_idx + stride * y_u_idx, + x_u_idx + stride * y_u_idx, + ], + dim=-1, + ) + self.register_buffer("weights", weights) + self.register_buffer("mask", mask) + self.register_buffer("index", index) + + def forward(self, z: torch.Tensor): + """ + Interpolate the field + + Args: + z: shape [Y, X] + """ + *shape, y, x = z.shape + zrs = z.view(-1, y * x).T + # using embedding bag is 2x faster on cpu and 4x on gpu. + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') + interpolated = torch.full( + [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device + ) + interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) + interpolated = interpolated.T.view(*shape, self.mask.numel()) + return interpolated + + +def KNNS2Interpolator( + src_lon: torch.Tensor, + src_lat: torch.Tensor, + dest_lon: torch.Tensor, + dest_lat: torch.Tensor, + k: int = 1, + eps=1e-7, +) -> Regridder: + """K-nearest neighbor interpolator with inverse distance weighting + + Args: + src_lon: (m,) source longitude in degrees E + src_lat: (m,) source latitude in degrees N + dest_lon: (n,) output longitude in degrees E + dest_lat: (n,) output latitude in degrees N + k: number of neighbors, default: 1 + eps: regularization factor for inverse distance weighting. Only used if + k > 1. + + """ + if (src_lat.ndim != 1) or (src_lon.ndim != 1) or (dest_lat.ndim != 1) or (dest_lon.ndim != 1): + raise ValueError("All input coordinates must be 1 dimensional.") + + src_lon = torch.deg2rad(src_lon.cpu()) + src_lat = torch.deg2rad(src_lat.cpu()) + + dest_lon = torch.deg2rad(dest_lon.cpu()) + dest_lat = torch.deg2rad(dest_lat.cpu()) + + vec = torch.stack(ang2vec(src_lon, src_lat), -1) + + # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. + tree = spatial.KDTree(vec) + vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) + _, neighbors = tree.query(vec, k=k) + regridder = Regridder(dest_lon.shape, k) + regridder.index.copy_(torch.as_tensor(neighbors).view(-1, k)) + if k > 1: + d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) + lam = 1 / (d + eps) + lam = lam / lam.sum(-1, keepdim=True) + regridder.weight.copy_(lam) + + return regridder + + class Identity(torch.nn.Module): def forward(self, x): return x - - -def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: - """Get a regridder from `src` to `dest`""" - if src == dest: - return Identity() - elif isinstance(src, LatLonGrid) and isinstance(dest, LatLonGrid): - return src.get_bilinear_regridder_to(dest.lat, dest.lon) - elif isinstance(src, LatLonGrid) and isinstance(dest, healpix.Grid): - return src.get_bilinear_regridder_to(dest.lat, dest.lon) - elif isinstance(src, healpix.Grid): - return src.get_bilinear_regridder_to(dest.lat, dest.lon) - elif isinstance(dest, healpix.Grid): - return src.get_healpix_regridder(dest) # type: ignore - - raise ValueError(src, dest, "not supported.") diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 6b93176..fb8887c 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -43,6 +43,7 @@ import torch from earth2grid import healpix_bare +from earth2grid._regrid import Regridder try: import pyvista as pv @@ -230,26 +231,6 @@ def _convert_xyindex(nside: int, src: XY, dest: XY, i): return i -class ApplyWeights(torch.nn.Module): - def __init__(self, pix: torch.Tensor, weight: torch.Tensor): - super().__init__() - - # the first dim is the 4 point stencil - n, *self.shape = pix.shape - - pix = pix.view(n, -1).T - weight = weight.view(n, -1).T - - self.register_buffer("index", pix) - self.register_buffer("weight", weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - *shape, npix = x.shape - x = x.view(-1, npix).T - interpolated = torch.nn.functional.embedding_bag(self.index, x, per_sample_weights=self.weight, mode="sum").T - return interpolated.view(shape + self.shape) - - @dataclass class Grid(base.Grid): """A Healpix Grid @@ -345,7 +326,16 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat)) i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel()) i_me = self._nest2me(i_nest).reshape(i_ring.shape) - return ApplyWeights(i_me, weights) + + # reshape to (*, p) + weights = weights.movedim(0, -1) + index = i_me.movedim(0, -1) + + regridder = Regridder(weights.shape[:-1], p=weights.shape[-1]) + regridder.to(weights) + regridder.index.copy_(index) + regridder.weight.copy_(weights) + return regridder def approximate_grid_length_meters(self): return approx_grid_length_meters(self._nside()) diff --git a/earth2grid/latlon.py b/earth2grid/latlon.py index 5ccd45d..4b52669 100644 --- a/earth2grid/latlon.py +++ b/earth2grid/latlon.py @@ -12,12 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math - import numpy as np import torch from earth2grid import base +from earth2grid._regrid import BilinearInterpolator try: import pyvista as pv @@ -25,104 +24,6 @@ pv = None -class BilinearInterpolator(torch.nn.Module): - """Bilinear interpolation for a non-uniform grid""" - - def __init__( - self, - x_coords: torch.Tensor, - y_coords: torch.Tensor, - x_query: torch.Tensor, - y_query: torch.Tensor, - fill_value=math.nan, - ) -> None: - """ - - Args: - x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order. - y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order. - x_query (Tensor): X-coordinates for query points, shape [N]. - y_query (Tensor): Y-coordinates for query points, shape [N]. - """ - super().__init__() - self.fill_value = fill_value - - # Ensure input coordinates are float for interpolation - x_coords, y_coords = x_coords.double(), y_coords.double() - x_query = x_query.double() - y_query = y_query.double() - - if torch.any(x_coords[1:] < x_coords[:-1]): - raise ValueError("x_coords must be in non-decreasing order.") - - if torch.any(y_coords[1:] < y_coords[:-1]): - raise ValueError("y_coords must be in non-decreasing order.") - - # Find indices for the closest lower and upper bounds in x and y directions - x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1 - x_u_idx = x_l_idx + 1 - y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1 - y_u_idx = y_l_idx + 1 - - # fill in nan outside mask - def isin(x, a, b): - return (x <= b) & (x >= a) - - mask = ( - isin(x_l_idx, 0, x_coords.size(0) - 2) - & isin(x_u_idx, 1, x_coords.size(0) - 1) - & isin(y_l_idx, 0, y_coords.size(0) - 2) - & isin(y_u_idx, 1, y_coords.size(0) - 1) - ) - x_u_idx = x_u_idx[mask] - x_l_idx = x_l_idx[mask] - y_u_idx = y_u_idx[mask] - y_l_idx = y_l_idx[mask] - x_query = x_query[mask] - y_query = y_query[mask] - - # Compute weights - x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx]) - x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx]) - y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx]) - y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx]) - weights = torch.stack( - [x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1 - ) - - stride = x_coords.size(-1) - index = torch.stack( - [ - x_l_idx + stride * y_l_idx, - x_u_idx + stride * y_l_idx, - x_l_idx + stride * y_u_idx, - x_u_idx + stride * y_u_idx, - ], - dim=-1, - ) - self.register_buffer("weights", weights) - self.register_buffer("mask", mask) - self.register_buffer("index", index) - - def forward(self, z: torch.Tensor): - """ - Interpolate the field - - Args: - z: shape [Y, X] - """ - *shape, y, x = z.shape - zrs = z.view(-1, y * x).T - # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') - interpolated = torch.full( - [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device - ) - interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) - interpolated = interpolated.T.view(*shape, self.mask.numel()) - return interpolated - - class LatLonGrid(base.Grid): def __init__(self, lat: list[float], lon: list[float]): """ diff --git a/earth2grid/spatial.py b/earth2grid/spatial.py new file mode 100644 index 0000000..974108e --- /dev/null +++ b/earth2grid/spatial.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def haversine_distance(lon1, lat1, lon2, lat2): + """ + Calculate the Haversine distance between two points on unit sphere + + Args: + lon1 (float): Longitude of the first point in radians. + lat1 (float): Latitude of the first point in radians. + lon2 (float): Longitude of the second point in radians. + lat2 (float): Latitude of the second point in radians. + + Returns: + float: Distance between the two points in kilometers. + """ + # Differences in coordinates + dlon = lon2 - lon1 + dlat = lat2 - lat1 + + # Haversine formula + a = torch.sin(dlat / 2) ** 2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon / 2) ** 2 + c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a)) + return c + + +def ang2vec(lon, lat): + """convert lon,lat in radians to cartesian coordinates""" + x = torch.cos(lat) * torch.cos(lon) + y = torch.cos(lat) * torch.sin(lon) + z = torch.sin(lat) + return (x, y, z) diff --git a/pyproject.toml b/pyproject.toml index f8bc513..0b54520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "netCDF4>=1.6.5", "numpy>=1.23.3", "torch>=2.0.1", + "scipy" ] [project.urls] diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 29bf264..73a6b9e 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -20,7 +20,7 @@ import torch import earth2grid -from earth2grid.latlon import BilinearInterpolator +from earth2grid import BilinearInterpolator @pytest.mark.parametrize("with_channels", [True, False]) @@ -195,3 +195,30 @@ def test_out_of_bounds(): output = regrid(data) assert torch.all(torch.isnan(output)) + + +@pytest.mark.parametrize("k", [1, 2, 3]) +def test_NearestNeighborInterpolator(k): + n = 10000 + m = 887 + torch.manual_seed(0) + lon = torch.rand(n) * 360 + lat = torch.rand(n) * 180 - 90 + + lond = torch.rand(m) * 360 + latd = torch.rand(m) * 180 - 90 + + interpolate = earth2grid.KNNS2Interpolator(lon, lat, lond, latd, k=k) + out = interpolate(torch.cos(torch.deg2rad(lon))) + expected = torch.cos(torch.deg2rad(lond)) + mae = torch.mean(torch.abs(out - expected)) + assert mae.item() < 0.02 + + # load-reload + earth2grid.Regridder.from_state_dict(interpolate.state_dict()) + + # try batched interpolation + x = torch.cos(torch.deg2rad(lon)) + x = x.unsqueeze(0) + out = interpolate(x) + assert out.shape == (1, m)