Skip to content

Commit

Permalink
Dev icenet-ai#279: Enable masking based on lat/lon + metrics based on…
Browse files Browse the repository at this point in the history
… this
  • Loading branch information
bnubald committed Oct 18, 2024
1 parent c7cd989 commit 9d9f615
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 40 deletions.
102 changes: 94 additions & 8 deletions icenet/data/sic/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self,
polarhole_radii: object = POLARHOLE_RADII,
data_shape: object = (432, 432),
dtype: object = np.float32,
longitudes = None,
latitudes = None,
**kwargs):
"""Initialises Masks across specified hemispheres.
Expand All @@ -52,7 +54,10 @@ def __init__(self,
self._polarhole_radii = polarhole_radii
self._dtype = dtype
self._shape = data_shape
self.longitudes = longitudes
self.latitudes = latitudes
self._region = (slice(None, None), slice(None, None))
self._region_geo_mask = None

self.init_params()

Expand Down Expand Up @@ -199,6 +204,36 @@ def generate(self,
logging.info("Saving polarhole {}".format(polarhole_path))
np.save(polarhole_path, polarhole)

def get_region_data(self, data):
"""
Get either a lat/lon region or a pixel bounded region via slicing.
If setting region via lat/lon, coordinates must be passed by calling
`self.set_region_by_lonlat` method first.
"""
if self._region_geo_mask is not None:
if self.longitudes is None or self.latitudes is None:
raise ValueError(f"Call {self.__name__}.set_region_by_lonlat first," +
"to pass in latitude and longitude coordinates.")
array = xr.DataArray(
data,
dims=('yc', 'xc'),
coords={
'yc': self.yc,
'xc': self.xc,
})
array["lon"] = (("yc", "xc"), self.longitudes.data)
array["lat"] = (("yc", "xc"), self.latitudes.data)
array = array.where(self._region_geo_mask.compute(), drop=True).values

# When used as weights for xarray.DataArray.weighted(), it shouldn't have
# nan's in grid (i.e., outside of lat/lon bounds), so set these areas to 0.
array = np.nan_to_num(array)

return array
else:
return data[self._region]

def get_active_cell_mask(self, month: object) -> object:
"""Loads an active grid cell mask from numpy file.
Expand All @@ -221,9 +256,11 @@ def get_active_cell_mask(self, month: object) -> object:
raise RuntimeError("Active cell masks have not been generated, "
"this is not done automatically so you might "
"want to address this!")

# logging.debug("Loading active cell mask {}".format(mask_path))
return np.load(mask_path)[self._region]
data = np.load(mask_path)

return self.get_region_data(data)


def get_active_cell_da(self, src_da: object) -> object:
"""Generate an xarray.DataArray object containing the active cell masks
Expand All @@ -237,18 +274,22 @@ def get_active_cell_da(self, src_da: object) -> object:
An xarray.DataArray containing active cell masks for each time
in source DataArray.
"""
return xr.DataArray(
[
active_cell_mask = [
self.get_active_cell_mask(pd.to_datetime(date).month)
for date in src_da.time.values
],
]

active_cell_mask_da = xr.DataArray(
active_cell_mask,
dims=('time', 'yc', 'xc'),
coords={
'time': src_da.time.values,
'yc': src_da.yc.values,
'xc': src_da.xc.values,
})

return active_cell_mask_da

def get_land_mask(self,
land_mask_filename: str = LAND_MASK_FILENAME) -> object:
"""Generate an xarray.DataArray object containing the active cell masks
Expand All @@ -271,7 +312,8 @@ def get_land_mask(self,
"address this!")

# logging.debug("Loading land mask {}".format(mask_path))
return np.load(mask_path)[self._region]
data = np.load(mask_path)
return self.get_region_data(data)

def get_polarhole_mask(self, date: object) -> object:
"""Get mask of polar hole region.
Expand All @@ -289,7 +331,8 @@ def get_polarhole_mask(self, date: object) -> object:
self.get_data_var_folder("masks"),
"polarhole{}_mask.npy".format(i + 1))
# logging.debug("Loading polarhole {}".format(polarhole_path))
return np.load(polarhole_path)[self._region]
data = np.load(polarhole_path)
return self.get_region_data(data)
return None

