Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert glrt SHP method to use JAX #444

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/dolphin/shp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,10 @@ def estimate_neighbors(
- `method` not a valid `ShpMethod`

"""
import numba

if strides is None:
strides = {"x": 1, "y": 1}
logger.debug(f"NUMBA THREADS: {numba.get_num_threads()}")
strides = (1, 1)
if prune_disconnected:
logger.warning("`prune_disconnected` is deprecated: ignoring")

if method == ShpMethod.RECT:
# No estimation needed
Expand All @@ -96,7 +95,6 @@ def estimate_neighbors(
strides=strides,
nslc=nslc,
alpha=alpha,
prune_disconnected=prune_disconnected,
)
elif method.lower() == ShpMethod.KS:
if amp_stack is None:
Expand Down
219 changes: 64 additions & 155 deletions src/dolphin/shp/_glrt.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,90 @@
from __future__ import annotations

import csv
from functools import lru_cache
from math import log
from functools import lru_cache, partial
from pathlib import Path
from typing import Optional

import numba
import numpy as np
import jax.numpy as jnp
from jax import Array, jit, lax, vmap
from numpy.typing import ArrayLike

from dolphin._types import Strides
from dolphin.utils import _get_slices, compute_out_shape
from dolphin.utils import compute_out_shape

from ._common import remove_unconnected

_get_slices = numba.njit(_get_slices)
@lru_cache
def _read_cutoff_csv():
filename = Path(__file__).parent / "glrt_cutoffs.csv"

result = {}
with open(filename) as file:
reader = csv.DictReader(file)
for row in reader:
n = int(row["N"])
alpha = float(row["alpha"])
cutoff = float(row["cutoff"])
result[(n, alpha)] = cutoff

return result


@partial(jit, static_argnames=["halfwin_rowcol", "strides", "nslc", "alpha"])
def estimate_neighbors(
mean: ArrayLike,
var: ArrayLike,
halfwin_rowcol: tuple[int, int],
nslc: int,
strides: Optional[dict] = None,
alpha: float = 0.05,
prune_disconnected: bool = False,
strides: tuple[int, int] = (1, 1),
alpha: float = 0.001,
):
"""Estimate the number of neighbors based on the GLRT.
"""Estimate the number of neighbors based on the GLRT."""
# Convert mean/var to the Rayleigh scale parameter
rows, cols = mean.shape
half_row, half_col = halfwin_rowcol
row_strides, col_strides = strides
# window_size = rsize * csize

Based on the method described in [@Parizzi2011AdaptiveInSARStack].
Assumes Rayleigh distributed amplitudes ([@Siddiqui1962ProblemsConnectedRayleigh])
in_r_start = row_strides // 2
in_c_start = col_strides // 2
out_rows, out_cols = compute_out_shape((rows, cols), strides)

Parameters
----------
mean : ArrayLike, 2D
Mean amplitude of each pixel.
var: ArrayLike, 2D
Variance of each pixel's amplitude.
halfwin_rowcol : tuple[int, int]
Half the size of the block in (row, col) dimensions
nslc : int
Number of images in the stack used to compute `mean` and `var`.
Used to compute the degrees of freedom for the t- and F-tests to
determine the critical values.
strides: dict, optional
The (x, y) strides (in pixels) to use for the sliding window.
By default {"x": 1, "y": 1}
alpha : float, default=0.05
Significance level at which to reject the null hypothesis.
Rejecting means declaring a neighbor is not a SHP.
prune_disconnected : bool, default=False
If True, keeps only SHPs that are 8-connected to the current pixel.
Otherwise, any pixel within the window may be considered an SHP, even
if it is not directly connected.


Notes
-----
When `strides` is not (1, 1), the output first two dimensions
are smaller than `mean` and `var` by a factor of `strides`. This
will match the downstream shape of the strided phase linking results.
scale_squared = (var + mean**2) / 2
threshold = get_cutoff_jax(alpha=alpha, N=nslc)

Returns
-------
is_shp : np.ndarray, 4D
Boolean array marking which neighbors are SHPs for each pixel in the block.
Shape is (out_rows, out_cols, window_rows, window_cols), where
`out_rows` and `out_cols` are computed by
`[dolphin.io.compute_out_shape][]`
`window_rows = 2 * halfwin_rowcol[0] + 1`
`window_cols = 2 * halfwin_rowcol[1] + 1`
def _get_window(arr, r: int, c: int, half_row: int, half_col: int) -> Array:
r0 = r - half_row
c0 = c - half_col
start_indices = (r0, c0)

"""
if strides is None:
strides = {"x": 1, "y": 1}
half_row, half_col = halfwin_rowcol
rows, cols = mean.shape
rsize = 2 * half_row + 1
csize = 2 * half_col + 1
slice_sizes = (rsize, csize)

threshold = get_cutoff(alpha=alpha, N=nslc)
return lax.dynamic_slice(arr, start_indices, slice_sizes)

strides_rowcol = (strides["y"], strides["x"])
out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides_rowcol))
is_shp = np.zeros(
(out_rows, out_cols, 2 * half_row + 1, 2 * half_col + 1), dtype=np.bool_
)
return _loop_over_pixels(
mean,
var,
halfwin_rowcol,
strides_rowcol,
threshold,
prune_disconnected,
is_shp,
def _process_row_col(out_r, out_c):
in_r = in_r_start + out_r * row_strides
in_c = in_c_start + out_c * col_strides

scale_1 = scale_squared[in_r, in_c] # One pixel
# and one window for scale 2, will broadcast
scale_2 = _get_window(scale_squared, in_r, in_c, half_row, half_col)
# Compute the GLRT test statistic.
scale_pooled = (scale_1 + scale_2) / 2
test_stat = 2 * jnp.log(scale_pooled) - jnp.log(scale_1) - jnp.log(scale_2)

return threshold > test_stat

# Now make a 2D grid of indices to access all output pixels
out_r_indices, out_c_indices = jnp.meshgrid(
jnp.arange(out_rows), jnp.arange(out_cols), indexing="ij"
)

# Create the vectorized function in 2d
_process_2d = vmap(_process_row_col)
# Then in 3d
_process_3d = vmap(_process_2d)
return _process_3d(out_r_indices, out_c_indices)


def get_cutoff(alpha: float, N: int) -> float:
r"""Compute the upper cutoff for the GLRT test statistic.
Expand All @@ -119,88 +109,7 @@ def get_cutoff(alpha: float, N: int) -> float:

"""
n_alpha_to_cutoff = _read_cutoff_csv()
try:
return n_alpha_to_cutoff[(N, alpha)]
except KeyError as e:
msg = f"Not implemented for {N = }, {alpha = }"
raise NotImplementedError(msg) from e

return n_alpha_to_cutoff[(max(N, 50), alpha)]

@numba.njit(nogil=True)
def _compute_glrt_test_stat(scale_1, scale_2):
"""Compute the GLRT test statistic."""
scale_pooled = (scale_1 + scale_2) / 2
return 2 * log(scale_pooled) - log(scale_1) - log(scale_2)


@numba.njit(nogil=True, parallel=True)
def _loop_over_pixels(
mean: ArrayLike,
var: ArrayLike,
halfwin_rowcol: tuple[int, int],
strides_rowcol: tuple[int, int],
threshold: float,
prune_disconnected: bool,
is_shp: np.ndarray,
) -> np.ndarray:
"""Loop common to SHP tests using only mean and variance."""
half_row, half_col = halfwin_rowcol
row_strides, col_strides = strides_rowcol
# location to start counting from in the larger input
r0, c0 = row_strides // 2, col_strides // 2
in_rows, in_cols = mean.shape
out_rows, out_cols = is_shp.shape[:2]

# Convert mean/var to the Rayleigh scale parameter
scale_squared = (var + mean**2) / 2

for out_r in numba.prange(out_rows):
for out_c in range(out_cols):
in_r = r0 + out_r * row_strides
in_c = c0 + out_c * col_strides

scale_1 = scale_squared[in_r, in_c]
# Clamp the window to the image bounds
(r_start, r_end), (c_start, c_end) = _get_slices(
half_row, half_col, in_r, in_c, in_rows, in_cols
)
if mean[in_r, in_c] == 0:
# Skip nodata pixels
continue

for in_r2 in range(r_start, r_end):
for in_c2 in range(c_start, c_end):
# window offsets for dims 3,4 of `is_shp`
r_off = in_r2 - r_start
c_off = in_c2 - c_start

# Don't count itself as a neighbor
if in_r2 == in_r and in_c2 == in_c:
is_shp[out_r, out_c, r_off, c_off] = False
continue
scale_2 = scale_squared[in_r2, in_c2]

T = _compute_glrt_test_stat(scale_1, scale_2)

is_shp[out_r, out_c, r_off, c_off] = threshold > T
if prune_disconnected:
# For this pixel, prune the groups not connected to the center
remove_unconnected(is_shp[out_r, out_c], inplace=True)

return is_shp


@lru_cache
def _read_cutoff_csv():
filename = Path(__file__).parent / "glrt_cutoffs.csv"

result = {}
with open(filename) as file:
reader = csv.DictReader(file)
for row in reader:
n = int(row["N"])
alpha = float(row["alpha"])
cutoff = float(row["cutoff"])
result[(n, alpha)] = cutoff

return result
get_cutoff_jax = jit(get_cutoff, static_argnames=["alpha", "N"])
10 changes: 3 additions & 7 deletions src/dolphin/shp/_ks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from math import exp, sqrt
from typing import Optional

import numba
import numpy as np
Expand All @@ -24,7 +23,7 @@ def estimate_neighbors(
amp_stack: ArrayLike,
halfwin_rowcol: tuple[int, int],
alpha: float,
strides: Optional[dict[str, int]] = None,
strides: tuple[int, int] = (1, 1),
is_sorted: bool = False,
prune_disconnected: bool = False,
):
Expand All @@ -37,16 +36,13 @@ def estimate_neighbors(
# neighbor_arrays,
# )

if strides is None:
strides = {"x": 1, "y": 1}
sorted_amp_stack = amp_stack if is_sorted else np.sort(amp_stack, axis=0)

num_slc, rows, cols = sorted_amp_stack.shape
ecdf_dist_cutoff = _get_ecdf_critical_distance(num_slc, alpha)
logger.debug(f"ecdf_dist_cutoff: {ecdf_dist_cutoff}")

strides_rowcol = strides["y"], strides["x"]
out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides_rowcol))
out_rows, out_cols = compute_out_shape((rows, cols), Strides(*strides))
half_row, half_col = halfwin_rowcol
is_shp = np.zeros(
(out_rows, out_cols, 2 * half_row + 1, 2 * half_col + 1), dtype=np.bool_
Expand All @@ -55,7 +51,7 @@ def estimate_neighbors(
_loop_over_neighbors(
sorted_amp_stack,
halfwin_rowcol,
strides_rowcol,
strides,
ecdf_dist_cutoff,
prune_disconnected,
is_shp,
Expand Down
4 changes: 3 additions & 1 deletion src/dolphin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ def prepare_geometry(
return stitched_geo_list


def compute_out_shape(shape: tuple[int, int], strides: Strides) -> tuple[int, int]:
def compute_out_shape(
shape: tuple[int, int], strides: Strides | tuple[int, int]
) -> tuple[int, int]:
"""Calculate the output size for an input `shape` and row/col `strides`.

Parameters
Expand Down
Loading
Loading