Skip to content

Commit

Permalink
Dev icenet-ai#279: Refactor region processing, optimise code
Browse files Browse the repository at this point in the history
  • Loading branch information
bnubald committed Jul 31, 2024
1 parent 1bd0792 commit 01b566b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 38 deletions.
23 changes: 19 additions & 4 deletions icenet/plotting/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,11 +1673,13 @@ def plot_forecast():

# Define CRS for plotting
reproject = True if args.crs else False
data_crs = ccrs.LambertAzimuthalEqualArea(central_latitude=pole*90, central_longitude=0)
data_crs_proj = ccrs.LambertAzimuthalEqualArea(central_latitude=pole*90, central_longitude=0)
data_crs_geo = ccrs.PlateCarree()

if args.crs:
target_crs = get_crs(args.crs)
else:
target_crs = data_crs
target_crs = data_crs_proj
transform_crs = ccrs.PlateCarree()

region_args = None
Expand All @@ -1688,7 +1690,6 @@ def plot_forecast():
region_args = args.region_geographic
method = "geographic"


## Clip the actual data to the requested region.
## This can cause empty region at the borders if used with different CRS projections
## due to re-projection.
Expand All @@ -1700,7 +1701,18 @@ def plot_forecast():

# Reproject, and process regions if necessary
# TODO: Split this function to separate `reproject` and `process_regions`
fc = process_regions(region_args, [fc], method=method, proj=target_crs, pole=pole, no_clip_region=args.no_clip_region)[0]
if reproject:
projection = target_crs
else:
projection = None

fc = process_regions(region_args,
[fc],
method=method,
target_crs=projection,
pole=pole,
clip_geographic_region=not args.no_clip_region
)[0]

vmax = 1.

Expand Down Expand Up @@ -1731,6 +1743,9 @@ def plot_forecast():
coastlines = "default"
if args.region is not None or args.region_geographic is not None:
extent = (bound_args["x1"], bound_args["x2"], bound_args["y1"], bound_args["y2"])
# Note: Using GSHHS coastlines will slow png/mp4 output quite a bit!
# This is automatically activated when a sub region is specified with the
# '-r' or '-z' region flags.
coastlines = "gshhs"
else:
extent = None
Expand Down
70 changes: 36 additions & 34 deletions icenet/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,15 +596,16 @@ def reproject_projected_coords(data,
def process_regions(region: tuple=None,
data: tuple=None,
method: str = "pixel",
proj=None,
target_crs=None,
pole=1,
no_clip_region=True,
clip_geographic_region=True,
) -> tuple:
"""Extract subset of pan-Arctic/Antarctic region based on region bounds.
:param region: Either image pixel bounds, or geographic bounds.
:param data: Contains the full xarray DataArray.
:param method: Whether providing pixel coordinates or geographic (i.e. lon/lat).
:param clip_geographic_region: Whether to clip the data to the defined lon/lat region bounds.
:return:
"""
Expand All @@ -618,45 +619,46 @@ def process_regions(region: tuple=None,

for idx, arr in enumerate(data):
if arr is not None:
if not proj:
if target_crs is None:
reprojected_data = arr
elif (method == "geographic" and clip_geographic_region):
data[idx] = arr
else:
# Reproject only when target_crs is defined, and when region is bounded by lon/lat without the
# 'clip_geographic_region' flag
logging.info(f"Reprojecting data to specified CRS")
reprojected_data = reproject_projected_coords(arr,
target_crs=proj,
pole=pole,
)
target_crs=target_crs,
pole=pole,
)
data[idx] = reprojected_data


if region is not None:
if method.casefold() == "pixel":
logging.info(f"Clipping data to specified bounds: {region}")
if method.casefold() == "geographic":
if clip_geographic_region:
# Limit to lon/lat region, within a given tolerance
tolerance = 1E-1
# Create condition where data is within geographic (lon/lat) region
condition = (arr.lon >= x1-tolerance) & (arr.lon <= x2+tolerance) & \
(arr.lat >= y1-tolerance) & (arr.lat <= y2+tolerance)

# Extract subset within region using where()
clipped_data = arr.where(condition, drop=True)

# Reproject just the clipped region for speed
data[idx] = reproject_projected_coords(clipped_data,
target_crs=target_crs,
pole=pole,
)
elif method.casefold() == "pixel":
x_max, y_max = reprojected_data.xc.shape[0], reprojected_data.yc.shape[0]
max_x = min(x_max, x2)
max_y = min(y_max, y2)

# Clip the data array
clipped_data = reprojected_data[..., (y_max - y2):(y_max - y1), x1:x2]
elif method.casefold() == "geographic" and not no_clip_region:
arr = reprojected_data

# Limit to lon/lat region, within a given tolerance
tolerance = 1E-1
# Create condition where data is within geographic (lon/lat) region
condition = (arr.lon >= x1-tolerance) & (arr.lon <= x2+tolerance) & \
(arr.lat >= y1-tolerance) & (arr.lat <= y2+tolerance)

# Extract subset within region using where()
clipped_data = arr.where(condition, drop=True)
# clipped_data = reproject_projected_coords(clipped_data,
# target_crs=proj,
# pole=pole,
# )
elif method.casefold() == "geographic" and no_clip_region:
data[idx] = reprojected_data
continue

# Clip the data array to specified pixel region
data[idx] = reprojected_data[..., (y_max - y2):(y_max - y1), x1:x2]
else:
raise NotImplementedError
data[idx] = clipped_data
else:
data[idx] = reprojected_data
raise NotImplementedError("Only method='pixel' or 'geographic' bounds are supported")

return data

Expand Down

0 comments on commit 01b566b

Please sign in to comment.