def get_blank_mask(self) -> object:
Expand All @@ -300,7 +343,49 @@ def get_blank_mask(self) -> object:
of shape `self._shape` (the `data_shape` instance initialisation
value).
"""
return np.full(self._shape, False)[self._region]
data = np.full(self._shape, False)
return self.get_region_data(data)

def set_region_by_lonlat(self, xc, yc, lon, lat, region):
"""
Sets the region based on longitude and latitude bounds by converting
them into index slices based on the provided xarray DataArray.
Alternative to __getitem__ if not using slicing.
Args:
xc:
yc:
lon:
lat:
region: lat/lon region bounds to get masks for, [lon_min, lat_min, lon_max, lat_max]
src_da: An xarray.DataArray that contains longitude and latitude coordinates.
"""
self.xc = xc
self.yc = yc
self.longitudes = lon
self.latitudes = lat
self.region_geographic = region

lon_min, lat_min, lon_max, lat_max = region

lon_mask = (lon >= lon_min) & (lon <= lon_max)
lat_mask = (lat >= lat_min) & (lat <= lat_max)

lat_lon_mask = lat_mask & lon_mask

rows, cols = np.where(lat_lon_mask)

row_min, row_max = rows.min(), rows.max()
col_min, col_max = cols.min(), cols.max()

# Specify min/max latlon bounds via slicing
# (sideffect of slicing rectangular area aligned with `xc, yc`,
# instead of actual lon/lat)
# self._region = (slice(row_min, row_max+1), slice(col_min, col_max+1))

# Instead, when this is set, masks based on actual lon/lat bounds.
self._region_geo_mask = lat_lon_mask

def __getitem__(self, item):
"""Sets slice of region wanted for masking, and allows method chaining.
Expand All @@ -318,6 +403,7 @@ def reset_region(self):
"""Resets the mask region and logs a message indicating that the whole mask will be returned."""
logging.info("Mask region reset, whole mask will be returned")
self._region = (slice(None, None), slice(None, None))
self._region_geo_mask = None


