From 0c682a5ba670c4e51255a82a1ec3ebe56fd70ede Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Tue, 8 Oct 2024 12:51:53 -0700 Subject: [PATCH 1/3] Convert `glrt` SHP method to use JAX --- src/dolphin/shp/_glrt.py | 222 ++++++++++++--------------------------- src/dolphin/utils.py | 4 +- 2 files changed, 69 insertions(+), 157 deletions(-) diff --git a/src/dolphin/shp/_glrt.py b/src/dolphin/shp/_glrt.py index 13dc6d6f..bbfac5f0 100644 --- a/src/dolphin/shp/_glrt.py +++ b/src/dolphin/shp/_glrt.py @@ -1,100 +1,91 @@ from __future__ import annotations import csv -from functools import lru_cache -from math import log +from functools import lru_cache, partial from pathlib import Path -from typing import Optional -import numba -import numpy as np +import jax.numpy as jnp +from jax import Array, jit, lax, vmap from numpy.typing import ArrayLike -from dolphin._types import Strides -from dolphin.utils import _get_slices, compute_out_shape +from dolphin._types import HalfWindow +from dolphin.utils import compute_out_shape -from ._common import remove_unconnected -_get_slices = numba.njit(_get_slices) +@lru_cache +def _read_cutoff_csv(): + filename = Path(__file__).parent / "glrt_cutoffs.csv" + result = {} + with open(filename) as file: + reader = csv.DictReader(file) + for row in reader: + n = int(row["N"]) + alpha = float(row["alpha"]) + cutoff = float(row["cutoff"]) + result[(n, alpha)] = cutoff + return result + + +@partial(jit, static_argnames=["half_window", "strides", "nslc", "alpha"]) def estimate_neighbors( mean: ArrayLike, var: ArrayLike, - halfwin_rowcol: tuple[int, int], + half_window: HalfWindow, nslc: int, - strides: Optional[dict] = None, - alpha: float = 0.05, - prune_disconnected: bool = False, + strides: tuple[int, int] = (1, 1), + alpha: float = 0.001, ): - """Estimate the number of neighbors based on the GLRT. + """Estimate the number of neighbors based on the GLRT.""" + # Convert mean/var to the Rayleigh scale parameter + rows, cols = mean.shape + half_row, half_col = half_window + row_strides, col_strides = strides - Based on the method described in [@Parizzi2011AdaptiveInSARStack]. - Assumes Rayleigh distributed amplitudes ([@Siddiqui1962ProblemsConnectedRayleigh]) + in_r_start = row_strides // 2 + in_c_start = col_strides // 2 + out_rows, out_cols = compute_out_shape((rows, cols), strides) - Parameters - ---------- - mean : ArrayLike, 2D - Mean amplitude of each pixel. - var: ArrayLike, 2D - Variance of each pixel's amplitude. - halfwin_rowcol : tuple[int, int] - Half the size of the block in (row, col) dimensions - nslc : int - Number of images in the stack used to compute `mean` and `var`. - Used to compute the degrees of freedom for the t- and F-tests to - determine the critical values. - strides: dict, optional - The (x, y) strides (in pixels) to use for the sliding window. - By default {"x": 1, "y": 1} - alpha : float, default=0.05 - Significance level at which to reject the null hypothesis. - Rejecting means declaring a neighbor is not a SHP. - prune_disconnected : bool, default=False - If True, keeps only SHPs that are 8-connected to the current pixel. - Otherwise, any pixel within the window may be considered an SHP, even - if it is not directly connected. - - - Notes - ----- - When `strides` is not (1, 1), the output first two dimensions - are smaller than `mean` and `var` by a factor of `strides`. This - will match the downstream shape of the strided phase linking results. + scale_squared = (var + mean**2) / 2 + threshold = get_cutoff_jax(alpha=alpha, N=nslc) - Returns - ------- - is_shp : np.ndarray, 4D - Boolean array marking which neighbors are SHPs for each pixel in the block. - Shape is (out_rows, out_cols, window_rows, window_cols), where - `out_rows` and `out_cols` are computed by - `[dolphin.io.compute_out_shape][]` - `window_rows = 2 * halfwin_rowcol[0] + 1` - `window_cols = 2 * halfwin_rowcol[1] + 1` + def _get_window(arr, r: int, c: int, half_row: int, half_col: int) -> Array: + r0 = r - half_row + c0 = c - half_col + start_indices = (r0, c0) - """ - if strides is None: - strides = {"x": 1, "y": 1} - half_row, half_col = halfwin_rowcol - rows, cols = mean.shape + rsize = 2 * half_row + 1 + csize = 2 * half_col + 1 + slice_sizes = (rsize, csize) - threshold = get_cutoff(alpha=alpha, N=nslc) + return lax.dynamic_slice(arr, start_indices, slice_sizes) - strides_rowcol = (strides["y"], strides["x"]) - out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides_rowcol)) - is_shp = np.zeros( - (out_rows, out_cols, 2 * half_row + 1, 2 * half_col + 1), dtype=np.bool_ - ) - return _loop_over_pixels( - mean, - var, - halfwin_rowcol, - strides_rowcol, - threshold, - prune_disconnected, - is_shp, + def _process_row_col(out_r, out_c): + in_r = in_r_start + out_r * row_strides + in_c = in_c_start + out_c * col_strides + + scale_1 = scale_squared[in_r, in_c] # One pixel + # and one window for scale 2, will broadcast + scale_2 = _get_window(scale_squared, in_r, in_c, half_row, half_col) + + # Compute the GLRT test statistic. + scale_pooled = (scale_1 + scale_2) / 2 + test_stat = 2 * jnp.log(scale_pooled) - jnp.log(scale_1) - jnp.log(scale_2) + + return threshold > test_stat + + # Now make a 2D grid of indices to access all output pixels + out_r_indices, out_c_indices = jnp.meshgrid( + jnp.arange(out_rows), jnp.arange(out_cols), indexing="ij" ) + # Create the vectorized function in 2d + _process_2d = vmap(_process_row_col) + # Then in 3d + _process_3d = vmap(_process_2d) + return _process_3d(out_r_indices, out_c_indices) + def get_cutoff(alpha: float, N: int) -> float: r"""Compute the upper cutoff for the GLRT test statistic. @@ -119,88 +110,7 @@ def get_cutoff(alpha: float, N: int) -> float: """ n_alpha_to_cutoff = _read_cutoff_csv() - try: - return n_alpha_to_cutoff[(N, alpha)] - except KeyError as e: - msg = f"Not implemented for {N = }, {alpha = }" - raise NotImplementedError(msg) from e - + return n_alpha_to_cutoff[(max(N, 50), alpha)] -@numba.njit(nogil=True) -def _compute_glrt_test_stat(scale_1, scale_2): - """Compute the GLRT test statistic.""" - scale_pooled = (scale_1 + scale_2) / 2 - return 2 * log(scale_pooled) - log(scale_1) - log(scale_2) - -@numba.njit(nogil=True, parallel=True) -def _loop_over_pixels( - mean: ArrayLike, - var: ArrayLike, - halfwin_rowcol: tuple[int, int], - strides_rowcol: tuple[int, int], - threshold: float, - prune_disconnected: bool, - is_shp: np.ndarray, -) -> np.ndarray: - """Loop common to SHP tests using only mean and variance.""" - half_row, half_col = halfwin_rowcol - row_strides, col_strides = strides_rowcol - # location to start counting from in the larger input - r0, c0 = row_strides // 2, col_strides // 2 - in_rows, in_cols = mean.shape - out_rows, out_cols = is_shp.shape[:2] - - # Convert mean/var to the Rayleigh scale parameter - scale_squared = (var + mean**2) / 2 - - for out_r in numba.prange(out_rows): - for out_c in range(out_cols): - in_r = r0 + out_r * row_strides - in_c = c0 + out_c * col_strides - - scale_1 = scale_squared[in_r, in_c] - # Clamp the window to the image bounds - (r_start, r_end), (c_start, c_end) = _get_slices( - half_row, half_col, in_r, in_c, in_rows, in_cols - ) - if mean[in_r, in_c] == 0: - # Skip nodata pixels - continue - - for in_r2 in range(r_start, r_end): - for in_c2 in range(c_start, c_end): - # window offsets for dims 3,4 of `is_shp` - r_off = in_r2 - r_start - c_off = in_c2 - c_start - - # Don't count itself as a neighbor - if in_r2 == in_r and in_c2 == in_c: - is_shp[out_r, out_c, r_off, c_off] = False - continue - scale_2 = scale_squared[in_r2, in_c2] - - T = _compute_glrt_test_stat(scale_1, scale_2) - - is_shp[out_r, out_c, r_off, c_off] = threshold > T - if prune_disconnected: - # For this pixel, prune the groups not connected to the center - remove_unconnected(is_shp[out_r, out_c], inplace=True) - - return is_shp - - -@lru_cache -def _read_cutoff_csv(): - filename = Path(__file__).parent / "glrt_cutoffs.csv" - - result = {} - with open(filename) as file: - reader = csv.DictReader(file) - for row in reader: - n = int(row["N"]) - alpha = float(row["alpha"]) - cutoff = float(row["cutoff"]) - result[(n, alpha)] = cutoff - - return result +get_cutoff_jax = jit(get_cutoff, static_argnames=["alpha", "N"]) diff --git a/src/dolphin/utils.py b/src/dolphin/utils.py index a5cb9789..9aa11f1a 100644 --- a/src/dolphin/utils.py +++ b/src/dolphin/utils.py @@ -634,7 +634,9 @@ def prepare_geometry( return stitched_geo_list -def compute_out_shape(shape: tuple[int, int], strides: Strides) -> tuple[int, int]: +def compute_out_shape( + shape: tuple[int, int], strides: Strides | tuple[int, int] +) -> tuple[int, int]: """Calculate the output size for an input `shape` and row/col `strides`. Parameters From b68c860e60c93ce4b79395280d242866d4c5933c Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Thu, 10 Oct 2024 23:01:36 -0400 Subject: [PATCH 2/3] fix test --- tests/test_shp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_shp.py b/tests/test_shp.py index b3743459..0cd6db93 100644 --- a/tests/test_shp.py +++ b/tests/test_shp.py @@ -160,7 +160,7 @@ def test_shp_half_var_different(mean, var): assert not shps_mid_pixel[:5, :].any() -@pytest.mark.parametrize("strides", [{"x": 1, "y": 1}, {"x": 2, "y": 2}]) +@pytest.mark.parametrize("strides", [(1, 1), (2, 2)]) def test_shp_glrt_nodata_0(mean, var, strides): """Ensure""" method = "glrt" @@ -179,13 +179,13 @@ def test_shp_glrt_nodata_0(mean, var, strides): alpha=0.005, method=method, ) - out_col, out_row = 2 // strides["x"], 2 // strides["x"] + out_row, out_col = 2 // strides[0], 2 // strides[1] assert neighbors[:out_row, :out_col, :, :].sum() == 0 @pytest.mark.parametrize("method", ["glrt", "ks"]) @pytest.mark.parametrize("alpha", [0.01, 0.05]) -@pytest.mark.parametrize("strides", [{"x": 1, "y": 1}, {"x": 2, "y": 2}]) +@pytest.mark.parametrize("strides", [(1, 1), (2, 2)]) def test_shp_statistics(method, alpha, strides): """Check that with repeated tries, the alpha is correct.""" From 0cb764212f4d4460e21f78de4c7fb93976f37dd6 Mon Sep 17 00:00:00 2001 From: Scott Staniewicz Date: Thu, 10 Oct 2024 23:01:46 -0400 Subject: [PATCH 3/3] fix strides and tests, add edge for extra data --- src/dolphin/shp/__init__.py | 8 +++----- src/dolphin/shp/_glrt.py | 9 ++++----- src/dolphin/shp/_ks.py | 10 +++------- src/dolphin/workflows/single.py | 29 ++++++++++++++++++++++++----- 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/dolphin/shp/__init__.py b/src/dolphin/shp/__init__.py index 658d55bd..af0e6b6b 100644 --- a/src/dolphin/shp/__init__.py +++ b/src/dolphin/shp/__init__.py @@ -71,11 +71,10 @@ def estimate_neighbors( - `method` not a valid `ShpMethod` """ - import numba - if strides is None: - strides = {"x": 1, "y": 1} - logger.debug(f"NUMBA THREADS: {numba.get_num_threads()}") + strides = (1, 1) + if prune_disconnected: + logger.warning("`prune_disconnected` is deprecated: ignoring") if method == ShpMethod.RECT: # No estimation needed @@ -96,7 +95,6 @@ def estimate_neighbors( strides=strides, nslc=nslc, alpha=alpha, - prune_disconnected=prune_disconnected, ) elif method.lower() == ShpMethod.KS: if amp_stack is None: diff --git a/src/dolphin/shp/_glrt.py b/src/dolphin/shp/_glrt.py index bbfac5f0..546b1f43 100644 --- a/src/dolphin/shp/_glrt.py +++ b/src/dolphin/shp/_glrt.py @@ -8,7 +8,6 @@ from jax import Array, jit, lax, vmap from numpy.typing import ArrayLike -from dolphin._types import HalfWindow from dolphin.utils import compute_out_shape @@ -28,11 +27,11 @@ def _read_cutoff_csv(): return result -@partial(jit, static_argnames=["half_window", "strides", "nslc", "alpha"]) +@partial(jit, static_argnames=["halfwin_rowcol", "strides", "nslc", "alpha"]) def estimate_neighbors( mean: ArrayLike, var: ArrayLike, - half_window: HalfWindow, + halfwin_rowcol: tuple[int, int], nslc: int, strides: tuple[int, int] = (1, 1), alpha: float = 0.001, @@ -40,8 +39,9 @@ def estimate_neighbors( """Estimate the number of neighbors based on the GLRT.""" # Convert mean/var to the Rayleigh scale parameter rows, cols = mean.shape - half_row, half_col = half_window + half_row, half_col = halfwin_rowcol row_strides, col_strides = strides + # window_size = rsize * csize in_r_start = row_strides // 2 in_c_start = col_strides // 2 @@ -68,7 +68,6 @@ def _process_row_col(out_r, out_c): scale_1 = scale_squared[in_r, in_c] # One pixel # and one window for scale 2, will broadcast scale_2 = _get_window(scale_squared, in_r, in_c, half_row, half_col) - # Compute the GLRT test statistic. scale_pooled = (scale_1 + scale_2) / 2 test_stat = 2 * jnp.log(scale_pooled) - jnp.log(scale_1) - jnp.log(scale_2) diff --git a/src/dolphin/shp/_ks.py b/src/dolphin/shp/_ks.py index 5df99d16..297f81c7 100644 --- a/src/dolphin/shp/_ks.py +++ b/src/dolphin/shp/_ks.py @@ -2,7 +2,6 @@ import logging from math import exp, sqrt -from typing import Optional import numba import numpy as np @@ -24,7 +23,7 @@ def estimate_neighbors( amp_stack: ArrayLike, halfwin_rowcol: tuple[int, int], alpha: float, - strides: Optional[dict[str, int]] = None, + strides: tuple[int, int] = (1, 1), is_sorted: bool = False, prune_disconnected: bool = False, ): @@ -37,16 +36,13 @@ def estimate_neighbors( # neighbor_arrays, # ) - if strides is None: - strides = {"x": 1, "y": 1} sorted_amp_stack = amp_stack if is_sorted else np.sort(amp_stack, axis=0) num_slc, rows, cols = sorted_amp_stack.shape ecdf_dist_cutoff = _get_ecdf_critical_distance(num_slc, alpha) logger.debug(f"ecdf_dist_cutoff: {ecdf_dist_cutoff}") - strides_rowcol = strides["y"], strides["x"] - out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides_rowcol)) + out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides)) half_row, half_col = halfwin_rowcol is_shp = np.zeros( (out_rows, out_cols, 2 * half_row + 1, 2 * half_col + 1), dtype=np.bool_ @@ -55,7 +51,7 @@ def estimate_neighbors( _loop_over_neighbors( sorted_amp_stack, halfwin_rowcol, - strides_rowcol, + strides, ecdf_dist_cutoff, prune_disconnected, is_shp, diff --git a/src/dolphin/workflows/single.py b/src/dolphin/workflows/single.py index 42eb380c..890e41b9 100644 --- a/src/dolphin/workflows/single.py +++ b/src/dolphin/workflows/single.py @@ -33,6 +33,7 @@ class OutputFile: dtype: DTypeLike strides: Optional[dict[str, int]] = None nbands: int = 1 + nodata: float = 0 @atomic_output(output_arg="output_folder", is_dir=True) @@ -124,7 +125,12 @@ def run_wrapped_phase_single( ), OutputFile(output_folder / f"shp_counts_{start_end}.tif", np.uint16, strides), OutputFile(output_folder / f"eigenvalues_{start_end}.tif", np.float32, strides), - OutputFile(output_folder / f"estimator_{start_end}.tif", np.int8, strides), + OutputFile( + output_folder / f"estimator_{start_end}.tif", + np.int8, + strides, + nodata=255, + ), OutputFile(output_folder / f"avg_coh_{start_end}.tif", np.uint16, strides), ] for op in output_files: @@ -135,7 +141,7 @@ def run_wrapped_phase_single( dtype=op.dtype, strides=op.strides, nbands=op.nbands, - nodata=0, + nodata=op.nodata, ) # Iterate over the output grid @@ -182,7 +188,7 @@ def run_wrapped_phase_single( neighbor_arrays = shp.estimate_neighbors( halfwin_rowcol=(yhalf, xhalf), alpha=shp_alpha, - strides=strides, + strides=tuple(strides_tup), mean=amp_mean[in_rows, in_cols] if amp_mean is not None else None, var=amp_variance[in_rows, in_cols] if amp_variance is not None else None, nslc=shp_nslc, @@ -265,7 +271,7 @@ def run_wrapped_phase_single( ) # All other outputs are strided (smaller in size) - out_datas = [ + out_datas: list[np.ndarray | None] = [ pl_output.temp_coh, pl_output.shp_counts, pl_output.eigenvalues, @@ -275,8 +281,9 @@ def run_wrapped_phase_single( for data, output_file in zip(out_datas, output_files[1:]): if data is None: # May choose to skip some outputs, e.g. "avg_coh" continue + trimmed_data = data[out_trim_rows, out_trim_cols] writer.queue_write( - data[out_trim_rows, out_trim_cols], + _erode_edge_pixels(trimmed_data, nodata=output_file.nodata), output_file.filename, out_rows.start, out_cols.start, @@ -410,3 +417,15 @@ def setup_output_folder( phase_linked_slc_files.append(output_path) return phase_linked_slc_files + + +def _erode_edge_pixels(arr: np.ndarray, nodata: float, n: int = 1) -> np.ndarray: + from scipy import ndimage + + mask = arr == nodata if not np.isnan(nodata) else np.isnan(arr) + mask_expanded = ndimage.binary_dilation( + mask, structure=np.ones((1 + 2 * n, 1 + 2 * n)) + ) + out = arr.copy() + out[mask_expanded] = nodata + return out