Skip to content

Commit

Permalink
Merge pull request #85 from mraspaud/add-spline-interpolation
Browse files Browse the repository at this point in the history
Add a spline interpolator for 2d arrays
  • Loading branch information
mraspaud authored Oct 11, 2024
2 parents 5f44745 + ffc1628 commit 11ba92f
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 49 deletions.
60 changes: 36 additions & 24 deletions geotiepoints/geointerpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Geographical interpolation (lon/lats)."""

import numpy as np
from geotiepoints.interpolator import Interpolator, MultipleGridInterpolator
from geotiepoints.interpolator import Interpolator, MultipleGridInterpolator, MultipleSplineInterpolator


EARTH_RADIUS = 6370997.0
Expand Down Expand Up @@ -92,26 +92,38 @@ def xyz2lonlat(x__, y__, z__, radius=EARTH_RADIUS, thr=0.8, low_lat_z=True):
return lons, lats


class GeoGridInterpolator(MultipleGridInterpolator):
"""Interpolate geographical coordinates from a regular grid of tie points."""

def __init__(self, tie_points, *data, **kwargs):
"""Set up the interpolator."""
if len(data) == 1:
xyz = data[0].get_cartesian_coords()
data = [xyz[:, :, 0], xyz[:, :, 1], xyz[:, :, 2]]
elif len(data) == 2:
data = lonlat2xyz(*data)
else:
raise ValueError("Either pass lon/lats or a pyresample definition.")
super().__init__(tie_points, *data, **kwargs)

def interpolate(self, fine_points, **kwargs):
"""Interpolate to *fine_points*."""
x, y, z = super().interpolate(fine_points, **kwargs)
return xyz2lonlat(x, y, z)

def interpolate_to_shape(self, shape, **kwargs):
"""Interpolate to a given *shape*."""
fine_points = [np.arange(size) for size in shape]
return self.interpolate(fine_points, **kwargs)
def _work_with_lonlats(klass):
"""Adapt MultipleInterpolator classes to work with geographical coordinates."""

class GeoKlass(klass):

def __init__(self, tie_points, *data, **interpolator_init_kwargs):
"""Set up the interpolator."""
data = to_xyz(data)
super().__init__(tie_points, *data, **interpolator_init_kwargs)

def interpolate(self, fine_points, **interpolator_call_kwargs):
"""Interpolate to *fine_points*."""
x, y, z = super().interpolate(fine_points, **interpolator_call_kwargs)
return xyz2lonlat(x, y, z)

return GeoKlass


def to_xyz(data):
"""Convert data to cartesian.
Data can be a class with a `get_cartesian_coords` method, or a tuple of (lon, lat) arrays.
"""
if len(data) == 1:
xyz = data[0].get_cartesian_coords()
data = [xyz[:, :, 0], xyz[:, :, 1], xyz[:, :, 2]]
elif len(data) == 2:
data = lonlat2xyz(*data)
else:
raise ValueError("Either pass lon/lats or a pyresample definition.")
return data


GeoGridInterpolator = _work_with_lonlats(MultipleGridInterpolator)
GeoSplineInterpolator = _work_with_lonlats(MultipleSplineInterpolator)
98 changes: 74 additions & 24 deletions geotiepoints/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Generic interpolation routines."""

from abc import ABC, abstractmethod
from functools import partial
import numpy as np
from scipy.interpolate import RectBivariateSpline, splev, splrep, RegularGridInterpolator

Expand Down Expand Up @@ -241,37 +243,36 @@ def interpolate(self):
return self.new_data


class SingleGridInterpolator:
"""An interpolator for a single 2d data array."""
class AbstractSingleInterpolator(ABC):
"""An abstract interpolator for a single 2d data array."""

def __init__(self, points, values, **kwargs):
def __init__(self, points, values, scipy_interpolator, **interpolator_init_kwargs):
"""Set up the interpolator.
*kwargs* are passed to the underlying RegularGridInterpolator instance.
*kwargs* are passed to the underlying scipy interpolator instance.
So for example, to allow extrapolation, the kwargs can be `bounds_error=False, fill_value=None`.
"""
self.interpolator = RegularGridInterpolator(points, values, **kwargs)
self.interpolator = scipy_interpolator(points, values, **interpolator_init_kwargs)
self.points = points
self.values = values

def interpolate(self, fine_points, method="linear", chunks=None):
def interpolate(self, fine_points, chunks=None, **interpolator_call_kwargs):
"""Interpolate the value points to the *fine_points* grid.
Args:
fine_points: the points on the target grid to use, as one dimensional vectors for each dimension.
method: the method to use for interpolation as described in RegularGridInterpolator's documentation.
Default is "linear".
chunks: If not None, a lazy (dask-based) interpolation will be performed using the chunk sizes specified.
The result will be a dask array in this case. Defaults to None.
interpolator_kwargs: The keyword arguments to pass to the underlying scipy interpolator.
"""
if chunks is not None:
res = self.interpolate_dask(fine_points, method=method, chunks=chunks)
res = self.interpolate_dask(fine_points, chunks=chunks, **interpolator_call_kwargs)
else:
res = self.interpolate_numpy(fine_points, method=method)
res = self.interpolate_numpy(fine_points, **interpolator_call_kwargs)

return res

def interpolate_dask(self, fine_points, method, chunks):
def interpolate_dask(self, fine_points, chunks, **interpolator_call_kwargs):
"""Interpolate (lazily) to a dask array."""
from dask.base import tokenize
import dask.array as da
Expand All @@ -281,25 +282,26 @@ def interpolate_dask(self, fine_points, method, chunks):

chunks = normalize_chunks(chunks, shape, dtype=self.values.dtype)

token = tokenize(chunks, self.points, self.values, fine_points, method)
token = tokenize(chunks, self.points, self.values, fine_points, interpolator_call_kwargs)
name = 'interpolate-' + token

