Skip to content

Commit

Permalink
Refactor repack_raster to work in block for lower memory (#473)
Browse files Browse the repository at this point in the history
* Refactor `repack_raster` to work in block for lower memory

Also adds missing unit tests

* move import to top of function

* fix spurt version, force paths during test
  • Loading branch information
scottstanie authored Oct 30, 2024
1 parent 3aec7d0 commit b90c04c
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-build-push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
"opera-utils>=0.4.1" \
git+https://github.com/isce-framework/tophu@main \
git+https://github.com/isce-framework/whirlwind@40defb38d2d6deca2819934788ebbc57e418e32d
python -m pip install git+https://github.com/scottstanie/spurt@use-mp-spawn
python -m pip install git+https://github.com/isce-framework/spurt@5bdf88089351d09bba17cfbb5cacaab1441e9b78
python -m pip install --no-deps .
- name: Install test dependencies
run: |
Expand Down
38 changes: 33 additions & 5 deletions src/dolphin/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def repack_raster(
raster_path: Path,
output_dir: Path | None = None,
keep_bits: int | None = None,
block_shape: int | tuple[int, int] = (1024, 1024),
**output_options,
) -> Path:
"""Repack a single raster file with GDAL Translate using provided options.
Expand All @@ -129,6 +130,8 @@ def repack_raster(
keep_bits : int, optional
Number of bits to preserve in mantissa. Defaults to None.
Lower numbers will truncate the mantissa more and enable more compression.
block_shape: int | tuple[int, int]
Size of blocks to read in at one time.
**output_options
Keyword args passed to `get_gtiff_options`
Expand All @@ -140,6 +143,12 @@ def repack_raster(
"""
import rasterio as rio
from rasterio.windows import Window

from ._blocks import iter_blocks

if isinstance(block_shape, int):
block_shape = (block_shape, block_shape)

if output_dir is None:
output_file = tempfile.NamedTemporaryFile( # noqa: SIM115
Expand All @@ -151,20 +160,32 @@ def repack_raster(
output_path = output_dir / raster_path.name

options = get_gtiff_options(**output_options)

with rio.open(raster_path) as src:
profile = src.profile
profile.update(**options)
# Work in blocks on the input raster
blocks = iter_blocks(
arr_shape=(src.height, src.width),
block_shape=block_shape,
)

with rio.open(output_path, "w", **profile) as dst:
for i in range(1, src.count + 1):
data = src.read(i)
if keep_bits is not None:
round_mantissa(data, keep_bits)
dst.write(data, i)
for row_slice, col_slice in blocks:
window = Window.from_slices(rows=row_slice, cols=col_slice)
data = src.read(i, window=window)

if keep_bits is not None:
round_mantissa(data, keep_bits)

dst.write(data, i, window=window)

if output_dir is None:
# Overwrite the original
shutil.move(output_path, raster_path)
output_path = raster_path

return output_path


Expand All @@ -173,6 +194,7 @@ def repack_rasters(
output_dir: Path | None = None,
num_threads: int = 4,
keep_bits: int | None = None,
block_shape: int | tuple[int, int] = (1024, 1024),
**output_options,
):
"""Recreate and compress a list of raster files.
Expand All @@ -191,6 +213,8 @@ def repack_rasters(
keep_bits : int, optional
Number of bits to preserve in mantissa. Defaults to None.
Lower numbers will truncate the mantissa more and enable more compression.
block_shape: int | tuple[int, int]
Size of blocks to read in at one time.
**output_options
Creation options to pass to `get_gtiff_options`
Expand All @@ -205,7 +229,11 @@ def repack_rasters(

thread_map(
lambda raster: repack_raster(
raster, output_dir, keep_bits=keep_bits, **output_options
raster,
output_dir,
keep_bits=keep_bits,
block_shape=block_shape,
**output_options,
),
raster_files,
max_workers=num_threads,
Expand Down
104 changes: 104 additions & 0 deletions tests/test_io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import shutil
import tempfile
from pathlib import Path

import numpy as np
import pytest
import rasterio as rio

from dolphin.io._utils import repack_raster, repack_rasters


@pytest.fixture(scope="module")
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)


@pytest.fixture(params=["float32", "complex64", "uint8"])
def test_raster(request, temp_dir):
dtype = request.param
raster_path = temp_dir / f"test_raster_{dtype}.tif"

# Create a test raster
data = np.random.rand(100, 100).astype(dtype)
if np.dtype(dtype) == np.complex64:
data = data + 1j * np.random.rand(100, 100)

profile = {
"driver": "GTiff",
"height": 100,
"width": 100,
"count": 1,
"dtype": str(dtype),
"crs": "EPSG:4326",
"transform": rio.transform.from_bounds(0, 0, 1, 1, 100, 100),
}

with rio.open(raster_path, "w", **profile) as dst:
dst.write(data, 1)

return raster_path


def test_repack_raster(test_raster, temp_dir):
output_dir = temp_dir / "output"
keep_bits = 10
with rio.open(test_raster) as src:
dtype = src.dtypes[0]

if np.dtype(dtype) == np.uint8:
with pytest.raises(TypeError):
repack_raster(
Path(test_raster),
output_dir=output_dir,
keep_bits=keep_bits,
block_shape=(32, 32),
)
return
output_path = repack_raster(
Path(test_raster),
output_dir=output_dir,
keep_bits=keep_bits,
block_shape=(32, 32),
)

assert output_path.exists()
assert output_path.parent == output_dir

with rio.open(test_raster) as src, rio.open(output_path) as dst:
assert src.profile["dtype"] == dst.profile["dtype"]
old, new = src.read(), dst.read()
assert old.shape == new.shape
tol = 2**keep_bits

# Check if data is close but not exactly the same (due to keep_bits)
np.testing.assert_allclose(old, new, atol=tol)


def test_repack_rasters(test_raster, temp_dir):
keep_bits = 10

# Add another to test the threaded version
new_raster = str(test_raster) + ".copy.tif"
shutil.copy(test_raster, new_raster)
raster_paths = [Path(test_raster), Path(new_raster)]

output_dir = temp_dir / "output_multiple"
with rio.open(raster_paths[0]) as src:
dtype = src.dtypes[0]
if np.dtype(dtype) == np.uint8:
with pytest.raises(TypeError):
repack_rasters(
raster_paths,
output_dir=output_dir,
keep_bits=keep_bits,
block_shape=(32, 32),
)
return
repack_rasters(
raster_paths,
output_dir=output_dir,
keep_bits=keep_bits,
block_shape=(32, 32),
)

0 comments on commit b90c04c

Please sign in to comment.