Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step to combine previous amplitudes/amplitude dispersions for cleaner PS/SHP estimation #179

Merged
merged 7 commits into from
Oct 29, 2024
Merged
3 changes: 3 additions & 0 deletions src/disp_s1/_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def create_layover_shadow_masks(
f = files[0]
input_name = format_nc_filename(f, ds_name="/data/layover_shadow_mask")
out_file = output_path / f"layover_shadow_{burst_id}.tif"
if out_file.exists():
output_files.append(out_file)
continue

logger.info(f"Extracting layover shadow mask from {f} to {out_file}")
layover_data = load_gdal(input_name)
Expand Down
265 changes: 265 additions & 0 deletions src/disp_s1/_ps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
from __future__ import annotations

import logging
import multiprocessing as mp
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor
from enum import Enum
from pathlib import Path

import dolphin.ps
import numpy as np
from dolphin import io, masking
from dolphin._types import PathOrStr
from dolphin.utils import DummyProcessPoolExecutor
from dolphin.workflows._utils import _create_burst_cfg, _remove_dir_if_empty
from dolphin.workflows.config import DisplacementWorkflow
from dolphin.workflows.wrapped_phase import _get_mask
from opera_utils import group_by_burst

logger = logging.getLogger(__name__)


class WeightScheme(str, Enum):
"""Methods for weighted combination of old and current amplitudes."""

LINEAR = "linear"
EQUAL = "equal"
EXPONENTIAL = "exponential"


def precompute_ps(cfg: DisplacementWorkflow) -> tuple[list[Path], list[Path]]:
try:
grouped_slc_files = group_by_burst(cfg.cslc_file_list)
except ValueError as e:
# Make sure it's not some other ValueError
if "Could not parse burst id" not in str(e):
raise
# Otherwise, we have SLC files which are not OPERA burst files
grouped_slc_files = {"": cfg.cslc_file_list}

if cfg.amplitude_dispersion_files:
grouped_amp_dispersion_files = group_by_burst(cfg.amplitude_dispersion_files)
else:
grouped_amp_dispersion_files = defaultdict(list)
if cfg.amplitude_mean_files:
grouped_amp_mean_files = group_by_burst(cfg.amplitude_mean_files)
else:
grouped_amp_mean_files = defaultdict(list)
if cfg.layover_shadow_mask_files:
grouped_layover_shadow_mask_files = group_by_burst(
cfg.layover_shadow_mask_files
)
else:
grouped_layover_shadow_mask_files = defaultdict(list)

# ######################################
# 1. Burst-wise Wrapped phase estimation
# ######################################
if len(grouped_slc_files) > 1:
logger.info(f"Found SLC files from {len(grouped_slc_files)} bursts")
wrapped_phase_cfgs = [
(
burst, # Include the burst for logging purposes
_create_burst_cfg(
cfg,
burst,
grouped_slc_files,
grouped_amp_mean_files,
grouped_amp_dispersion_files,
grouped_layover_shadow_mask_files,
),
)
for burst in grouped_slc_files
]
for _, burst_cfg in wrapped_phase_cfgs:
burst_cfg.create_dir_tree()
# Remove the mid-level directories which will be empty due to re-grouping
_remove_dir_if_empty(cfg.phase_linking._directory)
_remove_dir_if_empty(cfg.ps_options._directory)

else:
# grab the only key (either a burst, or "") and use that
cfg.create_dir_tree()
b = next(iter(grouped_slc_files.keys()))
wrapped_phase_cfgs = [(b, cfg)]

combined_dispersion_files: list[Path] = []
combined_mean_files: list[Path] = []
num_parallel = min(cfg.worker_settings.n_parallel_bursts, len(grouped_slc_files))
Executor = ProcessPoolExecutor if num_parallel > 1 else DummyProcessPoolExecutor
mw = cfg.worker_settings.n_parallel_bursts
ctx = mp.get_context("spawn")
with Executor(
max_workers=mw,
mp_context=ctx,
) as exc:
fut_to_burst = {
exc.submit(run_burst_ps, burst_cfg): burst
for i, (burst, burst_cfg) in enumerate(wrapped_phase_cfgs)
}
for fut in fut_to_burst:
fut_to_burst[fut]

combined_dispersion_file, combined_mean_file = fut.result()
combined_dispersion_files.append(combined_dispersion_file)
combined_mean_files.append(combined_mean_file)

return combined_dispersion_files, combined_mean_files


def run_burst_ps(cfg: DisplacementWorkflow) -> tuple[Path, Path]:
input_file_list = cfg.cslc_file_list
if not input_file_list:
msg = "No input files found"
raise ValueError(msg)

subdataset = cfg.input_options.subdataset
# Mark any files beginning with "compressed" as compressed
is_compressed = ["compressed" in str(f).lower() for f in input_file_list]

non_compressed_slcs = [
f for f, is_comp in zip(input_file_list, is_compressed) if not is_comp
]
vrt_stack = io.VRTStack(
non_compressed_slcs,
subdataset=subdataset,
outfile=cfg.work_directory / "non_compressed_slc_stack.vrt",
)

layover_shadow_mask = (
cfg.layover_shadow_mask_files[0] if cfg.layover_shadow_mask_files else None
)
mask_filename = _get_mask(
output_dir=cfg.work_directory,
output_bounds=cfg.output_options.bounds,
output_bounds_wkt=cfg.output_options.bounds_wkt,
output_bounds_epsg=cfg.output_options.bounds_epsg,
like_filename=vrt_stack.outfile,
layover_shadow_mask=layover_shadow_mask,
cslc_file_list=non_compressed_slcs,
)
nodata_mask = masking.load_mask_as_numpy(mask_filename) if mask_filename else None

output_file_list = [
cfg.ps_options._output_file,
cfg.ps_options._amp_mean_file,
cfg.ps_options._amp_dispersion_file,
]
ps_output = cfg.ps_options._output_file
if not all(f.exists() for f in output_file_list):
logger.info(f"Creating persistent scatterer file {ps_output}")
# dispersions: np.ndarray, means: np.ndarray, N: ArrayLike | Sequence
dolphin.ps.create_ps(
reader=vrt_stack,
output_file=output_file_list[0],
output_amp_mean_file=output_file_list[1],
output_amp_dispersion_file=output_file_list[2],
like_filename=vrt_stack.outfile,
amp_dispersion_threshold=cfg.ps_options.amp_dispersion_threshold,
nodata_mask=nodata_mask,
block_shape=cfg.worker_settings.block_shape,
)
# Remove the actual PS mask, since we're going to redo after combining
cfg.ps_options._output_file.unlink()

compressed_slc_files = [
f for f, is_comp in zip(input_file_list, is_compressed) if is_comp
]
logger.info(f"Combining existing means/dispersions from {compressed_slc_files}")
return run_combine(
cfg.ps_options._amp_mean_file,
cfg.ps_options._amp_dispersion_file,
compressed_slc_files,
num_slc=len(non_compressed_slcs),
subdataset=subdataset,
)


def run_combine(
cur_mean: Path,
cur_dispersion: Path,
compressed_slc_files: list[PathOrStr],
num_slc: int,
weight_scheme: WeightScheme = WeightScheme.EXPONENTIAL,
subdataset: str = "/data/VV",
) -> tuple[Path, Path]:
out_dispersion = cur_dispersion.parent / "combined_dispersion.tif"
out_mean = cur_mean.parent / "combined_mean.tif"
if out_dispersion.exists() and out_mean.exists():
logger.info(f"{out_mean} and {out_dispersion} exist, skipping")
return out_dispersion, out_mean

reader_compslc = io.HDF5StackReader.from_file_list(
file_list=compressed_slc_files,
dset_names=subdataset,
nodata=np.nan,
)
reader_compslc_dispersion = io.HDF5StackReader.from_file_list(
file_list=compressed_slc_files,
dset_names="/data/amplitude_dispersion",
nodata=np.nan,
)
reader_mean = io.RasterReader.from_file(cur_mean, band=1)
reader_dispersion = io.RasterReader.from_file(cur_dispersion, band=1)

num_images = 1 + len(compressed_slc_files)
if weight_scheme == WeightScheme.LINEAR:
# Increase the weights from older to newer.
N = np.linspace(0, 1, num=num_images) * num_slc
elif weight_scheme == WeightScheme.EQUAL:
# Increase the weights from older to newer.
N = num_slc * np.ones((num_images,))
elif weight_scheme == WeightScheme.EXPONENTIAL:
alpha = 0.5
weights = np.exp(alpha * np.arange(num_images))
weights /= weights.max()
N = weights.round().astype(int)
else:
raise ValueError(f"Unrecognized {weight_scheme = }")

def read_and_combine(
readers: Sequence[io.StackReader], rows: slice, cols: slice
) -> tuple[np.ndarray, slice, slice]:
reader_compslc, reader_compslc_dispersion, reader_mean, reader_dispersion = (
readers
)
compslc_mean = np.abs(reader_compslc[:, rows, cols])
if compslc_mean.ndim == 2:
compslc_mean = compslc_mean[np.newaxis]
compslc_dispersion = reader_compslc_dispersion[:, rows, cols]
if compslc_dispersion.ndim == 2:
compslc_dispersion = compslc_dispersion[np.newaxis]

mean = reader_mean[rows, cols][np.newaxis]
dispersion = reader_dispersion[rows, cols][np.newaxis]

# Fit a line to each pixel with weighted least squares
dispersions = np.vstack([compslc_dispersion, dispersion])

means = np.vstack([compslc_mean, mean])
new_dispersion, new_mean = dolphin.ps.combine_amplitude_dispersions(
dispersions=dispersions,
means=means,
N=N,
)
return (
np.stack([np.nan_to_num(new_dispersion), np.nan_to_num(new_mean)]),
rows,
cols,
)

out_paths = (out_dispersion, out_mean)
readers = reader_compslc, reader_compslc_dispersion, reader_mean, reader_dispersion
writer = io.BackgroundStackWriter(out_paths, like_filename=cur_mean, debug=True)
io.process_blocks(
readers=readers,
writer=writer,
func=read_and_combine,
block_shape=(256, 256),
num_threads=1,
)

writer.notify_finished()
return out_paths
20 changes: 18 additions & 2 deletions src/disp_s1/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from disp_s1 import __version__, product
from disp_s1._masking import create_layover_shadow_masks, create_mask_from_distance
from disp_s1._ps import precompute_ps
from disp_s1.pge_runconfig import RunConfig

from ._reference import ReferencePoint, read_reference_point
Expand All @@ -33,7 +34,7 @@ def run(
cfg: DisplacementWorkflow,
pge_runconfig: RunConfig,
debug: bool = False,
):
) -> None:
"""Run the displacement workflow on a stack of SLCs.

Parameters
Expand Down Expand Up @@ -67,6 +68,18 @@ def run(
)
cfg.layover_shadow_mask_files = layover_binary_mask_files

if any("compressed" in f.name.lower() for f in cfg.cslc_file_list):
# If we are passed Compressed SLCs, combine the old amplitudes with the
# current real SLCs for a better estimate of amplitude dispersion for PS/SHPs
logger.info("Combining old amplitudes with current SLCs")
combined_dispersion_files, combined_mean_files = precompute_ps(cfg=cfg)
cfg.amplitude_dispersion_files = combined_dispersion_files
cfg.amplitude_mean_files = combined_mean_files
else:
# This is the first ministack: The amplitude estimation will be weak.
# Drop the PS threshold to a conservate number to avoid false positives
cfg.ps_options.amp_dispersion_threshold = 0.15

# Run dolphin's displacement workflow
out_paths = run_displacement(cfg=cfg, debug=debug)

Expand All @@ -81,8 +94,11 @@ def run(
ref_point = read_reference_point(out_paths.timeseries_paths[0].parent)

# Find the geometry files, if created
los_east_file: Path | None
los_north_file: Path | None
try:
los_east_file = next(cfg.work_directory.rglob("los_east.tif"))
assert los_east_file is not None
los_north_file = los_east_file.parent / "los_north.tif"
except StopIteration:
los_east_file = los_north_file = None
Expand Down Expand Up @@ -194,7 +210,7 @@ def run(


def _assert_dates_match(
disp_date_keys: set[datetime], test_paths: Iterable[Path], name: str
disp_date_keys: set[tuple[datetime, ...]], test_paths: Iterable[Path], name: str
) -> None:
"""Assert that the dates in `paths_to_check` match the reference dates.

Expand Down
Loading