def main():
Expand Down
42 changes: 30 additions & 12 deletions icenet/plotting/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ def compute_metrics_leadtime_avg(metric: str,
:return: pandas dataframe with columns 'date', 'leadtime' and the metric name.
"""
pole = 1 if hemisphere == "north" else -1

# open forecast file
fc_ds = xr.open_dataset(forecast_file)

Expand Down Expand Up @@ -600,14 +602,15 @@ def compute_metrics_leadtime_avg(metric: str,
if region is not None:
seas, fc, obs, masks = process_subregion(region,
[seas, fc, obs, masks],
region_definition="pixel"
pole,
region_definition="pixel",
)
elif region_geographic is not None:
raise NotImplementedError("Computing this metric with lon/lat region "
"bounds has not been implemented yet.")
seas, fc, obs, masks = process_subregion(region_geographic,
[seas, fc, obs, masks],
region_definition="geographic"
pole,
src_da=obs,
region_definition="geographic",
)

# compute metrics
Expand Down Expand Up @@ -923,6 +926,7 @@ def plot_metrics_leadtime_avg(metric: str,
data_path=data_path,
bias_correct=bias_correct,
region=region,
region_geographic=region_geographic,
**kwargs)

fc_metric_df = metric_df[metric_df["forecast_name"] == "IceNet"]
Expand Down Expand Up @@ -1486,6 +1490,8 @@ def binary_accuracy():
ap = (ForecastPlotArgParser().allow_ecmwf().allow_threshold())
args = ap.parse_args()

pole = 1 if args.hemisphere == "north" else -1

masks = Masks(north=args.hemisphere == "north",
south=args.hemisphere == "south")

Expand Down Expand Up @@ -1513,13 +1519,14 @@ def binary_accuracy():
if args.region is not None:
seas, fc, obs, masks = process_subregion(args.region,
[seas, fc, obs, masks],
pole,
region_definition="pixel"
)
elif args.region_geographic is not None:
raise NotImplementedError("Computing this metric with lon/lat region "
"bounds has not been implemented yet.")
seas, fc, obs, masks = process_subregion(args.region_geographic,
[seas, fc, obs, masks],
pole,
src_da=obs,
region_definition="geographic"
)

Expand All @@ -1539,6 +1546,8 @@ def sie_error():
ap = (ForecastPlotArgParser().allow_ecmwf().allow_threshold().allow_sie())
args = ap.parse_args()

pole = 1 if args.hemisphere == "north" else -1

masks = Masks(north=args.hemisphere == "north",
south=args.hemisphere == "south")

Expand Down Expand Up @@ -1566,13 +1575,14 @@ def sie_error():
if args.region is not None:
seas, fc, obs, masks = process_subregion(args.region,
[seas, fc, obs, masks],
pole,
region_definition="pixel"
)
elif args.region_geographic is not None:
raise NotImplementedError("Computing this metric with lon/lat region "
"bounds has not been implemented yet.")
seas, fc, obs, masks = process_subregion(args.region_geographic,
[seas, fc, obs, masks],
pole,
src_da=obs,
region_definition="geographic"
)

Expand Down Expand Up @@ -1706,6 +1716,7 @@ def plot_forecast():
# Clip the actual data to the requested region if necessary
fc = process_subregion(region_args,
[fc],
pole,
region_definition=region_definition,
)[0]

Expand Down Expand Up @@ -1880,6 +1891,8 @@ def metric_plots():
ap = (ForecastPlotArgParser().allow_ecmwf().allow_metrics())
args = ap.parse_args()

pole = 1 if args.hemisphere == "north" else -1

masks = Masks(north=args.hemisphere == "north",
south=args.hemisphere == "south")

Expand Down Expand Up @@ -1909,13 +1922,14 @@ def metric_plots():
if args.region is not None:
seas, fc, obs, masks = process_subregion(args.region,
[seas, fc, obs, masks],
pole,
region_definition="pixel"
)
elif args.region_geographic is not None:
raise NotImplementedError("Computing this metric with lon/lat region "
"bounds has not been implemented yet.")
seas, fc, obs, masks = process_subregion(args.region_geographic,
[seas, fc, obs, masks],
pole,
src_da=obs,
region_definition="geographic"
)

Expand Down Expand Up @@ -1981,6 +1995,7 @@ def leadtime_avg_plots():
target_date_avg=args.target_date_average,
bias_correct=args.bias_correct,
region=args.region,
region_geographic=args.region_geographic,
threshold=args.threshold,
grid_area_size=args.grid_area)

Expand All @@ -1992,6 +2007,8 @@ def sic_error():
ap = ForecastPlotArgParser()
args = ap.parse_args()

pole = 1 if args.hemisphere == "north" else -1

masks = Masks(north=args.hemisphere == "north",
south=args.hemisphere == "south")

Expand All @@ -2006,13 +2023,14 @@ def sic_error():
if args.region is not None:
fc, obs, masks = process_subregion(args.region,
[fc, obs, masks],
pole,
region_definition="pixel"
)
elif args.region_geographic is not None:
raise NotImplementedError("Computing this metric with lon/lat region "
"bounds has not been implemented yet.")
fc, obs, masks = process_subregion(args.region_geographic,
[fc, obs, masks],
pole,
src_da=obs,
region_definition="geographic"
)

Expand Down
Loading

0 comments on commit 9d9f615

Please sign in to comment.