diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..77b3943 --- /dev/null +++ b/environment.yml @@ -0,0 +1,12 @@ +name: icenet +channels: + - conda-forge + - nodefaults +dependencies: + - eccodes + - ffmpeg + - netcdf4<1.6.1 + - pip + - python=3.11 + - pip: + - -e . diff --git a/icenet/data/sic/mask.py b/icenet/data/sic/mask.py index f73e71e..2ab6285 100644 --- a/icenet/data/sic/mask.py +++ b/icenet/data/sic/mask.py @@ -37,6 +37,8 @@ def __init__(self, polarhole_radii: object = POLARHOLE_RADII, data_shape: object = (432, 432), dtype: object = np.float32, + longitudes = None, + latitudes = None, **kwargs): """Initialises Masks across specified hemispheres. @@ -52,7 +54,10 @@ def __init__(self, self._polarhole_radii = polarhole_radii self._dtype = dtype self._shape = data_shape + self.longitudes = longitudes + self.latitudes = latitudes self._region = (slice(None, None), slice(None, None)) + self._region_geo_mask = None self.init_params() @@ -199,6 +204,36 @@ def generate(self, logging.info("Saving polarhole {}".format(polarhole_path)) np.save(polarhole_path, polarhole) + def get_region_data(self, data): + """ + Get either a lat/lon region or a pixel bounded region via slicing. + + If setting region via lat/lon, coordinates must be passed by calling + `self.set_region_by_lonlat` method first. + """ + if self._region_geo_mask is not None: + if self.longitudes is None or self.latitudes is None: + raise ValueError(f"Call {self.__name__}.set_region_by_lonlat first," + + "to pass in latitude and longitude coordinates.") + array = xr.DataArray( + data, + dims=('yc', 'xc'), + coords={ + 'yc': self.yc, + 'xc': self.xc, + }) + array["lon"] = (("yc", "xc"), self.longitudes.data) + array["lat"] = (("yc", "xc"), self.latitudes.data) + array = array.where(self._region_geo_mask.compute(), drop=True).values + + # When used as weights for xarray.DataArray.weighted(), it shouldn't have + # nan's in grid (i.e., outside of lat/lon bounds), so set these areas to 0. + array = np.nan_to_num(array) + + return array + else: + return data[self._region] + def get_active_cell_mask(self, month: object) -> object: """Loads an active grid cell mask from numpy file. @@ -221,9 +256,11 @@ def get_active_cell_mask(self, month: object) -> object: raise RuntimeError("Active cell masks have not been generated, " "this is not done automatically so you might " "want to address this!") - # logging.debug("Loading active cell mask {}".format(mask_path)) - return np.load(mask_path)[self._region] + data = np.load(mask_path) + + return self.get_region_data(data) + def get_active_cell_da(self, src_da: object) -> object: """Generate an xarray.DataArray object containing the active cell masks @@ -237,11 +274,13 @@ def get_active_cell_da(self, src_da: object) -> object: An xarray.DataArray containing active cell masks for each time in source DataArray. """ - return xr.DataArray( - [ + active_cell_mask = [ self.get_active_cell_mask(pd.to_datetime(date).month) for date in src_da.time.values - ], + ] + + active_cell_mask_da = xr.DataArray( + active_cell_mask, dims=('time', 'yc', 'xc'), coords={ 'time': src_da.time.values, @@ -249,6 +288,8 @@ def get_active_cell_da(self, src_da: object) -> object: 'xc': src_da.xc.values, }) + return active_cell_mask_da + def get_land_mask(self, land_mask_filename: str = LAND_MASK_FILENAME) -> object: """Generate an xarray.DataArray object containing the active cell masks @@ -271,7 +312,8 @@ def get_land_mask(self, "address this!") # logging.debug("Loading land mask {}".format(mask_path)) - return np.load(mask_path)[self._region] + data = np.load(mask_path) + return self.get_region_data(data) def get_polarhole_mask(self, date: object) -> object: """Get mask of polar hole region. @@ -289,7 +331,8 @@ def get_polarhole_mask(self, date: object) -> object: self.get_data_var_folder("masks"), "polarhole{}_mask.npy".format(i + 1)) # logging.debug("Loading polarhole {}".format(polarhole_path)) - return np.load(polarhole_path)[self._region] + data = np.load(polarhole_path) + return self.get_region_data(data) return None def get_blank_mask(self) -> object: @@ -300,7 +343,49 @@ def get_blank_mask(self) -> object: of shape `self._shape` (the `data_shape` instance initialisation value). """ - return np.full(self._shape, False)[self._region] + data = np.full(self._shape, False) + return self.get_region_data(data) + + def set_region_by_lonlat(self, xc, yc, lon, lat, region): + """ + Sets the region based on longitude and latitude bounds by converting + them into index slices based on the provided xarray DataArray. + + Alternative to __getitem__ if not using slicing. + + Args: + xc: + yc: + lon: + lat: + region: lat/lon region bounds to get masks for, [lon_min, lat_min, lon_max, lat_max] + src_da: An xarray.DataArray that contains longitude and latitude coordinates. + """ + self.xc = xc + self.yc = yc + self.longitudes = lon + self.latitudes = lat + self.region_geographic = region + + lon_min, lat_min, lon_max, lat_max = region + + lon_mask = (lon >= lon_min) & (lon <= lon_max) + lat_mask = (lat >= lat_min) & (lat <= lat_max) + + lat_lon_mask = lat_mask & lon_mask + + rows, cols = np.where(lat_lon_mask) + + row_min, row_max = rows.min(), rows.max() + col_min, col_max = cols.min(), cols.max() + + # Specify min/max latlon bounds via slicing + # (sideffect of slicing rectangular area aligned with `xc, yc`, + # instead of actual lon/lat) + # self._region = (slice(row_min, row_max+1), slice(col_min, col_max+1)) + + # Instead, when this is set, masks based on actual lon/lat bounds. + self._region_geo_mask = lat_lon_mask def __getitem__(self, item): """Sets slice of region wanted for masking, and allows method chaining. @@ -318,6 +403,7 @@ def reset_region(self): """Resets the mask region and logs a message indicating that the whole mask will be returned.""" logging.info("Mask region reset, whole mask will be returned") self._region = (slice(None, None), slice(None, None)) + self._region_geo_mask = None def main(): diff --git a/icenet/plotting/forecast.py b/icenet/plotting/forecast.py index 16b273f..c7f1fea 100644 --- a/icenet/plotting/forecast.py +++ b/icenet/plotting/forecast.py @@ -6,12 +6,16 @@ from datetime import timedelta +import cartopy.crs as ccrs +import cartopy.feature as cfeature import matplotlib as mpl import matplotlib.cm as cm import matplotlib.pyplot as plt import matplotlib.dates as mdates +import matplotlib.ticker as mticker from matplotlib.animation import FuncAnimation from matplotlib.backends.backend_pdf import PdfPages +from mpl_toolkits.axes_grid1 import make_axes_locatable import seaborn as sns @@ -23,11 +27,14 @@ from icenet import __version__ as icenet_version from icenet.data.cli import date_arg from icenet.data.sic.mask import Masks -from icenet.plotting.utils import (filter_ds_by_obs, get_forecast_ds, +from icenet.plotting.utils import (calculate_extents, filter_ds_by_obs, + get_forecast_ds, get_obs_da, get_seas_forecast_da, - get_seas_forecast_init_dates, show_img, - get_plot_axes, process_probes, - process_regions) + get_seas_forecast_init_dates, + geographic_box, show_img, + get_plot_axes, set_plot_geoaxes, process_probes, + process_region, get_custom_cmap, + get_crs) from icenet.plotting.video import xarray_to_video @@ -521,9 +528,10 @@ def compute_metrics_leadtime_avg(metric: str, data_path: str, bias_correct: bool = False, region: tuple = None, + region_geographic: tuple = None, **kwargs) -> object: """ - Given forecast file, for each initialisation date in the xarrray.DataArray + Given forecast file, for each initialisation date in the xarray.DataArray we compute the metric for each leadtime and store the results in a pandas dataframe with columns 'date' (specifying the initialisation date), 'leadtime' and the metric name. This pandas dataframe can then be used @@ -542,13 +550,17 @@ def compute_metrics_leadtime_avg(metric: str, :param bias_correct: bool to indicate whether or not to perform a bias correction on SEAS forecast, by default False. Ignored if ecmwf=False - :param region: region to zoom in to + :param region: region to zoom in to defined by pixel coordinates + :param region_geographic: region to zoom in to defined by geographic, i.e., + (lon_min, lat_min, lon_max, lat_max) coordinates :param kwargs: any keyword arguments that are required for the computation of the metric, e.g. 'threshold' for SIE error and binary accuracy metrics, or 'grid_area_size' for SIE error metric :return: pandas dataframe with columns 'date', 'leadtime' and the metric name. """ + pole = 1 if hemisphere == "north" else -1 + # open forecast file fc_ds = xr.open_dataset(forecast_file) @@ -588,8 +600,18 @@ def compute_metrics_leadtime_avg(metric: str, seas = None if region is not None: - seas, fc, obs, masks = process_regions(region, - [seas, fc, obs, masks]) + seas, fc, obs, masks = process_region(region, + [seas, fc, obs, masks], + pole, + region_definition="pixel", + ) + elif region_geographic is not None: + seas, fc, obs, masks = process_region(region_geographic, + [seas, fc, obs, masks], + pole, + src_da=obs, + region_definition="geographic", + ) # compute metrics fc_metrics_list.append( @@ -817,6 +839,7 @@ def plot_metrics_leadtime_avg(metric: str, target_date_avg: bool = False, bias_correct: bool = False, region: tuple = None, + region_geographic: tuple = None, **kwargs) -> object: """ Plots leadtime averaged metrics either using all the forecasts @@ -847,7 +870,9 @@ def plot_metrics_leadtime_avg(metric: str, :param bias_correct: bool to indicate whether or not to perform a bias correction on SEAS forecast, by default False. Ignored if ecmwf=False - :param region: region to zoom in to + :param region: region to zoom in to defined by pixel coordinates + :param region_geographic: region to zoom in to defined by geographic, i.e., + (lon_min, lat_min, lon_max, lat_max) coordinates :param kwargs: any keyword arguments that are required for the computation of the metric, e.g. 'threshold' for SIE error and binary accuracy metrics, or 'grid_area_size' for SIE error metric @@ -901,6 +926,7 @@ def plot_metrics_leadtime_avg(metric: str, data_path=data_path, bias_correct=bias_correct, region=region, + region_geographic=region_geographic, **kwargs) fc_metric_df = metric_df[metric_df["forecast_name"] == "IceNet"] @@ -1384,6 +1410,11 @@ def __init__(self, *args, forecast_date: bool = True, **kwargs): default=None, type=region_arg, help="Region specified x1, y1, x2, y2") + self.add_argument("-z", + "--region-geographic", + default=None, + type=region_arg, + help="Geographic region specified as lon and lat min/max: lon_min, lat_min, lon_max, lat_max") def allow_ecmwf(self): self.add_argument("-b", @@ -1459,6 +1490,8 @@ def binary_accuracy(): ap = (ForecastPlotArgParser().allow_ecmwf().allow_threshold()) args = ap.parse_args() + pole = 1 if args.hemisphere == "north" else -1 + masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") @@ -1483,9 +1516,20 @@ def binary_accuracy(): else: seas = None - if args.region: - seas, fc, obs, masks = process_regions(args.region, - [seas, fc, obs, masks]) + if args.region is not None: + seas, fc, obs, masks = process_region(args.region, + [seas, fc, obs, masks], + pole, + region_definition="pixel" + ) + elif args.region_geographic is not None: + seas, fc, obs, masks = process_region(args.region_geographic, + [seas, fc, obs, masks], + pole, + src_da=obs, + region_definition="geographic" + ) + plot_binary_accuracy(masks=masks, fc_da=fc, @@ -1502,6 +1546,8 @@ def sie_error(): ap = (ForecastPlotArgParser().allow_ecmwf().allow_threshold().allow_sie()) args = ap.parse_args() + pole = 1 if args.hemisphere == "north" else -1 + masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") @@ -1526,9 +1572,19 @@ def sie_error(): else: seas = None - if args.region: - seas, fc, obs, masks = process_regions(args.region, - [seas, fc, obs, masks]) + if args.region is not None: + seas, fc, obs, masks = process_region(args.region, + [seas, fc, obs, masks], + pole, + region_definition="pixel" + ) + elif args.region_geographic is not None: + seas, fc, obs, masks = process_region(args.region_geographic, + [seas, fc, obs, masks], + pole, + src_da=obs, + region_definition="geographic" + ) plot_sea_ice_extent_error(masks=masks, fc_da=fc, @@ -1565,6 +1621,11 @@ def plot_forecast(): help="Format to output in", choices=("mp4", "png", "svg", "tiff"), default="png") + ap.add_argument("-g", + "--gridlines", + help="Turn on gridlines for plots", + action="store_true", + default=False) ap.add_argument("-n", "--cmap-name", help="Color map name if not wanting to use default", @@ -1574,6 +1635,17 @@ def plot_forecast(): help="Plot the standard deviation from the ensemble", action="store_true", default=False) + ap.add_argument("--crs", + default=None, + help="Coordinate Reference System to use for plotting") + ap.add_argument("--clip-region", + action="store_true", + default=False, + help="Whether to clip the data to the region specified by lon/lat,"\ + " When enabled, this crops forecast plot to the bounds, can cause"\ + " empty pixels across image edges due to lon/lat curvature."\ + " Default is False." + ) args = ap.parse_args() fc = get_forecast_ds(args.forecast_file, @@ -1594,7 +1666,13 @@ def plot_forecast(): os.path.splitext(os.path.basename(args.forecast_file))[0], args.forecast_date) - cmap_name = "BuPu_r" if args.stddev else "Blues_r" + if args.stddev: + cmap_name = "BuPu_r" + colorbar_label = "Sea-ice concentration fraction (standard deviation)" + else: + cmap_name = "Blues_r" + colorbar_label = "Sea-ice concentration fraction" + if args.cmap_name is not None: cmap_name = args.cmap_name @@ -1602,8 +1680,49 @@ def plot_forecast(): cmap = plt.get_cmap(cmap_name) cmap.set_bad("dimgrey") + pole = 1 if args.hemisphere == "north" else -1 + + # Define Coordinate Reference System (CRS) for plotting if reprojecting + reproject = True if args.crs else False + # `xc` and `yc` CRS + data_crs_proj = ccrs.LambertAzimuthalEqualArea(central_latitude=pole*90, central_longitude=0) + # `lon` and `lat` CRS + data_crs_geo = ccrs.PlateCarree() + + # Whether to reproject, and if so, what CRS to reproject to + if args.crs: + target_crs = get_crs(args.crs) + else: + target_crs = data_crs_proj + + # Whether subregion is defined via bounds using pixel coords or lon/lat + region_args = None + region_definition = "pixel" + bound_args = dict(north=args.hemisphere == "north", + south=args.hemisphere == "south") if args.region is not None: - fc = process_regions(args.region, [fc])[0] + region_args = args.region + bound_args.update(x1=args.region[0], + x2=args.region[2], + y1=args.region[1], + y2=args.region[3]) + elif args.region_geographic is not None: + region_args = args.region_geographic + region_definition = "geographic" + bound_args.update(x1=args.region_geographic[0], + x2=args.region_geographic[2], + y1=args.region_geographic[1], + y2=args.region_geographic[3]) + + # Clip the actual data to the requested region if necessary + if not args.clip_region: + region_args = None + + fc = process_region(region_args, + [fc], + pole, + region_definition=region_definition, + )[0] vmax = 1. @@ -1615,6 +1734,15 @@ def plot_forecast(): if args.leadtimes is not None \ else list(range(1, int(max(fc.leadtime.values)) + 1)) + if args.region is not None or args.region_geographic is not None: + extent = (bound_args["x1"], bound_args["x2"], bound_args["y1"], bound_args["y2"]) + else: + extent = None + + coastlines = not args.no_coastlines + + custom_cmap = get_custom_cmap(cmap) + if args.format == "mp4": pred_da = fc.isel(time=0).sel(leadtime=leadtimes) @@ -1629,10 +1757,7 @@ def plot_forecast(): pred_da = pred_da.drop("time").drop("leadtime").\ rename(leadtime="time", forecast_date="time").set_index(time="time") - anim_args = dict(figsize=5) - if not args.no_coastlines: - logging.warning("Coastlines will not work with the current " - "implementation of xarray_to_video") + anim_args = dict(figsize=(10, 8)) output_filename = os.path.join( output_path, @@ -1640,41 +1765,100 @@ def plot_forecast(): args.forecast_date.strftime("%Y%m%d"), "" if not args.stddev else "stddev.", args.format)) + xarray_to_video(pred_da, fps=1, cmap=cmap, - imshow_kwargs=dict(vmin=0., vmax=vmax) - if not args.stddev else None, + imshow_kwargs={}, video_path=output_filename, + reproject=reproject, + extent=extent, + region_definition=region_definition, + coastlines=coastlines, + gridlines=args.gridlines, + target_crs=target_crs, + transform_crs=data_crs_geo, + north=bound_args["north"], + south=bound_args["south"], + clim=(0, vmax), + colorbar_label=colorbar_label, **anim_args) else: - for leadtime in leadtimes: + fig, ax = get_plot_axes(**bound_args, + geoaxes=True, + target_crs=target_crs, + ) + ax = set_plot_geoaxes(ax, + extent=extent, + region_definition=region_definition, + coastlines=coastlines, + gridlines=args.gridlines, + north=bound_args["north"], + south=bound_args["south"], + ) + # Convert from km to m + fc = fc.assign_coords(xc=fc.xc.data * 1000, yc=fc.yc.data * 1000) + + cbar = None + for i, leadtime in enumerate(leadtimes): pred_da = fc.sel(leadtime=leadtime).isel(time=0) - bound_args = dict(north=args.hemisphere == "north", - south=args.hemisphere == "south") - - if args.region is not None: - bound_args.update(x1=args.region[0], - x2=args.region[2], - y1=args.region[1], - y2=args.region[3]) - - ax = get_plot_axes(**bound_args, - do_coastlines=not args.no_coastlines) - bound_args.update(cmap=cmap) + im = pred_da.plot.pcolormesh("xc", + "yc", + ax=ax, + transform=data_crs_proj, + vmin=0, + vmax=vmax, + add_colorbar=False, + cmap=custom_cmap, + # shading="gouraud", + # rasterized=True, + ) + + if args.region_geographic: + # Special case, when using geographic (lon/lat) region clipping + # Highlights sub-region being plotted in a reference plot of the globe + stored_extent = ax.get_extent() + + # Output a reference image showing bounds of clipped region + if i == 0: + box_lon, box_lat = geographic_box((bound_args["x1"], bound_args["x2"]), (bound_args["y1"], bound_args["y2"]), segments=10) + + region_plot = ax.plot(box_lon, box_lat, transform=data_crs_geo, color="red", zorder=999) + ax.set_global() + + output_filename = os.path.join( + output_path, "{}.{}_reference.{}{}".format( + forecast_name, + (args.forecast_date + dt.timedelta(days=leadtime)).strftime("%Y%m%d"), + "" if not args.stddev else "stddev.", "jpg")) + plt.savefig(output_filename, dpi=600) + + for handle in region_plot: + handle.remove() + + extent = [bound_args["x1"], bound_args["x2"], bound_args["y1"], bound_args["y2"]] + # With some projections like Mercator, it doesn't like having exact boundary longitude + if bound_args["x1"] == -180: + extent[0] = -179.99 + logging.debug("Forecast plot extent:", extent) + ax.set_extent(extent, crs=data_crs_geo) + + if not cbar: + divider = make_axes_locatable(ax) + # Pass axes_class to set correct colourbar height with cartopy + cax = divider.append_axes("right", size="5%", pad=0.05, axes_class=plt.Axes) + cbar = plt.colorbar(im, ax=ax, cax=cax) + if colorbar_label: + cbar.set_label(colorbar_label) - im = show_img(ax, - pred_da, - **bound_args, - vmax=vmax, - do_coastlines=not args.no_coastlines) - - plt.colorbar(im, ax=ax) plot_date = args.forecast_date + dt.timedelta(leadtime) ax.set_title("{:04d}/{:02d}/{:02d}".format(plot_date.year, plot_date.month, - plot_date.day)) + plot_date.day), + fontsize="large", + ) + plt.subplots_adjust(right=0.9) output_filename = os.path.join( output_path, "{}.{}.{}{}".format( forecast_name, @@ -1684,7 +1868,9 @@ def plot_forecast(): logging.info("Saving to {}".format(output_filename)) plt.savefig(output_filename) - plt.clf() + im.remove() + + plt.close() def parse_metrics_arg(argument: str) -> object: @@ -1707,6 +1893,8 @@ def metric_plots(): ap = (ForecastPlotArgParser().allow_ecmwf().allow_metrics()) args = ap.parse_args() + pole = 1 if args.hemisphere == "north" else -1 + masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") @@ -1733,9 +1921,19 @@ def metric_plots(): else: seas = None - if args.region: - seas, fc, obs, masks = process_regions(args.region, - [seas, fc, obs, masks]) + if args.region is not None: + seas, fc, obs, masks = process_region(args.region, + [seas, fc, obs, masks], + pole, + region_definition="pixel" + ) + elif args.region_geographic is not None: + seas, fc, obs, masks = process_region(args.region_geographic, + [seas, fc, obs, masks], + pole, + src_da=obs, + region_definition="geographic" + ) plot_metrics(metrics=metrics, masks=masks, @@ -1799,6 +1997,7 @@ def leadtime_avg_plots(): target_date_avg=args.target_date_average, bias_correct=args.bias_correct, region=args.region, + region_geographic=args.region_geographic, threshold=args.threshold, grid_area_size=args.grid_area) @@ -1810,6 +2009,8 @@ def sic_error(): ap = ForecastPlotArgParser() args = ap.parse_args() + pole = 1 if args.hemisphere == "north" else -1 + masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") @@ -1821,8 +2022,19 @@ def sic_error(): timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, args.forecast_date) - if args.region: - fc, obs, masks = process_regions(args.region, [fc, obs, masks]) + if args.region is not None: + fc, obs, masks = process_region(args.region, + [fc, obs, masks], + pole, + region_definition="pixel" + ) + elif args.region_geographic is not None: + fc, obs, masks = process_region(args.region_geographic, + [fc, obs, masks], + pole, + src_da=obs, + region_definition="geographic" + ) sic_error_video(fc_da=fc, obs_da=obs, diff --git a/icenet/plotting/utils.py b/icenet/plotting/utils.py index f08a2fc..7f463b5 100644 --- a/icenet/plotting/utils.py +++ b/icenet/plotting/utils.py @@ -5,13 +5,26 @@ import re import cartopy.crs as ccrs +import cartopy.feature as cfeature +import dask.array as da +import matplotlib as mpl +import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np import pandas as pd +import rioxarray import xarray as xr +from cartopy.feature import ShapelyFeature, NaturalEarthFeature +from cartopy.feature import AdaptiveScaler +from functools import cache from ibicus.debias import LinearScaling +from matplotlib.path import Path +from pyproj import CRS, Transformer +from rasterio.enums import Resampling +from shapely.geometry import Polygon +from icenet.data.sic.mask import Masks def broadcast_forecast(start_date: object, end_date: object, @@ -295,6 +308,34 @@ def get_obs_da( return obs_ds.ice_conc +def get_crs(crs_str: str): + """Get Coordinate Reference System (CRS) from string input argument + + Args: + crs_str: A CRS given as EPSG code (e.g. `EPSG:3347` for North Canada) + or, a pre-defined Cartopy CRS call (e.g. "PlateCarree") + """ + if crs_str.casefold().startswith("epsg"): + crs = ccrs.epsg(int(crs_str.split(":")[1])) + elif crs_str == "Mercator.GOOGLE": + crs = ccrs.Mercator.GOOGLE + else: + try: + crs = getattr(ccrs, crs_str)() + except AttributeError: + get_crs_options = [crs_option for crs_option in dir(ccrs) + if isinstance(getattr(ccrs, crs_option), type) + and issubclass(getattr(ccrs, crs_option), ccrs.CRS) + ] + ["Mercator.GOOGLE"] + get_crs_options.sort() + get_crs_options = ", ".join(get_crs_options) + raise AttributeError("Unsupported CRS defined, supported options are:",\ + f"{get_crs_options}" + ) + + return crs + + def calculate_extents(x1: int, x2: int, y1: int, y2: int): """ @@ -317,39 +358,178 @@ def calculate_extents(x1: int, x2: int, y1: int, y2: int): return extents +def pixel_to_projection(pixel_x_min, pixel_x_max, + pixel_y_min, pixel_y_max, + x_min_proj: float=-5387500, x_max_proj: float=5387500, + y_min_proj: float=-5387500, y_max_proj: float=5387500, + image_width: int=432, image_height: int=432, + ): + """Converts pixel coordinates to CRS projection coordinates""" + proj_x_min = (pixel_x_min / image_width ) * (x_max_proj - x_min_proj) + x_min_proj + proj_x_max = (pixel_x_max / image_width ) * (x_max_proj - x_min_proj) + x_min_proj + proj_y_min = (pixel_y_min / image_height) * (y_max_proj - y_min_proj) + y_min_proj + proj_y_max = (pixel_y_max / image_height) * (y_max_proj - y_min_proj) + y_min_proj + + return proj_x_min, proj_x_max, proj_y_min, proj_y_max + + +def get_bounds(proj=None, pole=1): + """Get min/max bounds for a given CRS projection""" + if proj is None or isinstance(proj, ccrs.LambertAzimuthalEqualArea): + proj = ccrs.LambertAzimuthalEqualArea(0, pole * 90) + x_min_proj, x_max_proj = [-5387500, 5387500] + y_min_proj, y_max_proj = [-5387500, 5387500] + else: + x_min_proj, x_max_proj = proj.x_limits + y_min_proj, y_max_proj = proj.y_limits + logging.debug(f"Projection bounds: {proj.x_limits}, {proj.y_limits}") + return proj, x_min_proj, x_max_proj, y_min_proj, y_max_proj + + def get_plot_axes(x1: int = 0, x2: int = 432, y1: int = 0, y2: int = 432, - do_coastlines: bool = True, north: bool = True, - south: bool = False): + south: bool = False, + geoaxes: bool = True, + target_crs: object = None, + figsize: int = (10, 8), + dpi: int = 150, + ): """ :param x1: :param x2: :param y1: :param y2: - :param do_coastlines: - :param north: - :param south: + :param geoaxes: :return: """ - assert north ^ south, "One hemisphere only must be selected" + assert north ^ south, "Only one hemisphere must be selected" - fig = plt.figure(figsize=(10, 8), dpi=150, layout='tight') + fig = plt.figure(figsize=figsize, dpi=dpi, layout="tight") - if do_coastlines: + if geoaxes: + # pole = 1 if north else -1 + # target_crs, x_min_proj, x_max_proj, y_min_proj, y_max_proj = get_bounds(target_crs, pole) pole = 1 if north else -1 - proj = ccrs.LambertAzimuthalEqualArea(0, pole * 90) + proj = ccrs.LambertAzimuthalEqualArea(central_longitude=0, central_latitude=pole*90) if target_crs is None else target_crs + ax = fig.add_subplot(1, 1, 1, projection=proj) - extents = calculate_extents(x1, x2, y1, y2) - ax.set_extent(extents, crs=proj) else: ax = fig.add_subplot(1, 1, 1) + return fig, ax + + +def set_plot_geoaxes(ax, + region_definition: str = None, + extent: list = None, + coastlines: str = None, + gridlines: bool = False, + north: bool = True, + south: bool = False, + ): + plt.tight_layout(pad=4.0) + + # Set colour for areas outside of `process_region()` - i.e., no data here. + ax.set_facecolor("dimgrey") + + pole = 1 if north else -1 + proj = ccrs.LambertAzimuthalEqualArea(0, pole * 90) + + if extent: + if region_definition == "pixel": + extents = calculate_extents(*extent) + ax.set_extent(extents, crs=proj) + elif region_definition == "geographic": + lon_min, lon_max, lat_min, lat_max = extent + # With some projections like Mercator, it doesn't like having exact boundary longitude + if lon_min == -180: + lon_min = -179.99 + elif lon_max == 180: + lon_max = 179.99 + ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) + clipping_polygon = Polygon(get_geoextent_polygon(extent)) + path = Path(np.array(clipping_polygon.exterior.coords)) + + if coastlines: + auto_scaler = AdaptiveScaler("110m", (("50m", 150), ("10m", 50))) + land = NaturalEarthFeature("physical", "land", scale="10m", facecolor="dimgrey") + if extent and region_definition == "geographic": + clipped_land = ShapelyFeature([clipping_polygon.intersection(geom) + for geom in land.geometries()], + ccrs.PlateCarree(), facecolor="dimgrey") + ax.add_feature(clipped_land) + # Draw coastlines explicitly within the clipping region + ax.add_geometries([clipping_polygon], ccrs.PlateCarree(), edgecolor="red", facecolor="none", linewidth=0.75, linestyle="dashed", zorder=100) + else: + ax.add_feature(land) + + # Add OSMnx GeoDataFrame of coastlines + #gdf = ox.features_from_place("Antarctica", tags={"natural": "coastline"}) + #gdf.plot(ax=ax, facecolor='none', edgecolor='black', linewidth=0.5) + ax.coastlines(resolution=auto_scaler) + + if gridlines: + gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True) + + # Prevent generating labels beneath the colourbar + gl.top_labels = False + gl.right_labels = False + return ax +def get_geoextent_polygon(extent, crs=ccrs.PlateCarree(), n_points=100): + """Create a high-resolution polygon for the boundary. + + Increase the number of points to approximate the curved edges + Define the number of interpolation points for the curves + """ + lon_min, lon_max, lat_min, lat_max = extent + + # Create arrays for the curved sections + lon_values_bottom = np.linspace(lon_min, lon_max, n_points) + lat_values_left = np.linspace(lat_min, lat_max, n_points) + + # Create a polygon by defining more points along the edges + polygon = [] + + # Bottom edge (lat_min) + for lon in lon_values_bottom: + polygon.append([lon, lat_min]) + + # Right edge (lon_max) + for lat in lat_values_left: + polygon.append([lon_max, lat]) + + # Top edge (lat_max) + for lon in lon_values_bottom[::-1]: + polygon.append([lon, lat_max]) + + # Left edge (lon_min) + for lat in lat_values_left[::-1]: + polygon.append([lon_min, lat]) + + return polygon + +def set_plot_geoextent(ax, extent, crs=ccrs.PlateCarree(), n_points=100): + """Create a high-resolution polygon for the boundary + """ + ax.set_extent(extent, crs=crs) + + # Create polygon and convert it to a matplotlib Path + polygon = Path(get_geoextent_polygon(extent), crs=crs, n_points=n_points) + + # Show polygon patch in plot + patch = patches.PathPatch(polygon, facecolor='orange', lw=2, transform=ccrs.PlateCarree()) + #ax.add_patch(patch) + + # Sets custom boundary, buggy with small lat/lon bounds + # Coastlines, land, and gridlines spill outside of boundary + ax.set_boundary(polygon, transform=ccrs.PlateCarree()) + def show_img(ax, arr, @@ -358,11 +538,14 @@ def show_img(ax, y1: int = 0, y2: int = 432, cmap: object = None, - do_coastlines: bool = True, + geoaxes: bool = True, vmin: float = 0., vmax: float = 1., north: bool = True, - south: bool = False): + south: bool = False, + crs: object = None, + extents: list = None + ): """ :param ax: @@ -372,7 +555,7 @@ def show_img(ax, :param y1: :param y2: :param cmap: - :param do_coastlines: + :param geoaxes: :param vmin: :param vmax: :param north: @@ -382,7 +565,7 @@ def show_img(ax, assert north ^ south, "One hemisphere only must be selected" - if do_coastlines: + if geoaxes: pole = 1 if north else -1 data_crs = ccrs.LambertAzimuthalEqualArea(0, pole * 90) extents = calculate_extents(x1, x2, y1, y2) @@ -422,20 +605,243 @@ def process_probes(probes, data) -> tuple: return data -def process_regions(region: tuple, data: tuple) -> tuple: +def reproject_array(array, target_crs): + return array.rio.reproject(target_crs.proj4_init, + # resampling=Resampling.bilinear, + nodata=np.nan + ) + +def process_block(block, target_crs): + # dataarray = xr.DataArray(block, dims=["leadtime", "y", "x"]) + dataarray = block + reprojected = reproject_array(dataarray, target_crs) + return reprojected.drop_vars(["time"]) + + +def reproject_projected_coords(data: object, + target_crs: object, + pole: int=1, + ) -> object: + """ + Reprojects an xarray Dataset from LambertAzimuthalEqualArea to `target_crs`. + + The Dataset is expected to have dims of (xc, yc). + + Args: + data: xarray dataset with dims (xc, yc), and also coords of lon and lat. + target_crs: Cartopy CRS to project to (e.g. `ccrs.Mercator()`) + pole: Whether north (`1`) or south pole (`-1`). + + Returns: + Reprojected data as an xarray dataset. + + Examples: + + >>> reprojected_data = reproject_projected_coords(arr, # doctest: +SKIP + >>> target_crs=target_crs, + >>> pole=pole, + >>> ) """ + # Eastings/Northings projection + data_crs_proj = ccrs.LambertAzimuthalEqualArea(0, pole*90) + # geographic projection + data_crs_geo = ccrs.PlateCarree() + + data_reproject = data.copy() + data_reproject = data_reproject.assign_coords({"xc": data_reproject.xc.data*1000, + "yc": data_reproject.yc.data*1000 + }) + + # Need to use correctly scaled xc and yc to get coastlines working even if not reprojecting. + # So, just return scaled DataArray back and not reproject if don't need to. + if target_crs == data_crs_proj: + return data_reproject + + data_reproject = data_reproject.drop_vars(["Lambert_Azimuthal_Grid", "lon", "lat"]) + + # Set xc, yc (eastings and northings) projection details + data_reproject = data_reproject.rename({"xc": "x", "yc": "y"}) + data_reproject.rio.write_crs(data_crs_proj.proj4_init, inplace=True) + data_reproject.rio.write_nodata(np.nan, inplace=True) + + times = len(data_reproject.time) + leadtimes = len(data_reproject.leadtime) + + # Create a sample image block for use as template for Dask + sample_block = data_reproject.isel(time=0, leadtime=0) + sample_reprojected = reproject_array(sample_block, target_crs) + + # Create a template DataArray based on the reprojected sample block + template_shape = (data_reproject.sizes['leadtime'], sample_reprojected.sizes['y'], sample_reprojected.sizes['x']) + template_data = da.zeros(template_shape, chunks=(1, -1, -1)) + template = xr.DataArray(template_data, dims=['leadtime', 'y', 'x'], + coords={'leadtime': data_reproject.coords['leadtime'], + 'y': sample_reprojected.coords['y'], + 'x': sample_reprojected.coords['x'], + } + ) + + reprojected_data = [] + for time in range(times): + leadtime_data = xr.map_blocks(process_block, data_reproject.isel(time=time), template=template, kwargs={"target_crs": target_crs}) + reprojected_data.append(leadtime_data) + + # TODO: Add projection info into DataArray, like the `Lambert_Azimuthal_Grid` dropped above + reprojected_data = xr.concat(reprojected_data, dim="time") + reprojected_data.coords["time"] = data_reproject.time.data + + # Set attributes + reprojected_data.rio.write_crs(target_crs.proj4_init, inplace=True) + reprojected_data.rio.write_nodata(np.nan, inplace=True) + + # Compute geographic for reprojected image + transformer = Transformer.from_crs(target_crs.proj4_init, data_crs_geo.proj4_init) + x = reprojected_data.x.values + y = reprojected_data.y.values + + X, Y = np.meshgrid(x, y) + lon_grid, lat_grid = transformer.transform(X, Y) + + reprojected_data["lon"] = (("y", "x"), lon_grid) + reprojected_data["lat"] = (("y", "x"), lat_grid) + + # Rename back to 'xc' and 'yc', although, these are now in metres rather than 1000 metres + reprojected_data = reprojected_data.rename({"x": "xc", "y": "yc"}) + + return reprojected_data + + +def projection_to_geographic_coords(data, target_crs): + # Compute geographic for reprojected image + transform_crs=ccrs.PlateCarree() + transformer = Transformer.from_crs(target_crs.proj4_init, transform_crs.proj4_init) + x = data.xc.values*1000 + y = data.yc.values*1000 + + X, Y = np.meshgrid(x, y) + lon_grid, lat_grid = transformer.transform(X, Y) + + data["lon"] = (("yc", "xc"), lon_grid) + data["lat"] = (("yc", "xc"), lat_grid) + + return data + + +def process_region(region: tuple=None, + data: tuple=None, + pole: int=1, + src_da: object=None, + region_definition: str = "pixel", + ) -> tuple: + """Extract subset of pan-Arctic/Antarctic region based on region bounds. - :param region: - :param data: + :param region: Either image pixel bounds, or geographic bounds. + :param data: Contains list of xarray DataArrays. + :param region_definition: Whether providing pixel coordinates or geographic (i.e. lon/lat). :return: """ - assert len(region) == 4, "Region needs to be a list of four integers" - x1, y1, x2, y2 = region - assert x2 > x1 and y2 > y1, "Region is not valid" + if region is not None: + assert len(region) == 4, "Region needs to be a list of four integers" + x1, y1, x2, y2 = region + assert x2 > x1 and y2 > y1, "Region is not valid" + if region_definition == "geographic": + assert x1 >= -180 and x2 <= 180, "Expect longitude range to be `-180<=longitude>=180`" for idx, arr in enumerate(data): - if arr is not None: - data[idx] = arr[..., (432 - y2):(432 - y1), x1:x2] + if arr is not None and region is not None: + logging.debug(f"Clipping data to specified bounds: {region}") + # Case when not an array, but an IceNet Masks class + if isinstance(arr, Masks): + if region_definition.casefold() == "geographic": + masks = arr + xc, yc = src_da.xc, src_da.yc + lon, lat = src_da.lon, src_da.lat + # Edge cases, where the time dimension is passed in, + # seems to be with "./data/osisaf/north/siconca/2020.nc" + # and, possibly newer. + if "time" in lon.dims: + lon = lon.isel(time=0) + if "time" in lat.dims: + lat = lat.isel(time=0) + masks.set_region_by_lonlat(xc, yc, lon,lat, region) + data[idx] = masks + elif region_definition.casefold() == "pixel": + data[idx] = arr[..., (432 - y2):(432 - y1), x1:x2] + else: + # If array only contains "xc" and "yc", but not "lon" and "lat". + # Reproject using pyproj to get it. + if "lon" not in arr.coords and "lat" not in arr.coords: + target_crs = ccrs.LambertAzimuthalEqualArea(0, pole*90) + arr = projection_to_geographic_coords(arr, target_crs) + + lon, lat = arr.lon, arr.lat + + if region_definition.casefold() == "geographic": + # Limit to lon/lat region, within a given tolerance + tolerance = 0 + # Create mask where data is within geographic (lon/lat) region + mask = (lon >= x1-tolerance) & (lon <= x2+tolerance) & \ + (lat >= y1-tolerance) & (lat <= y2+tolerance) + + # Extract subset within region using where() + data[idx] = arr.where(mask.compute(), drop=True) + elif region_definition.casefold() == "pixel": + x_max, y_max = arr.xc.shape[0], arr.yc.shape[0] + + # Clip the data array to specified pixel region + data[idx] = arr[..., (y_max - y2):(y_max - y1), x1:x2] + else: + raise NotImplementedError("Only region_definition='pixel' or 'geographic' bounds are supported") + return data + + +@cache +def geographic_box(lon_bounds: np.array, lat_bounds: np.array, segments: int=1): + """Rectangular boundary coordinates in lon/lat coordinates. + + Args: + lon_bounds: (min, max) lon values + lat_bounds: (min, max) lat values + segments: Number of segments per edge + + Returns: + (lats, lons) for rectangular boundary region + """ + + segments += 1 + rectangular_sides = 4 + + lons = np.empty((segments*rectangular_sides)) + lats = np.empty((segments*rectangular_sides)) + + bounds = [ + [0, 0], + [-1, 0], + [-1, -1], + [0, -1], + ] + + for i, (lat_min, lat_max) in enumerate(bounds): + lats[i*segments:(i+1)*segments] = np.linspace(lat_bounds[lat_min], lat_bounds[lat_max], num=segments) + + bounds.reverse() + + for i, (lon_min, lon_max) in enumerate(bounds): + lons[i*segments:(i+1)*segments] = np.linspace(lon_bounds[lon_min], lon_bounds[lon_max], num=segments) + + return lons, lats + +def get_custom_cmap(cmap): + """Creates a new colormap, but with nan set to <0. + + Hack since cartopy needs transparency for nan regions to wraparound + correctly with pcolormesh. + """ + colors = cmap(np.linspace(0, 1, cmap.N)) + custom_cmap = mpl.colors.ListedColormap(colors) + custom_cmap.set_bad("dimgrey", alpha=0) + custom_cmap.set_under("dimgrey") + return custom_cmap diff --git a/icenet/plotting/video.py b/icenet/plotting/video.py index 7df86e7..342dbf9 100644 --- a/icenet/plotting/video.py +++ b/icenet/plotting/video.py @@ -6,6 +6,8 @@ from concurrent.futures import as_completed, ProcessPoolExecutor +import cartopy.crs as ccrs +import cartopy.feature as cfeature import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -16,7 +18,7 @@ from icenet.process.predict import get_refcube from icenet.utils import setup_logging - +from icenet.plotting.utils import get_plot_axes, set_plot_geoaxes, set_plot_geoextent, get_custom_cmap # TODO: This can be a plotting or analysis util function elsewhere def get_dataarray_from_files(files: object, numpy: bool = False) -> object: @@ -72,14 +74,23 @@ def xarray_to_video( da: object, fps: int, video_path: object = None, + reproject: bool = False, + north: bool = True, + south: bool = False, + extent: tuple = None, + region_definition: str = "pixel", + coastlines: str = "default", + gridlines: bool = False, + target_crs: object = None, + transform_crs: object = None, mask: object = None, mask_type: str = 'contour', clim: object = None, crop: object = None, data_type: str = 'abs', video_dates: object = None, - cmap: object = "viridis", - figsize: int = 12, + cmap: object = plt.get_cmap("viridis"), + figsize: tuple = (10, 8), dpi: int = 150, imshow_kwargs: dict = None, ax_init: object = None, @@ -112,10 +123,21 @@ def xarray_to_video( :param ax_init: pre-initialised axes object for display :param ax_extra: Extra method called with axes for additional plotting """ + assert north ^ south, "Only one hemisphere must be selected" + pole = 1 if north else -1 + + target_crs = ccrs.LambertAzimuthalEqualArea(central_latitude=pole*90, central_longitude=0) if target_crs is None else target_crs + transform_crs = ccrs.PlateCarree() if transform_crs is None else transform_crs + + # Hack since cartopy needs transparency for nan regions to wraparound + # correctly with pcolormesh, set nan areas as under range. + if reproject: + da = da.where(~np.isnan(da), -9999, drop=False) def update(date): logging.debug("Plotting {}".format(date.strftime("%D"))) - image.set_data(da.sel(time=date)) + data = da.sel(time=date) + image.set_array(data) image_title.set_text("{:04d}/{:02d}/{:02d}".format( date.year, date.month, date.day)) @@ -154,51 +176,99 @@ def update(date): logging.info("Initialising plot") if ax_init is None: - fig, ax = plt.subplots(figsize=(figsize, figsize)) - fig.set_dpi(dpi) + fig, ax = get_plot_axes( + geoaxes=True, + north=north, + south=south, + target_crs=target_crs, + figsize=figsize, + dpi=dpi, + ) + ax = set_plot_geoaxes(ax, + region_definition=region_definition, + extent=extent, + coastlines=coastlines, + gridlines=gridlines, + north=north, + south=south, + ) else: ax = ax_init fig = ax.get_figure() - if mask is not None: - if mask_type == 'contour': - ax.contour(mask, levels=[.5, 1], colors='k', zorder=3) - elif mask_type == 'contourf': - ax.contourf(mask, levels=[.5, 1], colors='k', zorder=3) - ax.axes.xaxis.set_visible(False) ax.axes.yaxis.set_visible(False) if ax_extra is not None: ax_extra(ax) + #if extent and region_definition == "geographic": + # # ax.set_extent(extent, crs=transform_crs) + # set_plot_geoextent(ax, extent) + date = pd.Timestamp(da.time.values[0]).to_pydatetime() - image = ax.imshow(da.sel(time=date), - cmap=cmap, - clim=(n_min, n_max), - animated=True, - zorder=1, - **imshow_kwargs if imshow_kwargs is not None else {}) + + data = da.sel(time=date) + + if mask is not None: + if mask_type == 'contour': + image = ax.contour(data.xc.data, data.yc.data, mask, + levels=[.5, 1], + colors='k', + transform=target_crs, + zorder=3, + ) + elif mask_type == 'contourf': + image = ax.contourf(data.xc.data, data.yc.data, mask, + levels=[.5, 1], + colors='k', + transform=target_crs, + zorder=3, + ) + + # TODO: Tidy up, and cover all argument options + # Hack since cartopy needs transparency for nan regions to wraparound + # correctly with pcolormesh. + custom_cmap = get_custom_cmap(cmap) + + image = data.plot.pcolormesh("lon", + "lat", + ax=ax, + transform=transform_crs, + animated=True, + zorder=1, + add_colorbar=False, + cmap=custom_cmap, + vmin=n_min, + vmax=n_max, + **imshow_kwargs if imshow_kwargs is not None else {} + ) image_title = ax.set_title("{:04d}/{:02d}/{:02d}".format( date.year, date.month, date.day), - fontsize="medium", + fontsize="large", zorder=2) try: divider = make_axes_locatable(ax) - cax = divider.append_axes('right', size='5%', pad=0.05, zorder=2) - cbar = plt.colorbar(image, cax) + cax = divider.append_axes("right", size="5%", pad=0.05, zorder=2, axes_class=plt.Axes) + cbar = plt.colorbar(image, ax=ax, cax=cax) if colorbar_label: cbar.set_label(colorbar_label) - fig.subplots_adjust(right=0.85) + plt.subplots_adjust(right=0.9) except KeyError as ex: logging.warning("Could not configure locatable colorbar: {}".format(ex)) logging.info("Animating") # Investigated blitting, but it causes a few problems with masks/titles. - animation = FuncAnimation(fig, update, video_dates, interval=1000 / fps) + animation = FuncAnimation(fig, + func=update, + frames=video_dates, + interval=1000 / fps, + repeat=False, + blit=True, + ) plt.close() diff --git a/icenet/process/predict.py b/icenet/process/predict.py index 985f0c3..43b4b12 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -60,13 +60,24 @@ def get_refcube(north: bool = True, south: bool = False) -> object: return cube -def get_prediction_data(root: object, name: object, date: object) -> tuple: +def get_prediction_data(root: str, name: str, date: dt, + return_ensemble_data: bool = False) -> tuple: """ - - :param root: - :param name: - :param date: - :return: + Get prediction data from ensemble numpy files for specified date. + + Args: + root: Root directory path to pipeline results. + name: Name of the prediction. + date: Forecast date to get prediction data for. + return_ensemble_data (optional): Whether to also return full ensemble data + array, or just the mean. Defaults to False. + + Returns: + tuple: + - If `return_ensemble_data` is True: + Returns (data_mean, full_data_ensemble, number_of_ensemble_members) + - If `return_ensemble_data` is False: + Returns (data_mean, number_of_ensemble_members) """ logging.info("Post-processing {}".format(date)) @@ -85,8 +96,13 @@ def get_prediction_data(root: object, name: object, date: object) -> tuple: logging.debug("Data read from disk: {} from: {}".format( data.shape, np_files)) - return np.stack([data.mean(axis=0), data.std(axis=0)], - axis=-1).squeeze(), ens_members + data_mean = np.stack([data.mean(axis=0), data.std(axis=0)], + axis=-1).squeeze() + + if return_ensemble_data: + return data_mean, data, ens_members + else: + return data_mean, ens_members def date_arg(string: str) -> object: