diff --git a/.github/workflows/test-build-push.yml b/.github/workflows/test-build-push.yml index c8a581f8..cbe49916 100644 --- a/.github/workflows/test-build-push.yml +++ b/.github/workflows/test-build-push.yml @@ -51,7 +51,7 @@ jobs: "opera-utils>=0.4.1" \ git+https://github.com/isce-framework/tophu@main \ git+https://github.com/isce-framework/whirlwind@40defb38d2d6deca2819934788ebbc57e418e32d - python -m pip install git+https://github.com/scottstanie/spurt@use-mp-spawn + python -m pip install git+https://github.com/isce-framework/spurt@5bdf88089351d09bba17cfbb5cacaab1441e9b78 python -m pip install --no-deps . - name: Install test dependencies run: | diff --git a/src/dolphin/io/_utils.py b/src/dolphin/io/_utils.py index fe7f0e8e..cd7ee13f 100644 --- a/src/dolphin/io/_utils.py +++ b/src/dolphin/io/_utils.py @@ -116,6 +116,7 @@ def repack_raster( raster_path: Path, output_dir: Path | None = None, keep_bits: int | None = None, + block_shape: int | tuple[int, int] = (1024, 1024), **output_options, ) -> Path: """Repack a single raster file with GDAL Translate using provided options. @@ -129,6 +130,8 @@ def repack_raster( keep_bits : int, optional Number of bits to preserve in mantissa. Defaults to None. Lower numbers will truncate the mantissa more and enable more compression. + block_shape: int | tuple[int, int] + Size of blocks to read in at one time. **output_options Keyword args passed to `get_gtiff_options` @@ -140,6 +143,12 @@ def repack_raster( """ import rasterio as rio + from rasterio.windows import Window + + from ._blocks import iter_blocks + + if isinstance(block_shape, int): + block_shape = (block_shape, block_shape) if output_dir is None: output_file = tempfile.NamedTemporaryFile( # noqa: SIM115 @@ -151,20 +160,32 @@ def repack_raster( output_path = output_dir / raster_path.name options = get_gtiff_options(**output_options) + with rio.open(raster_path) as src: profile = src.profile profile.update(**options) + # Work in blocks on the input raster + blocks = iter_blocks( + arr_shape=(src.height, src.width), + block_shape=block_shape, + ) + with rio.open(output_path, "w", **profile) as dst: for i in range(1, src.count + 1): - data = src.read(i) - if keep_bits is not None: - round_mantissa(data, keep_bits) - dst.write(data, i) + for row_slice, col_slice in blocks: + window = Window.from_slices(rows=row_slice, cols=col_slice) + data = src.read(i, window=window) + + if keep_bits is not None: + round_mantissa(data, keep_bits) + + dst.write(data, i, window=window) if output_dir is None: # Overwrite the original shutil.move(output_path, raster_path) output_path = raster_path + return output_path @@ -173,6 +194,7 @@ def repack_rasters( output_dir: Path | None = None, num_threads: int = 4, keep_bits: int | None = None, + block_shape: int | tuple[int, int] = (1024, 1024), **output_options, ): """Recreate and compress a list of raster files. @@ -191,6 +213,8 @@ def repack_rasters( keep_bits : int, optional Number of bits to preserve in mantissa. Defaults to None. Lower numbers will truncate the mantissa more and enable more compression. + block_shape: int | tuple[int, int] + Size of blocks to read in at one time. **output_options Creation options to pass to `get_gtiff_options` @@ -205,7 +229,11 @@ def repack_rasters( thread_map( lambda raster: repack_raster( - raster, output_dir, keep_bits=keep_bits, **output_options + raster, + output_dir, + keep_bits=keep_bits, + block_shape=block_shape, + **output_options, ), raster_files, max_workers=num_threads, diff --git a/tests/test_io_utils.py b/tests/test_io_utils.py new file mode 100644 index 00000000..9f93274c --- /dev/null +++ b/tests/test_io_utils.py @@ -0,0 +1,104 @@ +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import rasterio as rio + +from dolphin.io._utils import repack_raster, repack_rasters + + +@pytest.fixture(scope="module") +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield Path(tmpdirname) + + +@pytest.fixture(params=["float32", "complex64", "uint8"]) +def test_raster(request, temp_dir): + dtype = request.param + raster_path = temp_dir / f"test_raster_{dtype}.tif" + + # Create a test raster + data = np.random.rand(100, 100).astype(dtype) + if np.dtype(dtype) == np.complex64: + data = data + 1j * np.random.rand(100, 100) + + profile = { + "driver": "GTiff", + "height": 100, + "width": 100, + "count": 1, + "dtype": str(dtype), + "crs": "EPSG:4326", + "transform": rio.transform.from_bounds(0, 0, 1, 1, 100, 100), + } + + with rio.open(raster_path, "w", **profile) as dst: + dst.write(data, 1) + + return raster_path + + +def test_repack_raster(test_raster, temp_dir): + output_dir = temp_dir / "output" + keep_bits = 10 + with rio.open(test_raster) as src: + dtype = src.dtypes[0] + + if np.dtype(dtype) == np.uint8: + with pytest.raises(TypeError): + repack_raster( + Path(test_raster), + output_dir=output_dir, + keep_bits=keep_bits, + block_shape=(32, 32), + ) + return + output_path = repack_raster( + Path(test_raster), + output_dir=output_dir, + keep_bits=keep_bits, + block_shape=(32, 32), + ) + + assert output_path.exists() + assert output_path.parent == output_dir + + with rio.open(test_raster) as src, rio.open(output_path) as dst: + assert src.profile["dtype"] == dst.profile["dtype"] + old, new = src.read(), dst.read() + assert old.shape == new.shape + tol = 2**keep_bits + + # Check if data is close but not exactly the same (due to keep_bits) + np.testing.assert_allclose(old, new, atol=tol) + + +def test_repack_rasters(test_raster, temp_dir): + keep_bits = 10 + + # Add another to test the threaded version + new_raster = str(test_raster) + ".copy.tif" + shutil.copy(test_raster, new_raster) + raster_paths = [Path(test_raster), Path(new_raster)] + + output_dir = temp_dir / "output_multiple" + with rio.open(raster_paths[0]) as src: + dtype = src.dtypes[0] + if np.dtype(dtype) == np.uint8: + with pytest.raises(TypeError): + repack_rasters( + raster_paths, + output_dir=output_dir, + keep_bits=keep_bits, + block_shape=(32, 32), + ) + return + repack_rasters( + raster_paths, + output_dir=output_dir, + keep_bits=keep_bits, + block_shape=(32, 32), + )