dskx = {(name, ) + position: (self.interpolate_slices,
slices,
method)
interpolate_slices = partial(self.interpolate_slices, **interpolator_call_kwargs)

dskx = {(name, ) + position: (interpolate_slices,
slices)
for position, slices in _enumerate_chunk_slices(chunks)}

res = da.Array(dskx, name, shape=list(shape),
chunks=chunks,
dtype=self.values.dtype)
return res

def interpolate_numpy(self, fine_points, method="linear"):
@abstractmethod
def interpolate_numpy(self, fine_points, **interpolator_call_kwargs):
"""Interpolate to a numpy array."""
fine_x, fine_y = np.meshgrid(*fine_points, indexing='ij')
return self.interpolator((fine_x, fine_y), method=method).astype(self.values.dtype)
raise NotImplementedError

def interpolate_slices(self, fine_points, method="linear"):
def interpolate_slices(self, fine_points, **interpolator_call_kwargs):
"""Interpolate using slices.
*fine_points* are a tuple of slices for the y and x dimensions
Expand All @@ -309,7 +311,7 @@ def interpolate_slices(self, fine_points, method="linear"):
points_x = np.arange(slice_x.start, slice_x.stop)
fine_points = points_y, points_x

return self.interpolate_numpy(fine_points, method=method)
return self.interpolate_numpy(fine_points, **interpolator_call_kwargs)


def _enumerate_chunk_slices(chunks):
Expand All @@ -324,18 +326,66 @@ def _enumerate_chunk_slices(chunks):
yield (position, slices)


class MultipleGridInterpolator:
"""Interpolator that works on multiple data arrays."""
class SingleGridInterpolator(AbstractSingleInterpolator):
"""A regular grid interpolator for a single 2d data array."""

def __init__(self, *args, **interpolator_init_kwargs):
"""Set up the grid interpolator."""
super().__init__(*args, scipy_interpolator=RegularGridInterpolator, **interpolator_init_kwargs)

def interpolate_numpy(self, fine_points, **interpolator_call_kwargs):
"""Interpolate to a numpy array."""
fine_x, fine_y = np.meshgrid(*fine_points, indexing='ij')
return self.interpolator((fine_x, fine_y), **interpolator_call_kwargs).astype(self.values.dtype)


class SingleSplineInterpolator(AbstractSingleInterpolator):
"""An spline interpolator for a single 2d data array."""

def __init__(self, points, values, **interpolator_init_kwargs):
"""Set up the spline interpolator."""
self.interpolator = RectBivariateSpline(*points, values, **interpolator_init_kwargs)
self.points = points
self.values = values

def interpolate_numpy(self, fine_points, **interpolator_call_kwargs):
"""Interpolate to a numpy array."""
return self.interpolator(*fine_points, **interpolator_call_kwargs).astype(self.values.dtype)


class AbstractMultipleInterpolator(ABC): # noqa: B024
"""Abstract interpolator that works on mulitple arrays."""

def __init__(self, tie_points, *data, **kwargs):
def __init__(self, interpolator, tie_points, *data, **interpolator_init_kwargs):
"""Set up the interpolator from the multiple `data` arrays."""
self.interpolators = []
for values in data:
self.interpolators.append(SingleGridInterpolator(tie_points, values, **kwargs))
self.interpolators.append(interpolator(tie_points, values, **interpolator_init_kwargs))

def interpolate(self, fine_points, **kwargs):
"""Interpolate the data.
The keyword arguments will be passed on to SingleGridInterpolator's interpolate function.
"""
return (interpolator.interpolate(fine_points, **kwargs) for interpolator in self.interpolators)

def interpolate_to_shape(self, shape, **interpolator_call_kwargs):
"""Interpolate to a given *shape*."""
fine_points = [np.arange(size) for size in shape]
return self.interpolate(fine_points, **interpolator_call_kwargs)


class MultipleGridInterpolator(AbstractMultipleInterpolator):
"""Grid interpolator that works on multiple data arrays."""

def __init__(self, tie_points, *data, **interpolator_init_kwargs):
"""Set up the interpolator from the multiple `data` arrays."""
super().__init__(SingleGridInterpolator, tie_points, *data, **interpolator_init_kwargs)


class MultipleSplineInterpolator(AbstractMultipleInterpolator):
"""Spline interpolator that works on multiple data arrays."""

def __init__(self, tie_points, *data, **interpolator_init_kwargs):
"""Set up the interpolator from the multiple `data` arrays."""
super().__init__(SingleSplineInterpolator, tie_points, *data, **interpolator_init_kwargs)
78 changes: 77 additions & 1 deletion geotiepoints/tests/test_geointerpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytest
from pyresample.geometry import SwathDefinition

from geotiepoints.geointerpolator import GeoInterpolator, GeoGridInterpolator
from geotiepoints.geointerpolator import GeoInterpolator, GeoGridInterpolator, GeoSplineInterpolator

TIES_EXP1 = np.array([[6384905.78040055, 6381081.08333225, 6371519.34066148,
6328950.00792935, 6253610.69157758, 6145946.19489936,
Expand Down Expand Up @@ -291,3 +291,79 @@ def test_geogrid_interpolation_can_extrapolate(self):
lons, lats = interpolator.interpolate_to_shape((16, 16), method="cubic")

assert lons.shape == (16, 16)


class TestGeoSplineInterpolator:
"""Test the GeoGridInterpolator."""

@pytest.mark.parametrize("args", ((TIE_LONS, TIE_LATS),
[SwathDefinition(TIE_LONS, TIE_LATS)]
))
def test_geospline_interpolation(self, args):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoSplineInterpolator((y_points, x_points), *args, kx=1, ky=1)

fine_x_points = np.arange(8)
fine_y_points = np.arange(16)

lons, lats = interpolator.interpolate((fine_y_points, fine_x_points))

lons_expected = np.array([1., 2., 2.5, 3., 3.25, 3.5, 3.75, 4.])
lats_expected = np.array([1., 2., 2.5, 3., 3.25, 3.5, 3.75, 4., 4.125,
4.25, 4.375, 4.5, 4.625, 4.75, 4.875, 5.])

np.testing.assert_allclose(lons[0, :], lons_expected, rtol=5e-5)
np.testing.assert_allclose(lats[:, 0], lats_expected, rtol=5e-5)

def test_geospline_interpolation_to_shape(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoSplineInterpolator((y_points, x_points), TIE_LONS, TIE_LATS, kx=1, ky=1)

lons, lats = interpolator.interpolate_to_shape((16, 8))

lons_expected = np.array([1., 2., 2.5, 3., 3.25, 3.5, 3.75, 4.])
lats_expected = np.array([1., 2., 2.5, 3., 3.25, 3.5, 3.75, 4., 4.125,
4.25, 4.375, 4.5, 4.625, 4.75, 4.875, 5.])

np.testing.assert_allclose(lons[0, :], lons_expected, rtol=5e-5)
np.testing.assert_allclose(lats[:, 0], lats_expected, rtol=5e-5)

def test_geospline_interpolation_preserves_dtype(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoGridInterpolator((y_points, x_points),
TIE_LONS.astype(np.float32), TIE_LATS.astype(np.float32))

lons, lats = interpolator.interpolate_to_shape((16, 8))

assert lons.dtype == np.float32
assert lats.dtype == np.float32

def test_chunked_geospline_interpolation(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
dask = pytest.importorskip("dask")

x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoGridInterpolator((y_points, x_points),
TIE_LONS.astype(np.float32), TIE_LATS.astype(np.float32))

lons, lats = interpolator.interpolate_to_shape((16, 8), chunks=4)

assert lons.chunks == ((4, 4, 4, 4), (4, 4))
assert lats.chunks == ((4, 4, 4, 4), (4, 4))

with dask.config.set({"array.chunk-size": 64}):

lons, lats = interpolator.interpolate_to_shape((16, 8), chunks="auto")
assert lons.chunks == ((4, 4, 4, 4), (4, 4))
assert lats.chunks == ((4, 4, 4, 4), (4, 4))

0 comments on commit 11ba92f

Please sign in to comment.