From ecc2ee0d003691008610660471c33cffc5ff1c11 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 20:42:27 +0100 Subject: [PATCH 01/61] Fixes #185: using -po for parallel-opens, this will be better address via download-toolbox in 0.3.0 --- icenet/data/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/data/cli.py b/icenet/data/cli.py index 279a77c2..64b2c09a 100644 --- a/icenet/data/cli.py +++ b/icenet/data/cli.py @@ -167,7 +167,7 @@ def process_args(dates: bool = True, ap.add_argument("-l", "--lag", type=int, default=2) ap.add_argument("-f", "--forecast", type=int, default=93) - ap.add_argument("-p", "--parallel-opens", + ap.add_argument("-po", "--parallel-opens", default=False, action="store_true", help="Allow xarray mfdataset to work with parallel opens") From 63489c32ead0a65166c0d98f83c27226a6c317db Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 21:05:41 +0100 Subject: [PATCH 02/61] Closes #187: additional arguments for masking --- icenet/process/predict.py | 42 ++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/icenet/process/predict.py b/icenet/process/predict.py index 5b92bc72..f98eef6d 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -115,6 +115,14 @@ def get_args(): ap.add_argument("datefile", type=argparse.FileType("r")) ap.add_argument("-m", "--mask", default=False, action="store_true") + + ap.add_argument("--nan", help="Apply nans, not zeroes, to land mask", + default=False, action="store_true") + ap.add_argument("--no-acgm", help="No active grid cell masking", + default=True, action="store_false", dest="acgm") + ap.add_argument("--no-land", help="No land while masking", + default=True, action="store_false", dest="land") + ap.add_argument("-o", "--output-dir", default=".") ap.add_argument("-r", "--root", type=str, default=".") @@ -152,26 +160,28 @@ def create_cf_output(): if args.mask: mask_gen = Masks(north=ds.north, south=ds.south) - logging.info("Land masking the forecast output") - land_mask = mask_gen.get_land_mask() - mask = land_mask[np.newaxis, ..., np.newaxis] - mask = np.repeat(mask, sic_mean.shape[-1], axis=-1) - mask = np.repeat(mask, sic_mean.shape[0], axis=0) + if args.land: + logging.info("Land masking the forecast output") + land_mask = mask_gen.get_land_mask() + mask = land_mask[np.newaxis, ..., np.newaxis] + mask = np.repeat(mask, sic_mean.shape[-1], axis=-1) + mask = np.repeat(mask, sic_mean.shape[0], axis=0) - sic_mean[mask] = 0 - sic_stddev[mask] = 0 + sic_mean[mask] = 0 if not args.nan else np.nan + sic_stddev[mask] = 0 if not args.nan else np.nan - logging.info("Applying active grid cell masks") + if args.agcm: + logging.info("Applying active grid cell masks") - for idx, forecast_date in enumerate(dates): - for lead_idx in np.arange(0, arr.shape[3], 1): - lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) - logging.debug("Active grid cell mask start {} forecast date {}". - format(forecast_date, lead_dt)) + for idx, forecast_date in enumerate(dates): + for lead_idx in np.arange(0, arr.shape[3], 1): + lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) + logging.debug("Active grid cell mask start {} forecast date {}". + format(forecast_date, lead_dt)) - grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) - sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 - sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 + grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) + sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 + sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 xarr = xr.Dataset( data_vars=dict( From 17a8001c5bc83bb2b98167c7912239f931939631 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 21:08:24 +0100 Subject: [PATCH 03/61] Alpha versioning --- icenet/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/__init__.py b/icenet/__init__.py index 63f59f2c..01932438 100644 --- a/icenet/__init__.py +++ b/icenet/__init__.py @@ -4,4 +4,4 @@ __copyright__ = "British Antarctic Survey" __email__ = "jambyr@bas.ac.uk" __license__ = "MIT" -__version__ = "0.2.6" +__version__ = "0.2.7a0" From 0cf6305feb4f2a1fb61581dff7408b630d20b331 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 21:13:55 +0100 Subject: [PATCH 04/61] Update #187: fixing incorrect acgm ref, should be agcm --- icenet/process/predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icenet/process/predict.py b/icenet/process/predict.py index f98eef6d..a854d1b9 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -118,8 +118,8 @@ def get_args(): ap.add_argument("--nan", help="Apply nans, not zeroes, to land mask", default=False, action="store_true") - ap.add_argument("--no-acgm", help="No active grid cell masking", - default=True, action="store_false", dest="acgm") + ap.add_argument("--no-agcm", help="No active grid cell masking", + default=True, action="store_false", dest="agcm") ap.add_argument("--no-land", help="No land while masking", default=True, action="store_false", dest="land") From 6ea2b7b4e8e9e10ac7c982c32fac9919d594692a Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 22:54:54 +0100 Subject: [PATCH 05/61] Closes #189: basic implementation of isoformat for time span attributes --- icenet/process/predict.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/icenet/process/predict.py b/icenet/process/predict.py index a854d1b9..1ebce8e5 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -183,6 +183,12 @@ def create_cf_output(): sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 + lists_of_fcast_dates = [ + [pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) + for lead_idx in np.arange(1, arr.shape[3] + 1, 1)] + for date in dates + ] + xarr = xr.Dataset( data_vars=dict( Lambert_Azimuthal_Grid=ref_sic.Lambert_Azimuthal_Grid, @@ -192,11 +198,7 @@ def create_cf_output(): coords=dict( time=[pd.Timestamp(d) for d in dates], leadtime=np.arange(1, arr.shape[3] + 1, 1), - forecast_date=(("time", "leadtime"), [ - [pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) - for lead_idx in np.arange(1, arr.shape[3] + 1, 1)] - for date in dates - ]), + forecast_date=(("time", "leadtime"), lists_of_fcast_dates), xc=ref_cube.coord("projection_x_coordinate").points, yc=ref_cube.coord("projection_y_coordinate").points, lat=(("yc", "xc"), ref_cube.coord("latitude").points), @@ -257,8 +259,8 @@ def create_cf_output(): """, # Use ISO 8601:2004 duration format, preferably the extended format # as recommended in the Attribute Content Guidance section. - time_coverage_start="", - time_coverage_end="", + time_coverage_start=min(set([item for row in lists_of_fcast_dates for item in row])).isoformat(), + time_coverage_end=max(set([item for row in lists_of_fcast_dates for item in row])).isoformat(), time_coverage_duration="P1D", time_coverage_resolution="P1D", title="Sea Ice Concentration Prediction", From 8bb40eba4d68b6d57bcbb4d79a82d8813c98f0e1 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 23:48:33 +0100 Subject: [PATCH 06/61] Additional logging --- icenet/process/predict.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/icenet/process/predict.py b/icenet/process/predict.py index 1ebce8e5..a7be20e2 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -167,8 +167,14 @@ def create_cf_output(): mask = np.repeat(mask, sic_mean.shape[-1], axis=-1) mask = np.repeat(mask, sic_mean.shape[0], axis=0) - sic_mean[mask] = 0 if not args.nan else np.nan - sic_stddev[mask] = 0 if not args.nan else np.nan + if not args.nan: + logging.info("Applying nans to land mask") + sic_mean[mask] = np.nan + sic_stddev[mask] = np.nan + else: + logging.info("Applying zeros to land mask") + sic_mean[mask] = 0 + sic_stddev[mask] = 0 if args.agcm: logging.info("Applying active grid cell masks") From 2b49b766112c3d39a93a93d471fd3a6f8b70fd03 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Tue, 19 Sep 2023 23:50:19 +0100 Subject: [PATCH 07/61] Update #187: nans stopped working --- icenet/process/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/process/predict.py b/icenet/process/predict.py index a7be20e2..c9cc0709 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -167,7 +167,7 @@ def create_cf_output(): mask = np.repeat(mask, sic_mean.shape[-1], axis=-1) mask = np.repeat(mask, sic_mean.shape[0], axis=0) - if not args.nan: + if args.nan: logging.info("Applying nans to land mask") sic_mean[mask] = np.nan sic_stddev[mask] = np.nan From c573671b2ec8759d633b35d2987318e3552d50c1 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Thu, 21 Sep 2023 10:41:39 +0100 Subject: [PATCH 08/61] Fixes #185: missed po for download_args, only applied to process_args --- icenet/data/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/data/cli.py b/icenet/data/cli.py index 64b2c09a..8784fe24 100644 --- a/icenet/data/cli.py +++ b/icenet/data/cli.py @@ -120,7 +120,7 @@ def download_args(choices: object = None, if workers: ap.add_argument("-w", "--workers", default=8, type=int) - ap.add_argument("-p", "--parallel-opens", + ap.add_argument("-po", "--parallel-opens", default=False, action="store_true", help="Allow xarray mfdataset to work with parallel opens") From 2606519b5922eb9338a7e18c40d4be29f67ead7e Mon Sep 17 00:00:00 2001 From: James Byrne Date: Fri, 22 Sep 2023 12:13:17 +0100 Subject: [PATCH 09/61] Fixes #187 and #191: better masking, ensembler numbers per prediction added in --- icenet/__init__.py | 2 +- icenet/process/predict.py | 46 +++++++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/icenet/__init__.py b/icenet/__init__.py index 01932438..dbd51d00 100644 --- a/icenet/__init__.py +++ b/icenet/__init__.py @@ -4,4 +4,4 @@ __copyright__ = "British Antarctic Survey" __email__ = "jambyr@bas.ac.uk" __license__ = "MIT" -__version__ = "0.2.7a0" +__version__ = "0.2.7a1" diff --git a/icenet/process/predict.py b/icenet/process/predict.py index c9cc0709..62a74d95 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -61,7 +61,7 @@ def get_refcube(north: bool = True, south: bool = False) -> object: def get_prediction_data(root: object, name: object, - date: object) -> object: + date: object) -> tuple: """ :param root: @@ -85,12 +85,13 @@ def get_prediction_data(root: object, data = [np.load(f) for f in np_files] data = np.array(data) + ens_members = data.shape[0] logging.debug("Data read from disk: {} from: {}".format(data.shape, np_files)) return np.stack( [data.mean(axis=0), data.std(axis=0)], - axis=-1).squeeze() + axis=-1).squeeze(), ens_members def date_arg(string: str) -> object: @@ -148,9 +149,9 @@ def create_cf_output(): for s in args.datefile.read().split()] args.datefile.close() - arr = np.array( - [get_prediction_data(args.root, args.name, date) - for date in dates]) + arr, ens_members = zip(*[get_prediction_data(args.root, args.name, date) for date in dates]) + ens_members = list(ens_members) + arr = np.array(arr) logging.info("Dataset arr shape: {}".format(arr.shape)) @@ -160,6 +161,19 @@ def create_cf_output(): if args.mask: mask_gen = Masks(north=ds.north, south=ds.south) + if args.agcm: + logging.info("Applying active grid cell masks") + + for idx, forecast_date in enumerate(dates): + for lead_idx in np.arange(0, arr.shape[3], 1): + lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) + logging.debug("Active grid cell mask start {} forecast date {}". + format(forecast_date, lead_dt)) + + grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) + sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 + sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 + if args.land: logging.info("Land masking the forecast output") land_mask = mask_gen.get_land_mask() @@ -176,19 +190,6 @@ def create_cf_output(): sic_mean[mask] = 0 sic_stddev[mask] = 0 - if args.agcm: - logging.info("Applying active grid cell masks") - - for idx, forecast_date in enumerate(dates): - for lead_idx in np.arange(0, arr.shape[3], 1): - lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) - logging.debug("Active grid cell mask start {} forecast date {}". - format(forecast_date, lead_dt)) - - grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) - sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 - sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 - lists_of_fcast_dates = [ [pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) for lead_idx in np.arange(1, arr.shape[3] + 1, 1)] @@ -200,6 +201,7 @@ def create_cf_output(): Lambert_Azimuthal_Grid=ref_sic.Lambert_Azimuthal_Grid, sic_mean=(["time", "yc", "xc", "leadtime"], sic_mean), sic_stddev=(["time", "yc", "xc", "leadtime"], sic_stddev), + ensemble_members=(["time"], ens_members), ), coords=dict( time=[pd.Timestamp(d) for d in dates], @@ -260,7 +262,7 @@ def create_cf_output(): standard_name_vocabulary="CF Standard Name Table v27", summary=""" This is an output of sea ice concentration predictions from the - IceNet UNet run in an ensemble, with postprocessing to determine + IceNet run in an ensemble, with postprocessing to determine the mean and standard deviation across the runs. """, # Use ISO 8601:2004 duration format, preferably the extended format @@ -332,6 +334,12 @@ def create_cf_output(): units="1", ) + xarr.ensemble_members.attrs = dict( + long_name="number of ensemble members used to create this prediction", + short_name="ensemble_members", + # units="1", + ) + # TODO: split into daily files output_path = os.path.join(args.output_dir, "{}.nc".format(args.name)) logging.info("Saving to {}".format(output_path)) From f36c7774c1f7697c94918a73b1b4f33465f19a23 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 7 Nov 2023 18:06:18 +0000 Subject: [PATCH 10/61] Dev #193: Initial pre-commit set-up --- .pre-commit-config.yaml | 39 +++++++++++++++++++++++++++++++++++++++ requirements_dev.txt | 2 ++ 2 files changed, 41 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..91a4cd97 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +repos: + # General pre-commit hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + language_version: python3 + + # black - Formatting + - repo: https://github.com/psf/black + rev: 23.10.1 + hooks: + - id: black + args: [] + + # isort - Sorting imports + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + # ruff - Linting + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.4 + hooks: + - id: ruff + args: ["--exclude", "setup.py"] + + - repo: local + hooks: + # Run pytest + - id: pytest + name: Run pytest + entry: pytest + language: system + pass_filenames: false diff --git a/requirements_dev.txt b/requirements_dev.txt index 39c41c3f..c256644d 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -7,3 +7,5 @@ pytest black build importlib_metadata +ruff +pre-commit From ce27d59ec763b5571d449b8522301132426f4131 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 7 Nov 2023 18:43:14 +0000 Subject: [PATCH 11/61] Dev 198: Ignore E712 lint for lines --- icenet/data/processors/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icenet/data/processors/utils.py b/icenet/data/processors/utils.py index fa973364..9ab4dd0c 100644 --- a/icenet/data/processors/utils.py +++ b/icenet/data/processors/utils.py @@ -71,12 +71,12 @@ def sic_interpolate(da: object, nan_neighbour_arr[-1, :] = False if np.sum(nan_neighbour_arr) == 1: - res = np.where(np.array(nan_neighbour_arr) == True) + res = np.where(np.array(nan_neighbour_arr) == True) # noqa: E712 logging.warning("Not enough nans for interpolation, extending {}".format(res)) x_idx, y_idx = res[0][0], res[1][0] nan_neighbour_arr[x_idx-1:x_idx+2, y_idx] = True nan_neighbour_arr[x_idx, y_idx-1:y_idx+2] = True - logging.debug(np.where(np.array(nan_neighbour_arr) == True)) + logging.debug(np.where(np.array(nan_neighbour_arr) == True)) # noqa: E712 # Perform bilinear interpolation x_valid = xx[nan_neighbour_arr] From 9c10041b5c4705320385e207e911c741465620ff Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 7 Nov 2023 18:49:47 +0000 Subject: [PATCH 12/61] Dev 198: Remove unused imports for F401 lint --- icenet/model/train.py | 3 --- icenet/plotting/data.py | 1 - 2 files changed, 4 deletions(-) diff --git a/icenet/model/train.py b/icenet/model/train.py index 2203f08a..cc3236d2 100644 --- a/icenet/model/train.py +++ b/icenet/model/train.py @@ -3,12 +3,9 @@ import json import logging import os -import pkg_resources import random import time -from pprint import pformat - import numpy as np import pandas as pd import tensorflow as tf diff --git a/icenet/plotting/data.py b/icenet/plotting/data.py index b9a57be4..86b08a74 100644 --- a/icenet/plotting/data.py +++ b/icenet/plotting/data.py @@ -14,7 +14,6 @@ from icenet.data.dataset import IceNetDataSet from icenet.utils import setup_logging -import matplotlib as mpl import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable From 607524fad3343987a2f6107e002c7335aabaea2c Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 7 Nov 2023 19:40:34 +0000 Subject: [PATCH 13/61] Dev 198: Change f-string for F541 lint --- icenet/plotting/forecast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/icenet/plotting/forecast.py b/icenet/plotting/forecast.py index c8451c29..0da3285c 100644 --- a/icenet/plotting/forecast.py +++ b/icenet/plotting/forecast.py @@ -737,9 +737,9 @@ def standard_deviation_heatmap(metric: str, if metric in ["mae", "mse", "rmse"]: ylabel = f"SIC {metric.upper()} (%)" elif metric == "binacc": - ylabel = f"Binary accuracy (%)" + ylabel = "Binary accuracy (%)" elif metric == "sie": - ylabel = f"SIE error (km)" + ylabel = "SIE error (km)" # plot heatmap of standard deviation for IceNet fig, ax = plt.subplots(figsize=(12, 6)) @@ -903,9 +903,9 @@ def plot_metrics_leadtime_avg(metric: str, if metric in ["mae", "mse", "rmse"]: ylabel = f"SIC {metric.upper()} (%)" elif metric == "binacc": - ylabel = f"Binary accuracy (%)" + ylabel = "Binary accuracy (%)" elif metric == "sie": - ylabel = f"SIE error (km)" + ylabel = "SIE error (km)" if average_over == "all": # averaging metric over leadtime for all forecasts From 02443531f31e1aee1f83fd9389a61cd553741833 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 7 Nov 2023 20:03:58 +0000 Subject: [PATCH 14/61] Dev 198: Add missing import for F821 lint --- icenet/plotting/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/icenet/plotting/utils.py b/icenet/plotting/utils.py index df8c89a3..4dcccf80 100644 --- a/icenet/plotting/utils.py +++ b/icenet/plotting/utils.py @@ -6,6 +6,7 @@ import cartopy.crs as ccrs import matplotlib.pyplot as plt +import numpy as np import pandas as pd import xarray as xr From 15e090263348212fcf974eaa5abc11a75c82c9e9 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 7 Nov 2023 23:06:09 +0000 Subject: [PATCH 15/61] Fixes #192 --- icenet/data/sic/osisaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/data/sic/osisaf.py b/icenet/data/sic/osisaf.py index e88d6bbd..2158303e 100644 --- a/icenet/data/sic/osisaf.py +++ b/icenet/data/sic/osisaf.py @@ -322,7 +322,7 @@ def download(self): cache = {} osi430b_start = dt.date(2016, 1, 1) - osi430a_start = dt.date(2018, 11, 18) + osi430a_start = dt.date(2021, 1, 1) dt_arr = list(reversed(sorted(copy.copy(self._dates)))) From a87a345c087172a973eb79acf15459a42d36c90a Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 11:12:13 +0000 Subject: [PATCH 16/61] Dev 193: Switch to yapf google style profile --- .pre-commit-config.yaml | 18 +++++++++--------- .ruff.toml | 4 ++++ setup.cfg | 6 ++++++ 3 files changed, 19 insertions(+), 9 deletions(-) create mode 100644 .ruff.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91a4cd97..9e30dba8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,27 +8,27 @@ repos: - id: trailing-whitespace language_version: python3 - # black - Formatting - - repo: https://github.com/psf/black - rev: 23.10.1 + - repo: https://github.com/google/yapf + rev: v0.40.2 hooks: - - id: black - args: [] + - id: yapf + name: "yapf" + args: ["--in-place", "--parallel"] # isort - Sorting imports - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort - args: ["--profile", "black", "--filter-files"] + args: ["--filter-files"] # ruff - Linting - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.4 + rev: v0.1.6 hooks: - id: ruff - args: ["--exclude", "setup.py"] - + args: [] + - repo: local hooks: # Run pytest diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 00000000..b5b52017 --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,4 @@ +[lint] +select = ["E", "F"] +ignore = ["E721"] +exclude = ["setup.py"] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9b768cf0..b29f2b65 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,9 +1,15 @@ [bdist_wheel] universal = 1 +[isort] +profile=google + [flake8] exclude = docs +[yapf] +based_on_style = google + [tool:pytest] collect_ignore = ['setup.py'] From 7f797800cc828a460c7595f8f1dc5b291ad6eb82 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 11:24:32 +0000 Subject: [PATCH 17/61] Dev 193: Add toml check --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e30dba8..af026dee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: + - id: check-toml - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace From 1a2325d2d0aee0fc592bf17170b349e42c37d009 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 11:39:28 +0000 Subject: [PATCH 18/61] Dev 193: Update pre-commit --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af026dee..796a4fe9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,4 @@ +exclude: (LICENSE|README.md) repos: # General pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks From 69ef28212b4fc8fb1f8897d90e91154803cee35c Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 11:41:08 +0000 Subject: [PATCH 19/61] Dev 193: Run end-of-file-fixer --- .ruff.toml | 2 +- docs/conf.py | 3 --- docs/requirements.txt | 1 - icenet/data/interfaces/cmems.py | 1 - icenet/data/interfaces/downloader.py | 1 - icenet/data/interfaces/utils.py | 1 - icenet/data/loaders/base.py | 2 -- icenet/data/loaders/dask.py | 3 --- icenet/data/processors/hres.py | 1 - icenet/model/metrics.py | 1 - icenet/model/predict.py | 1 - icenet/plotting/data.py | 2 +- icenet/plotting/utils.py | 2 +- icenet/plotting/video.py | 2 -- icenet/process/azure.py | 1 - tox.ini | 1 - 16 files changed, 3 insertions(+), 22 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index b5b52017..af4b09be 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,4 +1,4 @@ [lint] select = ["E", "F"] ignore = ["E721"] -exclude = ["setup.py"] \ No newline at end of file +exclude = ["setup.py"] diff --git a/docs/conf.py b/docs/conf.py index 45e0514e..abf5a4cd 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -162,6 +162,3 @@ 'One line description of project.', 'Miscellaneous'), ] - - - diff --git a/docs/requirements.txt b/docs/requirements.txt index 217b17e5..5cff4a3a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,2 @@ jinja2==3.0.3 Sphinx==1.8.5 - diff --git a/icenet/data/interfaces/cmems.py b/icenet/data/interfaces/cmems.py index 4777d203..977315bf 100644 --- a/icenet/data/interfaces/cmems.py +++ b/icenet/data/interfaces/cmems.py @@ -191,4 +191,3 @@ def main(): ) oras5.download() oras5.regrid() - diff --git a/icenet/data/interfaces/downloader.py b/icenet/data/interfaces/downloader.py index d1e97de9..c9de444b 100644 --- a/icenet/data/interfaces/downloader.py +++ b/icenet/data/interfaces/downloader.py @@ -647,4 +647,3 @@ def pregrid_prefix(self): @property def var_names(self): return self._var_names - diff --git a/icenet/data/interfaces/utils.py b/icenet/data/interfaces/utils.py index 60e3e02b..96090c25 100644 --- a/icenet/data/interfaces/utils.py +++ b/icenet/data/interfaces/utils.py @@ -245,4 +245,3 @@ def reprocess_main(): output_base=args.output, dry=args.dry, var_names=args.vars) - diff --git a/icenet/data/loaders/base.py b/icenet/data/loaders/base.py index c1fb7766..e50aa78c 100644 --- a/icenet/data/loaders/base.py +++ b/icenet/data/loaders/base.py @@ -343,5 +343,3 @@ def pickup(self): @property def workers(self): return self._workers - - diff --git a/icenet/data/loaders/dask.py b/icenet/data/loaders/dask.py index a39ba128..2da50a94 100644 --- a/icenet/data/loaders/dask.py +++ b/icenet/data/loaders/dask.py @@ -507,6 +507,3 @@ def generate_sample(forecast_date: object, v1 += channels[var_name] return x, y, sample_weights - - - diff --git a/icenet/data/processors/hres.py b/icenet/data/processors/hres.py index 5e9cdb1a..f953802b 100644 --- a/icenet/data/processors/hres.py +++ b/icenet/data/processors/hres.py @@ -41,4 +41,3 @@ def main(): lag_days=args.lag, ) hres.process() - diff --git a/icenet/model/metrics.py b/icenet/model/metrics.py index 1f8c8eb6..988b99e5 100644 --- a/icenet/model/metrics.py +++ b/icenet/model/metrics.py @@ -311,4 +311,3 @@ def result(self): :return: """ return 100 * super().result() - diff --git a/icenet/model/predict.py b/icenet/model/predict.py index fa06a7b4..1c2c6c96 100644 --- a/icenet/model/predict.py +++ b/icenet/model/predict.py @@ -211,4 +211,3 @@ def main(): seed=args.seed, start_dates=dates, test_set=args.testset) - diff --git a/icenet/plotting/data.py b/icenet/plotting/data.py index 86b08a74..ccbcd6c6 100644 --- a/icenet/plotting/data.py +++ b/icenet/plotting/data.py @@ -183,4 +183,4 @@ def plot_channel_data(data: object, fig.colorbar(im1, cax=cax1, orientation='vertical') plt.savefig(output_path) - plt.close() \ No newline at end of file + plt.close() diff --git a/icenet/plotting/utils.py b/icenet/plotting/utils.py index 4dcccf80..6e385baa 100644 --- a/icenet/plotting/utils.py +++ b/icenet/plotting/utils.py @@ -459,4 +459,4 @@ def process_regions(region: tuple, for idx, arr in enumerate(data): if arr is not None: data[idx] = arr[..., (432 - y2):(432 - y1), x1:x2] - return data \ No newline at end of file + return data diff --git a/icenet/plotting/video.py b/icenet/plotting/video.py index 063bd48b..396c4f9c 100644 --- a/icenet/plotting/video.py +++ b/icenet/plotting/video.py @@ -372,5 +372,3 @@ def data_cli(): logging.info("Produced {}".format(res)) except Exception as e: logging.error(e) - - diff --git a/icenet/process/azure.py b/icenet/process/azure.py index cc48c81a..db85e58d 100644 --- a/icenet/process/azure.py +++ b/icenet/process/azure.py @@ -82,4 +82,3 @@ def upload(): if args.date and not args.leave: logging.info("Removing {}".format(tmpdir)) shutil.rmtree(tmpdir) - diff --git a/tox.ini b/tox.ini index 8a5592a6..253c2caa 100644 --- a/tox.ini +++ b/tox.ini @@ -22,4 +22,3 @@ deps = commands = pip install -U pip pytest --basetemp={envtmpdir} - From 4756b0c53a41520e64c6694a9be288f77a9307c5 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 11:42:44 +0000 Subject: [PATCH 20/61] Dev 193: Run trailing-whitespace --- icenet/data/cli.py | 12 +-- icenet/data/dataset.py | 2 +- icenet/data/interfaces/cds.py | 4 +- icenet/data/interfaces/esgf.py | 2 +- icenet/data/loaders/dask.py | 8 +- icenet/data/loaders/stdlib.py | 2 +- icenet/data/process.py | 14 +-- icenet/data/producers.py | 16 +-- icenet/model/metrics.py | 2 +- icenet/model/predict.py | 2 +- icenet/plotting/forecast.py | 156 +++++++++++++++--------------- icenet/plotting/utils.py | 6 +- icenet/plotting/video.py | 2 +- icenet/process/predict.py | 4 +- icenet/tests/test_entry_points.py | 2 +- 15 files changed, 117 insertions(+), 117 deletions(-) diff --git a/icenet/data/cli.py b/icenet/data/cli.py index 8784fe24..21730394 100644 --- a/icenet/data/cli.py +++ b/icenet/data/cli.py @@ -16,8 +16,8 @@ def date_arg(string: str) -> object: """ - :param string: - :return: + :param string: + :return: """ date_match = re.search(r"(\d{4})-(\d{1,2})-(\d{1,2})", string) return dt.date(*[int(s) for s in date_match.groups()]) @@ -26,8 +26,8 @@ def date_arg(string: str) -> object: def dates_arg(string: str) -> object: """ - :param string: - :return: + :param string: + :return: """ if string == "none": return [] @@ -48,7 +48,7 @@ def csv_arg(string: str) -> list: """ csv_items = [] string = re.sub(r'^\'(.*)\'$', r'\1', string) - + for el in string.split(","): if len(el) == 0: csv_items.append(None) @@ -119,7 +119,7 @@ def download_args(choices: object = None, if workers: ap.add_argument("-w", "--workers", default=8, type=int) - + ap.add_argument("-po", "--parallel-opens", default=False, action="store_true", help="Allow xarray mfdataset to work with parallel opens") diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 596ba243..35ca7b28 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -189,7 +189,7 @@ def _load_configurations(self, paths: object): north=False, south=False ) - + for path in paths: if os.path.exists(path): logging.info("Loading configuration {}".format(path)) diff --git a/icenet/data/interfaces/cds.py b/icenet/data/interfaces/cds.py index 410c93c8..b934b69e 100644 --- a/icenet/data/interfaces/cds.py +++ b/icenet/data/interfaces/cds.py @@ -129,7 +129,7 @@ def _single_toolbox_download(self, try: logging.info("Downloading data for {}...".format(var)) logging.debug("Result: {}".format(result)) - + location = result[0]['location'] res = requests.get(location, stream=True) @@ -246,7 +246,7 @@ def additional_regrid_processing(self, """ (datafile_path, datafile_name) = os.path.split(datafile) var_name = datafile_path.split(os.sep)[self._var_name_idx] - + if var_name == 'tos': # Overwrite maksed values with zeros logging.debug("ERA5 regrid postprocess: {}".format(var_name)) diff --git a/icenet/data/interfaces/esgf.py b/icenet/data/interfaces/esgf.py index 65e36a33..8cf6f051 100644 --- a/icenet/data/interfaces/esgf.py +++ b/icenet/data/interfaces/esgf.py @@ -43,7 +43,7 @@ class CMIP6Downloader(ClimateDownloader): "EC-Earth3", "r14i1p1f1", "gr" """ - + TABLE_MAP = { 'siconca': 'SIday', 'tas': 'day', diff --git a/icenet/data/loaders/dask.py b/icenet/data/loaders/dask.py index 2da50a94..87fd84f1 100644 --- a/icenet/data/loaders/dask.py +++ b/icenet/data/loaders/dask.py @@ -22,7 +22,7 @@ """ Dask implementations for icenet data loading -Still WIP to re-introduce alternate implementations that might work better in +Still WIP to re-introduce alternate implementations that might work better in certain deployments """ @@ -266,7 +266,7 @@ def generate_sample(self, [v for k, v in var_files.items() if k.endswith("linear_trend")] trend_ds = None - + if len(trend_files) > 0: trend_ds = xr.open_mfdataset( trend_files, @@ -420,8 +420,8 @@ def generate_sample(forecast_date: object, y = da.zeros((*shape, n_forecast_days, 1), dtype=dtype) sample_weights = da.zeros((*shape, n_forecast_days, 1), dtype=dtype) - - + + if not prediction: try: sample_output = var_ds.siconca_abs.sel(time=forecast_dts) diff --git a/icenet/data/loaders/stdlib.py b/icenet/data/loaders/stdlib.py index 547400b1..9a9b6c9b 100644 --- a/icenet/data/loaders/stdlib.py +++ b/icenet/data/loaders/stdlib.py @@ -3,7 +3,7 @@ """ Python Standard Library implementations for icenet data loading -Still WIP to re-introduce alternate implementations that might work better in +Still WIP to re-introduce alternate implementations that might work better in certain deployments """ diff --git a/icenet/data/process.py b/icenet/data/process.py index 6d53fdd3..543149b8 100644 --- a/icenet/data/process.py +++ b/icenet/data/process.py @@ -25,25 +25,25 @@ class IceNetPreProcessor(Processor): :param name: :param train_dates: :param val_dates: - :param test_dates: + :param test_dates: :param *args: - :param data_shape: + :param data_shape: :param dtype: - :param exclude_vars: + :param exclude_vars: :param file_filters: :param identifier: - :param linear_trends: - :param linear_trend_days: + :param linear_trends: + :param linear_trend_days: :param meta_vars: :param missing_dates: :param minmax: - :param no_normalise: + :param no_normalise: :param path: :param parallel_opens: :param ref_procdir: :param source_data: :param update_key: - :param update_loader: + :param update_loader: """ DATE_FORMAT = "%Y_%m_%d" diff --git a/icenet/data/producers.py b/icenet/data/producers.py index 5238b01b..24818976 100644 --- a/icenet/data/producers.py +++ b/icenet/data/producers.py @@ -130,7 +130,7 @@ def download(self): class Generator(DataProducer): """ - + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -146,13 +146,13 @@ def generate(self): class Processor(DataProducer): """ - - :param identifier: - :param source_data: - :param *args: - :param file_filters: - :param test_dates: - :param train_dates: + + :param identifier: + :param source_data: + :param *args: + :param file_filters: + :param test_dates: + :param train_dates: :param val_dates: """ def __init__(self, diff --git a/icenet/model/metrics.py b/icenet/model/metrics.py index 988b99e5..db58f508 100644 --- a/icenet/model/metrics.py +++ b/icenet/model/metrics.py @@ -166,7 +166,7 @@ def __init__(self, self._leadtime_idx = leadtime_idx super().__init__(name=name, **kwargs) - + def update_state(self, y_true: object, y_pred: object, diff --git a/icenet/model/predict.py b/icenet/model/predict.py index 1c2c6c96..a531305c 100644 --- a/icenet/model/predict.py +++ b/icenet/model/predict.py @@ -99,7 +99,7 @@ def predict_forecast( missing = set(start_dates).difference(test_dates) if len(missing) > 0: raise RuntimeError("{} are not in the test set". - format(", ".join([str(pd.to_datetime(el).date()) + format(", ".join([str(pd.to_datetime(el).date()) for el in missing]))) data_iter = test_inputs.as_numpy_iterator() diff --git a/icenet/plotting/forecast.py b/icenet/plotting/forecast.py index 0da3285c..534b25a7 100644 --- a/icenet/plotting/forecast.py +++ b/icenet/plotting/forecast.py @@ -122,11 +122,11 @@ def plot_binary_accuracy(masks: object, where we consider a binary class prediction of ice with SIC > 15%. In particular, we compute the mean percentage of correct classifications over the active grid cell area. - + :param masks: an icenet Masks object - :param fc_da: the forecasts given as an xarray.DataArray object + :param fc_da: the forecasts given as an xarray.DataArray object with time, xc, yc coordinates - :param cmp_da: a comparison forecast / sea ice data given as an + :param cmp_da: a comparison forecast / sea ice data given as an xarray.DataArray object with time, xc, yc coordinates. If None, will ignore plotting a comparison forecast :param obs_da: the "ground truth" given as an xarray.DataArray object @@ -134,8 +134,8 @@ def plot_binary_accuracy(masks: object, :param output_path: string specifying the path to store the plot :param threshold: the SIC threshold of interest (in percentage as a fraction), i.e. threshold is between 0 and 1 - - :return: tuple of (binary accuracy for forecast (fc_da), + + :return: tuple of (binary accuracy for forecast (fc_da), binary accuracy for comparison (cmp_da)) """ binacc_fc = compute_binary_accuracy(masks=masks, @@ -181,7 +181,7 @@ def compute_sea_ice_extent_error(masks: object, defined as the total area covered by grid cells with SIC > (threshold*100)%. :param masks: an icenet Masks object - :param fc_da: the forecasts given as an xarray.DataArray object + :param fc_da: the forecasts given as an xarray.DataArray object with time, xc, yc coordinates :param obs_da: the "ground truth" given as an xarray.DataArray object with time, xc, yc coordinates @@ -196,10 +196,10 @@ def compute_sea_ice_extent_error(masks: object, threshold = 0.15 if threshold is None else threshold if (threshold < 0) or (threshold > 1): raise ValueError("threshold must be a float between 0 and 1") - + # obtain mask agcm = masks.get_active_cell_da(obs_da) - + # binary for observed (i.e. truth) binary_obs_da = obs_da > threshold binary_obs_weighted_da = binary_obs_da.astype(int).weighted(agcm) @@ -207,13 +207,13 @@ def compute_sea_ice_extent_error(masks: object, # binary for forecast binary_fc_da = fc_da > threshold binary_fc_weighted_da = binary_fc_da.astype(int).weighted(agcm) - + # sie error forecast_sie_error = ( binary_fc_weighted_da.sum(['xc', 'yc']) - binary_obs_weighted_da.sum(['xc', 'yc']) ) * (grid_area_size**2) - + return forecast_sie_error @@ -227,11 +227,11 @@ def plot_sea_ice_extent_error(masks: object, """ Compute and plot sea ice extent (SIE) error of a forecast, where SIE error is defined as the total area covered by grid cells with SIC > (threshold*100)%. - + :param masks: an icenet Masks object - :param fc_da: the forecasts given as an xarray.DataArray object + :param fc_da: the forecasts given as an xarray.DataArray object with time, xc, yc coordinates - :param cmp_da: a comparison forecast / sea ice data given as an + :param cmp_da: a comparison forecast / sea ice data given as an xarray.DataArray object with time, xc, yc coordinates. If None, will ignore plotting a comparison forecast :param obs_da: the "ground truth" given as an xarray.DataArray object @@ -241,7 +241,7 @@ def plot_sea_ice_extent_error(masks: object, by default set to 25 (so area of grid is 25*25) :param threshold: the SIC threshold of interest (in percentage as a fraction), i.e. threshold is between 0 and 1 - + :return: tuple of (SIE error for forecast (fc_da), SIE error for comparison (cmp_da)) """ forecast_sie_error = compute_sea_ice_extent_error(masks=masks, @@ -249,7 +249,7 @@ def plot_sea_ice_extent_error(masks: object, obs_da=obs_da, grid_area_size=grid_area_size, threshold=threshold) - + fig, ax = plt.subplots(figsize=(12, 6)) ax.set_title(f"SIE error comparison ({grid_area_size} km grid resolution) " f"(threshold SIC = {threshold*100}%)") @@ -294,8 +294,8 @@ def compute_metrics(metrics: object, :param masks: an icenet Masks object :param fc_da: an xarray.DataArray object with time, xc, yc coordinates :param obs_da: an xarray.DataArray object with time, xc, yc coordinates - - :return: dictionary with keys as metric names and values as + + :return: dictionary with keys as metric names and values as xarray.DataArray's storing the computed metrics for each forecast """ # check requested metrics have been implemented @@ -304,10 +304,10 @@ def compute_metrics(metrics: object, if metric not in implemented_metrics: raise NotImplementedError(f"{metric} metric has not been implemented. " f"Please only choose out of {implemented_metrics}.") - + # obtain mask mask_da = masks.get_active_cell_da(obs_da) - + metric_dict = {} # compute raw error err_da = (fc_da-obs_da)*100 @@ -319,7 +319,7 @@ def compute_metrics(metrics: object, # compute squared SIC errors square_err_da = err_da**2 square_weighted_da = square_err_da.weighted(mask_da) - + for metric in metrics: if metric == "mae": metric_dict[metric] = abs_weighted_da.mean(dim=['yc', 'xc']) @@ -335,8 +335,8 @@ def compute_metrics(metrics: object, # only return metrics requested (might've computed MSE when computing RMSE) return {k: metric_dict[k] for k in metrics} - - + + def plot_metrics(metrics: object, masks: object, fc_da: object, @@ -353,7 +353,7 @@ def plot_metrics(metrics: object, :param metrics: a list of strings :param masks: an icenet Masks object :param fc_da: an xarray.DataArray object with time, xc, yc coordinates - :param cmp_da: a comparison forecast / sea ice data given as an + :param cmp_da: a comparison forecast / sea ice data given as an xarray.DataArray object with time, xc, yc coordinates. If None, will ignore plotting a comparison forecast :param obs_da: an xarray.DataArray object with time, xc, yc coordinates @@ -361,8 +361,8 @@ def plot_metrics(metrics: object, If separate=True, this should be a directory :param separate: bool specifying whether there is a plot created for each metric (True) or not (False), default is False - - :return: dictionary with keys as metric names and values as + + :return: dictionary with keys as metric names and values as xarray.DataArray's storing the computed metrics for each forecast """ # compute metrics @@ -377,7 +377,7 @@ def plot_metrics(metrics: object, obs_da=obs_da) else: cmp_metric_dict = None - + if separate: # produce separate plots for each metric for metric in metrics: @@ -390,13 +390,13 @@ def plot_metrics(metrics: object, ax.plot(cmp_metric_dict[metric].time, cmp_metric_dict[metric].values, label="SEAS") - + ax.xaxis.set_major_formatter( mdates.ConciseDateFormatter(ax.xaxis.get_major_locator())) ax.xaxis.set_major_locator(mdates.MonthLocator()) ax.xaxis.set_minor_locator(mdates.DayLocator()) ax.legend(loc='lower right') - + outpath = os.path.join("plot", f"{metric}.png") \ if not output_path else os.path.join(output_path, f"{metric}.png") logging.info(f"Saving to {outpath}") @@ -414,7 +414,7 @@ def plot_metrics(metrics: object, cmp_metric_dict[metric].values, label=f"SEAS {metric.upper()}", linestyle="dotted") - + ax.set_ylabel("SIC (%)") ax.xaxis.set_major_formatter( mdates.ConciseDateFormatter(ax.xaxis.get_major_locator())) @@ -422,12 +422,12 @@ def plot_metrics(metrics: object, ax.xaxis.set_minor_locator(mdates.DayLocator()) ax.set_xlabel("Date") ax.legend(loc='lower right') - + output_path = os.path.join("plot", "metrics.png") \ if not output_path else output_path logging.info(f"Saving to {output_path}") plt.savefig(output_path) - + return fc_metric_dict, cmp_metric_dict @@ -441,7 +441,7 @@ def compute_metric_as_dataframe(metric: object, Computes a metric for each leadtime in a forecast and stores the results in a pandas dataframe with columns 'date' (which is the initialisation date passed in), 'leadtime' and the metric name(s). - + :param metric: string, or list of strings, specifying which metric(s) to compute :param masks: an icenet Masks object :param init_date: forecast initialisation date which gets @@ -451,7 +451,7 @@ def compute_metric_as_dataframe(metric: object, :param kwargs: any keyword arguments that are required for the computation of the metric, e.g. 'threshold' for SIE error and binary accuracy metrics, or 'grid_area_size' for SIE error metric - + :return: computed metric in a pandas dataframe with columns 'date', 'leadtime' and 'met' for each metric, met, in metric """ @@ -483,10 +483,10 @@ def compute_metric_as_dataframe(metric: object, threshold=kwargs["threshold"]).values else: raise NotImplementedError(f"{met} is not implemented") - + # create dataframe from metric_dict metric_df = pd.DataFrame(metric_dict) - + init_date = pd.to_datetime(init_date) # compute day of year after first converting year to a non-leap year # avoids issue where 2016-03-31 is different to 2015-03-31 @@ -532,7 +532,7 @@ def compute_metrics_leadtime_avg(metric: str, in a pandas dataframe with columns 'date' (specifying the initialisation date), 'leadtime' and the metric name. This pandas dataframe can then be used to average over leadtime to obtain leadtime averaged metrics. - + :param metric: string specifying which metric to compute :param masks: an icenet Masks object :param hemisphere: string, typically either 'north' or 'south' @@ -550,12 +550,12 @@ def compute_metrics_leadtime_avg(metric: str, :param kwargs: any keyword arguments that are required for the computation of the metric, e.g. 'threshold' for SIE error and binary accuracy metrics, or 'grid_area_size' for SIE error metric - + :return: pandas dataframe with columns 'date', 'leadtime' and the metric name. """ # open forecast file fc_ds = xr.open_dataset(forecast_file) - + if ecmwf: # find out what dates cross over with the SEAS5 predictions (fc_start_date, fc_end_date) = (fc_ds.time.values.min(), fc_ds.time.values.max()) @@ -563,10 +563,10 @@ def compute_metrics_leadtime_avg(metric: str, dates = dates[(dates > fc_start_date) & (dates <= fc_end_date)] times = [x for x in fc_ds.time.values if x in dates] fc_ds = fc_ds.sel(time=times) - + logging.info(f"Computing {metric} for {len(fc_ds.time.values)} forecasts") # obtain metric for each leadtime at each initialised date in the forecast file - + fc_metrics_list = [] if ecmwf: seas_metrics_list = [] @@ -577,7 +577,7 @@ def compute_metrics_leadtime_avg(metric: str, start_date=pd.to_datetime(time) + timedelta(days=1), end_date=pd.to_datetime(time) + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, time) - + if ecmwf: # obtain SEAS forecast seas = get_seas_forecast_da(hemisphere=hemisphere, @@ -588,11 +588,11 @@ def compute_metrics_leadtime_avg(metric: str, seas = seas.isel(time=slice(1, None)) else: seas = None - + if region is not None: seas, fc, obs, masks = process_regions(region, [seas, fc, obs, masks]) - + # compute metrics fc_metrics_list.append(compute_metric_as_dataframe(metric=metric, masks=masks, @@ -607,7 +607,7 @@ def compute_metrics_leadtime_avg(metric: str, fc_da=seas, obs_da=obs, **kwargs)) - + # groupby the leadtime and compute the mean average of the metric fc_metric_df = pd.concat(fc_metrics_list) fc_metric_df["forecast_name"] = "IceNet" @@ -615,7 +615,7 @@ def compute_metrics_leadtime_avg(metric: str, seas_metric_df = pd.concat(seas_metrics_list) seas_metric_df["forecast_name"] = "SEAS" fc_metric_df = pd.concat([fc_metric_df, seas_metric_df]) - + if data_path is not None: logging.info(f"Saving the metric dataframe in {data_path}") try: @@ -623,7 +623,7 @@ def compute_metrics_leadtime_avg(metric: str, except OSError: # don't break if not successful, still return dataframe logging.info("Save not successful! Make sure the data_path directory exists") - + return fc_metric_df.reset_index(drop=True) @@ -633,11 +633,11 @@ def _parse_day_of_year(dayofyear: int, leapyear: bool = False) -> int: the integer day of year. Useful for ensuring consistency over leap years, as dates after March could have different day of years due to leap years. For example, 01/03/00 is 60th day in 2000 but 01/03/01 is the 59th day in 2001. - + :param dayofyear: integer as int or float type :param leapyear: bool to indicate if we want to convert a leapyear dayofyear to non-leapyear - + :return: int dayofyear """ if leapyear: @@ -662,7 +662,7 @@ def _heatmap_ylabels(metrics_df: pd.DataFrame, average_over: str, groupby_col: s or "target_dayofyear". If average_over="month" or "day", this is typically "month" or "target_month". - + :return: list of labels for the y-axis """ if average_over == "day": @@ -714,7 +714,7 @@ def standard_deviation_heatmap(metric: str, If None, this will be computed here :param vmax: float specifying maximum to anchor the colourmap. If None, this is inferred from the data - + :return: dataframe of the standard deviation of the metric over leadtime """ logging.info(f"Creating standard deviation over leadtime plot for " @@ -725,14 +725,14 @@ def standard_deviation_heatmap(metric: str, groupby_col = "month" if target_date_avg: groupby_col = "target_" + groupby_col - + if fc_std_metric is None: # compute standard deviation of metric fc_std_metric = metrics_df.groupby([groupby_col, "leadtime"]).std(numeric_only=True).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ sort_values(groupby_col, ascending=True) n_forecast_days = fc_std_metric.shape[1] - + # set ylabel (if average_over == "all"), or legend label (otherwise) if metric in ["mae", "mse", "rmse"]: ylabel = f"SIC {metric.upper()} (%)" @@ -740,7 +740,7 @@ def standard_deviation_heatmap(metric: str, ylabel = "Binary accuracy (%)" elif metric == "sie": ylabel = "SIE error (km)" - + # plot heatmap of standard deviation for IceNet fig, ax = plt.subplots(figsize=(12, 6)) sns.heatmap(data=fc_std_metric, @@ -749,7 +749,7 @@ def standard_deviation_heatmap(metric: str, vmax=vmax, cmap="mako_r", cbar_kws=dict(label=f"Standard deviation of {ylabel}")) - + # y-axis labels = _heatmap_ylabels(metrics_df=metrics_df, average_over=average_over, @@ -761,13 +761,13 @@ def standard_deviation_heatmap(metric: str, ax.set_ylabel("Target date of forecast") else: ax.set_ylabel("Initialisation date of forecast") - + # x-axis ax.set_xticks(np.arange(30, n_forecast_days, 30)) ax.set_xticklabels(np.arange(30, n_forecast_days, 30)) plt.xticks(rotation=0) ax.set_xlabel("Lead time (days)") - + # add plot title (start_date, end_date) = (metrics_df["date"].min().strftime('%d/%m/%Y'), metrics_df["date"].max().strftime('%d/%m/%Y')) @@ -790,7 +790,7 @@ def standard_deviation_heatmap(metric: str, if not output_path else output_path logging.info(f"Saving to {output_path}") plt.savefig(output_path) - + return fc_std_metric @@ -811,7 +811,7 @@ def plot_metrics_leadtime_avg(metric: str, """ Plots leadtime averaged metrics either using all the forecasts in the forecast file, or averaging them over by month or day. - + :param metric: string specifying which metric to compute :param masks: an icenet Masks object :param hemisphere: string, typically either 'north' or 'south' @@ -841,7 +841,7 @@ def plot_metrics_leadtime_avg(metric: str, :param kwargs: any keyword arguments that are required for the computation of the metric, e.g. 'threshold' for SIE error and binary accuracy metrics, or 'grid_area_size' for SIE error metric - + :return: pandas dataframe with columns 'date', 'leadtime' and the metric name """ implemented_metrics = ["binacc", "sie", "mae", "mse", "rmse"] @@ -864,7 +864,7 @@ def plot_metrics_leadtime_avg(metric: str, del kwargs["grid_area_size"] if "threshold" in kwargs.keys(): del kwargs["threshold"] - + do_compute_metrics = True if data_path is not None: # loading in precomputed dataframes for the metrics @@ -898,7 +898,7 @@ def plot_metrics_leadtime_avg(metric: str, fig, ax = plt.subplots(figsize=(12, 6)) (start_date, end_date) = (fc_metric_df["date"].min().strftime('%d/%m/%Y'), fc_metric_df["date"].max().strftime('%d/%m/%Y')) - + # set ylabel (if average_over == "all"), or legend label (otherwise) if metric in ["mae", "mse", "rmse"]: ylabel = f"SIC {metric.upper()} (%)" @@ -906,13 +906,13 @@ def plot_metrics_leadtime_avg(metric: str, ylabel = "Binary accuracy (%)" elif metric == "sie": ylabel = "SIE error (km)" - + if average_over == "all": # averaging metric over leadtime for all forecasts fc_avg_metric = fc_metric_df.groupby("leadtime").mean(metric).\ sort_values("leadtime", ascending=True)[metric] n_forecast_days = fc_avg_metric.index.max() - + # plot leadtime averaged metrics ax.plot(fc_avg_metric.index, fc_avg_metric, label="IceNet", color="blue") if plot_std: @@ -954,13 +954,13 @@ def plot_metrics_leadtime_avg(metric: str, groupby_col = "month" if target_date_avg: groupby_col = "target_" + groupby_col - + # compute metric by first grouping the dataframe by groupby_col and leadtime fc_avg_metric = fc_metric_df.groupby([groupby_col, "leadtime"]).mean(metric).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ sort_values(groupby_col, ascending=True) n_forecast_days = fc_avg_metric.shape[1] - + if seas_metric_df is not None: # compute the difference in leadtime average to SEAS forecast seas_avg_metric = seas_metric_df.groupby([groupby_col, "leadtime"]).mean(metric).\ @@ -968,7 +968,7 @@ def plot_metrics_leadtime_avg(metric: str, sort_values(groupby_col, ascending=True) heatmap_df_diff = fc_avg_metric - seas_avg_metric max = np.nanmax(np.abs(heatmap_df_diff.values)) - + # plot heatmap of the difference between IceNet and SEAS sns.heatmap(data=heatmap_df_diff, ax=ax, @@ -979,7 +979,7 @@ def plot_metrics_leadtime_avg(metric: str, else: # plot heatmap of the leadtime averaged metric when grouped by groupby_col - sns.heatmap(data=fc_avg_metric, + sns.heatmap(data=fc_avg_metric, ax=ax, cmap="inferno" if metric in ["binacc", "sie"] else "inferno_r", cbar_kws=dict(label=ylabel)) @@ -995,7 +995,7 @@ def plot_metrics_leadtime_avg(metric: str, groupby_col=groupby_col) ax.set_yticks(np.arange(len(fc_metric_df[groupby_col].unique()))+0.5) ax.set_yticklabels(labels) - + plt.yticks(rotation=0) if target_date_avg: ax.set_ylabel("Target date of forecast") @@ -1003,7 +1003,7 @@ def plot_metrics_leadtime_avg(metric: str, ax.set_ylabel("Initialisation date of forecast") else: raise NotImplementedError(f"averaging over {average_over} not a valid option.") - + # add plot title if metric in ["mae", "mse", "rmse"]: title = f"{metric.upper()} comparison" @@ -1013,13 +1013,13 @@ def plot_metrics_leadtime_avg(metric: str, title = f"SIE error comparison ({kwargs['grid_area_size']} km grid resolution, " +\ f"threshold SIC = {kwargs['threshold'] * 100}%)" ax.set_title(title + time_coverage) - + # x-axis ax.set_xticks(np.arange(30, n_forecast_days, 30)) ax.set_xticklabels(np.arange(30, n_forecast_days, 30)) plt.xticks(rotation=0) ax.set_xlabel("Lead time (days)") - + # save plot targ = "target" if target_date_avg and average_over != "all" else "init" filename = f"leadtime_averaged_{targ}_{average_over}_{metric}" + \ @@ -1028,7 +1028,7 @@ def plot_metrics_leadtime_avg(metric: str, if not output_path else output_path logging.info(f"Saving to {output_path}") plt.savefig(output_path) - + if plot_std and average_over in ["day", "month"]: # create heapmap for the standard deviation if seas_metric_df is not None: @@ -1079,7 +1079,7 @@ def sic_error_video(fc_da: object, :param obs_da: :param land_mask: :param output_path: - + :return: matplotlib animation """ @@ -1393,7 +1393,7 @@ def allow_sie(self): return self def allow_metrics(self): - self.add_argument("-m", + self.add_argument("-m", "--metrics", help="Which metrics to compute and plot", type=str, @@ -1465,7 +1465,7 @@ def binary_accuracy(): if args.region: seas, fc, obs, masks = process_regions(args.region, [seas, fc, obs, masks]) - + plot_binary_accuracy(masks=masks, fc_da=fc, cmp_da=seas, @@ -1666,9 +1666,9 @@ def parse_metrics_arg(argument: str) -> object: Splits a string into a list by separating on commas. Will remove any whitespace and removes duplicates. Used to parsing metrics argument in metric_plots. - + :param argument: string - + :return: list of metrics to compute """ return list(set([s.replace(" ", "") for s in argument.split(",")])) @@ -1767,7 +1767,7 @@ def leadtime_avg_plots(): args = ap.parse_args() masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") - + plot_metrics_leadtime_avg(metric=args.metric, masks=masks, hemisphere=args.hemisphere, diff --git a/icenet/plotting/utils.py b/icenet/plotting/utils.py index 6e385baa..74dd40a1 100644 --- a/icenet/plotting/utils.py +++ b/icenet/plotting/utils.py @@ -97,11 +97,11 @@ def get_seas_forecast_init_dates(hemisphere: str, :param hemisphere: string, typically either 'north' or 'south' :param source_path: path where north and south SEAS forecasts are stored - + :return: list of dates """ # list the files in the path where SEAS forecasts are stored - filenames = os.listdir(os.path.join(source_path, + filenames = os.listdir(os.path.join(source_path, hemisphere, "siconca")) # obtain the dates from files with YYYYMMDD.nc format @@ -215,7 +215,7 @@ def strip_overlapping_time(ds): pd.to_datetime(max(seas_da.time.values)).strftime("%Y-%m-%d"), len(seas_da.time) )) - + return seas_da diff --git a/icenet/plotting/video.py b/icenet/plotting/video.py index 396c4f9c..8367a239 100644 --- a/icenet/plotting/video.py +++ b/icenet/plotting/video.py @@ -235,7 +235,7 @@ def recurse_data_folders(base_path: object, and (re.match(r'^\d{4}\.nc$', f) or re.search(r'(abs|anom|linear_trend)\.nc$', f))]) - + logging.debug("Files found: {}".format(", ".join(files))) if not len(files): return None diff --git a/icenet/process/predict.py b/icenet/process/predict.py index 62a74d95..6e1ec678 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -261,8 +261,8 @@ def create_cf_output(): # comply with CF standard_name_vocabulary="CF Standard Name Table v27", summary=""" - This is an output of sea ice concentration predictions from the - IceNet run in an ensemble, with postprocessing to determine + This is an output of sea ice concentration predictions from the + IceNet run in an ensemble, with postprocessing to determine the mean and standard deviation across the runs. """, # Use ISO 8601:2004 duration format, preferably the extended format diff --git a/icenet/tests/test_entry_points.py b/icenet/tests/test_entry_points.py index 0720b189..09e1d257 100644 --- a/icenet/tests/test_entry_points.py +++ b/icenet/tests/test_entry_points.py @@ -15,7 +15,7 @@ def test_have_entry_points(): tests passing vacuously if these are moved) """ assert len(icenet_entry_points) > 0 - + @pytest.mark.parametrize("entry_point", icenet_entry_points) def test_entry_point_exists(entry_point): From 84b6f681292a525196bcbcdac85eea1915958c75 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 12:44:53 +0000 Subject: [PATCH 21/61] Dev 193: Run yapf formatter --- .pre-commit-config.yaml | 1 + icenet/data/cli.py | 108 +++-- icenet/data/dataset.py | 47 +- icenet/data/datasets/utils.py | 46 +- icenet/data/interfaces/cds.py | 107 ++--- icenet/data/interfaces/cmems.py | 59 +-- icenet/data/interfaces/downloader.py | 152 +++---- icenet/data/interfaces/esgf.py | 102 ++--- icenet/data/interfaces/mars.py | 127 +++--- icenet/data/interfaces/utils.py | 48 ++- icenet/data/loader.py | 83 ++-- icenet/data/loaders/__init__.py | 9 +- icenet/data/loaders/base.py | 153 ++++--- icenet/data/loaders/dask.py | 191 ++++----- icenet/data/loaders/stdlib.py | 10 +- icenet/data/loaders/utils.py | 26 +- icenet/data/process.py | 243 ++++++----- icenet/data/processors/cmip.py | 25 +- icenet/data/processors/era5.py | 5 +- icenet/data/processors/hres.py | 5 +- icenet/data/processors/meta.py | 23 +- icenet/data/processors/oras5.py | 4 +- icenet/data/processors/osi.py | 54 +-- icenet/data/processors/utils.py | 49 ++- icenet/data/producers.py | 81 ++-- icenet/data/sic/mask.py | 84 ++-- icenet/data/sic/osisaf.py | 373 +++++++++------- icenet/data/sic/utils.py | 5 +- icenet/data/utils.py | 16 +- icenet/model/callbacks.py | 16 +- icenet/model/losses.py | 7 +- icenet/model/metrics.py | 23 +- icenet/model/models.py | 208 ++++++--- icenet/model/predict.py | 83 ++-- icenet/model/train.py | 204 ++++----- icenet/model/utils.py | 18 +- icenet/plotting/data.py | 50 ++- icenet/plotting/forecast.py | 617 +++++++++++++-------------- icenet/plotting/utils.py | 161 +++---- icenet/plotting/video.py | 126 +++--- icenet/process/azure.py | 8 +- icenet/process/forecasts.py | 54 +-- icenet/process/local.py | 10 +- icenet/process/predict.py | 91 ++-- icenet/process/utils.py | 17 +- icenet/results/threshold.py | 9 +- icenet/tests/test_entry_points.py | 1 - icenet/tests/test_mod.py | 1 - icenet/utils.py | 6 +- setup.py | 24 +- 50 files changed, 2033 insertions(+), 1937 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 796a4fe9..cad224ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,6 +16,7 @@ repos: - id: yapf name: "yapf" args: ["--in-place", "--parallel"] + exclude: "docs/" # isort - Sorting imports - repo: https://github.com/pycqa/isort diff --git a/icenet/data/cli.py b/icenet/data/cli.py index 21730394..a3b39c53 100644 --- a/icenet/data/cli.py +++ b/icenet/data/cli.py @@ -7,7 +7,6 @@ import pandas as pd from icenet.utils import setup_logging - """ """ @@ -35,8 +34,8 @@ def dates_arg(string: str) -> object: date_match = re.findall(r"(\d{4})-(\d{1,2})-(\d{1,2})", string) if len(date_match) < 1: - raise argparse.ArgumentError("No dates found for supplied argument {}". - format(string)) + raise argparse.ArgumentError( + "No dates found for supplied argument {}".format(string)) return [dt.date(*[int(s) for s in date_tuple]) for date_tuple in date_match] @@ -120,12 +119,17 @@ def download_args(choices: object = None, if workers: ap.add_argument("-w", "--workers", default=8, type=int) - ap.add_argument("-po", "--parallel-opens", - default=False, action="store_true", + ap.add_argument("-po", + "--parallel-opens", + default=False, + action="store_true", help="Allow xarray mfdataset to work with parallel opens") - ap.add_argument("-d", "--dont-delete", dest="delete", - action="store_false", default=True) + ap.add_argument("-d", + "--dont-delete", + dest="delete", + action="store_false", + default=True) ap.add_argument("-v", "--verbose", action="store_true", default=False) if var_specs: @@ -133,12 +137,13 @@ def download_args(choices: object = None, help="Comma separated list of vars", type=csv_arg, default=[]) - ap.add_argument("--levels", - help="Comma separated list of pressures/depths as needed, " - "use zero length string if None (e.g. ',,500,,,') and " - "pipes for multiple per var (e.g. ',,250|500,,'", - type=csv_of_csv_arg, - default=[]) + ap.add_argument( + "--levels", + help="Comma separated list of pressures/depths as needed, " + "use zero length string if None (e.g. ',,500,,,') and " + "pipes for multiple per var (e.g. ',,250|500,,'", + type=csv_of_csv_arg, + default=[]) for arg in extra_args: ap.add_argument(*arg[0], **arg[1]) @@ -167,8 +172,10 @@ def process_args(dates: bool = True, ap.add_argument("-l", "--lag", type=int, default=2) ap.add_argument("-f", "--forecast", type=int, default=93) - ap.add_argument("-po", "--parallel-opens", - default=False, action="store_true", + ap.add_argument("-po", + "--parallel-opens", + default=False, + action="store_true", help="Allow xarray mfdataset to work with parallel opens") ap.add_argument("--abs", @@ -192,16 +199,20 @@ def process_args(dates: bool = True, ap.add_argument(*arg[0], **arg[1]) if ref_option: - ap.add_argument("-r", "--ref", + ap.add_argument("-r", + "--ref", help="Reference loader for normalisations etc", - default=None, type=str) + default=None, + type=str) ap.add_argument("-v", "--verbose", action="store_true", default=False) - ap.add_argument("-u", "--update-key", - default=None, - help="Add update key to processor to avoid overwriting default" - "entries in the loader configuration", - type=str) + ap.add_argument( + "-u", + "--update-key", + default=None, + help="Add update key to processor to avoid overwriting default" + "entries in the loader configuration", + type=str) args = ap.parse_args() return args @@ -212,18 +223,37 @@ def add_date_args(arg_parser: object): :param arg_parser: """ - arg_parser.add_argument("-ns", "--train_start", - type=dates_arg, required=False, default=[]) - arg_parser.add_argument("-ne", "--train_end", - type=dates_arg, required=False, default=[]) - arg_parser.add_argument("-vs", "--val_start", - type=dates_arg, required=False, default=[]) - arg_parser.add_argument("-ve", "--val_end", - type=dates_arg, required=False, default=[]) - arg_parser.add_argument("-ts", "--test-start", - type=dates_arg, required=False, default=[]) - arg_parser.add_argument("-te", "--test-end", dest="test_end", - type=dates_arg, required=False, default=[]) + arg_parser.add_argument("-ns", + "--train_start", + type=dates_arg, + required=False, + default=[]) + arg_parser.add_argument("-ne", + "--train_end", + type=dates_arg, + required=False, + default=[]) + arg_parser.add_argument("-vs", + "--val_start", + type=dates_arg, + required=False, + default=[]) + arg_parser.add_argument("-ve", + "--val_end", + type=dates_arg, + required=False, + default=[]) + arg_parser.add_argument("-ts", + "--test-start", + type=dates_arg, + required=False, + default=[]) + arg_parser.add_argument("-te", + "--test-end", + dest="test_end", + type=dates_arg, + required=False, + default=[]) def process_date_args(args: object) -> dict: @@ -240,11 +270,11 @@ def process_date_args(args: object) -> dict: for i, period_start in \ enumerate(getattr(args, "{}_start".format(dataset))): period_end = getattr(args, "{}_end".format(dataset))[i] - dataset_dates += [pd.to_datetime(date).date() for date in - pd.date_range(period_start, - period_end, freq="D")] - logging.info("Got {} dates for {}".format(len(dataset_dates), - dataset)) + dataset_dates += [ + pd.to_datetime(date).date() + for date in pd.date_range(period_start, period_end, freq="D") + ] + logging.info("Got {} dates for {}".format(len(dataset_dates), dataset)) dates[dataset] = sorted(list(dataset_dates)) return dates diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 35ca7b28..0c29ae7e 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -9,7 +9,6 @@ from icenet.data.loader import IceNetDataLoaderFactory from icenet.data.producers import DataCollection from icenet.utils import setup_logging - """ @@ -85,7 +84,7 @@ def _load_configuration(self, path: str): else: raise OSError("{} not found".format(path)) - def get_data_loader(self, n_forecast_days = None, generate_workers = None): + def get_data_loader(self, n_forecast_days=None, generate_workers=None): """ :return: @@ -101,16 +100,12 @@ def get_data_loader(self, n_forecast_days = None, generate_workers = None): self._config["var_lag"], n_forecast_days=n_forecast_days, generate_workers=generate_workers, - dataset_config_path=os.path.dirname( - self._configuration_path), - loss_weight_days=self._config[ - "loss_weight_days"], + dataset_config_path=os.path.dirname(self._configuration_path), + loss_weight_days=self._config["loss_weight_days"], north=self.north, - output_batch_size=self._config[ - "output_batch_size"], + output_batch_size=self._config["output_batch_size"], south=self.south, - var_lag_override=self._config[ - "var_lag_override"], + var_lag_override=self._config["var_lag_override"], ) return loader @@ -148,8 +143,8 @@ def __init__(self, if type(configuration_paths) != list else configuration_paths self._load_configurations(configuration_paths) - identifier = ".".join([loader.identifier - for loader in self._config["loaders"]]) + identifier = ".".join( + [loader.identifier for loader in self._config["loaders"]]) super().__init__(*args, identifier=identifier, @@ -183,12 +178,10 @@ def _load_configurations(self, paths: object): :param paths: """ - self._config = dict( - loader_paths=[], - loaders=[], - north=False, - south=False - ) + self._config = dict(loader_paths=[], + loaders=[], + north=False, + south=False) for path in paths: if os.path.exists(path): @@ -231,11 +224,14 @@ def _merge_configurations(self, path: str, other: object): self._config["counts"] = other["counts"].copy() else: for dataset, count in other["counts"].items(): - logging.info("Merging {} samples from {}".format(count, dataset)) + logging.info("Merging {} samples from {}".format( + count, dataset)) self._config["counts"][dataset] += count - general_attrs = ["channels", "dtype", "n_forecast_days", - "num_channels", "output_batch_size", "shape"] + general_attrs = [ + "channels", "dtype", "n_forecast_days", "num_channels", + "output_batch_size", "shape" + ] for attr in general_attrs: if attr not in self._config: @@ -259,8 +255,7 @@ def get_data_loader(self): ) return self._config["loader"][0] - def check_dataset(self, - split: str = "train"): + def check_dataset(self, split: str = "train"): """ :param split: @@ -281,8 +276,10 @@ def counts(self): def get_args(): ap = argparse.ArgumentParser() ap.add_argument("dataset") - ap.add_argument("-s", "--split", - choices=["train", "val", "test"], default="train") + ap.add_argument("-s", + "--split", + choices=["train", "val", "test"], + default="train") ap.add_argument("-v", "--verbose", action="store_true", default=False) args = ap.parse_args() return args diff --git a/icenet/data/datasets/utils.py b/icenet/data/datasets/utils.py index 99e7a508..6beee2ed 100644 --- a/icenet/data/datasets/utils.py +++ b/icenet/data/datasets/utils.py @@ -20,12 +20,11 @@ def get_decoder(shape: object, :param dtype: :return: """ - xf = tf.io.FixedLenFeature( - [*shape, channels], getattr(tf, dtype)) - yf = tf.io.FixedLenFeature( - [*shape, forecasts, num_vars], getattr(tf, dtype)) - sf = tf.io.FixedLenFeature( - [*shape, forecasts, num_vars], getattr(tf, dtype)) + xf = tf.io.FixedLenFeature([*shape, channels], getattr(tf, dtype)) + yf = tf.io.FixedLenFeature([*shape, forecasts, num_vars], + getattr(tf, dtype)) + sf = tf.io.FixedLenFeature([*shape, forecasts, num_vars], + getattr(tf, dtype)) @tf.function def decode_item(proto): @@ -108,8 +107,9 @@ def get_split_datasets(self, ratio: object = None): if test_idx > 0: self.test_fns = self.test_fns[:test_idx] - logging.info("Reduced: {} train, {} val and {} test filenames".format( - len(self.train_fns), len(self.val_fns), len(self.test_fns))) + logging.info( + "Reduced: {} train, {} val and {} test filenames".format( + len(self.train_fns), len(self.val_fns), len(self.test_fns))) train_ds, val_ds, test_ds = \ tf.data.TFRecordDataset(self.train_fns, @@ -151,8 +151,7 @@ def get_split_datasets(self, ratio: object = None): val_ds.prefetch(tf.data.AUTOTUNE), \ test_ds.prefetch(tf.data.AUTOTUNE) - def check_dataset(self, - split: str = "train"): + def check_dataset(self, split: str = "train"): logging.debug("Checking dataset {}".format(split)) decoder = get_decoder(self.shape, @@ -171,12 +170,9 @@ def check_dataset(self, y = y.numpy() sw = sw.numpy() - logging.debug("Got record {}:{} with x {} y {} sw {}". - format(df, - i, - x.shape, - y.shape, - sw.shape)) + logging.debug( + "Got record {}:{} with x {} y {} sw {}".format( + df, i, x.shape, y.shape, sw.shape)) input_nans = np.isnan(x).sum() output_nans = np.isnan(y[sw > 0.]).sum() @@ -187,19 +183,19 @@ def check_dataset(self, sw_min = np.min(x) sw_max = np.max(x) - logging.debug("Bounds: Input {}:{} Output {}:{} SW {}:{}". - format(input_min, input_max, - output_min, output_max, - sw_min, sw_max)) + logging.debug( + "Bounds: Input {}:{} Output {}:{} SW {}:{}".format( + input_min, input_max, output_min, output_max, + sw_min, sw_max)) if input_nans > 0: - logging.warning("Input NaNs detected in {}:{}". - format(df, i)) + logging.warning("Input NaNs detected in {}:{}".format( + df, i)) if output_nans > 0: - logging.warning("Output NaNs detected in {}:{}, not " - "accounted for by sample weighting". - format(df, i)) + logging.warning( + "Output NaNs detected in {}:{}, not " + "accounted for by sample weighting".format(df, i)) except tf.errors.DataLossError as e: logging.warning("{}: data loss error {}".format(df, e.message)) except tf.errors.OpError as e: diff --git a/icenet/data/interfaces/cds.py b/icenet/data/interfaces/cds.py index b934b69e..ebcaaf97 100644 --- a/icenet/data/interfaces/cds.py +++ b/icenet/data/interfaces/cds.py @@ -13,7 +13,6 @@ from icenet.data.cli import download_args from icenet.data.interfaces.downloader import ClimateDownloader - """ Module to download hourly ERA5 reanalysis latitude-longitude maps, compute daily averages, regrid them to the same EASE grid as the OSI-SAF sea @@ -67,15 +66,11 @@ def __init__(self, logging.info("Upping connection limit for max_threads > 10") adapter = requests.adapters.HTTPAdapter( pool_connections=self._max_threads, - pool_maxsize=self._max_threads - ) + pool_maxsize=self._max_threads) self.client.session.mount("https://", adapter) - def _single_toolbox_download(self, - var: object, - level: object, - req_dates: object, - download_path: object): + def _single_toolbox_download(self, var: object, level: object, + req_dates: object, download_path: object): """Implements a single download from CDS Toolbox API :param var: @@ -88,9 +83,9 @@ def _single_toolbox_download(self, var_prefix = var[0:-(len(str(level)))] if level else var params_dict = { - "realm": "c3s", - "project": "app-c3s-daily-era5-statistics", - "version": "master", + "realm": "c3s", + "project": "app-c3s-daily-era5-statistics", + "version": "master", "workflow_name": "application", "kwargs": { "dataset": "reanalysis-era5-single-levels", @@ -104,14 +99,14 @@ def _single_toolbox_download(self, "time_zone": "UTC+00:00", "grid": "0.25/0.25", "area": { - "lat": [min([self.hemisphere_loc[0], - self.hemisphere_loc[2]]), - max([self.hemisphere_loc[0], - self.hemisphere_loc[2]])], - "lon": [min([self.hemisphere_loc[1], - self.hemisphere_loc[3]]), - max([self.hemisphere_loc[1], - self.hemisphere_loc[3]])], + "lat": [ + min([self.hemisphere_loc[0], self.hemisphere_loc[2]]), + max([self.hemisphere_loc[0], self.hemisphere_loc[2]]) + ], + "lon": [ + min([self.hemisphere_loc[1], self.hemisphere_loc[3]]), + max([self.hemisphere_loc[1], self.hemisphere_loc[3]]) + ], }, }, } @@ -122,9 +117,8 @@ def _single_toolbox_download(self, params_dict["kwargs"]["pressure_level"] = level logging.debug("params_dict: {}".format(pformat(params_dict))) - result = self.client.service( - "tool.toolbox.orchestrator.workflow", - params=params_dict) + result = self.client.service("tool.toolbox.orchestrator.workflow", + params=params_dict) try: logging.info("Downloading data for {}...".format(var)) @@ -146,10 +140,7 @@ def _single_toolbox_download(self, "problem".format(download_path)) raise RuntimeError(e) - def _single_api_download(self, - var: str, - level: object, - req_dates: object, + def _single_api_download(self, var: str, level: object, req_dates: object, download_path: object): """Implements a single download from CDS API @@ -163,15 +154,22 @@ def _single_api_download(self, var_prefix = var[0:-(len(str(level)))] if level else var retrieve_dict = { - "product_type": "reanalysis", - "variable": self._cdi_map[var_prefix], - "year": req_dates[0].year, - "month": list(set(["{:02d}".format(rd.month) - for rd in sorted(req_dates)])), + "product_type": + "reanalysis", + "variable": + self._cdi_map[var_prefix], + "year": + req_dates[0].year, + "month": + list( + set(["{:02d}".format(rd.month) for rd in sorted(req_dates)]) + ), "day": ["{:02d}".format(d) for d in range(1, 32)], "time": ["{:02d}:00".format(h) for h in range(0, 24)], - "format": "netcdf", - "area": self.hemisphere_loc, + "format": + "netcdf", + "area": + self.hemisphere_loc, } dataset = "reanalysis-era5-single-levels" @@ -190,9 +188,7 @@ def _single_api_download(self, "problem".format(download_path)) raise RuntimeError(e) - def postprocess(self, - var: str, - download_path: object): + def postprocess(self, var: str, download_path: object): """Processing of CDS downloaded files If we've not used the toolbox to download the files, we have a lot of @@ -221,9 +217,12 @@ def postprocess(self, # FIXME: This will cause issues for already processed latlon data if len(doy_counts[doy_counts < 24]) > 0: strip_dates_before = min([ - dt.datetime.strptime("{}-{}".format( - d, pd.to_datetime(da.time.values[0]).year), "%j-%Y") - for d in doy_counts[doy_counts < 24].dayofyear.values]) + dt.datetime.strptime( + "{}-{}".format(d, + pd.to_datetime(da.time.values[0]).year), + "%j-%Y") + for d in doy_counts[doy_counts < 24].dayofyear.values + ]) da = da.where(da.time < pd.Timestamp(strip_dates_before), drop=True) if 'expver' in da.coords: @@ -236,9 +235,7 @@ def postprocess(self, da = da.sortby("time").resample(time='1D').mean() da.to_netcdf(download_path) - def additional_regrid_processing(self, - datafile: str, - cube_ease: object): + def additional_regrid_processing(self, datafile: str, cube_ease: object): """ :param datafile: @@ -252,7 +249,8 @@ def additional_regrid_processing(self, logging.debug("ERA5 regrid postprocess: {}".format(var_name)) cube_ease.data[cube_ease.data.mask] = 0. cube_ease.data = cube_ease.data.data - cube_ease.data = np.where(np.isnan(cube_ease.data), 0., cube_ease.data) + cube_ease.data = np.where(np.isnan(cube_ease.data), 0., + cube_ease.data) elif var_name in ['zg500', 'zg250']: # Convert from geopotential to geopotential height logging.debug("ERA5 additional regrid: {}".format(var_name)) @@ -261,17 +259,23 @@ def additional_regrid_processing(self, def main(): args = download_args(choices=["cdsapi", "toolbox"], - workers=True, extra_args=( - (("-n", "--do-not-download"), - dict(dest="download", action="store_false", default=True)), - (("-p", "--do-not-postprocess"), - dict(dest="postprocess", action="store_false", default=True)))) + workers=True, + extra_args=((("-n", "--do-not-download"), + dict(dest="download", + action="store_false", + default=True)), + (("-p", "--do-not-postprocess"), + dict(dest="postprocess", + action="store_false", + default=True)))) logging.info("ERA5 Data Downloading") era5 = ERA5Downloader( var_names=args.vars, - dates=[pd.to_datetime(date).date() for date in - pd.date_range(args.start_date, args.end_date, freq="D")], + dates=[ + pd.to_datetime(date).date() + for date in pd.date_range(args.start_date, args.end_date, freq="D") + ], delete_tempfiles=args.delete, download=args.download, levels=args.levels, @@ -279,7 +283,6 @@ def main(): postprocess=args.postprocess, north=args.hemisphere == "north", south=args.hemisphere == "south", - use_toolbox=args.choice == "toolbox" - ) + use_toolbox=args.choice == "toolbox") era5.download() era5.regrid() diff --git a/icenet/data/interfaces/cmems.py b/icenet/data/interfaces/cmems.py index 977315bf..6b8a8874 100644 --- a/icenet/data/interfaces/cmems.py +++ b/icenet/data/interfaces/cmems.py @@ -10,7 +10,6 @@ from icenet.data.cli import download_args from icenet.data.interfaces.downloader import ClimateDownloader from icenet.utils import run_command - """ DATASET: global-reanalysis-phy-001-031-grepv2-daily FTP ENDPOINT: ftp://my.cmems-du.eu/Core/GLOBAL_REANALYSIS_PHY_001_031/global-reanalysis-phy-001-031-grepv2-daily/1993/01/ @@ -28,18 +27,19 @@ class ORAS5Downloader(ClimateDownloader): """ ENDPOINTS = { # TODO: See #49 - not yet used - "cas": "https://cmems-cas.cls.fr/cas/login", - "dap": "https://my.cmems-du.eu/thredds/dodsC/{dataset}", + "cas": "https://cmems-cas.cls.fr/cas/login", + "dap": "https://my.cmems-du.eu/thredds/dodsC/{dataset}", "motu": "https://my.cmems-du.eu/motu-web/Motu", } VAR_MAP = { - "thetao": "thetao_oras", # sea_water_potential_temperature - "so": "so_oras", # sea_water_salinity - "uo": "uo_oras", # eastward_sea_water_velocity - "vo": "vo_oras", # northward_sea_water_velocity - "zos": "zos_oras", # sea_surface_height_above_geoid - "mlotst": "mlotst_oras", # ocean_mixed_layer_thickness_defined_by_sigma_theta + "thetao": "thetao_oras", # sea_water_potential_temperature + "so": "so_oras", # sea_water_salinity + "uo": "uo_oras", # eastward_sea_water_velocity + "vo": "vo_oras", # northward_sea_water_velocity + "zos": "zos_oras", # sea_surface_height_above_geoid + "mlotst": + "mlotst_oras", # ocean_mixed_layer_thickness_defined_by_sigma_theta } def __init__(self, @@ -73,9 +73,7 @@ def __init__(self, self.download_method = self._single_motu_download - def postprocess(self, - var: str, - download_path: object): + def postprocess(self, var: str, download_path: object): """ :param var: @@ -88,10 +86,7 @@ def postprocess(self, da = da.mean("depth").compute() da.to_netcdf(download_path) - def _single_motu_download(self, - var: str, - level: object, - req_dates: int, + def _single_motu_download(self, var: str, level: object, req_dates: int, download_path: object): """Implements a single download from ... server :param var: @@ -142,8 +137,9 @@ def _single_motu_download(self, if ret.returncode != 0 or not os.path.exists(download_path): attempts += 1 if attempts > self._max_failures: - logging.error("Couldn't download {} between {} and {}". - format(var, req_dates[0], req_dates[-1])) + logging.error( + "Couldn't download {} between {} and {}".format( + var, req_dates[0], req_dates[-1])) break time.sleep(30) else: @@ -151,12 +147,11 @@ def _single_motu_download(self, if success: dur = time.time() - tic - logging.debug("Done in {}m:{:.0f}s. ".format(np.floor(dur / 60), - dur % 60)) + logging.debug("Done in {}m:{:.0f}s. ".format( + np.floor(dur / 60), dur % 60)) return success - def additional_regrid_processing(self, - datafile: object, + def additional_regrid_processing(self, datafile: object, cube_ease: object) -> object: """ @@ -169,18 +164,24 @@ def additional_regrid_processing(self, def main(): - args = download_args(workers=True, extra_args=( - (("-n", "--do-not-download"), - dict(dest="download", action="store_false", default=True)), - (("-p", "--do-not-postprocess"), - dict(dest="postprocess", action="store_false", default=True)))) + args = download_args(workers=True, + extra_args=((("-n", "--do-not-download"), + dict(dest="download", + action="store_false", + default=True)), + (("-p", "--do-not-postprocess"), + dict(dest="postprocess", + action="store_false", + default=True)))) logging.info("ORAS5 Data Downloading") oras5 = ORAS5Downloader( var_names=args.vars, # TODO: currently hardcoded - dates=[pd.to_datetime(date).date() for date in - pd.date_range(args.start_date, args.end_date, freq="D")], + dates=[ + pd.to_datetime(date).date() + for date in pd.date_range(args.start_date, args.end_date, freq="D") + ], delete_tempfiles=args.delete, download=args.delete, levels=[None for _ in args.vars], diff --git a/icenet/data/interfaces/downloader.py b/icenet/data/interfaces/downloader.py index c9de444b..aa0d53f7 100644 --- a/icenet/data/interfaces/downloader.py +++ b/icenet/data/interfaces/downloader.py @@ -26,7 +26,6 @@ import numpy as np import pandas as pd import xarray as xr - """ """ @@ -64,29 +63,27 @@ def filter_dates_on_data(latlon_path: str, # meaning we can naively open and interrogate the dates if check_latlon and os.path.exists(latlon_path): try: - latlon_dates = xr.open_dataset( - latlon_path, - drop_variables=drop_vars).time.values + latlon_dates = xr.open_dataset(latlon_path, + drop_variables=drop_vars).time.values logging.debug("{} latlon dates already available in {}".format( - len(latlon_dates), latlon_path - )) + len(latlon_dates), latlon_path)) except ValueError: logging.warning("Latlon {} dates not readable, ignoring file") if check_regridded and os.path.exists(regridded_name): - regridded_dates = xr.open_dataset( - regridded_name, - drop_variables=drop_vars).time.values + regridded_dates = xr.open_dataset(regridded_name, + drop_variables=drop_vars).time.values logging.debug("{} regridded dates already available in {}".format( - len(regridded_dates), regridded_name - )) + len(regridded_dates), regridded_name)) exclude_dates = list(set(latlon_dates).union(set(regridded_dates))) logging.debug("Excluding {} dates already existing from {} dates " "requested.".format(len(exclude_dates), len(req_dates))) - return sorted(list(pd.to_datetime(req_dates). - difference(pd.to_datetime(exclude_dates)))) + return sorted( + list( + pd.to_datetime(req_dates).difference( + pd.to_datetime(exclude_dates)))) def merge_files(new_datafile: str, @@ -108,17 +105,14 @@ def merge_files(new_datafile: str, d1 = xr.open_dataarray(moved_new_datafile, drop_variables=drop_variables) - logging.info("Concatenating with previous data {}".format( - other_datafile - )) - d2 = xr.open_dataarray(other_datafile, - drop_variables=drop_variables) + logging.info( + "Concatenating with previous data {}".format(other_datafile)) + d2 = xr.open_dataarray(other_datafile, drop_variables=drop_variables) new_ds = xr.concat([d1, d2], dim="time").\ sortby("time").\ drop_duplicates("time", keep="first") - logging.info("Saving merged data to {}... ". - format(new_datafile)) + logging.info("Saving merged data to {}... ".format(new_datafile)) new_ds.to_netcdf(new_datafile) os.unlink(other_datafile) os.unlink(moved_new_datafile) @@ -139,7 +133,8 @@ class ClimateDownloader(Downloader): :param var_names: """ - def __init__(self, *args, + def __init__(self, + *args, dates: object = (), delete_tempfiles: bool = True, download: bool = True, @@ -222,10 +217,8 @@ def download(self): futures = [] for var_prefix, level, req_date in requests: - future = executor.submit(self._single_download, - var_prefix, - level, - req_date) + future = executor.submit(self._single_download, var_prefix, + level, req_date) futures.append(future) for future in concurrent.futures.as_completed(futures): @@ -234,12 +227,10 @@ def download(self): except Exception as e: logging.exception("Thread failure: {}".format(e)) - logging.info("{} daily files downloaded". - format(len(self._files_downloaded))) + logging.info("{} daily files downloaded".format( + len(self._files_downloaded))) - def _single_download(self, - var_prefix: str, - level: object, + def _single_download(self, var_prefix: str, level: object, req_dates: object): """Implements a single download based on configured download_method @@ -252,8 +243,9 @@ def _single_download(self, :param req_dates: the request date """ - logging.info("Processing single download for {} @ {} with {} dates". - format(var_prefix, level, len(req_dates))) + logging.info( + "Processing single download for {} @ {} with {} dates".format( + var_prefix, level, len(req_dates))) var = var_prefix if not level else \ "{}{}".format(var_prefix, level) var_folder = self.get_data_var_folder(var) @@ -269,23 +261,22 @@ def _single_download(self, if len(req_dates): if self._download: with tempfile.TemporaryDirectory() as tmpdir: - tmp_latlon_path = os.path.join(tmpdir, os.path.basename("{}.download".format(latlon_path))) + tmp_latlon_path = os.path.join( + tmpdir, + os.path.basename("{}.download".format(latlon_path))) - self.download_method(var, - level, - req_dates, - tmp_latlon_path) + self.download_method(var, level, req_dates, tmp_latlon_path) if os.path.exists(latlon_path): (ll_path, ll_file) = os.path.split(latlon_path) rename_latlon_path = os.path.join( - ll_path, "{}_old{}".format( - *os.path.splitext(ll_file))) + ll_path, + "{}_old{}".format(*os.path.splitext(ll_file))) os.rename(latlon_path, rename_latlon_path) - old_da = xr.open_dataarray(rename_latlon_path, - drop_variables=self._drop_vars) - tmp_da = xr.open_dataarray(tmp_latlon_path, - drop_variables=self._drop_vars) + old_da = xr.open_dataarray( + rename_latlon_path, drop_variables=self._drop_vars) + tmp_da = xr.open_dataarray( + tmp_latlon_path, drop_variables=self._drop_vars) logging.debug("Input (old): \n{}".format(old_da)) logging.debug("Input (dl): \n{}".format(tmp_da)) @@ -302,8 +293,8 @@ def _single_download(self, logging.info("Downloaded to {}".format(latlon_path)) else: - logging.info("Skipping actual download to {}". - format(latlon_path)) + logging.info( + "Skipping actual download to {}".format(latlon_path)) else: logging.info("No requested dates remain, likely already present") @@ -314,12 +305,10 @@ def _single_download(self, self._files_downloaded.append(latlon_path) def postprocess(self, var, download_path): - logging.debug("No postprocessing in place for {}: {}". - format(var, download_path)) + logging.debug("No postprocessing in place for {}: {}".format( + var, download_path)) - def save_temporal_files(self, var, da, - date_format=None, - freq=None): + def save_temporal_files(self, var, da, date_format=None, freq=None): """ :param var: @@ -378,9 +367,7 @@ def sic_ease_cube(self): 'projection_y_coordinate').convert_units('meters') return self._sic_ease_cubes[self._hemisphere] - def regrid(self, - files: object = None, - rotate_wind: bool = True): + def regrid(self, files: object = None, rotate_wind: bool = True): """ :param files: @@ -405,8 +392,9 @@ def regrid(self, fut_results = future.result() for res in fut_results: - logging.debug("Future result -> regrid_results: {}". - format(res)) + logging.debug( + "Future result -> regrid_results: {}".format( + res)) regrid_results.append(res) except Exception as e: logging.exception("Thread failure: {}".format(e)) @@ -420,8 +408,7 @@ def regrid(self, for new_datafile, moved_datafile in regrid_results: merge_files(new_datafile, moved_datafile, self._drop_vars) - def _batch_regrid(self, - files: object): + def _batch_regrid(self, files: object): """ :param files: @@ -431,8 +418,8 @@ def _batch_regrid(self, for datafile in files: (datafile_path, datafile_name) = os.path.split(datafile) - new_filename = re.sub(r'^{}'.format( - self.pregrid_prefix), '', datafile_name) + new_filename = re.sub(r'^{}'.format(self.pregrid_prefix), '', + datafile_name) new_datafile = os.path.join(datafile_path, new_filename) moved_datafile = None @@ -442,8 +429,8 @@ def _batch_regrid(self, moved_datafile = os.path.join(datafile_path, moved_filename) os.rename(new_datafile, moved_datafile) - logging.info("{} already existed, moved to {}". - format(new_filename, moved_filename)) + logging.info("{} already existed, moved to {}".format( + new_filename, moved_filename)) logging.debug("Regridding {}".format(datafile)) @@ -451,15 +438,15 @@ def _batch_regrid(self, cube = iris.load_cube(datafile) cube = self.convert_cube(cube) - cube_ease = cube.regrid( - self.sic_ease_cube, iris.analysis.Linear()) + cube_ease = cube.regrid(self.sic_ease_cube, + iris.analysis.Linear()) except iris.exceptions.CoordinateNotFoundError: - logging.warning("{} has no coordinates...". - format(datafile_name)) + logging.warning( + "{} has no coordinates...".format(datafile_name)) if self.delete: - logging.debug("Deleting failed file {}...". - format(datafile_name)) + logging.debug( + "Deleting failed file {}...".format(datafile_name)) os.unlink(datafile) continue @@ -486,9 +473,7 @@ def convert_cube(self, cube: object): return cube @abstractmethod - def additional_regrid_processing(self, - datafile: str, - cube_ease: object): + def additional_regrid_processing(self, datafile: str, cube_ease: object): """ :param datafile: @@ -511,8 +496,8 @@ def rotate_wind_data(self, angles = gridcell_angles_from_dim_coords(self.sic_ease_cube) invert_gridcell_angles(angles) - logging.info("Rotating wind data in {}".format( - " ".join([self.get_data_var_folder(v) for v in apply_to]))) + logging.info("Rotating wind data in {}".format(" ".join( + [self.get_data_var_folder(v) for v in apply_to]))) wind_files = {} @@ -526,11 +511,12 @@ def rotate_wind_data(self, wind_files[var] = sorted([ re.sub(r'{}'.format(self.pregrid_prefix), '', df) for df in latlon_files - if os.path.dirname(df).split(os.sep) - [self._var_name_idx] == var], - key=lambda x: int(re.search(r'^(?:\w+_)?(\d+).nc', - os.path.basename(x)).group(1)) - ) + if os.path.dirname(df).split(os.sep)[self._var_name_idx] == var + ], + key=lambda x: int( + re.search(r'^(?:\w+_)?(\d+).nc', + os.path.basename(x)).group(1) + )) logging.info("{} files for {}".format(len(wind_files[var]), var)) # NOTE: we're relying on apply_to having equal datasets @@ -578,9 +564,9 @@ def rotate_wind_data(self, for i, name in enumerate([wind_file_0, wind_file_1]): # NOTE: implementation with tempfile caused problems on NFS # mounted filesystem, so avoiding in place of letting iris do it - temp_name = os.path.join(os.path.split(name)[0], - "temp.{}".format( - os.path.basename(name))) + temp_name = os.path.join( + os.path.split(name)[0], + "temp.{}".format(os.path.basename(name))) logging.debug("Writing {}".format(temp_name)) iris.save(wind_cubes_r[apply_to[i]], temp_name) @@ -604,12 +590,10 @@ def get_req_filenames(self, latlon_path = os.path.join( var_folder, "{}{}.nc".format(self.pregrid_prefix, filename_date)) - regridded_name = os.path.join( - var_folder, "{}.nc".format(filename_date)) + regridded_name = os.path.join(var_folder, "{}.nc".format(filename_date)) logging.debug("Got {} filenames: {} and {}".format( - self._group_dates_by, latlon_path, regridded_name - )) + self._group_dates_by, latlon_path, regridded_name)) return latlon_path, regridded_name diff --git a/icenet/data/interfaces/esgf.py b/icenet/data/interfaces/esgf.py index 8cf6f051..dd9a3f7e 100644 --- a/icenet/data/interfaces/esgf.py +++ b/icenet/data/interfaces/esgf.py @@ -9,7 +9,6 @@ from icenet.data.interfaces.downloader import ClimateDownloader from icenet.data.cli import download_args from icenet.data.utils import esgf_search - """ """ @@ -68,8 +67,8 @@ class CMIP6Downloader(ClimateDownloader): 'hus': 'gn', 'psl': 'gn', 'rlds': 'gn', - 'rsus': 'gn', # Surface Upwelling Shortwave Radiation - 'rsds': 'gn', # Surface Downwelling Shortwave Radiation + 'rsus': 'gn', # Surface Upwelling Shortwave Radiation + 'rsds': 'gn', # Surface Downwelling Shortwave Radiation 'zg': 'gn', 'uas': 'gn', 'vas': 'gn', @@ -79,15 +78,17 @@ class CMIP6Downloader(ClimateDownloader): # Prioritise European first, US last, avoiding unnecessary queries # against nodes further afield (all traffic has a cost, and the coverage # of local nodes is more than enough) - ESGF_NODES = ("esgf.ceda.ac.uk", - "esg1.umr-cnrm.fr", - "vesg.ipsl.upmc.fr", - "esgf3.dkrz.de", - "esgf.bsc.es", - "esgf-data.csc.fi", - "noresg.nird.sigma2.no", - "esgf-data.ucar.edu", - "esgf-data2.diasjp.net",) + ESGF_NODES = ( + "esgf.ceda.ac.uk", + "esg1.umr-cnrm.fr", + "vesg.ipsl.upmc.fr", + "esgf3.dkrz.de", + "esgf.bsc.es", + "esgf-data.csc.fi", + "noresg.nird.sigma2.no", + "esgf-data.ucar.edu", + "esgf-data2.diasjp.net", + ) def __init__(self, *args, @@ -117,9 +118,7 @@ def __init__(self, self._grid_map = grid_map if grid_map else CMIP6Downloader.GRID_MAP self._grid_map_override = grid_override - def _single_download(self, - var_prefix: str, - level: object, + def _single_download(self, var_prefix: str, level: object, req_dates: object): """Overridden CMIP implementation for downloading from DAP server @@ -133,14 +132,19 @@ def _single_download(self, """ query = { - 'source_id': self._source, - 'member_id': self._member, - 'frequency': self._frequency, - 'variable_id': var_prefix, - 'table_id': self._table_map[var_prefix], - 'grid_label': self._grid_map_override - if self._grid_map_override - else self._grid_map[var_prefix], + 'source_id': + self._source, + 'member_id': + self._member, + 'frequency': + self._frequency, + 'variable_id': + var_prefix, + 'table_id': + self._table_map[var_prefix], + 'grid_label': + self._grid_map_override + if self._grid_map_override else self._grid_map[var_prefix], } var = var_prefix if not level else "{}{}".format(var_prefix, level) @@ -166,34 +170,29 @@ def _single_download(self, results.extend(node_results) break - logging.info("Found {} {} results from ESGF search". - format(len(results), var_prefix)) + logging.info("Found {} {} results from ESGF search".format( + len(results), var_prefix)) try: # http://xarray.pydata.org/en/stable/user-guide/io.html?highlight=opendap#opendap # Avoid 500MB DAP request limit cmip6_da = xr.open_mfdataset(results, combine='by_coords', - chunks={'time': '499MB'} - )[var_prefix] + chunks={'time': '499MB'})[var_prefix] - cmip6_da = cmip6_da.sel(time=slice(req_dates[0], - req_dates[-1])) + cmip6_da = cmip6_da.sel(time=slice(req_dates[0], req_dates[-1])) # TODO: possibly other attributes, especially with ocean vars if level: cmip6_da = cmip6_da.sel(plev=int(level) * 100) - cmip6_da = cmip6_da.sel(lat=slice(self.hemisphere_loc[2], - self.hemisphere_loc[0])) + cmip6_da = cmip6_da.sel( + lat=slice(self.hemisphere_loc[2], self.hemisphere_loc[0])) self.save_temporal_files(var, cmip6_da) except OSError as e: - logging.exception("Error encountered: {}".format(e), - exc_info=False) + logging.exception("Error encountered: {}".format(e), exc_info=False) - def additional_regrid_processing(self, - datafile: str, - cube_ease: object): + def additional_regrid_processing(self, datafile: str, cube_ease: object): """ :param datafile: @@ -217,8 +216,8 @@ def additional_regrid_processing(self, cube_ease.data = cube_ease.data.data if cube_ease.data.dtype != np.float32: - logging.info("Regrid processing, data type not float: {}". - format(cube_ease.data.dtype)) + logging.info("Regrid processing, data type not float: {}".format( + cube_ease.data.dtype)) cube_ease.data = cube_ease.data.astype(np.float32) def convert_cube(self, cube: object) -> object: @@ -236,17 +235,16 @@ def convert_cube(self, cube: object) -> object: def main(): - args = download_args( - dates=True, - extra_args=[ - (["source"], dict(type=str)), - (["member"], dict(type=str)), - (("-xs", "--exclude-server"), - dict(default=[], nargs="*")), - (("-o", "--override"), dict(required=None, type=str)), - ], - workers=True - ) + args = download_args(dates=True, + extra_args=[ + (["source"], dict(type=str)), + (["member"], dict(type=str)), + (("-xs", "--exclude-server"), + dict(default=[], nargs="*")), + (("-o", "--override"), dict(required=None, + type=str)), + ], + workers=True) logging.info("CMIP6 Data Downloading") @@ -254,8 +252,10 @@ def main(): source=args.source, member=args.member, var_names=args.vars, - dates=[pd.to_datetime(date).date() for date in - pd.date_range(args.start_date, args.end_date, freq="D")], + dates=[ + pd.to_datetime(date).date() + for date in pd.date_range(args.start_date, args.end_date, freq="D") + ], delete_tempfiles=args.delete, grid_override=args.override, levels=args.levels, diff --git a/icenet/data/interfaces/mars.py b/icenet/data/interfaces/mars.py index a3fcb5c4..74a353b1 100644 --- a/icenet/data/interfaces/mars.py +++ b/icenet/data/interfaces/mars.py @@ -12,7 +12,6 @@ from icenet.data.cli import download_args from icenet.data.interfaces.downloader import ClimateDownloader from icenet.data.interfaces.utils import batch_requested_dates - """ """ @@ -31,16 +30,16 @@ class HRESDownloader(ClimateDownloader): # https://confluence.ecmwf.int/pages/viewpage.action?pageId=85402030 # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Dateandtimespecification HRES_PARAMS = { - "siconca": (31, "siconc"), # sea_ice_area_fraction - "tos": (34, "sst"), # sea surface temperature (actually - # sst?) - "zg": (129, "z"), # geopotential - "ta": (130, "t"), # air_temperature (t) - "hus": (133, "q"), # specific_humidity - "psl": (134, "sp"), # surface_pressure - "uas": (165, "u10"), # 10m_u_component_of_wind - "vas": (166, "v10"), # 10m_v_component_of_wind - "tas": (167, "t2m"), # 2m_temperature (t2m) + "siconca": (31, "siconc"), # sea_ice_area_fraction + "tos": (34, "sst"), # sea surface temperature (actually + # sst?) + "zg": (129, "z"), # geopotential + "ta": (130, "t"), # air_temperature (t) + "hus": (133, "q"), # specific_humidity + "psl": (134, "sp"), # surface_pressure + "uas": (165, "u10"), # 10m_u_component_of_wind + "vas": (166, "v10"), # 10m_v_component_of_wind + "tas": (167, "t2m"), # 2m_temperature (t2m) # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Meanrates/fluxesandaccumulations # https://apps.ecmwf.int/codes/grib/param-db/?id=175 # https://confluence.ecmwf.int/pages/viewpage.action?pageId=197702790 @@ -50,8 +49,8 @@ class HRESDownloader(ClimateDownloader): # Table 3 for surface and single levels), except they are expressed as # temporal means, over the same processing periods, and so have units # of "per second". - "rlds": (175, "strd"), - "rsds": (169, "ssrd"), + "rlds": (175, "strd"), + "rsds": (169, "ssrd"), # plev 129.128 / 130.128 / 133.128 # sfc 31.128 / 34.128 / 134.128 / @@ -83,19 +82,12 @@ class HRESDownloader(ClimateDownloader): format=netcdf """ - def __init__(self, - *args, - identifier: str = "mars.hres", - **kwargs): - super().__init__(*args, - identifier=identifier, - **kwargs) + def __init__(self, *args, identifier: str = "mars.hres", **kwargs): + super().__init__(*args, identifier=identifier, **kwargs) self._server = ecmwfapi.ECMWFService("mars") - def _single_download(self, - var_names: object, - pressures: object, + def _single_download(self, var_names: object, pressures: object, req_dates: object): """ @@ -124,8 +116,7 @@ def _single_download(self, req_batch = req_batch[:-1] request_target = os.path.join( - self.base_path, - self.hemisphere_str[0], + self.base_path, self.hemisphere_str[0], "{}.{}.nc".format(levtype, request_month)) os.makedirs(os.path.dirname(request_target), exist_ok=True) @@ -134,12 +125,12 @@ def _single_download(self, area="/".join([str(s) for s in self.hemisphere_loc]), date="/".join([el.strftime("%Y%m%d") for el in req_batch]), levtype=levtype, - levlist="levelist={},\n ".format(pressures) if pressures else "", - params="/".join( - ["{}.{}".format( - self.params[v][0], - self.param_table) - for v in var_names]), + levlist="levelist={},\n ".format(pressures) + if pressures else "", + params="/".join([ + "{}.{}".format(self.params[v][0], self.param_table) + for v in var_names + ]), target=request_target, # We are only allowed date prior to -24 hours ago, dynamically # retrieve if date is today @@ -165,13 +156,13 @@ def _single_download(self, ds = xr.open_mfdataset(downloads) ds = ds.resample(time='1D', keep_attrs=True).mean(keep_attrs=True) - for var_name, pressure in product(var_names, pressures.split('/') - if pressures else [None]): + for var_name, pressure in product( + var_names, + pressures.split('/') if pressures else [None]): var = var_name if not pressure else \ "{}{}".format(var_name, pressure) - da = getattr(ds, - self.params[var_name][1]) + da = getattr(ds, self.params[var_name][1]) if pressure: da = da.sel(level=int(pressure)) @@ -193,12 +184,17 @@ def download(self): logging.info("Building request(s), downloading and daily averaging " "from {} API".format(self.identifier.upper())) - sfc_vars = [var for idx, var in enumerate(self.var_names) - if not self.levels[idx]] - level_vars = [var for idx, var in enumerate(self.var_names) - if self.levels[idx]] - levels = "/".join([str(s) for s in sorted(set( - [p for ps in self.levels if ps for p in ps]))]) + sfc_vars = [ + var for idx, var in enumerate(self.var_names) + if not self.levels[idx] + ] + level_vars = [ + var for idx, var in enumerate(self.var_names) if self.levels[idx] + ] + levels = "/".join([ + str(s) + for s in sorted(set([p for ps in self.levels if ps for p in ps])) + ]) # req_dates = self.filter_dates_on_data() @@ -213,12 +209,10 @@ def download(self): if len(level_vars) > 0: self._single_download(level_vars, levels, req_batch) - logging.info("{} daily files downloaded". - format(len(self._files_downloaded))) + logging.info("{} daily files downloaded".format( + len(self._files_downloaded))) - def additional_regrid_processing(self, - datafile: str, - cube_ease: object): + def additional_regrid_processing(self, datafile: str, cube_ease: object): """ :param datafile: @@ -289,9 +283,7 @@ class SEASDownloader(HRESDownloader): grid=0.25/0.25, area={area}""" - def _single_download(self, - var_names: object, - pressures: object, + def _single_download(self, var_names: object, pressures: object, req_dates: object): """ @@ -313,8 +305,7 @@ def _single_download(self, logging.info("Downloading daily file {}".format(request_day)) request_target = os.path.join( - self.base_path, - self.hemisphere_str[0], + self.base_path, self.hemisphere_str[0], "{}.{}.nc".format(levtype, request_day)) os.makedirs(os.path.dirname(request_target), exist_ok=True) @@ -322,12 +313,12 @@ def _single_download(self, area="/".join([str(s) for s in self.hemisphere_loc]), date=req_date.strftime("%Y-%m-%d"), levtype=levtype, - levlist="levelist={},\n ".format(pressures) if pressures else "", - params="/".join( - ["{}.{}".format( - self.params[v][0], - self.param_table) - for v in var_names]), + levlist="levelist={},\n ".format(pressures) + if pressures else "", + params="/".join([ + "{}.{}".format(self.params[v][0], self.param_table) + for v in var_names + ]), target=request_target, ) @@ -351,8 +342,9 @@ def _single_download(self, ds = xr.open_dataset(download_filename) ds = ds.mean("number") - for var_name, pressure in product(var_names, pressures.split('/') - if pressures else [None]): + for var_name, pressure in product( + var_names, + pressures.split('/') if pressures else [None]): var = var_name if not pressure else \ "{}{}".format(var_name, pressure) @@ -371,9 +363,7 @@ def _single_download(self, logging.info("Removing {}".format(downloaded_file)) os.unlink(downloaded_file) - def save_temporal_files(self, var, da, - date_format=None, - freq=None): + def save_temporal_files(self, var, da, date_format=None, freq=None): """ :param var: @@ -409,22 +399,21 @@ def main(identifier, extra_kwargs=None): instance = cls( identifier="mars.{}".format(identifier.lower()), var_names=args.vars, - dates=[pd.to_datetime(date).date() for date in - pd.date_range(args.start_date, args.end_date, freq="D")], + dates=[ + pd.to_datetime(date).date() + for date in pd.date_range(args.start_date, args.end_date, freq="D") + ], delete_tempfiles=args.delete, levels=args.levels, north=args.hemisphere == "north", south=args.hemisphere == "south", - **extra_kwargs - ) + **extra_kwargs) instance.download() instance.regrid() def seas_main(): - main("SEAS", dict( - group_dates_by="day", - )) + main("SEAS", dict(group_dates_by="day",)) def hres_main(): diff --git a/icenet/data/interfaces/utils.py b/icenet/data/interfaces/utils.py index 96090c25..050f9b11 100644 --- a/icenet/data/interfaces/utils.py +++ b/icenet/data/interfaces/utils.py @@ -10,8 +10,7 @@ from icenet.utils import setup_logging -def batch_requested_dates(dates: object, - attribute: str = "month") -> object: +def batch_requested_dates(dates: object, attribute: str = "month") -> object: """ TODO: should be using Pandas DatetimeIndexes / Periods for this, but the @@ -81,26 +80,26 @@ def reprocess_monthlies(source: str, logging.warning("Cannot derive year from {}".format(year)) continue - destination = os.path.join(output_base, - identifier, - hemisphere, - var_name, - str(year)) + destination = os.path.join(output_base, identifier, hemisphere, + var_name, str(year)) if not os.path.exists(destination): os.makedirs(destination, exist_ok=True) - logging.info("Processing {} from {} to {}". - format(var_name, year, destination)) + logging.info("Processing {} from {} to {}".format( + var_name, year, destination)) ds = xr.open_dataset(file) - var_names = [name for name in list(ds.data_vars.keys()) - if not name.startswith("lambert_")] + var_names = [ + name for name in list(ds.data_vars.keys()) + if not name.startswith("lambert_") + ] var_names = set(var_names) - logging.debug("Files have var names {} which will be renamed to {}". - format(", ".join(var_names), var_name)) + logging.debug( + "Files have var names {} which will be renamed to {}".format( + ", ".join(var_names), var_name)) ds = ds.rename({k: var_name for k in var_names}) da = getattr(ds, var_name) @@ -144,9 +143,9 @@ def add_time_dim(source: str, for path, filename in [os.path.split(el) for el in file_list]: if filename.startswith("{}_".format(var_name)): - raise RuntimeError("{} starts with var name, we only want " - "correctly named files to convert". - format(filename)) + raise RuntimeError( + "{} starts with var name, we only want " + "correctly named files to convert".format(filename)) year = str(path.split(os.sep)[-1]) if year not in files[var_name]: @@ -177,14 +176,15 @@ def add_time_dim(source: str, if "siconca" in year_files[0]: ds = ds.rename_vars({"siconca": "ice_conc"}) ds = ds.sortby("time") - ds['time'] = [pd.Timestamp(el) for el in - ds.indexes['time'].normalize()] + ds['time'] = [ + pd.Timestamp(el) for el in ds.indexes['time'].normalize() + ] for d in ds.time.values: dt = pd.to_datetime(d) date_str = dt.strftime("%Y_%m_%d") - fpath = os.path.join(os.path.split(year_files[0])[0], - "{}.nc".format(date_str)) + fpath = os.path.join( + os.path.split(year_files[0])[0], "{}.nc".format(date_str)) if not os.path.exists(fpath): dw = ds.sel(time=slice(dt, dt)) @@ -231,7 +231,9 @@ def add_time_dim_main(): raise RuntimeError("output is not used for this command: {}".format( args.output)) - add_time_dim(args.source, args.hemisphere, args.identifier, + add_time_dim(args.source, + args.hemisphere, + args.identifier, dry=args.dry, var_names=args.vars) @@ -241,7 +243,9 @@ def reprocess_main(): """ args = get_args() logging.info("Temporary solution for reprocessing monthly files") - reprocess_monthlies(args.source, args.hemisphere, args.identifier, + reprocess_monthlies(args.source, + args.hemisphere, + args.identifier, output_base=args.output, dry=args.dry, var_names=args.vars) diff --git a/icenet/data/loader.py b/icenet/data/loader.py index 1d128d68..b962c889 100644 --- a/icenet/data/loader.py +++ b/icenet/data/loader.py @@ -7,7 +7,6 @@ from icenet.data.loaders import IceNetDataLoaderFactory from icenet.data.cli import add_date_args, process_date_args from icenet.utils import setup_logging - """ """ @@ -25,38 +24,68 @@ def create_get_args(): ap.add_argument("name", type=str) ap.add_argument("hemisphere", choices=("north", "south")) - ap.add_argument("-c", "--cfg-only", help="Do not generate data, " - "only config", default=False, - action="store_true", dest="cfg") - ap.add_argument("-d", "--dry", + ap.add_argument("-c", + "--cfg-only", + help="Do not generate data, " + "only config", + default=False, + action="store_true", + dest="cfg") + ap.add_argument("-d", + "--dry", help="Don't output files, just generate data", - default=False, action="store_true") + default=False, + action="store_true") ap.add_argument("-dt", "--dask-timeouts", type=int, default=120) ap.add_argument("-dp", "--dask-port", type=int, default=8888) - ap.add_argument("-f", "--futures-per-worker", type=float, default=2., + ap.add_argument("-f", + "--futures-per-worker", + type=float, + default=2., dest="futures") - ap.add_argument("-fn", "--forecast-name", dest="forecast_name", - default=None, type=str) - ap.add_argument("-fd", "--forecast-days", dest="forecast_days", - default=93, type=int) - - ap.add_argument("-i", "--implementation", type=str, + ap.add_argument("-fn", + "--forecast-name", + dest="forecast_name", + default=None, + type=str) + ap.add_argument("-fd", + "--forecast-days", + dest="forecast_days", + default=93, + type=int) + + ap.add_argument("-i", + "--implementation", + type=str, choices=implementations, default=implementations[0]) ap.add_argument("-l", "--lag", type=int, default=2) - ap.add_argument("-ob", "--output-batch-size", dest="batch_size", type=int, + ap.add_argument("-ob", + "--output-batch-size", + dest="batch_size", + type=int, default=8) - ap.add_argument("-p", "--pickup", help="Skip existing tfrecords", - default=False, action="store_true") - ap.add_argument("-t", "--tmp-dir", help="Temporary directory", - default="/local/tmp", dest="tmp_dir", type=str) + ap.add_argument("-p", + "--pickup", + help="Skip existing tfrecords", + default=False, + action="store_true") + ap.add_argument("-t", + "--tmp-dir", + help="Temporary directory", + default="/local/tmp", + dest="tmp_dir", + type=str) ap.add_argument("-v", "--verbose", action="store_true", default=False) - ap.add_argument("-w", "--workers", help="Number of workers to use " - "generating sets", - type=int, default=2) + ap.add_argument("-w", + "--workers", + help="Number of workers to use " + "generating sets", + type=int, + default=2) add_date_args(ap) args = ap.parse_args() @@ -93,9 +122,7 @@ def create(): dl.generate() -def save_sample(output_folder: str, - date: object, - sample: tuple): +def save_sample(output_folder: str, date: object, sample: tuple): """ :param output_folder: @@ -108,14 +135,14 @@ def save_sample(output_folder: str, logging.warning("{} output already exists".format(output_folder)) os.makedirs(output_folder, exist_ok=output_folder) - for date, output, directory in ((date, net_input, "input"), - (date, net_output, "outputs"), + for date, output, directory in ((date, net_input, + "input"), (date, net_output, "outputs"), (date, sample_weights, "weights")): output_directory = os.path.join(output_folder, "loader", directory) os.makedirs(output_directory, exist_ok=True) loader_output_path = os.path.join(output_directory, date.strftime("%Y_%m_%d.npy")) - logging.info("Saving {} - generated {} {}". - format(date, directory, output.shape)) + logging.info("Saving {} - generated {} {}".format( + date, directory, output.shape)) np.save(loader_output_path, output) diff --git a/icenet/data/loaders/__init__.py b/icenet/data/loaders/__init__.py index 4e165e30..9fff6c8b 100644 --- a/icenet/data/loaders/__init__.py +++ b/icenet/data/loaders/__init__.py @@ -10,6 +10,7 @@ class IceNetDataLoaderFactory: """ """ + def __init__(self): self._loader_map = dict( dask=icenet.data.loaders.dask.DaskMultiWorkerLoader, @@ -28,11 +29,11 @@ def add_data_loader(self, loader_name: str, loader_impl: object): self._loader_map[loader_name] = loader_impl else: raise RuntimeError("{} is not descended from " - "IceNetBaseDataLoader". - format(loader_impl.__name__)) + "IceNetBaseDataLoader".format( + loader_impl.__name__)) else: - raise RuntimeError("Cannot add {} as already in loader map". - format(loader_name)) + raise RuntimeError( + "Cannot add {} as already in loader map".format(loader_name)) def create_data_loader(self, loader_name, *args, **kwargs): """ diff --git a/icenet/data/loaders/base.py b/icenet/data/loaders/base.py index e50aa78c..09887d38 100644 --- a/icenet/data/loaders/base.py +++ b/icenet/data/loaders/base.py @@ -10,7 +10,6 @@ from icenet.data.process import IceNetPreProcessor from icenet.data.producers import Generator - """ """ @@ -47,10 +46,7 @@ def __init__(self, pickup: bool = False, var_lag_override: object = None, **kwargs): - super().__init__(*args, - identifier=identifier, - path=path, - **kwargs) + super().__init__(*args, identifier=identifier, path=path, **kwargs) self._channels = dict() self._channel_files = dict() @@ -81,7 +77,8 @@ def __init__(self, self._missing_dates = [ dt.datetime.strptime(s, IceNetPreProcessor.DATE_FORMAT) - for s in self._config["missing_dates"]] + for s in self._config["missing_dates"] + ] def write_dataset_config_only(self): """ @@ -95,14 +92,15 @@ def write_dataset_config_only(self): # FIXME: cloned mechanism from generate() - do we need to treat these as # sets that might have missing data for fringe cases? for dataset in splits: - forecast_dates = sorted(list(set( - [dt.datetime.strptime(s, - IceNetPreProcessor.DATE_FORMAT).date() - for identity in - self._config["sources"].keys() - for s in - self._config["sources"][identity] - ["dates"][dataset]]))) + forecast_dates = sorted( + list( + set([ + dt.datetime.strptime( + s, IceNetPreProcessor.DATE_FORMAT).date() + for identity in self._config["sources"].keys() + for s in self._config["sources"][identity]["dates"] + [dataset] + ]))) logging.info("{} {} dates in total, NOT generating cache " "data.".format(len(forecast_dates), dataset)) @@ -111,9 +109,7 @@ def write_dataset_config_only(self): self._write_dataset_config(counts, network_dataset=False) @abstractmethod - def generate_sample(self, - date: object, - prediction: bool = False): + def generate_sample(self, date: object, prediction: bool = False): """ :param date: @@ -142,16 +138,12 @@ def get_sample_files(self) -> object: if var_name not in var_files: var_files[var_name] = var_file elif var_file != var_files[var_name]: - raise RuntimeError("Differing files? {} {} vs {}". - format(var_name, - var_file, - var_files[var_name])) + raise RuntimeError("Differing files? {} {} vs {}".format( + var_name, var_file, var_files[var_name])) return var_files - def _add_channel_files(self, - var_name: str, - filelist: object): + def _add_channel_files(self, var_name: str, filelist: object): """ :param var_name: @@ -173,51 +165,48 @@ def _construct_channels(self): """ # As of Python 3.7 dict guarantees the order of keys based on # original insertion order, which is great for this method - lag_vars = [(identity, var, data_format) - for data_format in ("abs", "anom") - for identity in - sorted(self._config["sources"].keys()) - for var in - sorted(self._config["sources"][identity][data_format])] + lag_vars = [ + (identity, var, data_format) + for data_format in ("abs", "anom") + for identity in sorted(self._config["sources"].keys()) + for var in sorted(self._config["sources"][identity][data_format]) + ] for identity, var_name, data_format in lag_vars: var_prefix = "{}_{}".format(var_name, data_format) - var_lag = (self._var_lag - if var_name not in self._var_lag_override + var_lag = (self._var_lag if var_name not in self._var_lag_override else self._var_lag_override[var_name]) self._channels[var_prefix] = int(var_lag) - self._add_channel_files( - var_prefix, - [el for el in - self._config["sources"][identity]["var_files"][var_name] - if var_prefix in os.path.split(el)[1]]) + self._add_channel_files(var_prefix, [ + el for el in self._config["sources"][identity]["var_files"] + [var_name] if var_prefix in os.path.split(el)[1] + ]) trend_names = [(identity, var, self._config["sources"][identity]["linear_trend_steps"]) - for identity in - sorted(self._config["sources"].keys()) - for var in - sorted( - self._config["sources"][identity]["linear_trends"])] + for identity in sorted(self._config["sources"].keys()) + for var in sorted(self._config["sources"][identity] + ["linear_trends"])] for identity, var_name, trend_steps in trend_names: var_prefix = "{}_linear_trend".format(var_name) self._channels[var_prefix] = len(trend_steps) self._trend_steps[var_prefix] = trend_steps - filelist = [el for el in - self._config["sources"][identity]["var_files"][var_name] - if "linear_trend" in os.path.split(el)[1]] + filelist = [ + el for el in self._config["sources"][identity]["var_files"] + [var_name] if "linear_trend" in os.path.split(el)[1] + ] self._add_channel_files(var_prefix, filelist) # Metadata input variables that don't span time - meta_names = [(identity, var) - for identity in - sorted(self._config["sources"].keys()) - for var in - sorted(self._config["sources"][identity]["meta"])] + meta_names = [ + (identity, var) + for identity in sorted(self._config["sources"].keys()) + for var in sorted(self._config["sources"][identity]["meta"]) + ] for identity, var_name in meta_names: self._meta_channels.append(var_name) @@ -226,8 +215,9 @@ def _construct_channels(self): var_name, self._config["sources"][identity]["var_files"][var_name]) - logging.debug("Channel quantities deduced:\n{}\n\nTotal channels: {}". - format(pformat(self._channels), self.num_channels)) + logging.debug( + "Channel quantities deduced:\n{}\n\nTotal channels: {}".format( + pformat(self._channels), self.num_channels)) def _get_var_file(self, var_name: str): """ @@ -240,8 +230,9 @@ def _get_var_file(self, var_name: str): files = self._channel_files[var_name] if len(self._channel_files[var_name]) > 1: - logging.warning("Multiple files found for {}, only returning {}". - format(filename, files[0])) + logging.warning( + "Multiple files found for {}, only returning {}".format( + filename, files[0])) elif not len(files): logging.warning("No files in channel list for {}".format(filename)) return None @@ -271,6 +262,7 @@ def _write_dataset_config(self, :param network_dataset: :return: """ + # TODO: move to utils for this and process def _serialize(x): if x is dt.date: @@ -278,40 +270,41 @@ def _serialize(x): return str(x) configuration = { - "identifier": self.identifier, - "implementation": self.__class__.__name__, + "identifier": self.identifier, + "implementation": self.__class__.__name__, # This is only for convenience ;) - "channels": [ + "channels": [ "{}_{}".format(channel, i) - for channel, s in - self._channels.items() - for i in range(1, s + 1)], - "counts": counts, - "dtype": self._dtype.__name__, - "loader_config": os.path.abspath(self._configuration_path), - "missing_dates": [date.strftime( - IceNetPreProcessor.DATE_FORMAT) for date in - self._missing_dates], - "n_forecast_days": self._n_forecast_days, - "north": self.north, - "num_channels": self.num_channels, + for channel, s in self._channels.items() + for i in range(1, s + 1) + ], + "counts": counts, + "dtype": self._dtype.__name__, + "loader_config": os.path.abspath(self._configuration_path), + "missing_dates": [ + date.strftime(IceNetPreProcessor.DATE_FORMAT) + for date in self._missing_dates + ], + "n_forecast_days": self._n_forecast_days, + "north": self.north, + "num_channels": self.num_channels, # FIXME: this naming is inconsistent, sort it out!!! ;) - "shape": list(self._shape), - "south": self.south, + "shape": list(self._shape), + "south": self.south, # For recreating this dataloader # "dataset_config_path = ".", - "dataset_path": self._path if network_dataset else False, + "dataset_path": self._path if network_dataset else False, "generate_workers": self.workers, "loss_weight_days": self._loss_weight_days, "output_batch_size": self._output_batch_size, - "var_lag": self._var_lag, + "var_lag": self._var_lag, "var_lag_override": self._var_lag_override, } - output_path = os.path.join(self._dataset_config_path, - "dataset_config.{}.json".format( - self.identifier)) + output_path = os.path.join( + self._dataset_config_path, + "dataset_config.{}.json".format(self.identifier)) logging.info("Writing configuration to {}".format(output_path)) @@ -320,9 +313,11 @@ def _serialize(x): @property def channel_names(self): - return ["{}_{}".format(nom, idx) if idx_qty > 1 else nom - for nom, idx_qty in self._channels.items() - for idx in range(1, idx_qty + 1)] + return [ + "{}_{}".format(nom, idx) if idx_qty > 1 else nom + for nom, idx_qty in self._channels.items() + for idx in range(1, idx_qty + 1) + ] @property def config(self): diff --git a/icenet/data/loaders/dask.py b/icenet/data/loaders/dask.py index 87fd84f1..b12394bd 100644 --- a/icenet/data/loaders/dask.py +++ b/icenet/data/loaders/dask.py @@ -17,8 +17,6 @@ from icenet.data.loaders.base import IceNetBaseDataLoader from icenet.data.loaders.utils import IceNetDataWarning, write_tfrecord from icenet.data.sic.mask import Masks - - """ Dask implementations for icenet data loading @@ -29,6 +27,7 @@ class DaskBaseDataLoader(IceNetBaseDataLoader): + def __init__(self, *args, dask_port: int = 8888, @@ -48,9 +47,9 @@ def generate(self): dashboard = "localhost:{}".format(self._dashboard_port) with dask.config.set({ - "temporary_directory": self._tmp_dir, - "distributed.comm.timeouts.connect": self._timeout, - "distributed.comm.timeouts.tcp": self._timeout, + "temporary_directory": self._tmp_dir, + "distributed.comm.timeouts.connect": self._timeout, + "distributed.comm.timeouts.tcp": self._timeout, }): cluster = LocalCluster( dashboard_address=dashboard, @@ -80,9 +79,8 @@ def client_generate(self, class DaskMultiSharingWorkerLoader(DaskBaseDataLoader): - def __init__(self, - *args, - **kwargs): + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: https://github.com/icenet-ai/icenet/blob/83fdbf4b23ccf6ac221e77809b47d407b70b707f/icenet2/data/loader.py raise NotImplementedError("Not yet adapted from old implementation") @@ -99,9 +97,7 @@ def client_generate(self, """ pass - def generate_sample(self, - date: object, - prediction: bool = False): + def generate_sample(self, date: object, prediction: bool = False): """ :param date: @@ -111,15 +107,13 @@ def generate_sample(self, class DaskMultiWorkerLoader(DaskBaseDataLoader): - def __init__(self, - *args, - futures_per_worker: int = 2, - **kwargs): + + def __init__(self, *args, futures_per_worker: int = 2, **kwargs): super().__init__(*args, **kwargs) masks = Masks(north=self.north, south=self.south) - self._masks = da.array([ - masks.get_active_cell_mask(month) for month in range(1, 13)]) + self._masks = da.array( + [masks.get_active_cell_mask(month) for month in range(1, 13)]) self._futures = futures_per_worker @@ -160,17 +154,15 @@ def batch(batch_dates, num): batch_number = 0 futures = [] - forecast_dates = set([dt.datetime.strptime(s, - IceNetPreProcessor.DATE_FORMAT).date() - for identity in - self._config["sources"].keys() - for s in - self._config["sources"][identity] - ["dates"][dataset]]) + forecast_dates = set([ + dt.datetime.strptime(s, IceNetPreProcessor.DATE_FORMAT).date() + for identity in self._config["sources"].keys() + for s in self._config["sources"][identity]["dates"][dataset] + ]) if dates_override: - logging.info("{} available {} dates". - format(len(forecast_dates), dataset)) + logging.info("{} available {} dates".format( + len(forecast_dates), dataset)) forecast_dates = forecast_dates.intersection( dates_override[dataset]) forecast_dates = sorted(list(forecast_dates)) @@ -186,17 +178,10 @@ def batch(batch_dates, num): (pickup and not os.path.exists(tf_path.format(batch_number))): args = [ - self._channels, - self._dtype, - self._loss_weight_days, - self._meta_channels, - self._missing_dates, - self._n_forecast_days, - self.num_channels, - self._shape, - self._trend_steps, - masks, - False + self._channels, self._dtype, self._loss_weight_days, + self._meta_channels, self._missing_dates, + self._n_forecast_days, self.num_channels, self._shape, + self._trend_steps, masks, False ] fut = client.submit(generate_and_write, @@ -221,8 +206,8 @@ def batch(batch_dates, num): # tf_path.format(batch_number), args, dry=self._dry) else: counts[dataset] += len(dates) - logging.warning("Skipping {} on pickup run". - format(tf_path.format(batch_number))) + logging.warning("Skipping {} on pickup run".format( + tf_path.format(batch_number))) batch_number += 1 @@ -234,13 +219,11 @@ def batch(batch_dates, num): exec_times += gen_times if len(exec_times) > 0: - logging.info("Average sample generation time: {}". - format(np.average(exec_times))) + logging.info("Average sample generation time: {}".format( + np.average(exec_times))) self._write_dataset_config(counts) - def generate_sample(self, - date: object, - prediction: bool = False): + def generate_sample(self, date: object, prediction: bool = False): """ :param date: @@ -255,11 +238,10 @@ def generate_sample(self, ) var_files = self.get_sample_files() - var_ds = xr.open_mfdataset( - [v for k, v in var_files.items() - if k not in self._meta_channels - and not k.endswith("linear_trend")], - **ds_kwargs) + var_ds = xr.open_mfdataset([ + v for k, v in var_files.items() + if k not in self._meta_channels and not k.endswith("linear_trend") + ], **ds_kwargs) var_ds = var_ds.transpose("yc", "xc", "time") trend_files = \ @@ -268,31 +250,18 @@ def generate_sample(self, trend_ds = None if len(trend_files) > 0: - trend_ds = xr.open_mfdataset( - trend_files, - **ds_kwargs) + trend_ds = xr.open_mfdataset(trend_files, **ds_kwargs) trend_ds = trend_ds.transpose("yc", "xc", "time") args = [ - self._channels, - self._dtype, - self._loss_weight_days, - self._meta_channels, - self._missing_dates, - self._n_forecast_days, - self.num_channels, - self._shape, - self._trend_steps, - self._masks, + self._channels, self._dtype, self._loss_weight_days, + self._meta_channels, self._missing_dates, self._n_forecast_days, + self.num_channels, self._shape, self._trend_steps, self._masks, prediction ] - x, y, sw = generate_sample(date, - var_ds, - var_files, - trend_ds, - *args) + x, y, sw = generate_sample(date, var_ds, var_files, trend_ds, *args) return x.compute(), y.compute(), sw.compute() @@ -315,16 +284,8 @@ def generate_and_write(path: str, # TODO: refactor, this is very smelly - with new data throughput args # will always be the same - (channels, - dtype, - loss_weight_days, - meta_channels, - missing_dates, - n_forecast_days, - num_channels, - shape, - trend_steps, - masks, + (channels, dtype, loss_weight_days, meta_channels, missing_dates, + n_forecast_days, num_channels, shape, trend_steps, masks, prediction) = args ds_kwargs = dict( @@ -333,20 +294,19 @@ def generate_and_write(path: str, parallel=True, ) - var_ds = xr.open_mfdataset( - [v for k, v in var_files.items() - if k not in meta_channels and not k.endswith("linear_trend")], - **ds_kwargs) + var_ds = xr.open_mfdataset([ + v for k, v in var_files.items() + if k not in meta_channels and not k.endswith("linear_trend") + ], **ds_kwargs) var_ds = var_ds.transpose("yc", "xc", "time") - trend_files = [v for k, v in var_files.items() - if k.endswith("linear_trend")] + trend_files = [ + v for k, v in var_files.items() if k.endswith("linear_trend") + ] trend_ds = None if len(trend_files): - trend_ds = xr.open_mfdataset( - trend_files, - **ds_kwargs) + trend_ds = xr.open_mfdataset(trend_files, **ds_kwargs) trend_ds = trend_ds.transpose("yc", "xc", "time") with tf.io.TFRecordWriter(path) as writer: @@ -354,26 +314,24 @@ def generate_and_write(path: str, start = time.time() try: - x, y, sample_weights = generate_sample(date, - var_ds, - var_files, - trend_ds, - *args) + x, y, sample_weights = generate_sample(date, var_ds, var_files, + trend_ds, *args) if not dry: x[da.isnan(x)] = 0. - x, y, sample_weights = dask.compute(x, y, sample_weights, + x, y, sample_weights = dask.compute(x, + y, + sample_weights, optimize_graph=True) - write_tfrecord(writer, - x, y, sample_weights) + write_tfrecord(writer, x, y, sample_weights) count += 1 except IceNetDataWarning: continue end = time.time() times.append(end - start) - logging.debug("Time taken to produce {}: {}". - format(date, times[-1])) + logging.debug("Time taken to produce {}: {}".format( + date, times[-1])) return path, count, times @@ -415,20 +373,21 @@ def generate_sample(forecast_date: object, # Prepare data sample # To become array of shape (*raw_data_shape, n_forecast_days) - forecast_dts = [forecast_date + dt.timedelta(days=n) - for n in range(n_forecast_days)] + forecast_dts = [ + forecast_date + dt.timedelta(days=n) for n in range(n_forecast_days) + ] y = da.zeros((*shape, n_forecast_days, 1), dtype=dtype) sample_weights = da.zeros((*shape, n_forecast_days, 1), dtype=dtype) - if not prediction: try: sample_output = var_ds.siconca_abs.sel(time=forecast_dts) except KeyError as sic_ex: - logging.exception("Issue selecting data for non-prediction sample, " - "please review siconca ground-truth: dates {}". - format(forecast_dts)) + logging.exception( + "Issue selecting data for non-prediction sample, " + "please review siconca ground-truth: dates {}".format( + forecast_dts)) raise RuntimeError(sic_ex) y[:, :, :, 0] = sample_output @@ -436,8 +395,8 @@ def generate_sample(forecast_date: object, for leadtime_idx in range(n_forecast_days): forecast_day = forecast_date + dt.timedelta(days=leadtime_idx) - if any([forecast_day == missing_date - for missing_date in missing_dates]): + if any([forecast_day == missing_date for missing_date in missing_dates + ]): sample_weight = da.zeros(shape, dtype) else: # Zero loss outside of 'active grid cells' @@ -467,23 +426,27 @@ def generate_sample(forecast_date: object, if var_name.endswith("linear_trend"): channel_ds = trend_ds if type(trend_steps) == list: - channel_dates = [pd.Timestamp(forecast_date + - dt.timedelta(days=int(n))) - for n in trend_steps] + channel_dates = [ + pd.Timestamp(forecast_date + dt.timedelta(days=int(n))) + for n in trend_steps + ] else: - channel_dates = [pd.Timestamp(forecast_date + - dt.timedelta(days=n)) - for n in range(num_channels)] + channel_dates = [ + pd.Timestamp(forecast_date + dt.timedelta(days=n)) + for n in range(num_channels) + ] else: channel_ds = var_ds - channel_dates = [pd.Timestamp(forecast_date - dt.timedelta(days=n)) - for n in range(num_channels)] + channel_dates = [ + pd.Timestamp(forecast_date - dt.timedelta(days=n)) + for n in range(num_channels) + ] channel_data = [] for cdate in channel_dates: try: - channel_data.append(getattr(channel_ds, var_name). - sel(time=cdate)) + channel_data.append( + getattr(channel_ds, var_name).sel(time=cdate)) except KeyError: channel_data.append(da.zeros(shape)) diff --git a/icenet/data/loaders/stdlib.py b/icenet/data/loaders/stdlib.py index 9a9b6c9b..089eee8d 100644 --- a/icenet/data/loaders/stdlib.py +++ b/icenet/data/loaders/stdlib.py @@ -1,5 +1,4 @@ from icenet.data.loaders.base import IceNetBaseDataLoader - """ Python Standard Library implementations for icenet data loading @@ -10,9 +9,8 @@ class IceNetDataLoader(IceNetBaseDataLoader): - def __init__(self, - *args, - **kwargs): + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: https://github.com/icenet-ai/icenet/blob/cb68e5dec31d4c62d72411cbca4c6d3a0276e0f9/icenet2/data/loader.py raise NotImplementedError("Not yet adapted from old implementation") @@ -23,9 +21,7 @@ def generate(self): """ pass - def generate_sample(self, - date: object, - prediction: bool = False): + def generate_sample(self, date: object, prediction: bool = False): """ :param date: diff --git a/icenet/data/loaders/utils.py b/icenet/data/loaders/utils.py index 28472064..da34da1c 100644 --- a/icenet/data/loaders/utils.py +++ b/icenet/data/loaders/utils.py @@ -1,6 +1,4 @@ import tensorflow as tf - - """ """ @@ -10,9 +8,7 @@ class IceNetDataWarning(RuntimeWarning): pass -def write_tfrecord(writer: object, - x: object, - y: object, +def write_tfrecord(writer: object, x: object, y: object, sample_weights: object): """ @@ -41,13 +37,17 @@ def write_tfrecord(writer: object, # if data_check and x_nans > 0: - record_data = tf.train.Example(features=tf.train.Features(feature={ - "x": tf.train.Feature( - float_list=tf.train.FloatList(value=x.reshape(-1))), - "y": tf.train.Feature( - float_list=tf.train.FloatList(value=y.reshape(-1))), - "sample_weights": tf.train.Feature( - float_list=tf.train.FloatList(value=sample_weights.reshape(-1))), - })).SerializeToString() + record_data = tf.train.Example(features=tf.train.Features( + feature={ + "x": + tf.train.Feature(float_list=tf.train.FloatList( + value=x.reshape(-1))), + "y": + tf.train.Feature(float_list=tf.train.FloatList( + value=y.reshape(-1))), + "sample_weights": + tf.train.Feature(float_list=tf.train.FloatList( + value=sample_weights.reshape(-1))), + })).SerializeToString() writer.write(record_data) diff --git a/icenet/data/process.py b/icenet/data/process.py index 543149b8..6221a066 100644 --- a/icenet/data/process.py +++ b/icenet/data/process.py @@ -11,7 +11,6 @@ from icenet.data.producers import Processor from icenet.data.sic.mask import Masks from icenet.model.models import linear_trend_forecast - """ """ @@ -48,36 +47,37 @@ class IceNetPreProcessor(Processor): DATE_FORMAT = "%Y_%m_%d" - def __init__(self, - abs_vars, - anom_vars, - name, - # FIXME: the preprocessors don't need to have the concept of - # train, test, val: they only need to output daily files - # that either are, or are not, part of normalisation / - # climatology calculations. Not a problem, just fix - train_dates, - val_dates, - test_dates, - *args, - data_shape=(432, 432), - dtype=np.float32, - exclude_vars=(), - file_filters=tuple(["latlon_"]), - identifier=None, - linear_trends=tuple(["siconca"]), - linear_trend_steps=7, - meta_vars=tuple(), - missing_dates=tuple(), - minmax=True, - no_normalise=tuple(["siconca"]), - path=os.path.join(".", "processed"), - parallel_opens=False, - ref_procdir=None, - source_data=os.path.join(".", "data"), - update_key=None, - update_loader=True, - **kwargs): + def __init__( + self, + abs_vars, + anom_vars, + name, + # FIXME: the preprocessors don't need to have the concept of + # train, test, val: they only need to output daily files + # that either are, or are not, part of normalisation / + # climatology calculations. Not a problem, just fix + train_dates, + val_dates, + test_dates, + *args, + data_shape=(432, 432), + dtype=np.float32, + exclude_vars=(), + file_filters=tuple(["latlon_"]), + identifier=None, + linear_trends=tuple(["siconca"]), + linear_trend_steps=7, + meta_vars=tuple(), + missing_dates=tuple(), + minmax=True, + no_normalise=tuple(["siconca"]), + path=os.path.join(".", "processed"), + parallel_opens=False, + ref_procdir=None, + source_data=os.path.join(".", "data"), + update_key=None, + update_loader=True, + **kwargs): super().__init__(identifier, source_data, *args, @@ -111,8 +111,9 @@ def __init__(self, if update_loader else None if type(linear_trend_steps) == int: - logging.debug("Setting range for linear trend steps based on {}". - format(linear_trend_steps)) + logging.debug( + "Setting range for linear trend steps based on {}".format( + linear_trend_steps)) self._linear_trend_steps = list(range(1, linear_trend_steps + 1)) else: self._linear_trend_steps = [int(el) for el in linear_trend_steps] @@ -140,8 +141,8 @@ def pre_normalisation(self, var_name: str, da: object): :param da: :return: """ - logging.debug("No pre normalisation implemented for {}". - format(var_name)) + logging.debug( + "No pre normalisation implemented for {}".format(var_name)) return da def post_normalisation(self, var_name: str, da: object): @@ -151,8 +152,8 @@ def post_normalisation(self, var_name: str, da: object): :param da: :return: """ - logging.debug("No post normalisation implemented for {}". - format(var_name)) + logging.debug( + "No post normalisation implemented for {}".format(var_name)) return da # TODO: update this to store parameters, if appropriate @@ -161,6 +162,7 @@ def update_loader_config(self): :return: """ + def _serialize(x): if x is dt.date: return x.strftime(IceNetPreProcessor.DATE_FORMAT) @@ -169,30 +171,34 @@ def _serialize(x): # We have to be explicit with "dates" as the properties will not be # caught by _serialize source = { - "name": self._name, - "implementation": self.__class__.__name__, - "anom": self._anom_vars, - "abs": self._abs_vars, - "dates": { - "train": [d.strftime(IceNetPreProcessor.DATE_FORMAT) - for d in self._dates.train], - "val": [d.strftime(IceNetPreProcessor.DATE_FORMAT) - for d in self._dates.val], - "test": [d.strftime(IceNetPreProcessor.DATE_FORMAT) - for d in self._dates.test], + "name": self._name, + "implementation": self.__class__.__name__, + "anom": self._anom_vars, + "abs": self._abs_vars, + "dates": { + "train": [ + d.strftime(IceNetPreProcessor.DATE_FORMAT) + for d in self._dates.train + ], + "val": [ + d.strftime(IceNetPreProcessor.DATE_FORMAT) + for d in self._dates.val + ], + "test": [ + d.strftime(IceNetPreProcessor.DATE_FORMAT) + for d in self._dates.test + ], }, - "linear_trends": self._linear_trends, + "linear_trends": self._linear_trends, "linear_trend_steps": self._linear_trend_steps, - "meta": self._meta_vars, + "meta": self._meta_vars, # TODO: intention should perhaps be to strip these from # other date sets, this is just an indicative placeholder # for the mo - "var_files": self._processed_files, + "var_files": self._processed_files, } - configuration = { - "sources": {} - } + configuration = {"sources": {}} if os.path.exists(self._update_loader): logging.info("Loading configuration {}".format(self._update_loader)) @@ -245,8 +251,7 @@ def _save_variable(self, var_name: str, var_suffix: str): if self._refdir: logging.info("Loading climatology from alternate " "directory: {}".format(self._refdir)) - clim_path = os.path.join(self._refdir, - "params", + clim_path = os.path.join(self._refdir, "params", "climatology.{}".format(var_name)) else: clim_path = os.path.join(self.get_data_var_folder("params"), @@ -262,9 +267,9 @@ def _save_variable(self, var_name: str, var_suffix: str): climatology.to_netcdf(clim_path) else: - raise RuntimeError("{} does not exist and no " - "training data is supplied". - format(clim_path)) + raise RuntimeError( + "{} does not exist and no " + "training data is supplied".format(clim_path)) else: logging.info("Reusing climatology {}".format(clim_path)) climatology = xr.open_dataarray(clim_path) @@ -274,11 +279,11 @@ def _save_variable(self, var_name: str, var_suffix: str): logging.warning( "We don't have a full climatology ({}) " "compared with data ({})".format( - ",".join([str(i) - for i in climatology.month.values]), - ",".join([str(i) - for i in da.groupby("time.month"). - all().month.values]))) + ",".join([str(i) for i in climatology.month.values + ]), ",".join([ + str(i) for i in da.groupby( + "time.month").all().month.values + ]))) da = da - climatology.mean() else: da = da.groupby("time.month") - climatology @@ -309,9 +314,9 @@ def _save_variable(self, var_name: str, var_suffix: str): elif var_name in self._linear_trends \ and var_name not in self._abs_vars: - raise NotImplementedError("You've asked for linear trend " - "without an absolute value var: {}". - format(var_name)) + raise NotImplementedError( + "You've asked for linear trend " + "without an absolute value var: {}".format(var_name)) if var_name in self._no_normalise: logging.info("No normalisation for {}".format(var_name)) @@ -321,13 +326,11 @@ def _save_variable(self, var_name: str, var_suffix: str): da = self.post_normalisation(var_name, da) - self.save_processed_file(var_name, - "{}_{}.nc".format(var_name, var_suffix), - da.rename( - "_".join([var_name, var_suffix]))) + self.save_processed_file( + var_name, "{}_{}.nc".format(var_name, var_suffix), + da.rename("_".join([var_name, var_suffix]))) def _open_dataarray_from_files(self, var_name: str): - """ Open the yearly xarray files, accounting for some ERA5 variables that have erroneous 'unknown' NetCDF variable names which prevents @@ -339,24 +342,28 @@ def _open_dataarray_from_files(self, var_name: str): logging.info("Opening files for {}".format(var_name)) logging.debug("Files: {}".format(self._var_files[var_name])) - ds = xr.open_mfdataset(self._var_files[var_name], - # Solves issue with inheriting files without - # time dimension (only having coordinate) - combine="nested", - concat_dim="time", - coords="minimal", - compat="override", - drop_variables=("lat", "lon"), - parallel=self._parallel) + ds = xr.open_mfdataset( + self._var_files[var_name], + # Solves issue with inheriting files without + # time dimension (only having coordinate) + combine="nested", + concat_dim="time", + coords="minimal", + compat="override", + drop_variables=("lat", "lon"), + parallel=self._parallel) # For processing one file, we're going to assume a single non-lambert # variable exists at the start and rename all of them - var_names = [name for name in list(ds.data_vars.keys()) - if not name.startswith("lambert_")] + var_names = [ + name for name in list(ds.data_vars.keys()) + if not name.startswith("lambert_") + ] var_names = set(var_names) - logging.debug("Files have var names {} which will be renamed to {}". - format(", ".join(var_names), var_name)) + logging.debug( + "Files have var names {} which will be renamed to {}".format( + ", ".join(var_names), var_name)) ds = ds.rename({k: var_name for k in var_names}) da = getattr(ds, var_name) @@ -405,13 +412,12 @@ def mean_and_std(array: object): mean = np.nanmean(array) std = np.nanstd(array) - logging.info("Mean: {:.3f}, std: {:.3f}". - format(mean.item(), std.item())) + logging.info("Mean: {:.3f}, std: {:.3f}".format(mean.item(), + std.item())) return mean, std def _normalise_array_mean(self, var_name: str, da: object): - """ Using the *training* data only, compute the mean and standard deviation of the input raw satellite DataArray (`da`) @@ -437,10 +443,11 @@ def _normalise_array_mean(self, var_name: str, da: object): mean_path = os.path.join(proc_dir, "{}".format(var_name)) if os.path.exists(mean_path): - logging.debug("Loading norm-average mean-std from {}". - format(mean_path)) - mean, std = tuple([self._dtype(el) for el in - open(mean_path, "r").read().split(",")]) + logging.debug( + "Loading norm-average mean-std from {}".format(mean_path)) + mean, std = tuple([ + self._dtype(el) for el in open(mean_path, "r").read().split(",") + ]) elif self._dates.train: logging.debug("Generating norm-average mean-std from {} training " "dates".format(len(self._dates.train))) @@ -455,8 +462,7 @@ def _normalise_array_mean(self, var_name: str, da: object): new_da = (da - mean) / std if not self._refdir: - open(mean_path, "w").write(",".join([str(f) for f in - [mean, std]])) + open(mean_path, "w").write(",".join([str(f) for f in [mean, std]])) return new_da def _normalise_array_scaling(self, var_name: str, da: object): @@ -476,10 +482,12 @@ def _normalise_array_scaling(self, var_name: str, da: object): scale_path = os.path.join(proc_dir, "{}".format(var_name)) if os.path.exists(scale_path): - logging.debug("Loading norm-scaling min-max from {}". - format(scale_path)) - minimum, maximum = tuple([self._dtype(el) for el in - open(scale_path, "r").read().split(",")]) + logging.debug( + "Loading norm-scaling min-max from {}".format(scale_path)) + minimum, maximum = tuple([ + self._dtype(el) + for el in open(scale_path, "r").read().split(",") + ]) elif self._dates.train: logging.debug("Generating norm-scaling min-max from {} training " "dates".format(len(self._dates.train))) @@ -494,8 +502,8 @@ def _normalise_array_scaling(self, var_name: str, da: object): new_da = (da - minimum) / (maximum - minimum) if not self._refdir: - open(scale_path, "w").write(",".join([str(f) for f in - [minimum, maximum]])) + open(scale_path, + "w").write(",".join([str(f) for f in [minimum, maximum]])) return new_da def _build_linear_trend_da(self, @@ -516,18 +524,20 @@ def _build_linear_trend_da(self, if ref_da is None: ref_da = input_da - data_dates = sorted([pd.Timestamp(date) - for date in input_da.time.values]) + data_dates = sorted( + [pd.Timestamp(date) for date in input_da.time.values]) trend_dates = set() trend_steps = max(self._linear_trend_steps) - logging.info("Generating trend data up to {} steps ahead for {} dates". - format(trend_steps, len(data_dates))) + logging.info( + "Generating trend data up to {} steps ahead for {} dates".format( + trend_steps, len(data_dates))) for dat_date in data_dates: - trend_dates = trend_dates.union( - [dat_date + pd.DateOffset(days=d) - for d in self._linear_trend_steps]) + trend_dates = trend_dates.union([ + dat_date + pd.DateOffset(days=d) + for d in self._linear_trend_steps + ]) trend_dates = list(sorted(trend_dates)) logging.info("Generating {} trend dates".format(len(trend_dates))) @@ -544,21 +554,17 @@ def _build_linear_trend_da(self, # Could use shelve, but more likely we'll run into concurrency issues # pickleshare might be an option but a little over-engineery - trend_cache_path = os.path.join( - self.get_data_var_folder(var_name), - "{}_linear_trend.nc".format(var_name) - ) + trend_cache_path = os.path.join(self.get_data_var_folder(var_name), + "{}_linear_trend.nc".format(var_name)) trend_cache = linear_trend_da.copy() trend_cache.data = np.full_like(linear_trend_da.data, np.nan) if os.path.exists(trend_cache_path): trend_cache = xr.open_dataarray(trend_cache_path) - logging.info("Loaded {} entries from {}". - format(len(trend_cache.time), trend_cache_path)) + logging.info("Loaded {} entries from {}".format( + len(trend_cache.time), trend_cache_path)) - def data_selector(da, - processing_date, - missing_dates=tuple()): + def data_selector(da, processing_date, missing_dates=tuple()): target_date = pd.to_datetime(processing_date) date_da = da[(da.time['time.month'] == target_date.month) & @@ -573,7 +579,10 @@ def data_selector(da, output_map = trend_cache.sel(time=forecast_date) else: output_map = linear_trend_forecast( - data_selector, forecast_date, ref_da, land_mask, + data_selector, + forecast_date, + ref_da, + land_mask, missing_dates=self._missing_dates, shape=self._data_shape) diff --git a/icenet/data/processors/cmip.py b/icenet/data/processors/cmip.py index d64f571e..1326d42e 100644 --- a/icenet/data/processors/cmip.py +++ b/icenet/data/processors/cmip.py @@ -2,7 +2,6 @@ from icenet.data.process import IceNetPreProcessor from icenet.data.sic.mask import Masks from icenet.data.processors.utils import sic_interpolate - """ """ @@ -14,18 +13,14 @@ class IceNetCMIPPreProcessor(IceNetPreProcessor): :param source: :param member: """ - def __init__(self, - source: str, - member: str, - *args, **kwargs): + + def __init__(self, source: str, member: str, *args, **kwargs): cmip_source = "{}.{}".format(source, member) super().__init__(*args, identifier="cmip6.{}".format(cmip_source), **kwargs) - def pre_normalisation(self, - var_name: str, - da: object): + def pre_normalisation(self, var_name: str, da: object): """ :param var_name: @@ -40,12 +35,10 @@ def pre_normalisation(self, def main(): - args = process_args( - extra_args=[ - (["source"], dict(type=str)), - (["member"], dict(type=str)), - ], - ) + args = process_args(extra_args=[ + (["source"], dict(type=str)), + (["member"], dict(type=str)), + ],) dates = process_date_args(args) cmip = IceNetCMIPPreProcessor( @@ -65,7 +58,5 @@ def main(): south=args.hemisphere == "south", update_key=args.update_key, ) - cmip.init_source_data( - lag_days=args.lag, - ) + cmip.init_source_data(lag_days=args.lag,) cmip.process() diff --git a/icenet/data/processors/era5.py b/icenet/data/processors/era5.py index 3fcc2093..0f45810a 100644 --- a/icenet/data/processors/era5.py +++ b/icenet/data/processors/era5.py @@ -1,6 +1,5 @@ from icenet.data.cli import process_args, process_date_args from icenet.data.process import IceNetPreProcessor - """ """ @@ -37,7 +36,5 @@ def main(): south=args.hemisphere == "south", update_key=args.update_key, ) - era5.init_source_data( - lag_days=args.lag, - ) + era5.init_source_data(lag_days=args.lag,) era5.process() diff --git a/icenet/data/processors/hres.py b/icenet/data/processors/hres.py index f953802b..ee7f1db1 100644 --- a/icenet/data/processors/hres.py +++ b/icenet/data/processors/hres.py @@ -1,6 +1,5 @@ from icenet.data.cli import process_args, process_date_args from icenet.data.process import IceNetPreProcessor - """ """ @@ -37,7 +36,5 @@ def main(): south=args.hemisphere == "south", update_key=args.update_key, ) - hres.init_source_data( - lag_days=args.lag, - ) + hres.init_source_data(lag_days=args.lag,) hres.process() diff --git a/icenet/data/processors/meta.py b/icenet/data/processors/meta.py index 2f424dda..5d1b6f46 100644 --- a/icenet/data/processors/meta.py +++ b/icenet/data/processors/meta.py @@ -5,7 +5,6 @@ from icenet.data.cli import process_args from icenet.data.process import IceNetPreProcessor from icenet.data.sic.mask import Masks - """ """ @@ -73,11 +72,9 @@ def _save_land(self): if "land" not in self._meta_vars: self._meta_vars.append("land") - da = xr.DataArray( - data=land_map, - dims=["yc", "xc"], - attrs=dict(description="IceNet land mask metadata") - ) + da = xr.DataArray(data=land_map, + dims=["yc", "xc"], + attrs=dict(description="IceNet land mask metadata")) land_path = self.save_processed_file("land", "land.nc", da) return land_path @@ -106,11 +103,9 @@ def _save_circday(self): data=eval(var_name), dims=["time"], coords=dict( - time=pd.date_range(start='2012-1-1', end='2012-12-31') - ), + time=pd.date_range(start='2012-1-1', end='2012-12-31')), attrs=dict( - description="IceNet {} mask metadata".format(var_name)) - ) + description="IceNet {} mask metadata".format(var_name))) paths.append( self.save_processed_file(var_name, "{}.nc".format(var_name), da)) @@ -120,8 +115,6 @@ def _save_circday(self): def main(): args = process_args(dates=False, ref_option=False) - IceNetMetaPreProcessor( - args.name, - north=args.hemisphere == "north", - south=args.hemisphere == "south" - ).process() + IceNetMetaPreProcessor(args.name, + north=args.hemisphere == "north", + south=args.hemisphere == "south").process() diff --git a/icenet/data/processors/oras5.py b/icenet/data/processors/oras5.py index 6c5db8ce..6f6a15a2 100644 --- a/icenet/data/processors/oras5.py +++ b/icenet/data/processors/oras5.py @@ -33,7 +33,5 @@ def main(): south=args.hemisphere == "south", update_key=args.update_key, ) - oras5.init_source_data( - lag_days=args.lag, - ) + oras5.init_source_data(lag_days=args.lag,) oras5.process() diff --git a/icenet/data/processors/osi.py b/icenet/data/processors/osi.py index 086c74ad..4be1d433 100644 --- a/icenet/data/processors/osi.py +++ b/icenet/data/processors/osi.py @@ -5,7 +5,6 @@ from icenet.data.process import IceNetPreProcessor from icenet.data.sic.mask import Masks from icenet.data.processors.utils import sic_interpolate - """ """ @@ -16,28 +15,25 @@ class IceNetOSIPreProcessor(IceNetPreProcessor): :param missing_dates: """ - def __init__(self, *args, - missing_dates: object = None, - **kwargs): + + def __init__(self, *args, missing_dates: object = None, **kwargs): super().__init__(*args, identifier="osisaf", **kwargs) - missing_dates_path = os.path.join( - self._source_data, - "siconca", - "missing_days.csv") + missing_dates_path = os.path.join(self._source_data, "siconca", + "missing_days.csv") missing_dates = [] if missing_dates is None else missing_dates assert type(missing_dates) is list with open(missing_dates_path, "r") as fh: - missing_dates += [dt.date(*[int(s) - for s in line.strip().split(",")]) - for line in fh.readlines()] + missing_dates += [ + dt.date(*[int(s) + for s in line.strip().split(",")]) + for line in fh.readlines() + ] self.missing_dates = list(set(missing_dates)) - def pre_normalisation(self, - var_name: str, - da: object): + def pre_normalisation(self, var_name: str, da: object): """ :param var_name: @@ -56,21 +52,17 @@ def main(): args = process_args() dates = process_date_args(args) - osi = IceNetOSIPreProcessor( - args.abs, - args.anom, - args.name, - dates["train"], - dates["val"], - dates["test"], - linear_trends=args.trends, - linear_trend_steps=args.trend_lead, - north=args.hemisphere == "north", - parallel_opens=args.parallel_opens, - ref_procdir=args.ref, - south=args.hemisphere == "south" - ) - osi.init_source_data( - lag_days=args.lag, - ) + osi = IceNetOSIPreProcessor(args.abs, + args.anom, + args.name, + dates["train"], + dates["val"], + dates["test"], + linear_trends=args.trends, + linear_trend_steps=args.trend_lead, + north=args.hemisphere == "north", + parallel_opens=args.parallel_opens, + ref_procdir=args.ref, + south=args.hemisphere == "south") + osi.init_source_data(lag_days=args.lag,) osi.process() diff --git a/icenet/data/processors/utils.py b/icenet/data/processors/utils.py index 9ab4dd0c..9ae36064 100644 --- a/icenet/data/processors/utils.py +++ b/icenet/data/processors/utils.py @@ -12,14 +12,12 @@ from scipy import interpolate from scipy.spatial.qhull import QhullError - """ """ -def sic_interpolate(da: object, - masks: object) -> object: +def sic_interpolate(da: object, masks: object) -> object: """ :param da: @@ -27,8 +25,7 @@ def sic_interpolate(da: object, :return: """ for date in da.time.values: - polarhole_mask = masks.get_polarhole_mask( - pd.to_datetime(date).date()) + polarhole_mask = masks.get_polarhole_mask(pd.to_datetime(date).date()) da_day = da.sel(time=date) xx, yy = np.meshgrid(np.arange(432), np.arange(432)) @@ -71,12 +68,16 @@ def sic_interpolate(da: object, nan_neighbour_arr[-1, :] = False if np.sum(nan_neighbour_arr) == 1: - res = np.where(np.array(nan_neighbour_arr) == True) # noqa: E712 - logging.warning("Not enough nans for interpolation, extending {}".format(res)) + res = np.where( + np.array(nan_neighbour_arr) == True) # noqa: E712 + logging.warning( + "Not enough nans for interpolation, extending {}".format( + res)) x_idx, y_idx = res[0][0], res[1][0] - nan_neighbour_arr[x_idx-1:x_idx+2, y_idx] = True - nan_neighbour_arr[x_idx, y_idx-1:y_idx+2] = True - logging.debug(np.where(np.array(nan_neighbour_arr) == True)) # noqa: E712 + nan_neighbour_arr[x_idx - 1:x_idx + 2, y_idx] = True + nan_neighbour_arr[x_idx, y_idx - 1:y_idx + 2] = True + logging.debug( + np.where(np.array(nan_neighbour_arr) == True)) # noqa: E712 # Perform bilinear interpolation x_valid = xx[nan_neighbour_arr] @@ -97,7 +98,8 @@ def sic_interpolate(da: object, logging.warning("No valid values to interpolate with on " "{}".format(date)) except QhullError: - logging.exception("Geometrical degeneracy from QHull, interpolation failed") + logging.exception( + "Geometrical degeneracy from QHull, interpolation failed") return da @@ -115,9 +117,7 @@ def condense_main(): condense_data(args.identifier, args.hemisphere, args.variable) -def condense_data(identifier: str, - hemisphere: str, - variable: str): +def condense_data(identifier: str, hemisphere: str, variable: str): """Takes existing daily files and creates yearly files Previous early versions of the pipeline were storing files day by day, which @@ -142,19 +142,22 @@ def condense_data(identifier: str, dfs = glob.glob(os.path.join(data_path, "**", "*.nc")) def year_batch(filenames): - df_years = set([os.path.split(os.path.dirname(f_year))[-1] - for f_year in filenames]) + df_years = set([ + os.path.split(os.path.dirname(f_year))[-1] for f_year in filenames + ]) for year_el in df_years: - year_dfs = [el for el in filenames - if os.path.split(os.path.dirname(el))[-1] == year_el - and not os.path.split(el)[1].startswith("latlon")] + year_dfs = [ + el for el in filenames + if os.path.split(os.path.dirname(el))[-1] == year_el and + not os.path.split(el)[1].startswith("latlon") + ] logging.debug("{} has {} files".format(year_el, len(year_dfs))) yield year_el, year_dfs if len(dfs): - logging.debug("Got {} files, collecting to {}...".format(len(dfs), - data_path)) + logging.debug("Got {} files, collecting to {}...".format( + len(dfs), data_path)) for year, year_files in year_batch(dfs): year_path = os.path.join(data_path, "{}.nc".format(year)) @@ -164,8 +167,8 @@ def year_batch(filenames): ds = xr.open_mfdataset(year_files, parallel=True) years, datasets = zip(*ds.groupby("time.year")) if len(years) > 1: - raise RuntimeError("Too many years in one file {}". - format(years)) + raise RuntimeError( + "Too many years in one file {}".format(years)) logging.info("Saving to {}".format(year_path)) xr.save_mfdataset(datasets, [year_path]) else: diff --git a/icenet/data/producers.py b/icenet/data/producers.py index 24818976..59d1e1e4 100644 --- a/icenet/data/producers.py +++ b/icenet/data/producers.py @@ -22,7 +22,8 @@ class DataCollection(HemisphereMixin, metaclass=ABCMeta): """ @abstractmethod - def __init__(self, *args, + def __init__(self, + *args, identifier: object = None, north: bool = True, south: bool = False, @@ -54,7 +55,9 @@ class DataProducer(DataCollection): :param dry: :param overwrite: """ - def __init__(self, *args, + + def __init__(self, + *args, dry: bool = False, overwrite: bool = False, **kwargs): @@ -98,9 +101,8 @@ def get_data_var_folder(self, # to a single hemisphere hemisphere = self.hemisphere_str[0] - data_var_path = os.path.join( - self.base_path, *[hemisphere, var, *append] - ) + data_var_path = os.path.join(self.base_path, + *[hemisphere, var, *append]) if not os.path.exists(data_var_path): if not missing_error: @@ -116,6 +118,7 @@ class Downloader(DataProducer): """ """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -124,14 +127,15 @@ def download(self): """Abstract download method for this downloader """ - raise NotImplementedError("{}.download is abstract". - format(__class__.__name__)) + raise NotImplementedError("{}.download is abstract".format( + __class__.__name__)) class Generator(DataProducer): """ """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -140,8 +144,8 @@ def generate(self): """ """ - raise NotImplementedError("{}.generate is abstract". - format(__class__.__name__)) + raise NotImplementedError("{}.generate is abstract".format( + __class__.__name__)) class Processor(DataProducer): @@ -155,6 +159,7 @@ class Processor(DataProducer): :param train_dates: :param val_dates: """ + def __init__(self, identifier: str, source_data: object, @@ -165,14 +170,11 @@ def __init__(self, train_dates: object = (), val_dates: object = (), **kwargs): - super().__init__(*args, - identifier=identifier, - **kwargs) + super().__init__(*args, identifier=identifier, **kwargs) self._file_filters = list(file_filters) self._lead_time = lead_time - self._source_data = os.path.join(source_data, - identifier, + self._source_data = os.path.join(source_data, identifier, self.hemisphere_str[0]) self._var_files = dict() self._processed_files = dict() @@ -183,16 +185,15 @@ def __init__(self, val=list(val_dates), test=list(test_dates)) - def init_source_data(self, - lag_days: object = None): + def init_source_data(self, lag_days: object = None): """ :param lag_days: """ if not os.path.exists(self.source_data): - raise OSError("Source data directory {} does not exist". - format(self.source_data)) + raise OSError("Source data directory {} does not exist".format( + self.source_data)) var_files = {} @@ -200,11 +201,11 @@ def init_source_data(self, dates = sorted(getattr(self._dates, date_category)) if dates: - logging.info("Processing {} dates for {} category". - format(len(dates), date_category)) + logging.info("Processing {} dates for {} category".format( + len(dates), date_category)) else: - logging.info("No {} dates for this processor". - format(date_category)) + logging.info( + "No {} dates for this processor".format(date_category)) continue # TODO: ProcessPool for this (avoid the GIL for globbing) @@ -225,7 +226,8 @@ def init_source_data(self, # training with OSISAF data, but are we exploiting the # convenient usage of this data for linear trends? if self._lead_time: - logging.info("Including lead of {} days".format(self._lead_time)) + logging.info("Including lead of {} days".format( + self._lead_time)) additional_lead_dates = [] @@ -243,8 +245,9 @@ def init_source_data(self, logging.debug("Globbed {} files".format(len(dfs))) # FIXME: using hyphens broadly no? - data_dates = [df.split(os.sep)[-1][:-3].replace("_", "-") - for df in dfs] + data_dates = [ + df.split(os.sep)[-1][:-3].replace("_", "-") for df in dfs + ] dt_series = pd.Series(dfs, index=data_dates) logging.debug("Create structure of {} files".format(len(dt_series))) @@ -262,8 +265,10 @@ def init_source_data(self, match_dfs = [] for df in match_dfs: - if any([flt in os.path.split(df)[1] - for flt in self._file_filters]): + if any([ + flt in os.path.split(df)[1] + for flt in self._file_filters + ]): continue path_comps = str(os.path.split(df)[0]).split(os.sep) @@ -285,21 +290,19 @@ def init_source_data(self, var: var_files[var] for var in sorted(var_files.keys()) } for var in self._var_files.keys(): - logging.info("Got {} files for {}".format( - len(self._var_files[var]), var)) + logging.info("Got {} files for {}".format(len(self._var_files[var]), + var)) @abstractmethod def process(self): """ """ - raise NotImplementedError("{}.process is abstract". - format(__class__.__name__)) + raise NotImplementedError("{}.process is abstract".format( + __class__.__name__)) - def save_processed_file(self, - var_name: str, - name: str, - data: object, **kwargs): + def save_processed_file(self, var_name: str, name: str, data: object, + **kwargs): """ :param var_name: @@ -308,8 +311,8 @@ def save_processed_file(self, :param kwargs: :return: """ - file_path = os.path.join( - self.get_data_var_folder(var_name, **kwargs), name) + file_path = os.path.join(self.get_data_var_folder(var_name, **kwargs), + name) data.to_netcdf(file_path) if var_name not in self._processed_files.keys(): @@ -319,8 +322,8 @@ def save_processed_file(self, logging.debug("Adding {} file: {}".format(var_name, file_path)) self._processed_files[var_name].append(file_path) else: - logging.warning("{} already exists in {} processed list". - format(file_path, var_name)) + logging.warning("{} already exists in {} processed list".format( + file_path, var_name)) return file_path @property diff --git a/icenet/data/sic/mask.py b/icenet/data/sic/mask.py index 197780e3..145c0f9b 100644 --- a/icenet/data/sic/mask.py +++ b/icenet/data/sic/mask.py @@ -11,7 +11,6 @@ from icenet.data.producers import Generator from icenet.utils import run_command from icenet.data.sic.utils import SIC_HEMI_STR - """Sea Ice Masks """ @@ -35,7 +34,8 @@ class Masks(Generator): dt.date(2015, 12, 1), ) - def __init__(self, *args, + def __init__(self, + *args, polarhole_dates: object = POLARHOLE_DATES, polarhole_radii: object = POLARHOLE_RADII, data_shape: object = (432, 432), @@ -55,24 +55,25 @@ def init_params(self): """ """ - params_path = os.path.join( - self.get_data_var_folder("masks"), - "masks.params" - ) + params_path = os.path.join(self.get_data_var_folder("masks"), + "masks.params") if not os.path.exists(params_path): with open(params_path, "w") as fh: for i, polarhole in enumerate(self._polarhole_radii): - fh.write("{}\n".format( - ",".join([str(polarhole), - self._polarhole_dates[i].strftime("%Y%m%d")] - ))) + fh.write("{}\n".format(",".join([ + str(polarhole), + self._polarhole_dates[i].strftime("%Y%m%d") + ]))) else: - lines = [el.strip().split(",") - for el in open(params_path, "r").readlines()] + lines = [ + el.strip().split(",") + for el in open(params_path, "r").readlines() + ] radii, dates = zip(*lines) - self._polarhole_dates = [dt.datetime.strptime(el, "%Y%m%d").date() - for el in dates] + self._polarhole_dates = [ + dt.datetime.strptime(el, "%Y%m%d").date() for el in dates + ] self._polarhole_radii = [int(r) for r in radii] def generate(self, @@ -110,11 +111,12 @@ def generate(self, month_path = os.path.join(month_folder, filename_osi450) if not os.path.exists(month_path): - run_command(retrieve_cmd_template_osi450.format( - siconca_folder, year, month, filename_osi450)) + run_command( + retrieve_cmd_template_osi450.format(siconca_folder, year, + month, filename_osi450)) else: - logging.info("siconca {} already exists". - format(filename_osi450)) + logging.info( + "siconca {} already exists".format(filename_osi450)) with xr.open_dataset(month_path) as ds: status_flag = ds['status_flag'] @@ -125,17 +127,17 @@ def generate(self, reshape(*self._shape, 8) # Mask out: land, lake, and 'outside max climatology' (open sea) - max_extent_mask = np.sum( - binary[:, :, [7, 6, 0]], axis=2).reshape(*self._shape) >= 1 + max_extent_mask = np.sum(binary[:, :, [7, 6, 0]], + axis=2).reshape(*self._shape) >= 1 max_extent_mask = ~max_extent_mask # FIXME: Remove Caspian and Black seas - should we do this sh? if self.north: max_extent_mask[325:386, 317:380] = False - mask_path = os.path.join(self.get_data_var_folder("masks"), - "active_grid_cell_mask_{:02d}.npy". - format(month)) + mask_path = os.path.join( + self.get_data_var_folder("masks"), + "active_grid_cell_mask_{:02d}.npy".format(month)) logging.info("Saving {}".format(mask_path)) np.save(mask_path, max_extent_mask) @@ -171,22 +173,21 @@ def generate(self, polarhole = np.full(self._shape, False) polarhole[squaresum < radius**2] = True - polarhole_path = os.path.join(self.get_data_var_folder("masks"), - "polarhole{}_mask.npy". - format(i+1)) + polarhole_path = os.path.join( + self.get_data_var_folder("masks"), + "polarhole{}_mask.npy".format(i + 1)) logging.info("Saving polarhole {}".format(polarhole_path)) np.save(polarhole_path, polarhole) - def get_active_cell_mask(self, - month: object) -> object: + def get_active_cell_mask(self, month: object) -> object: """ :param month: :return: """ - mask_path = os.path.join(self.get_data_var_folder("masks"), - "active_grid_cell_mask_{:02d}.npy". - format(month)) + mask_path = os.path.join( + self.get_data_var_folder("masks"), + "active_grid_cell_mask_{:02d}.npy".format(month)) if not os.path.exists(mask_path): raise RuntimeError("Active cell masks have not been generated, " @@ -196,23 +197,23 @@ def get_active_cell_mask(self, # logging.debug("Loading active cell mask {}".format(mask_path)) return np.load(mask_path)[self._region] - def get_active_cell_da(self, - src_da: object) -> object: + def get_active_cell_da(self, src_da: object) -> object: """ :param src_da: """ return xr.DataArray( - [self.get_active_cell_mask(pd.to_datetime(date).month) - for date in src_da.time.values], + [ + self.get_active_cell_mask(pd.to_datetime(date).month) + for date in src_da.time.values + ], dims=('time', 'yc', 'xc'), coords={ 'time': src_da.time.values, 'yc': src_da.yc.values, 'xc': src_da.xc.values, - } - ) + }) def get_land_mask(self, land_mask_filename: str = LAND_MASK_FILENAME) -> object: @@ -232,8 +233,7 @@ def get_land_mask(self, # logging.debug("Loading land mask {}".format(mask_path)) return np.load(mask_path)[self._region] - def get_polarhole_mask(self, - date: object) -> object: + def get_polarhole_mask(self, date: object) -> object: """ :param date: @@ -244,9 +244,9 @@ def get_polarhole_mask(self, for i, r in enumerate(self._polarhole_radii): if date <= self._polarhole_dates[i]: - polarhole_path = os.path.join(self.get_data_var_folder("masks"), - "polarhole{}_mask.npy". - format(i + 1)) + polarhole_path = os.path.join( + self.get_data_var_folder("masks"), + "polarhole{}_mask.npy".format(i + 1)) # logging.debug("Loading polarhole {}".format(polarhole_path)) return np.load(polarhole_path)[self._region] return None diff --git a/icenet/data/sic/osisaf.py b/icenet/data/sic/osisaf.py index e88d6bbd..c3ffcbe6 100644 --- a/icenet/data/sic/osisaf.py +++ b/icenet/data/sic/osisaf.py @@ -18,54 +18,89 @@ from icenet.data.sic.mask import Masks from icenet.utils import Hemisphere, run_command from icenet.data.sic.utils import SIC_HEMI_STR - """ """ invalid_sic_days = { Hemisphere.NORTH: [ - *[d.date() for d in - pd.date_range(dt.date(1979, 5, 21), dt.date(1979, 6, 4))], - *[d.date() for d in - pd.date_range(dt.date(1979, 6, 10), dt.date(1979, 6, 26))], + *[ + d.date() + for d in pd.date_range(dt.date(1979, 5, 21), dt.date(1979, 6, 4)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1979, 6, 10), dt.date(1979, 6, 26)) + ], dt.date(1979, 7, 1), - *[d.date() for d in - pd.date_range(dt.date(1979, 7, 24), dt.date(1979, 7, 28))], - *[d.date() for d in - pd.date_range(dt.date(1980, 1, 4), dt.date(1980, 1, 10))], - *[d.date() for d in - pd.date_range(dt.date(1980, 2, 27), dt.date(1980, 3, 4))], - *[d.date() for d in - pd.date_range(dt.date(1980, 3, 16), dt.date(1980, 3, 22))], - *[d.date() for d in - pd.date_range(dt.date(1980, 4, 9), dt.date(1980, 4, 15))], - *[d.date() for d in - pd.date_range(dt.date(1981, 2, 27), dt.date(1981, 3, 5))], - *[d.date() for d in - pd.date_range(dt.date(1984, 8, 12), dt.date(1984, 8, 24))], + *[ + d.date() + for d in pd.date_range(dt.date(1979, 7, 24), dt.date(1979, 7, 28)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 1, 4), dt.date(1980, 1, 10)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 2, 27), dt.date(1980, 3, 4)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 3, 16), dt.date(1980, 3, 22)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 4, 9), dt.date(1980, 4, 15)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1981, 2, 27), dt.date(1981, 3, 5)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1984, 8, 12), dt.date(1984, 8, 24)) + ], dt.date(1984, 9, 14), - *[d.date() for d in - pd.date_range(dt.date(1985, 9, 22), dt.date(1985, 9, 28))], - *[d.date() for d in - pd.date_range(dt.date(1986, 3, 29), dt.date(1986, 7, 1))], - *[d.date() for d in - pd.date_range(dt.date(1987, 1, 3), dt.date(1987, 1, 19))], - *[d.date() for d in - pd.date_range(dt.date(1987, 1, 29), dt.date(1987, 2, 2))], + *[ + d.date() + for d in pd.date_range(dt.date(1985, 9, 22), dt.date(1985, 9, 28)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1986, 3, 29), dt.date(1986, 7, 1)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 1, 3), dt.date(1987, 1, 19)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 1, 29), dt.date(1987, 2, 2)) + ], dt.date(1987, 2, 23), - *[d.date() for d in - pd.date_range(dt.date(1987, 2, 26), dt.date(1987, 3, 2))], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 2, 26), dt.date(1987, 3, 2)) + ], dt.date(1987, 3, 13), - *[d.date() for d in - pd.date_range(dt.date(1987, 3, 22), dt.date(1987, 3, 26))], - *[d.date() for d in - pd.date_range(dt.date(1987, 4, 3), dt.date(1987, 4, 17))], - *[d.date() for d in - pd.date_range(dt.date(1987, 12, 1), dt.date(1988, 1, 12))], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 3, 22), dt.date(1987, 3, 26)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 4, 3), dt.date(1987, 4, 17)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 12, 1), dt.date(1988, 1, 12)) + ], dt.date(1989, 1, 3), - *[d.date() for d in - pd.date_range(dt.date(1990, 12, 21), dt.date(1990, 12, 26))], + *[ + d.date() + for d in pd.date_range(dt.date(1990, 12, 21), dt.date(1990, 12, 26)) + ], dt.date(1979, 5, 28), dt.date(1979, 5, 30), dt.date(1979, 6, 1), @@ -106,59 +141,97 @@ dt.date(1979, 2, 5), dt.date(1979, 2, 25), dt.date(1979, 3, 23), - *[d.date() for d in - pd.date_range(dt.date(1979, 3, 26), dt.date(1979, 3, 30))], + *[ + d.date() + for d in pd.date_range(dt.date(1979, 3, 26), dt.date(1979, 3, 30)) + ], dt.date(1979, 4, 12), dt.date(1979, 5, 16), - *[d.date() for d in - pd.date_range(dt.date(1979, 5, 21), dt.date(1979, 5, 27))], - *[d.date() for d in - pd.date_range(dt.date(1979, 7, 10), dt.date(1979, 7, 18))], + *[ + d.date() + for d in pd.date_range(dt.date(1979, 5, 21), dt.date(1979, 5, 27)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1979, 7, 10), dt.date(1979, 7, 18)) + ], dt.date(1979, 8, 10), dt.date(1979, 9, 3), - *[d.date() for d in - pd.date_range(dt.date(1980, 1, 4), dt.date(1980, 1, 10))], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 1, 4), dt.date(1980, 1, 10)) + ], dt.date(1980, 2, 16), - *[d.date() for d in - pd.date_range(dt.date(1980, 2, 27), dt.date(1980, 3, 4))], - *[d.date() for d in - pd.date_range(dt.date(1980, 3, 14), dt.date(1980, 3, 22))], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 2, 27), dt.date(1980, 3, 4)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 3, 14), dt.date(1980, 3, 22)) + ], dt.date(1980, 3, 31), - *[d.date() for d in - pd.date_range(dt.date(1980, 4, 9), dt.date(1980, 4, 15))], + *[ + d.date() + for d in pd.date_range(dt.date(1980, 4, 9), dt.date(1980, 4, 15)) + ], dt.date(1980, 4, 22), - *[d.date() for d in - pd.date_range(dt.date(1981, 2, 27), dt.date(1981, 3, 5))], + *[ + d.date() + for d in pd.date_range(dt.date(1981, 2, 27), dt.date(1981, 3, 5)) + ], dt.date(1981, 6, 10), - *[d.date() for d in - pd.date_range(dt.date(1981, 8, 3), dt.date(1982, 8, 9))], + *[ + d.date() + for d in pd.date_range(dt.date(1981, 8, 3), dt.date(1982, 8, 9)) + ], dt.date(1982, 8, 6), - *[d.date() for d in - pd.date_range(dt.date(1983, 7, 7), dt.date(1983, 7, 11))], + *[ + d.date() + for d in pd.date_range(dt.date(1983, 7, 7), dt.date(1983, 7, 11)) + ], dt.date(1983, 7, 22), dt.date(1984, 6, 12), - *[d.date() for d in - pd.date_range(dt.date(1984, 8, 12), dt.date(1984, 8, 24))], - *[d.date() for d in - pd.date_range(dt.date(1984, 9, 13), dt.date(1984, 9, 17))], - *[d.date() for d in - pd.date_range(dt.date(1984, 10, 3), dt.date(1984, 10, 9))], - *[d.date() for d in - pd.date_range(dt.date(1984, 11, 18), dt.date(1984, 11, 22))], + *[ + d.date() + for d in pd.date_range(dt.date(1984, 8, 12), dt.date(1984, 8, 24)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1984, 9, 13), dt.date(1984, 9, 17)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1984, 10, 3), dt.date(1984, 10, 9)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1984, 11, 18), dt.date(1984, 11, 22)) + ], dt.date(1985, 7, 23), - *[d.date() for d in - pd.date_range(dt.date(1985, 9, 22), dt.date(1985, 9, 28))], - *[d.date() for d in - pd.date_range(dt.date(1986, 3, 29), dt.date(1986, 11, 2))], - *[d.date() for d in - pd.date_range(dt.date(1987, 1, 3), dt.date(1987, 1, 15))], - *[d.date() for d in - pd.date_range(dt.date(1987, 12, 1), dt.date(1988, 1, 12))], + *[ + d.date() + for d in pd.date_range(dt.date(1985, 9, 22), dt.date(1985, 9, 28)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1986, 3, 29), dt.date(1986, 11, 2)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 1, 3), dt.date(1987, 1, 15)) + ], + *[ + d.date() + for d in pd.date_range(dt.date(1987, 12, 1), dt.date(1988, 1, 12)) + ], dt.date(1990, 8, 14), dt.date(1990, 8, 15), dt.date(1990, 8, 24), - *[d.date() for d in - pd.date_range(dt.date(1990, 12, 22), dt.date(1990, 12, 26))], + *[ + d.date() + for d in pd.date_range(dt.date(1990, 12, 22), dt.date(1990, 12, 26)) + ], dt.date(1979, 2, 5), dt.date(1979, 2, 25), dt.date(1979, 3, 23), @@ -190,8 +263,10 @@ dt.date(1984, 11, 19), dt.date(1984, 11, 21), dt.date(1985, 7, 23), - *[d.date() for d in - pd.date_range(dt.date(1986, 7, 2), dt.date(1986, 11, 1))], + *[ + d.date() + for d in pd.date_range(dt.date(1986, 7, 2), dt.date(1986, 11, 1)) + ], dt.date(1990, 8, 14), dt.date(1990, 8, 15), dt.date(1990, 8, 24), @@ -199,9 +274,11 @@ ] } -var_remove_list = ['time_bnds', 'raw_ice_conc_values', 'total_standard_error', - 'smearing_standard_error', 'algorithm_standard_error', - 'status_flag', 'Lambert_Azimuthal_Grid'] +var_remove_list = [ + 'time_bnds', 'raw_ice_conc_values', 'total_standard_error', + 'smearing_standard_error', 'algorithm_standard_error', 'status_flag', + 'Lambert_Azimuthal_Grid' +] # This is adapted from the data/loaders implementations @@ -225,10 +302,7 @@ def __init__(self, self._tmp_dir = dask_tmp_dir self._workers = workers - def dask_process(self, - *args, - method: callable, - **kwargs): + def dask_process(self, *args, method: callable, **kwargs): """ :param method: @@ -236,9 +310,9 @@ def dask_process(self, dashboard = "localhost:{}".format(self._dashboard_port) with dask.config.set({ - "temporary_directory": self._tmp_dir, - "distributed.comm.timeouts.connect": self._timeout, - "distributed.comm.timeouts.tcp": self._timeout, + "temporary_directory": self._tmp_dir, + "distributed.comm.timeouts.connect": self._timeout, + "distributed.comm.timeouts.tcp": self._timeout, }): cluster = LocalCluster( dashboard_address=dashboard, @@ -275,6 +349,7 @@ class SICDownloader(Downloader): :param download: :param dtype: """ + def __init__(self, *args, additional_invalid_dates: object = (), @@ -303,7 +378,7 @@ def __init__(self, self._mask_dict = { month: self._masks.get_active_cell_mask(month) - for month in np.arange(1, 12+1) + for month in np.arange(1, 12 + 1) } def download(self): @@ -315,10 +390,9 @@ def download(self): ftp = None var = "siconca" - logging.info( - "Not downloading SIC files, (re)processing NC files in " - "existence already" if not self._download else - "Downloading SIC datafiles to .temp intermediates...") + logging.info("Not downloading SIC files, (re)processing NC files in " + "existence already" if not self._download else + "Downloading SIC datafiles to .temp intermediates...") cache = {} osi430b_start = dt.date(2016, 1, 1) @@ -367,12 +441,11 @@ def download(self): if not self._download: if os.path.exists(nc_path): reproc_path = os.path.join( - self.get_data_var_folder(var, - append=[str(el.year)]), + self.get_data_var_folder(var, append=[str(el.year)]), "{}.reproc.nc".format(date_str)) - logging.debug("{} exists, becoming {}". - format(nc_path, reproc_path)) + logging.debug("{} exists, becoming {}".format( + nc_path, reproc_path)) os.rename(nc_path, reproc_path) data_files.append(reproc_path) else: @@ -411,15 +484,18 @@ def download(self): cache_match = "ice_conc_{}_ease*_{:04d}{:02d}{:02d}*.nc".\ format(hs, el.year, el.month, el.day) - ftp_files = [el for el in cache[chdir_path] - if fnmatch.fnmatch(el, cache_match)] + ftp_files = [ + el for el in cache[chdir_path] + if fnmatch.fnmatch(el, cache_match) + ] if len(ftp_files) > 1: - raise ValueError("More than a single file found: {}". - format(ftp_files)) + raise ValueError( + "More than a single file found: {}".format( + ftp_files)) elif not len(ftp_files): - logging.warning("File is not available: {}". - format(cache_match)) + logging.warning( + "File is not available: {}".format(cache_match)) continue except ftplib.error_perm: logging.warning("FTP error, possibly missing month chdir " @@ -459,13 +535,13 @@ def download(self): if coord not in da.coords: logging.warning("Adding {} vals to coords, as missing in " "this the combined dataset".format(coord)) - da.coords[coord] = self._get_missing_coordinates(var, - hs, - coord) + da.coords[coord] = self._get_missing_coordinates( + var, hs, coord) # In experimenting, I don't think this is actually required for month, mask in self._mask_dict.items(): - da.loc[dict(time=(da['time.month'] == month))].values[:, ~mask] = 0. + da.loc[dict( + time=(da['time.month'] == month))].values[:, ~mask] = 0. for date in da.time.values: day_da = da.sel(time=slice(date, date)) @@ -487,8 +563,9 @@ def download(self): var_folder, "old.{}.nc".format(getattr(req_date, "year"))) if os.path.exists(year_path): - logging.info("Existing file needs concatenating: {} -> {}". - format(year_path, old_year_path)) + logging.info( + "Existing file needs concatenating: {} -> {}".format( + year_path, old_year_path)) os.rename(year_path, old_year_path) old_da = xr.open_dataarray(old_year_path) year_da = year_da.drop_sel(time=old_da.time, @@ -513,17 +590,18 @@ def missing_dates(self): :return: """ - filenames = set([os.path.join( - self.get_data_var_folder("siconca"), - "{}.nc".format(el.strftime("%Y"))) - for el in self._dates]) + filenames = set([ + os.path.join(self.get_data_var_folder("siconca"), + "{}.nc".format(el.strftime("%Y"))) + for el in self._dates + ]) filenames = [f for f in filenames if os.path.exists(f)] logging.info("Opening for interpolation: {}".format(filenames)) ds = xr.open_mfdataset(filenames, combine="nested", concat_dim="time", - chunks=dict(time=self._chunk_size, ), + chunks=dict(time=self._chunk_size,), parallel=self._parallel_opens) return self._missing_dates(ds.ice_conc) @@ -538,39 +616,41 @@ def _missing_dates(self, da: object) -> object: and pd.Timestamp(1979, 1, 1) not in da.time.values: da_1979_01_01 = da.sel( time=[pd.Timestamp(1979, 1, 2)]).copy().assign_coords( - {'time': [pd.Timestamp(1979, 1, 1)]}) + {'time': [pd.Timestamp(1979, 1, 1)]}) da = xr.concat([da, da_1979_01_01], dim='time') da = da.sortby('time') dates_obs = [pd.to_datetime(date).date() for date in da.time.values] - dates_all = [pd.to_datetime(date).date() for date in - pd.date_range(min(self._dates), max(self._dates))] + dates_all = [ + pd.to_datetime(date).date() + for date in pd.date_range(min(self._dates), max(self._dates)) + ] # Weirdly, we were getting future warnings for timestamps, but unsure # where from invalid_dates = [pd.to_datetime(d).date() for d in self._invalid_dates] - missing_dates = [date for date in dates_all - if date not in dates_obs - or date in invalid_dates] + missing_dates = [ + date for date in dates_all + if date not in dates_obs or date in invalid_dates + ] logging.info("Processing {} missing dates".format(len(missing_dates))) - missing_dates_path = os.path.join( - self.get_data_var_folder("siconca"), "missing_days.csv") + missing_dates_path = os.path.join(self.get_data_var_folder("siconca"), + "missing_days.csv") with open(missing_dates_path, "a") as fh: for date in missing_dates: # FIXME: slightly unusual format for Ymd dates fh.write(date.strftime("%Y,%m,%d\n")) - logging.debug("Interpolating {} missing dates". - format(len(missing_dates))) + logging.debug("Interpolating {} missing dates".format( + len(missing_dates))) for date in missing_dates: if pd.Timestamp(date) not in da.time.values: logging.info("Interpolating {}".format(date)) - da = xr.concat([da, - da.interp(time=pd.to_datetime(date))], + da = xr.concat([da, da.interp(time=pd.to_datetime(date))], dim='time') logging.debug("Finished interpolation") @@ -603,8 +683,8 @@ def _get_missing_coordinates(self, var, hs, coord): :param hs: :param coord: """ - missing_coord_file = os.path.join( - self.get_data_var_folder(var), "missing_coord_data.nc") + missing_coord_file = os.path.join(self.get_data_var_folder(var), + "missing_coord_data.nc") if not os.path.exists(missing_coord_file): ftp_source_path = self._ftp_osi450.format(2000, 1) @@ -615,11 +695,13 @@ def _get_missing_coordinates(self, var, hs, coord): filename_osi450 = \ "ice_conc_{}_ease2-250_cdr-v2p0_200001011200.nc".format(hs) - run_command(retrieve_cmd_template_osi450.format( - missing_coord_file, ftp_source_path, filename_osi450)) + run_command( + retrieve_cmd_template_osi450.format(missing_coord_file, + ftp_source_path, + filename_osi450)) else: - logging.info("Coordinate path {} already exists". - format(missing_coord_file)) + logging.info( + "Coordinate path {} already exists".format(missing_coord_file)) ds = xr.open_dataset(missing_coord_file, drop_variables=var_remove_list, @@ -627,8 +709,9 @@ def _get_missing_coordinates(self, var, hs, coord): try: coord_data = getattr(ds, coord) except AttributeError as e: - logging.exception("{} does not exist in coord reference file {}". - format(coord, missing_coord_file)) + logging.exception( + "{} does not exist in coord reference file {}".format( + coord, missing_coord_file)) raise RuntimeError(e) return coord_data @@ -636,22 +719,22 @@ def _get_missing_coordinates(self, var, hs, coord): def main(): args = download_args(var_specs=False, workers=True, - extra_args=[ - (("-u", "--use-dask"), - dict(action="store_true", default=False)), - (("-c", "--sic-chunking-size"), - dict(type=int, default=10)), - (("-dt", "--dask-timeouts"), - dict(type=int, default=120)), - (("-dp", "--dask-port"), - dict(type=int, default=8888)) - ]) + extra_args=[(("-u", "--use-dask"), + dict(action="store_true", default=False)), + (("-c", "--sic-chunking-size"), + dict(type=int, default=10)), + (("-dt", "--dask-timeouts"), + dict(type=int, default=120)), + (("-dp", "--dask-port"), + dict(type=int, default=8888))]) logging.info("OSASIF-SIC Data Downloading") sic = SICDownloader( chunk_size=args.sic_chunking_size, - dates=[pd.to_datetime(date).date() for date in - pd.date_range(args.start_date, args.end_date, freq="D")], + dates=[ + pd.to_datetime(date).date() + for date in pd.date_range(args.start_date, args.end_date, freq="D") + ], delete_tempfiles=args.delete, north=args.hemisphere == "north", south=args.hemisphere == "south", diff --git a/icenet/data/sic/utils.py b/icenet/data/sic/utils.py index d1da1acb..ec75ec3f 100644 --- a/icenet/data/sic/utils.py +++ b/icenet/data/sic/utils.py @@ -2,7 +2,4 @@ """ -SIC_HEMI_STR = dict( - north="nh", - south="sh" -) +SIC_HEMI_STR = dict(north="nh", south="sh") diff --git a/icenet/data/utils.py b/icenet/data/utils.py index 3fe356ef..fd60ea5e 100644 --- a/icenet/data/utils.py +++ b/icenet/data/utils.py @@ -28,9 +28,7 @@ def assign_lat_lon_coord_system(cube: object): return cube -def rotate_grid_vectors(u_cube: object, - v_cube: object, - angles: object): +def rotate_grid_vectors(u_cube: object, v_cube: object, angles: object): """ Author: Tony Phillips (BAS) @@ -49,10 +47,14 @@ def rotate_grid_vectors(u_cube: object, v_r_all = iris.cube.CubeList() # get the X and Y dimension coordinates for each source cube - u_xy_coords = [u_cube.coord(axis='x', dim_coords=True), - u_cube.coord(axis='y', dim_coords=True)] - v_xy_coords = [v_cube.coord(axis='x', dim_coords=True), - v_cube.coord(axis='y', dim_coords=True)] + u_xy_coords = [ + u_cube.coord(axis='x', dim_coords=True), + u_cube.coord(axis='y', dim_coords=True) + ] + v_xy_coords = [ + v_cube.coord(axis='x', dim_coords=True), + v_cube.coord(axis='y', dim_coords=True) + ] # iterate over X, Y slices of the source cubes, rotating each in turn for u, v in zip(u_cube.slices(u_xy_coords, ordered=False), diff --git a/icenet/model/callbacks.py b/icenet/model/callbacks.py index 7d65ee75..2020f4d8 100644 --- a/icenet/model/callbacks.py +++ b/icenet/model/callbacks.py @@ -36,9 +36,7 @@ def __init__(self, self.val_dataloader = val_dataloader self.sample_at_zero = sample_at_zero - def on_train_batch_end(self, - batch: object, - logs: object = None): + def on_train_batch_end(self, batch: object, logs: object = None): """ :param batch: @@ -89,9 +87,7 @@ def __init__(self, elif self.mode == 'min': self.best = np.Inf - def on_train_batch_end(self, - batch: object, - logs: object = None): + def on_train_batch_end(self, batch: object, logs: object = None): """ :param batch: @@ -110,11 +106,9 @@ def on_train_batch_end(self, if save: tf.print('\n{} improved from {:.3f} to {:.3f}. ' - 'Saving model to {}.\n'. - format(self.monitor, - self.best, - logs[self.monitor], - self.model_path)) + 'Saving model to {}.\n'.format(self.monitor, self.best, + logs[self.monitor], + self.model_path)) self.best = logs[self.monitor] diff --git a/icenet/model/losses.py b/icenet/model/losses.py index 8eb3e3aa..91c6049c 100644 --- a/icenet/model/losses.py +++ b/icenet/model/losses.py @@ -7,8 +7,7 @@ class WeightedMSE(tf.keras.losses.MeanSquaredError): :param name: """ - def __init__(self, - name: str = 'mse', **kwargs): + def __init__(self, name: str = 'mse', **kwargs): super().__init__(name=name, **kwargs) def __call__(self, @@ -29,4 +28,6 @@ def __call__(self, # if sample_weight is not None: # sample_weight = tf.expand_dims(sample_weight, axis=-1) - return super().__call__(100*y_true, 100*y_pred, sample_weight=sample_weight) + return super().__call__(100 * y_true, + 100 * y_pred, + sample_weight=sample_weight) diff --git a/icenet/model/metrics.py b/icenet/model/metrics.py index db58f508..2b71d0ec 100644 --- a/icenet/model/metrics.py +++ b/icenet/model/metrics.py @@ -1,5 +1,4 @@ import tensorflow as tf - """ TensorFlow metrics. """ @@ -46,15 +45,16 @@ def update_state(self, if sample_weight is not None: sample_weight = tf.transpose(sample_weight, [0, 1, 2, 4, 3]) - super().update_state( - y_true, y_pred, sample_weight=sample_weight) + super().update_state(y_true, y_pred, sample_weight=sample_weight) elif not self.use_all_forecast_months: super().update_state( y_true[..., self.single_forecast_leadtime_idx], y_pred[..., self.single_forecast_leadtime_idx], - sample_weight=sample_weight[..., self.single_forecast_leadtime_idx] > 0) + sample_weight=sample_weight[..., + self.single_forecast_leadtime_idx] + > 0) def result(self): """ @@ -88,13 +88,12 @@ class WeightedBinaryAccuracy(tf.keras.metrics.BinaryAccuracy): :param leadtime_idx: """ - def __init__(self, - leadtime_idx=None, **kwargs): + def __init__(self, leadtime_idx=None, **kwargs): name = 'binacc' # Leadtime to compute metric over - leave as None to use all lead times if leadtime_idx is not None: - name += str(leadtime_idx+1) + name += str(leadtime_idx + 1) self._leadtime_idx = leadtime_idx super().__init__(name=name, **kwargs) @@ -162,7 +161,7 @@ def __init__(self, **kwargs): # Leadtime to compute metric over - leave as None to use all lead times if leadtime_idx is not None: - name += str(leadtime_idx+1) + name += str(leadtime_idx + 1) self._leadtime_idx = leadtime_idx super().__init__(name=name, **kwargs) @@ -217,7 +216,7 @@ def __init__(self, **kwargs): # Leadtime to compute metric over - leave as None to use all lead times if leadtime_idx is not None: - name += str(leadtime_idx+1) + name += str(leadtime_idx + 1) self._leadtime_idx = leadtime_idx super().__init__(name=name, **kwargs) @@ -265,13 +264,11 @@ class WeightedMSE(tf.keras.metrics.MeanSquaredError): :param name: """ - def __init__(self, - leadtime_idx: object = None, - **kwargs): + def __init__(self, leadtime_idx: object = None, **kwargs): name = 'mse' # Leadtime to compute metric over - leave as None to use all lead times if leadtime_idx is not None: - name += str(leadtime_idx+1) + name += str(leadtime_idx + 1) self._leadtime_idx = leadtime_idx super().__init__(name=name, **kwargs) diff --git a/icenet/model/models.py b/icenet/model/models.py index 993c3aa3..a8e9fa63 100644 --- a/icenet/model/models.py +++ b/icenet/model/models.py @@ -4,7 +4,6 @@ from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, \ concatenate, MaxPooling2D, Input from tensorflow.keras.optimizers import Adam - """ Defines the Python-based sea ice forecasting models, such as the IceNet architecture and the linear trend extrapolation model. @@ -18,10 +17,13 @@ class TemperatureScale(tf.keras.layers.Layer): Implements the temperature scaling layer for probability calibration, as introduced in Guo 2017 (http://proceedings.mlr.press/v70/guo17a.html). """ + def __init__(self, **kwargs): super(TemperatureScale, self).__init__(**kwargs) - self.temp = tf.Variable(initial_value=1.0, trainable=False, - dtype=tf.float32, name='temp') + self.temp = tf.Variable(initial_value=1.0, + trainable=False, + dtype=tf.float32, + name='temp') def call(self, inputs: object, **kwargs): """ Divide the input logits by the T value. @@ -43,6 +45,7 @@ def get_config(self): ### Network architectures: # -------------------------------------------------------------------- + def unet_batchnorm(input_shape: object, loss: object, metrics: object, @@ -63,86 +66,155 @@ def unet_batchnorm(input_shape: object, """ inputs = Input(shape=input_shape) - conv1 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(inputs) - conv1 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv1) + conv1 = Conv2D(int(64 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(inputs) + conv1 = Conv2D(int(64 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv1) bn1 = BatchNormalization(axis=-1)(conv1) pool1 = MaxPooling2D(pool_size=(2, 2))(bn1) - conv2 = Conv2D(int(128*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(pool1) - conv2 = Conv2D(int(128*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv2) + conv2 = Conv2D(int(128 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool1) + conv2 = Conv2D(int(128 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv2) bn2 = BatchNormalization(axis=-1)(conv2) pool2 = MaxPooling2D(pool_size=(2, 2))(bn2) - conv3 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(pool2) - conv3 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv3) + conv3 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool2) + conv3 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv3) bn3 = BatchNormalization(axis=-1)(conv3) pool3 = MaxPooling2D(pool_size=(2, 2))(bn3) - conv4 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(pool3) - conv4 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv4) + conv4 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool3) + conv4 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv4) bn4 = BatchNormalization(axis=-1)(conv4) pool4 = MaxPooling2D(pool_size=(2, 2))(bn4) - conv5 = Conv2D(int(512*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(pool4) - conv5 = Conv2D(int(512*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv5) + conv5 = Conv2D(int(512 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(pool4) + conv5 = Conv2D(int(512 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv5) bn5 = BatchNormalization(axis=-1)(conv5) - up6 = Conv2D(int(256*n_filters_factor), 2, activation='relu', - padding='same', kernel_initializer='he_normal')( - UpSampling2D(size=(2, 2), interpolation='nearest')(bn5)) + up6 = Conv2D(int(256 * n_filters_factor), + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn5)) merge6 = concatenate([bn4, up6], axis=3) - conv6 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(merge6) - conv6 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv6) + conv6 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge6) + conv6 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv6) bn6 = BatchNormalization(axis=-1)(conv6) - up7 = Conv2D(int(256*n_filters_factor), 2, activation='relu', - padding='same', kernel_initializer='he_normal')( - UpSampling2D(size=(2, 2), interpolation='nearest')(bn6)) + up7 = Conv2D(int(256 * n_filters_factor), + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn6)) merge7 = concatenate([bn3, up7], axis=3) - conv7 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(merge7) - conv7 = Conv2D(int(256*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv7) + conv7 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge7) + conv7 = Conv2D(int(256 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv7) bn7 = BatchNormalization(axis=-1)(conv7) - up8 = Conv2D(int(128*n_filters_factor), 2, activation='relu', - padding='same', kernel_initializer='he_normal')( - UpSampling2D(size=(2, 2), interpolation='nearest')(bn7)) + up8 = Conv2D(int(128 * n_filters_factor), + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn7)) merge8 = concatenate([bn2, up8], axis=3) - conv8 = Conv2D(int(128*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(merge8) - conv8 = Conv2D(int(128*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv8) + conv8 = Conv2D(int(128 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge8) + conv8 = Conv2D(int(128 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv8) bn8 = BatchNormalization(axis=-1)(conv8) - up9 = Conv2D(int(64*n_filters_factor), 2, activation='relu', - padding='same', kernel_initializer='he_normal')( - UpSampling2D(size=(2, 2), interpolation='nearest')(bn8)) + up9 = Conv2D(int(64 * n_filters_factor), + 2, + activation='relu', + padding='same', + kernel_initializer='he_normal')(UpSampling2D( + size=(2, 2), interpolation='nearest')(bn8)) merge9 = concatenate([conv1, up9], axis=3) - conv9 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(merge9) - conv9 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv9) - conv9 = Conv2D(int(64*n_filters_factor), filter_size, activation='relu', - padding='same', kernel_initializer='he_normal')(conv9) - - final_layer = Conv2D(n_forecast_days, - kernel_size=1, activation='sigmoid')(conv9) + conv9 = Conv2D(int(64 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(merge9) + conv9 = Conv2D(int(64 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv9) + conv9 = Conv2D(int(64 * n_filters_factor), + filter_size, + activation='relu', + padding='same', + kernel_initializer='he_normal')(conv9) + + final_layer = Conv2D(n_forecast_days, kernel_size=1, + activation='sigmoid')(conv9) # Keras graph mode needs y_pred and y_true to have the same shape, so we # we must pad an extra dimension onto the model output to train with @@ -151,20 +223,20 @@ def unet_batchnorm(input_shape: object, model = Model(inputs, final_layer) - model.compile( - optimizer=Adam(learning_rate=learning_rate), - loss=loss, - weighted_metrics=metrics) + model.compile(optimizer=Adam(learning_rate=learning_rate), + loss=loss, + weighted_metrics=metrics) return model -def linear_trend_forecast(usable_selector: object, - forecast_date: object, - da: object, - mask: object, - missing_dates: object = (), - shape: object = (432, 432)) -> object: +def linear_trend_forecast( + usable_selector: object, + forecast_date: object, + da: object, + mask: object, + missing_dates: object = (), + shape: object = (432, 432)) -> object: """ :param usable_selector: @@ -185,8 +257,8 @@ def linear_trend_forecast(usable_selector: object, src = np.c_[x, np.ones_like(x)] r = np.linalg.lstsq(src, y, rcond=None)[0] - output_map = np.matmul( - np.array([len(usable_data.time), 1]), r).reshape(*shape) + output_map = np.matmul(np.array([len(usable_data.time), 1]), + r).reshape(*shape) output_map[mask] = 0. output_map[output_map < 0] = 0. output_map[output_map > 1] = 1. diff --git a/icenet/model/predict.py b/icenet/model/predict.py index a531305c..f76264da 100644 --- a/icenet/model/predict.py +++ b/icenet/model/predict.py @@ -13,7 +13,6 @@ from icenet.data.loader import save_sample from icenet.data.dataset import IceNetDataSet from icenet.utils import setup_logging - """ """ @@ -55,20 +54,15 @@ def predict_forecast( network_folder = os.path.join(".", "results", "networks", network_name) dataset_name = dataset_name if dataset_name else ds.identifier - network_path = os.path.join(network_folder, - "{}.network_{}.{}.h5".format(network_name, - dataset_name, - seed)) + network_path = os.path.join( + network_folder, "{}.network_{}.{}.h5".format(network_name, dataset_name, + seed)) logging.info("Loading model from {}...".format(network_path)) - network = model_func( - (*ds.shape, dl.num_channels), - [], - [], - n_filters_factor=n_filters_factor, - n_forecast_days=ds.n_forecast_days - ) + network = model_func((*ds.shape, dl.num_channels), [], [], + n_filters_factor=n_filters_factor, + n_forecast_days=ds.n_forecast_days) network.load_weights(network_path) if not test_set: @@ -89,8 +83,11 @@ def predict_forecast( source_key = [k for k in dl.config['sources'].keys() if k != "meta"][0] # FIXME: should be using date format from class - test_dates = [dt.date(*[int(v) for v in d.split("_")]) for d in - dl.config["sources"][source_key]["dates"]["test"]] + test_dates = [ + dt.date(*[int(v) + for v in d.split("_")]) + for d in dl.config["sources"][source_key]["dates"]["test"] + ] if len(test_dates) == 0: raise RuntimeError("No processed files were produced for the test " @@ -98,9 +95,8 @@ def predict_forecast( missing = set(start_dates).difference(test_dates) if len(missing) > 0: - raise RuntimeError("{} are not in the test set". - format(", ".join([str(pd.to_datetime(el).date()) - for el in missing]))) + raise RuntimeError("{} are not in the test set".format(", ".join( + [str(pd.to_datetime(el).date()) for el in missing]))) data_iter = test_inputs.as_numpy_iterator() # FIXME: this is broken, this entry never gets added to the set? @@ -120,17 +116,13 @@ def predict_forecast( run_prediction(network=network, date=test_dates[idx], output_folder=output_folder, - data_sample=(x[arr_idx, ...], - y[arr_idx, ...], - sw[arr_idx, ...]), + data_sample=(x[arr_idx, ...], y[arr_idx, + ...], sw[arr_idx, + ...]), save_args=save_args) -def run_prediction(network, - date, - output_folder, - data_sample, - save_args): +def run_prediction(network, date, output_folder, data_sample, save_args): net_input, net_output, sample_weights = data_sample logging.info("Running prediction {}".format(date)) @@ -174,8 +166,12 @@ def get_args(): ap.add_argument("seed", type=int, default=42) ap.add_argument("datefile", type=argparse.FileType("r")) - ap.add_argument("-i", "--train-identifier", dest="ident", - help="Train dataset identifier", type=str, default=None) + ap.add_argument("-i", + "--train-identifier", + dest="ident", + help="Train dataset identifier", + type=str, + default=None) ap.add_argument("-n", "--n-filters-factor", type=float, default=1.) ap.add_argument("-t", "--testset", action="store_true", default=False) ap.add_argument("-v", "--verbose", action="store_true", default=False) @@ -191,23 +187,24 @@ def main(): os.path.join(".", "dataset_config.{}.json".format(args.dataset)) date_content = args.datefile.read() - dates = [dt.date(*[int(v) for v in s.split("-")]) - for s in date_content.split()] + dates = [ + dt.date(*[int(v) for v in s.split("-")]) for s in date_content.split() + ] args.datefile.close() - output_folder = os.path.join(".", "results", "predict", - args.output_name, + output_folder = os.path.join(".", "results", "predict", args.output_name, "{}.{}".format(args.network_name, args.seed)) - predict_forecast(dataset_config, - args.network_name, - # FIXME: this is turning into a mapping mess, - # do we need to retain the train SD name in the - # network? - dataset_name=args.ident if args.ident else args.dataset, - n_filters_factor=args.n_filters_factor, - output_folder=output_folder, - save_args=args.save_args, - seed=args.seed, - start_dates=dates, - test_set=args.testset) + predict_forecast( + dataset_config, + args.network_name, + # FIXME: this is turning into a mapping mess, + # do we need to retain the train SD name in the + # network? + dataset_name=args.ident if args.ident else args.dataset, + n_filters_factor=args.n_filters_factor, + output_folder=output_folder, + save_args=args.save_args, + seed=args.seed, + start_dates=dates, + test_set=args.testset) diff --git a/icenet/model/train.py b/icenet/model/train.py index cc3236d2..7a543522 100644 --- a/icenet/model/train.py +++ b/icenet/model/train.py @@ -30,34 +30,33 @@ pass -def train_model( - run_name: object, - dataset: object, - callback_objects: list = [], - checkpoint_monitor: str = 'val_rmse', - checkpoint_mode: str = 'min', - dataset_ratio: float = 1.0, - early_stopping_patience: int = 30, - epochs: int = 2, - filter_size: float = 3, - learning_rate: float = 1e-4, - lr_10e_decay_fac: float = 1.0, - lr_decay_start: float = 10, - lr_decay_end: float = 30, - max_queue_size: int = 3, - model_func: object = models.unet_batchnorm, - n_filters_factor: float = 2, - network_folder: object = None, - network_save: bool = True, - pickup_weights: bool = False, - pre_load_network: bool = False, - pre_load_path: object = None, - seed: int = 42, - strategy: object = tf.distribute.get_strategy(), - training_verbosity: int = 1, - workers: int = 5, - use_multiprocessing: bool = True, - use_tensorboard: bool = True) -> object: +def train_model(run_name: object, + dataset: object, + callback_objects: list = [], + checkpoint_monitor: str = 'val_rmse', + checkpoint_mode: str = 'min', + dataset_ratio: float = 1.0, + early_stopping_patience: int = 30, + epochs: int = 2, + filter_size: float = 3, + learning_rate: float = 1e-4, + lr_10e_decay_fac: float = 1.0, + lr_decay_start: float = 10, + lr_decay_end: float = 30, + max_queue_size: int = 3, + model_func: object = models.unet_batchnorm, + n_filters_factor: float = 2, + network_folder: object = None, + network_save: bool = True, + pickup_weights: bool = False, + pre_load_network: bool = False, + pre_load_path: object = None, + seed: int = 42, + strategy: object = tf.distribute.get_strategy(), + training_verbosity: int = 1, + workers: int = 5, + use_multiprocessing: bool = True, + use_tensorboard: bool = True) -> object: """ :param run_name: @@ -105,14 +104,12 @@ def train_model( logging.info("Creating network folder: {}".format(network_folder)) os.makedirs(network_folder, exist_ok=True) - weights_path = os.path.join(network_folder, - "{}.network_{}.{}.h5".format(run_name, - dataset.identifier, - seed)) - model_path = os.path.join(network_folder, - "{}.model_{}.{}".format(run_name, - dataset.identifier, - seed)) + weights_path = os.path.join( + network_folder, "{}.network_{}.{}.h5".format(run_name, + dataset.identifier, seed)) + model_path = os.path.join( + network_folder, "{}.model_{}.{}".format(run_name, dataset.identifier, + seed)) history_path = os.path.join(network_folder, "{}_{}_history.json".format(run_name, seed)) @@ -122,23 +119,19 @@ def train_model( # Checkpoint the model weights when a validation metric is improved callbacks_list.append( - ModelCheckpoint( - filepath=weights_path, - monitor=checkpoint_monitor, - verbose=1, - mode=checkpoint_mode, - save_best_only=True - )) + ModelCheckpoint(filepath=weights_path, + monitor=checkpoint_monitor, + verbose=1, + mode=checkpoint_mode, + save_best_only=True)) # Abort training when validation performance stops improving callbacks_list.append( - EarlyStopping( - monitor=checkpoint_monitor, - mode=checkpoint_mode, - verbose=1, - patience=early_stopping_patience, - baseline=prev_best - )) + EarlyStopping(monitor=checkpoint_monitor, + mode=checkpoint_mode, + verbose=1, + patience=early_stopping_patience, + baseline=prev_best)) callbacks_list.append( LearningRateScheduler( @@ -151,8 +144,8 @@ def train_model( if use_tensorboard: logging.info("Adding tensorboard callback") log_dir = "logs/" + dt.datetime.now().strftime("%d-%m-%y-%H%M%S") - callbacks_list.append(tf.keras.callbacks.TensorBoard(log_dir=log_dir, - histogram_freq=1)) + callbacks_list.append( + tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)) ############################################################################ # TRAINING MODEL @@ -182,8 +175,8 @@ def train_model( logging.info("Loading network weights from {}".format(pre_load_path)) network.load_weights(pre_load_path) elif pickup_weights and os.path.exists(weights_path): - logging.warning("Automagically loading network weights from {}". - format(weights_path)) + logging.warning("Automagically loading network weights from {}".format( + weights_path)) network.load_weights(weights_path) network.summary() @@ -200,8 +193,7 @@ def train_model( max_queue_size=max_queue_size, # not useful for tf.data usage according to docs, but useful in dev workers=workers, - use_multiprocessing=use_multiprocessing - ) + use_multiprocessing=use_multiprocessing) if network_save: logging.info("Saving network to: {}".format(weights_path)) @@ -243,17 +235,17 @@ def evaluate_model(model_path: object, "than test set") lead_times = list(range(1, dataset.n_forecast_days + 1)) - logging.info("Metric creation for lead time of {} days". - format(len(lead_times))) + logging.info("Metric creation for lead time of {} days".format( + len(lead_times))) metric_names = ["binacc", "mae", "rmse"] metrics_classes = [ metrics.WeightedBinaryAccuracy, metrics.WeightedMAE, metrics.WeightedRMSE, ] - metrics_list = [cls(leadtime_idx=lt - 1) - for lt in lead_times - for cls in metrics_classes] + metrics_list = [ + cls(leadtime_idx=lt - 1) for lt in lead_times for cls in metrics_classes + ] network.compile(weighted_metrics=metrics_list) @@ -292,40 +284,60 @@ def get_args(): ap.add_argument("-b", "--batch-size", type=int, default=4) ap.add_argument("-ca", "--checkpoint-mode", default="min", type=str) ap.add_argument("-cm", "--checkpoint-monitor", default="val_rmse", type=str) - ap.add_argument("-ds", "--additional-dataset", - dest="additional", nargs="*", default=[]) + ap.add_argument("-ds", + "--additional-dataset", + dest="additional", + nargs="*", + default=[]) ap.add_argument("-e", "--epochs", type=int, default=4) ap.add_argument("-f", "--filter-size", type=int, default=3) ap.add_argument("--early-stopping", type=int, default=50) - ap.add_argument("-m", "--multiprocessing", - action="store_true", default=False) + ap.add_argument("-m", + "--multiprocessing", + action="store_true", + default=False) ap.add_argument("-n", "--n-filters-factor", type=float, default=1.) ap.add_argument("-p", "--preload", type=str) - ap.add_argument("-pw", "--pickup-weights", - action="store_true", default=False) + ap.add_argument("-pw", + "--pickup-weights", + action="store_true", + default=False) ap.add_argument("-qs", "--max-queue-size", default=10, type=int) ap.add_argument("-r", "--ratio", default=1.0, type=float) - ap.add_argument("-s", "--strategy", default="default", + ap.add_argument("-s", + "--strategy", + default="default", choices=("default", "mirrored", "central")) - ap.add_argument("--shuffle-train", default=False, - action="store_true", help="Shuffle the training set") + ap.add_argument("--shuffle-train", + default=False, + action="store_true", + help="Shuffle the training set") ap.add_argument("--gpus", default=None) ap.add_argument("-v", "--verbose", action="store_true", default=False) ap.add_argument("-w", "--workers", type=int, default=4) # WandB additional arguments ap.add_argument("-nw", "--no-wandb", default=False, action="store_true") - ap.add_argument("-wo", "--wandb-offline", default=False, action="store_true") - ap.add_argument("-wp", "--wandb-project", - default=os.environ.get("ICENET_ENVIRONMENT"), type=str) - ap.add_argument("-wu", "--wandb-user", - default=os.environ.get("USER"), type=str) + ap.add_argument("-wo", + "--wandb-offline", + default=False, + action="store_true") + ap.add_argument("-wp", + "--wandb-project", + default=os.environ.get("ICENET_ENVIRONMENT"), + type=str) + ap.add_argument("-wu", + "--wandb-user", + default=os.environ.get("USER"), + type=str) ap.add_argument("--lr", default=1e-4, type=float) - ap.add_argument("--lr_10e_decay_fac", default=1.0, type=float, + ap.add_argument("--lr_10e_decay_fac", + default=1.0, + type=float, help="Factor by which LR is multiplied by every 10 epochs " - "using exponential decay. E.g. 1 -> no decay (default)" - ", 0.5 -> halve every 10 epochs.") + "using exponential decay. E.g. 1 -> no decay (default)" + ", 0.5 -> halve every 10 epochs.") ap.add_argument('--lr_decay_start', default=10, type=int) ap.add_argument('--lr_decay_end', default=30, type=int) @@ -335,8 +347,9 @@ def get_args(): def main(): args = get_args() - logging.warning("Setting seed for best attempt at determinism, value {}". - format(args.seed)) + logging.warning( + "Setting seed for best attempt at determinism, value {}".format( + args.seed)) # determinism is not guaranteed across different versions of TensorFlow. # determinism is not guaranteed across different hardware. os.environ['PYTHONHASHSEED'] = str(args.seed) @@ -355,12 +368,11 @@ def main(): shuffling=args.shuffle_train) else: dataset = MergedIceNetDataSet([ - "dataset_config.{}.json".format(el) for el in [ - args.dataset, *args.additional - ] + "dataset_config.{}.json".format(el) + for el in [args.dataset, *args.additional] ], - batch_size=args.batch_size, - shuffling=args.shuffle_train) + batch_size=args.batch_size, + shuffling=args.shuffle_train) strategy = tf.distribute.MirroredStrategy() \ if args.strategy == "mirrored" \ @@ -380,11 +392,11 @@ def main(): run = wandb.init( project=args.wandb_project, name="{}.{}".format(args.run_name, args.seed), - notes="{}: run at {}{}".format(args.run_name, - dt.datetime.now().strftime("%D %T"), - "" if - not args.preload is not None else - " preload {}".format(args.preload)), + notes="{}: run at {}{}".format( + args.run_name, + dt.datetime.now().strftime("%D %T"), + "" if not args.preload is not None else " preload {}".format( + args.preload)), entity=args.wandb_user, config=dict( seed=args.seed, @@ -397,8 +409,8 @@ def main(): batch_size=args.batch_size, ), settings=wandb.Settings( - # start_method="fork", - # _disable_stats=True, + # start_method="fork", + # _disable_stats=True, ), allow_val_change=True, mode='offline' if args.wandb_offline else 'online', @@ -453,10 +465,12 @@ def main(): if using_wandb: logging.info("Updating wandb run with evaluation metrics") - metric_vals = [[results[f'{name}{lt}'] - for lt in leads] for name in metric_names] + metric_vals = [ + [results[f'{name}{lt}'] for lt in leads] for name in metric_names + ] table_data = list(zip(leads, *metric_vals)) - table = wandb.Table(data=table_data, columns=['leadtime', *metric_names]) + table = wandb.Table(data=table_data, + columns=['leadtime', *metric_names]) # Log each metric vs. leadtime as a plot to wandb for name in metric_names: diff --git a/icenet/model/utils.py b/icenet/model/utils.py index 5d0103aa..49bf916f 100644 --- a/icenet/model/utils.py +++ b/icenet/model/utils.py @@ -55,8 +55,10 @@ def compute_heatmap(results_df: object, :return: """ - month_names = np.array(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sept', 'Oct', 'Nov', 'Dec']) + month_names = np.array([ + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sept', 'Oct', + 'Nov', 'Dec' + ]) # Mean over calendar month mean_df = results_df.loc[model, seed].reset_index().\ @@ -69,11 +71,8 @@ def compute_heatmap(results_df: object, return heatmap_df -def arr_to_ice_edge_arr(arr: object, - thresh: object, - land_mask: object, +def arr_to_ice_edge_arr(arr: object, thresh: object, land_mask: object, region_mask: object) -> object: - """ Compute a boolean mask with True over ice edge contour grid cells using matplotlib.pyplot.contour and an input threshold to define the ice edge @@ -108,11 +107,8 @@ def arr_to_ice_edge_arr(arr: object, return ice_edge_arr -def arr_to_ice_edge_rgba_arr(arr: object, - thresh: object, - land_mask: object, - region_mask: object, - rgb: object) -> object: +def arr_to_ice_edge_rgba_arr(arr: object, thresh: object, land_mask: object, + region_mask: object, rgb: object) -> object: """ :param arr: diff --git a/icenet/plotting/data.py b/icenet/plotting/data.py index ccbcd6c6..ec978426 100644 --- a/icenet/plotting/data.py +++ b/icenet/plotting/data.py @@ -41,8 +41,7 @@ def plot_tfrecord(): config = json.load(args.configuration) args.configuration.close() - decoder = get_decoder(tuple(config['shape']), - config['num_channels'], + decoder = get_decoder(tuple(config['shape']), config['num_channels'], config['n_forecast_days']) ds = ds.map(decoder).batch(1) @@ -58,12 +57,13 @@ def plot_tfrecord(): output_dir = os.path.join(args.output, "plot_set") os.makedirs(output_dir, exist_ok=True) - subprocess.run("rm -v {}/{}.*.png".format( - output_dir, config["identifier"]), shell=True) + subprocess.run("rm -v {}/{}.*.png".format(output_dir, config["identifier"]), + shell=True) for i, channel in enumerate(config['channels']): - output_path = os.path.join(output_dir, "{}.{:03d}_{}.png". - format(config["identifier"], i, channel)) + output_path = os.path.join( + output_dir, "{}.{:03d}_{}.png".format(config["identifier"], i, + channel)) logging.info("Producing {}".format(output_path)) fig, ax = plt.subplots() @@ -72,8 +72,8 @@ def plot_tfrecord(): plt.close() for i in range(config['n_forecast_days']): - output_path = os.path.join(output_dir, "{}.y.{:03d}.png". - format(config["identifier"], i + 1)) + output_path = os.path.join( + output_dir, "{}.y.{:03d}.png".format(config["identifier"], i + 1)) y_out = y[0, ..., i, 0] logging.info("Producing {}".format(output_path)) @@ -98,7 +98,10 @@ def get_sample_get_args(): ap.add_argument("date", type=date_arg) ap.add_argument("output_path", type=str, default="test.png") - ap.add_argument("-c", "--cols", type=int, default=8, + ap.add_argument("-c", + "--cols", + type=int, + default=8, help="Plotting data over this number of columns") data_type = ap.add_mutually_exclusive_group(required=False) @@ -128,22 +131,24 @@ def plot_sample_cli(): if args.weights: channel_data = net_weight.squeeze() - channel_labels = ["weights{}".format(i) - for i in range(channel_data.shape[-1])] - logging.info("Plotting {} weights from sample". - format(len(channel_labels))) + channel_labels = [ + "weights{}".format(i) for i in range(channel_data.shape[-1]) + ] + logging.info("Plotting {} weights from sample".format( + len(channel_labels))) elif args.outputs: channel_data = net_output.squeeze() - channel_labels = ["outputs{}".format(i) - for i in range(channel_data.shape[-1])] - logging.info("Plotting {} outputs from sample". - format(len(channel_labels))) + channel_labels = [ + "outputs{}".format(i) for i in range(channel_data.shape[-1]) + ] + logging.info("Plotting {} outputs from sample".format( + len(channel_labels))) else: logging.info("Plotting inputs from sample") channel_data = net_input channel_labels = dl.channel_names - logging.info("Plotting {} inputs from sample". - format(len(channel_labels))) + logging.info("Plotting {} inputs from sample".format( + len(channel_labels))) plot_channel_data(channel_data, channel_labels, @@ -168,10 +173,11 @@ def plot_channel_data(data: object, num_rows = int(len(var_names) / cols) + \ ceil(len(var_names) / cols - int(len(var_names) / cols)) - logging.debug("Plot Rows {} Cols {} Channels {}". - format(num_rows, cols, len(var_names))) + logging.debug("Plot Rows {} Cols {} Channels {}".format( + num_rows, cols, len(var_names))) fig = plt.figure(figsize=(cols * square_size, num_rows * square_size), - layout="tight", dpi=150) + layout="tight", + dpi=150) for i, var in enumerate(var_names): ax1 = fig.add_subplot(num_rows, cols, i + 1) diff --git a/icenet/plotting/forecast.py b/icenet/plotting/forecast.py index 534b25a7..c5f26888 100644 --- a/icenet/plotting/forecast.py +++ b/icenet/plotting/forecast.py @@ -23,17 +23,11 @@ from icenet import __version__ as icenet_version from icenet.data.cli import date_arg from icenet.data.sic.mask import Masks -from icenet.plotting.utils import ( - filter_ds_by_obs, - get_forecast_ds, - get_obs_da, - get_seas_forecast_da, - get_seas_forecast_init_dates, - show_img, - get_plot_axes, - process_probes, - process_regions -) +from icenet.plotting.utils import (filter_ds_by_obs, get_forecast_ds, + get_obs_da, get_seas_forecast_da, + get_seas_forecast_init_dates, show_img, + get_plot_axes, process_probes, + process_regions) from icenet.plotting.video import xarray_to_video @@ -49,8 +43,7 @@ def location_arg(argument: str): return (x, y) except ValueError: argparse.ArgumentTypeError( - "Expected a location (pair of integers separated by a comma)" - ) + "Expected a location (pair of integers separated by a comma)") def region_arg(argument: str): @@ -70,9 +63,7 @@ def region_arg(argument: str): "Region argument must be list of four integers") -def compute_binary_accuracy(masks: object, - fc_da: object, - obs_da: object, +def compute_binary_accuracy(masks: object, fc_da: object, obs_da: object, threshold: float) -> object: """ Compute the binary class accuracy of a forecast, @@ -143,7 +134,8 @@ def plot_binary_accuracy(masks: object, obs_da=obs_da, threshold=threshold) fig, ax = plt.subplots(figsize=(12, 6)) - ax.set_title(f"Binary accuracy comparison (threshold SIC = {threshold*100}%)") + ax.set_title( + f"Binary accuracy comparison (threshold SIC = {threshold*100}%)") ax.plot(binacc_fc.time, binacc_fc.values, label="IceNet") if cmp_da is not None: @@ -171,9 +163,7 @@ def plot_binary_accuracy(masks: object, return binacc_fc, binacc_cmp -def compute_sea_ice_extent_error(masks: object, - fc_da: object, - obs_da: object, +def compute_sea_ice_extent_error(masks: object, fc_da: object, obs_da: object, grid_area_size: int, threshold: float) -> object: """ @@ -209,10 +199,9 @@ def compute_sea_ice_extent_error(masks: object, binary_fc_weighted_da = binary_fc_da.astype(int).weighted(agcm) # sie error - forecast_sie_error = ( - binary_fc_weighted_da.sum(['xc', 'yc']) - - binary_obs_weighted_da.sum(['xc', 'yc']) - ) * (grid_area_size**2) + forecast_sie_error = (binary_fc_weighted_da.sum(['xc', 'yc']) - + binary_obs_weighted_da.sum(['xc', 'yc'])) * ( + grid_area_size**2) return forecast_sie_error @@ -244,11 +233,12 @@ def plot_sea_ice_extent_error(masks: object, :return: tuple of (SIE error for forecast (fc_da), SIE error for comparison (cmp_da)) """ - forecast_sie_error = compute_sea_ice_extent_error(masks=masks, - fc_da=fc_da, - obs_da=obs_da, - grid_area_size=grid_area_size, - threshold=threshold) + forecast_sie_error = compute_sea_ice_extent_error( + masks=masks, + fc_da=fc_da, + obs_da=obs_da, + grid_area_size=grid_area_size, + threshold=threshold) fig, ax = plt.subplots(figsize=(12, 6)) ax.set_title(f"SIE error comparison ({grid_area_size} km grid resolution) " @@ -256,11 +246,12 @@ def plot_sea_ice_extent_error(masks: object, ax.plot(forecast_sie_error.time, forecast_sie_error.values, label="IceNet") if cmp_da is not None: - cmp_sie_error = compute_sea_ice_extent_error(masks=masks, - fc_da=cmp_da, - obs_da=obs_da, - grid_area_size=grid_area_size, - threshold=threshold) + cmp_sie_error = compute_sea_ice_extent_error( + masks=masks, + fc_da=cmp_da, + obs_da=obs_da, + grid_area_size=grid_area_size, + threshold=threshold) ax.plot(cmp_sie_error.time, cmp_sie_error.values, label="SEAS") else: cmp_sie_error = None @@ -281,9 +272,7 @@ def plot_sea_ice_extent_error(masks: object, return forecast_sie_error, cmp_sie_error -def compute_metrics(metrics: object, - masks: object, - fc_da: object, +def compute_metrics(metrics: object, masks: object, fc_da: object, obs_da: object) -> object: """ Computes metrics based on SIC error which are passed in as a list of strings. @@ -302,15 +291,16 @@ def compute_metrics(metrics: object, implemented_metrics = ["mae", "mse", "rmse"] for metric in metrics: if metric not in implemented_metrics: - raise NotImplementedError(f"{metric} metric has not been implemented. " - f"Please only choose out of {implemented_metrics}.") + raise NotImplementedError( + f"{metric} metric has not been implemented. " + f"Please only choose out of {implemented_metrics}.") # obtain mask mask_da = masks.get_active_cell_da(obs_da) metric_dict = {} # compute raw error - err_da = (fc_da-obs_da)*100 + err_da = (fc_da - obs_da) * 100 if "mae" in metrics: # compute absolute SIC errors abs_err_da = da.fabs(err_da) @@ -431,12 +421,9 @@ def plot_metrics(metrics: object, return fc_metric_dict, cmp_metric_dict -def compute_metric_as_dataframe(metric: object, - masks: object, - init_date: object, - fc_da: object, - obs_da: object, - **kwargs) -> pd.DataFrame: +def compute_metric_as_dataframe(metric: object, masks: object, + init_date: object, fc_da: object, + obs_da: object, **kwargs) -> pd.DataFrame: """ Computes a metric for each leadtime in a forecast and stores the results in a pandas dataframe with columns 'date' (which is the @@ -466,21 +453,26 @@ def compute_metric_as_dataframe(metric: object, obs_da=obs_da)[met].values elif met == "binacc": if "threshold" not in kwargs.keys(): - raise KeyError("if met = 'binacc', must pass in argument for threshold") - metric_dict[met] = compute_binary_accuracy(masks=masks, - fc_da=fc_da, - obs_da=obs_da, - threshold=kwargs["threshold"]).values + raise KeyError( + "if met = 'binacc', must pass in argument for threshold") + metric_dict[met] = compute_binary_accuracy( + masks=masks, + fc_da=fc_da, + obs_da=obs_da, + threshold=kwargs["threshold"]).values elif met == "sie": if "grid_area_size" not in kwargs.keys(): - raise KeyError("if met = 'sie', must pass in argument for grid_area_size") + raise KeyError( + "if met = 'sie', must pass in argument for grid_area_size") if "threshold" not in kwargs.keys(): - raise KeyError("if met = 'sie', must pass in argument for threshold") - metric_dict[met] = compute_sea_ice_extent_error(masks=masks, - fc_da=fc_da, - obs_da=obs_da, - grid_area_size=kwargs["grid_area_size"], - threshold=kwargs["threshold"]).values + raise KeyError( + "if met = 'sie', must pass in argument for threshold") + metric_dict[met] = compute_sea_ice_extent_error( + masks=masks, + fc_da=fc_da, + obs_da=obs_da, + grid_area_size=kwargs["grid_area_size"], + threshold=kwargs["threshold"]).values else: raise NotImplementedError(f"{met} is not implemented") @@ -498,23 +490,27 @@ def compute_metric_as_dataframe(metric: object, dayofyear = init_date.replace(year=2001).dayofyear month = init_date.month # get target dates - leadtime = list(range(1, len(metric_df.index)+1, 1)) + leadtime = list(range(1, len(metric_df.index) + 1, 1)) target_date = pd.Series([init_date + timedelta(days=d) for d in leadtime]) target_dayofyear = target_date.dt.dayofyear # obtain day of year using same method above to avoid any leap-year issues - target_dayofyear = pd.Series([59 if d.strftime("%m-%d")=="02-29" - else d.replace(year=2001).dayofyear - for d in target_date]) + target_dayofyear = pd.Series([ + 59 if d.strftime("%m-%d") == "02-29" else d.replace(year=2001).dayofyear + for d in target_date + ]) target_month = target_date.dt.month - return pd.concat([pd.DataFrame({"date": init_date, - "dayofyear": dayofyear, - "month": month, - "target_date": target_date, - "target_dayofyear": target_dayofyear, - "target_month": target_month, - "leadtime": leadtime}), - metric_df], - axis=1) + return pd.concat([ + pd.DataFrame({ + "date": init_date, + "dayofyear": dayofyear, + "month": month, + "target_date": target_date, + "target_dayofyear": target_dayofyear, + "target_month": target_month, + "leadtime": leadtime + }), metric_df + ], + axis=1) def compute_metrics_leadtime_avg(metric: str, @@ -558,7 +554,8 @@ def compute_metrics_leadtime_avg(metric: str, if ecmwf: # find out what dates cross over with the SEAS5 predictions - (fc_start_date, fc_end_date) = (fc_ds.time.values.min(), fc_ds.time.values.max()) + (fc_start_date, fc_end_date) = (fc_ds.time.values.min(), + fc_ds.time.values.max()) dates = get_seas_forecast_init_dates(hemisphere) dates = dates[(dates > fc_start_date) & (dates <= fc_end_date)] times = [x for x in fc_ds.time.values if x in dates] @@ -575,7 +572,8 @@ def compute_metrics_leadtime_avg(metric: str, fc = fc_ds.sel(time=slice(time, time))["sic_mean"] obs = get_obs_da(hemisphere=hemisphere, start_date=pd.to_datetime(time) + timedelta(days=1), - end_date=pd.to_datetime(time) + timedelta(days=int(fc.leadtime.max()))) + end_date=pd.to_datetime(time) + + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, time) if ecmwf: @@ -594,19 +592,21 @@ def compute_metrics_leadtime_avg(metric: str, [seas, fc, obs, masks]) # compute metrics - fc_metrics_list.append(compute_metric_as_dataframe(metric=metric, - masks=masks, - init_date=time, - fc_da=fc, - obs_da=obs, - **kwargs)) + fc_metrics_list.append( + compute_metric_as_dataframe(metric=metric, + masks=masks, + init_date=time, + fc_da=fc, + obs_da=obs, + **kwargs)) if seas is not None: - seas_metrics_list.append(compute_metric_as_dataframe(metric=metric, - masks=masks, - init_date=time, - fc_da=seas, - obs_da=obs, - **kwargs)) + seas_metrics_list.append( + compute_metric_as_dataframe(metric=metric, + masks=masks, + init_date=time, + fc_da=seas, + obs_da=obs, + **kwargs)) # groupby the leadtime and compute the mean average of the metric fc_metric_df = pd.concat(fc_metrics_list) @@ -622,7 +622,8 @@ def compute_metrics_leadtime_avg(metric: str, fc_metric_df.to_csv(data_path) except OSError: # don't break if not successful, still return dataframe - logging.info("Save not successful! Make sure the data_path directory exists") + logging.info( + "Save not successful! Make sure the data_path directory exists") return fc_metric_df.reset_index(drop=True) @@ -641,12 +642,15 @@ def _parse_day_of_year(dayofyear: int, leapyear: bool = False) -> int: :return: int dayofyear """ if leapyear: - return (pd.Timestamp("2000-01-01") + timedelta(days=int(dayofyear) - 1)).strftime("%m-%d") + return (pd.Timestamp("2000-01-01") + + timedelta(days=int(dayofyear) - 1)).strftime("%m-%d") else: - return (pd.Timestamp("2001-01-01") + timedelta(days=int(dayofyear) - 1)).strftime("%m-%d") + return (pd.Timestamp("2001-01-01") + + timedelta(days=int(dayofyear) - 1)).strftime("%m-%d") -def _heatmap_ylabels(metrics_df: pd.DataFrame, average_over: str, groupby_col: str) -> object: +def _heatmap_ylabels(metrics_df: pd.DataFrame, average_over: str, + groupby_col: str) -> object: """ Private function to return the labels for the y-axis in heatmap plots. @@ -668,19 +672,24 @@ def _heatmap_ylabels(metrics_df: pd.DataFrame, average_over: str, groupby_col: s if average_over == "day": # only add labels to the start, end dates # and any days that represent the start of months - days_of_interest = np.array([metrics_df[groupby_col].min(), - 1, 32, 60, 91, 121, 152, - 182, 213, 244, 274, 305, 335, - metrics_df[groupby_col].max()]) - labels = [_parse_day_of_year(day) - if day in days_of_interest else "" - for day in sorted(metrics_df[groupby_col].unique())] + days_of_interest = np.array([ + metrics_df[groupby_col].min(), 1, 32, 60, 91, 121, 152, 182, 213, + 244, 274, 305, 335, metrics_df[groupby_col].max() + ]) + labels = [ + _parse_day_of_year(day) if day in days_of_interest else "" + for day in sorted(metrics_df[groupby_col].unique()) + ] else: # find out what months have been plotted and add their names - month_names = np.array(["Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sept", "Oct", "Nov", "Dec"]) - labels = [month_names[month-1] - for month in sorted(metrics_df[groupby_col].unique())] + month_names = np.array([ + "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sept", + "Oct", "Nov", "Dec" + ]) + labels = [ + month_names[month - 1] + for month in sorted(metrics_df[groupby_col].unique()) + ] return labels @@ -754,7 +763,7 @@ def standard_deviation_heatmap(metric: str, labels = _heatmap_ylabels(metrics_df=metrics_df, average_over=average_over, groupby_col=groupby_col) - ax.set_yticks(np.arange(len(metrics_df[groupby_col].unique()))+0.5) + ax.set_yticks(np.arange(len(metrics_df[groupby_col].unique())) + 0.5) ax.set_yticklabels(labels) plt.yticks(rotation=0) if target_date_avg: @@ -846,8 +855,9 @@ def plot_metrics_leadtime_avg(metric: str, """ implemented_metrics = ["binacc", "sie", "mae", "mse", "rmse"] if metric not in implemented_metrics: - raise NotImplementedError(f"{metric} metric has not been implemented. " - f"Please only choose out of {implemented_metrics}.") + raise NotImplementedError( + f"{metric} metric has not been implemented. " + f"Please only choose out of {implemented_metrics}.") if metric == "binacc": # add default kwargs if "threshold" not in kwargs.keys(): @@ -868,14 +878,16 @@ def plot_metrics_leadtime_avg(metric: str, do_compute_metrics = True if data_path is not None: # loading in precomputed dataframes for the metrics - logging.info(f"Attempting to read in metrics dataframe from {data_path}") + logging.info( + f"Attempting to read in metrics dataframe from {data_path}") try: metric_df = pd.read_csv(data_path) metric_df["date"] = pd.to_datetime(metric_df["date"]) do_compute_metrics = False except OSError: - logging.info(f"Couldn't load in dataframe from {data_path}, " - f"will compute metric dataframe and try save to {data_path}") + logging.info( + f"Couldn't load in dataframe from {data_path}, " + f"will compute metric dataframe and try save to {data_path}") if do_compute_metrics: # computing the dataframes for the metrics @@ -892,7 +904,8 @@ def plot_metrics_leadtime_avg(metric: str, fc_metric_df = metric_df[metric_df["forecast_name"] == "IceNet"] seas_metric_df = metric_df[metric_df["forecast_name"] == "SEAS"] - seas_metric_df = seas_metric_df if (len(seas_metric_df) != 0) and ecmwf else None + seas_metric_df = seas_metric_df if (len(seas_metric_df) + != 0) and ecmwf else None logging.info(f"Creating leadtime averaged plot for {metric} metric") fig, ax = plt.subplots(figsize=(12, 6)) @@ -914,7 +927,10 @@ def plot_metrics_leadtime_avg(metric: str, n_forecast_days = fc_avg_metric.index.max() # plot leadtime averaged metrics - ax.plot(fc_avg_metric.index, fc_avg_metric, label="IceNet", color="blue") + ax.plot(fc_avg_metric.index, + fc_avg_metric, + label="IceNet", + color="blue") if plot_std: # obtaining the standard deviation of the metric fc_std_metric = fc_metric_df.groupby("leadtime")[metric].std().\ @@ -928,7 +944,10 @@ def plot_metrics_leadtime_avg(metric: str, if seas_metric_df is not None: seas_avg_metric = seas_metric_df.groupby("leadtime").mean(metric).\ sort_values("leadtime", ascending=True)[metric] - ax.plot(seas_avg_metric.index, seas_avg_metric, label="SEAS", color="darkorange") + ax.plot(seas_avg_metric.index, + seas_avg_metric, + label="SEAS", + color="darkorange") if plot_std: # obtaining the standard deviation of the metric seas_std_metric = seas_metric_df.groupby("leadtime")[metric].std().\ @@ -970,19 +989,22 @@ def plot_metrics_leadtime_avg(metric: str, max = np.nanmax(np.abs(heatmap_df_diff.values)) # plot heatmap of the difference between IceNet and SEAS - sns.heatmap(data=heatmap_df_diff, - ax=ax, - vmax=max, - vmin=-max, - cmap="seismic_r" if metric in ["binacc", "sie"] else "seismic", - cbar_kws=dict(label=f"{ylabel} difference between IceNet and SEAS")) + sns.heatmap( + data=heatmap_df_diff, + ax=ax, + vmax=max, + vmin=-max, + cmap="seismic_r" if metric in ["binacc", "sie"] else "seismic", + cbar_kws=dict( + label=f"{ylabel} difference between IceNet and SEAS")) else: # plot heatmap of the leadtime averaged metric when grouped by groupby_col - sns.heatmap(data=fc_avg_metric, - ax=ax, - cmap="inferno" if metric in ["binacc", "sie"] else "inferno_r", - cbar_kws=dict(label=ylabel)) + sns.heatmap( + data=fc_avg_metric, + ax=ax, + cmap="inferno" if metric in ["binacc", "sie"] else "inferno_r", + cbar_kws=dict(label=ylabel)) # string to add in plot title time_coverage = "\nAveraged over a minimum of " + \ @@ -993,7 +1015,7 @@ def plot_metrics_leadtime_avg(metric: str, labels = _heatmap_ylabels(metrics_df=fc_metric_df, average_over=average_over, groupby_col=groupby_col) - ax.set_yticks(np.arange(len(fc_metric_df[groupby_col].unique()))+0.5) + ax.set_yticks(np.arange(len(fc_metric_df[groupby_col].unique())) + 0.5) ax.set_yticklabels(labels) plt.yticks(rotation=0) @@ -1002,7 +1024,8 @@ def plot_metrics_leadtime_avg(metric: str, else: ax.set_ylabel("Initialisation date of forecast") else: - raise NotImplementedError(f"averaging over {average_over} not a valid option.") + raise NotImplementedError( + f"averaging over {average_over} not a valid option.") # add plot title if metric in ["mae", "mse", "rmse"]: @@ -1040,13 +1063,17 @@ def plot_metrics_leadtime_avg(metric: str, reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ sort_values(groupby_col, ascending=True) # compute the maximum standard deviation to obtain a common scale - vmax = np.nanmax([np.nanmax(fc_std_metric.values), np.nanmax(seas_std_metric.values)]) + vmax = np.nanmax([ + np.nanmax(fc_std_metric.values), + np.nanmax(seas_std_metric.values) + ]) # create heapmap for the standard deviation standard_deviation_heatmap(metric=metric, model_name="SEAS", metrics_df=seas_metric_df, average_over=average_over, - output_path=output_path.replace(".png", "_SEAS_std.png"), + output_path=output_path.replace( + ".png", "_SEAS_std.png"), target_date_avg=target_date_avg, fc_std_metric=seas_std_metric, vmax=vmax, @@ -1060,7 +1087,8 @@ def plot_metrics_leadtime_avg(metric: str, model_name="IceNet", metrics_df=fc_metric_df, average_over=average_over, - output_path=output_path.replace(".png", "_IceNet_std.png"), + output_path=output_path.replace( + ".png", "_IceNet_std.png"), target_date_avg=target_date_avg, fc_std_metric=fc_std_metric, vmax=vmax, @@ -1069,9 +1097,7 @@ def plot_metrics_leadtime_avg(metric: str, return fc_metric_df, seas_metric_df -def sic_error_video(fc_da: object, - obs_da: object, - land_mask: object, +def sic_error_video(fc_da: object, obs_da: object, land_mask: object, output_path: object) -> object: """ @@ -1084,10 +1110,7 @@ def sic_error_video(fc_da: object, """ diff = fc_da - obs_da - fig, maps = plt.subplots(nrows=1, - ncols=3, - figsize=(16, 6), - layout="tight") + fig, maps = plt.subplots(nrows=1, ncols=3, figsize=(16, 6), layout="tight") fig.set_dpi(150) leadtime = 0 @@ -1101,11 +1124,7 @@ def sic_error_video(fc_da: object, logging.debug("Bounds of differences: {} - {}".format(diff_vmin, diff_vmax)) sic_cmap = mpl.cm.get_cmap("Blues_r", 20) - contour_kwargs = dict( - vmin=0, - vmax=1, - cmap=sic_cmap - ) + contour_kwargs = dict(vmin=0, vmax=1, cmap=sic_cmap) diff_cmap = mpl.cm.get_cmap("RdBu_r", 20) im1 = maps[0].imshow(fc_plot, **contour_kwargs) @@ -1115,20 +1134,24 @@ def sic_error_video(fc_da: object, vmax=diff_vmax, cmap=diff_cmap) - tic = maps[0].set_title("IceNet " - f"{pd.to_datetime(fc_da.isel(time=leadtime).time.values).strftime('%d/%m/%Y')}") - tio = maps[1].set_title("OSISAF Obs " - f"{pd.to_datetime(obs_da.isel(time=leadtime).time.values).strftime('%d/%m/%Y')}") + tic = maps[0].set_title( + "IceNet " + f"{pd.to_datetime(fc_da.isel(time=leadtime).time.values).strftime('%d/%m/%Y')}" + ) + tio = maps[1].set_title( + "OSISAF Obs " + f"{pd.to_datetime(obs_da.isel(time=leadtime).time.values).strftime('%d/%m/%Y')}" + ) maps[2].set_title("Diff") p0 = maps[0].get_position().get_points().flatten() p1 = maps[1].get_position().get_points().flatten() p2 = maps[2].get_position().get_points().flatten() - ax_cbar = fig.add_axes([p0[0]-0.05, 0.04, p1[2]-p0[0], 0.02]) + ax_cbar = fig.add_axes([p0[0] - 0.05, 0.04, p1[2] - p0[0], 0.02]) plt.colorbar(im1, orientation='horizontal', cax=ax_cbar) - ax_cbar1 = fig.add_axes([p2[0]+0.05, 0.04, p2[2]-p2[0], 0.02]) + ax_cbar1 = fig.add_axes([p2[0] + 0.05, 0.04, p2[2] - p2[0], 0.02]) plt.colorbar(im3, orientation='horizontal', cax=ax_cbar1) for m_ax in maps[0:3]: @@ -1149,9 +1172,11 @@ def update(date): diff_plot = diff.isel(time=date).to_numpy() tic.set_text("IceNet {}".format( - pd.to_datetime(fc_da.isel(time=date).time.values).strftime("%d/%m/%Y"))) + pd.to_datetime( + fc_da.isel(time=date).time.values).strftime("%d/%m/%Y"))) tio.set_text("OSISAF Obs {}".format( - pd.to_datetime(obs_da.isel(time=date).time.values).strftime("%d/%m/%Y"))) + pd.to_datetime( + obs_da.isel(time=date).time.values).strftime("%d/%m/%Y"))) im1.set_data(fc_plot) im2.set_data(obs_plot) @@ -1169,9 +1194,7 @@ def update(date): output_path = os.path.join("plot", "sic_error.mp4") \ if not output_path else output_path logging.info(f"Saving to {output_path}") - animation.save(output_path, - fps=10, - extra_args=['-vcodec', 'libx264']) + animation.save(output_path, fps=10, extra_args=['-vcodec', 'libx264']) return animation @@ -1179,25 +1202,16 @@ def sic_error_local_header_data(da: xr.DataArray): n_probe = len(da.probe) return { "probe array index": { - i_probe: ( - f"{da.xi.values[i_probe]}," - f"{da.yi.values[i_probe]}" - ) - for i_probe in range(n_probe) + i_probe: (f"{da.xi.values[i_probe]}," + f"{da.yi.values[i_probe]}") for i_probe in range(n_probe) }, "probe location (EASE)": { - i_probe: ( - f"{da.xc.values[i_probe]}," - f"{da.yc.values[i_probe]}" - ) - for i_probe in range(n_probe) + i_probe: (f"{da.xc.values[i_probe]}," + f"{da.yc.values[i_probe]}") for i_probe in range(n_probe) }, "probe location (lat, lon)": { - i_probe: ( - f"{da.lat.values[i_probe]}," - f"{da.lon.values[i_probe]}" - ) - for i_probe in range(n_probe) + i_probe: (f"{da.lat.values[i_probe]}," + f"{da.lon.values[i_probe]}") for i_probe in range(n_probe) }, "obs_kind": { 0: "forecast", @@ -1208,8 +1222,7 @@ def sic_error_local_header_data(da: xr.DataArray): } -def sic_error_local_write_fig(combined_da: xr.DataArray, - output_prefix: str): +def sic_error_local_write_fig(combined_da: xr.DataArray, output_prefix: str): """A helper function for `sic_error_local_plots`: plot error and forecast/observation data. @@ -1246,10 +1259,8 @@ def sic_error_local_write_fig(combined_da: xr.DataArray, fig, ax = plt.subplots() all_figs.append(fig) - ax.set_title( - f"Sea ice concentration at location {i_probe + 1}\n" - f"{lat:.3f}° {lat_h}, {lon:.3f}° {lon_h}" - ) + ax.set_title(f"Sea ice concentration at location {i_probe + 1}\n" + f"{lat:.3f}° {lat_h}, {lon:.3f}° {lon_h}") ax.set_xlabel("Date") ax.set_ylabel("SIC (%)") @@ -1268,8 +1279,7 @@ def sic_error_local_write_fig(combined_da: xr.DataArray, ax2.set_title( f"Sea ice concentration error at location {i_probe + 1}\n" - f"{lat:.3f}° {lat_h}, {lon:.3f}° {lon_h}" - ) + f"{lat:.3f}° {lat_h}, {lon:.3f}° {lon_h}") ax2.set_xlabel("Date") ax2.set_ylabel("SIC error (%)") @@ -1287,30 +1297,26 @@ def sic_error_local_plots(fc_da: object, obs_da: object, output_path: object, as_command: bool = False): - """ :param fc_da: a DataArray with dims ('time', 'probe') :param obs_da: a DataArray with dims ('time', 'probe') """ # convert SIC to percentages (ranging from 0% to 100%) rather than a fraction - fc_da = fc_da*100 - obs_da = obs_da*100 - err_da = (fc_da-obs_da) - combined_da = xr.concat( - [fc_da, obs_da, err_da], - dim="obs_kind", coords="minimal" - ) + fc_da = fc_da * 100 + obs_da = obs_da * 100 + err_da = (fc_da - obs_da) + combined_da = xr.concat([fc_da, obs_da, err_da], + dim="obs_kind", + coords="minimal") # convert to a dataframe for csv output df = ( - combined_da - .to_dataframe(name="SIC") + combined_da.to_dataframe(name="SIC") # drop unneeded coords (lat, lon, xc, yc) .loc[:, "SIC"] # Convert mult-indices - .unstack(2).unstack(0) - ) + .unstack(2).unstack(0)) if output_path is None: output_path = "sic_error_local.csv" @@ -1343,7 +1349,6 @@ def sic_error_local_plots(fc_da: object, class ForecastPlotArgParser(argparse.ArgumentParser): - """An ArgumentParser specialised to support forecast plot arguments Additional argument enabled by allow_ecmwf() etc. @@ -1353,9 +1358,7 @@ class ForecastPlotArgParser(argparse.ArgumentParser): :param forecast_date: allows this positional argument to be disabled """ - def __init__(self, *args, - forecast_date: bool = True, - **kwargs): + def __init__(self, *args, forecast_date: bool = True, **kwargs): super().__init__(*args, **kwargs) self.add_argument("hemisphere", choices=("north", "south")) @@ -1365,11 +1368,15 @@ def __init__(self, *args, self.add_argument("-o", "--output-path", type=str, default=None) self.add_argument("-v", "--verbose", action="store_true", default=False) - self.add_argument("-r", "--region", default=None, type=region_arg, + self.add_argument("-r", + "--region", + default=None, + type=region_arg, help="Region specified x1, y1, x2, y2") def allow_ecmwf(self): - self.add_argument("-b", "--bias-correct", + self.add_argument("-b", + "--bias-correct", help="Bias correct SEAS forecast array", action="store_true", default=False) @@ -1385,11 +1392,12 @@ def allow_threshold(self): return self def allow_sie(self): - self.add_argument("-ga", - "--grid-area", - help="The length of the sides of the grid used (in km)", - type=int, - default=25) + self.add_argument( + "-ga", + "--grid-area", + help="The length of the sides of the grid used (in km)", + type=int, + default=25) return self def allow_metrics(self): @@ -1398,17 +1406,22 @@ def allow_metrics(self): help="Which metrics to compute and plot", type=str, default="mae,mse,rmse") - self.add_argument("-s", - "--separate", - help="Whether or not to produce separate plots for each metric", - action="store_true", - default=False) + self.add_argument( + "-s", + "--separate", + help="Whether or not to produce separate plots for each metric", + action="store_true", + default=False) return self def allow_probes(self): self.add_argument( - "-p", "--probe", action="append", dest="probes", - type=location_arg, metavar="LOCATION", + "-p", + "--probe", + action="append", + dest="probes", + type=location_arg, + metavar="LOCATION", help="Sample at LOCATION", ) return self @@ -1416,11 +1429,13 @@ def allow_probes(self): def parse_args(self, *args, **kwargs): args = super().parse_args(*args, **kwargs) - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO) logging.getLogger("matplotlib").setLevel(logging.WARNING) return args + ## # CLI endpoints # @@ -1430,23 +1445,18 @@ def binary_accuracy(): """ Produces plot of the binary classification accuracy of forecasts. """ - ap = ( - ForecastPlotArgParser() - .allow_ecmwf() - .allow_threshold() - ) + ap = (ForecastPlotArgParser().allow_ecmwf().allow_threshold()) args = ap.parse_args() masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") - fc = get_forecast_ds(args.forecast_file, - args.forecast_date) - obs = get_obs_da(args.hemisphere, - pd.to_datetime(args.forecast_date) + - timedelta(days=1), - pd.to_datetime(args.forecast_date) + - timedelta(days=int(fc.leadtime.max()))) + fc = get_forecast_ds(args.forecast_file, args.forecast_date) + obs = get_obs_da( + args.hemisphere, + pd.to_datetime(args.forecast_date) + timedelta(days=1), + pd.to_datetime(args.forecast_date) + + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, args.forecast_date) if args.ecmwf: @@ -1478,24 +1488,18 @@ def sie_error(): """ Produces plot of the sea ice extent (SIE) error of forecasts. """ - ap = ( - ForecastPlotArgParser() - .allow_ecmwf() - .allow_threshold() - .allow_sie() - ) + ap = (ForecastPlotArgParser().allow_ecmwf().allow_threshold().allow_sie()) args = ap.parse_args() masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") - fc = get_forecast_ds(args.forecast_file, - args.forecast_date) - obs = get_obs_da(args.hemisphere, - pd.to_datetime(args.forecast_date) + - timedelta(days=1), - pd.to_datetime(args.forecast_date) + - timedelta(days=int(fc.leadtime.max()))) + fc = get_forecast_ds(args.forecast_file, args.forecast_date) + obs = get_obs_da( + args.hemisphere, + pd.to_datetime(args.forecast_date) + timedelta(days=1), + pd.to_datetime(args.forecast_date) + + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, args.forecast_date) if args.ecmwf: @@ -1530,24 +1534,32 @@ def plot_forecast(): :return: """ ap = ForecastPlotArgParser() - ap.add_argument("-l", "--leadtimes", + ap.add_argument("-l", + "--leadtimes", help="Leadtimes to output, multiple as CSV, range as n..n", - type=lambda s: [int(i) for i in - list(s.split(",") if "," in s else - range(int(s.split("..")[0]), - int(s.split("..")[1]) + 1) if ".." in s else - [s])]) - ap.add_argument("-c", "--no-coastlines", + type=lambda s: [ + int(i) for i in list( + s.split(",") + if "," in s else range(int(s.split("..")[0]), + int(s.split("..")[1]) + 1) + if ".." in s else [s]) + ]) + ap.add_argument("-c", + "--no-coastlines", help="Turn off cartopy integration", - action="store_true", default=False) - ap.add_argument("-f", "--format", + action="store_true", + default=False) + ap.add_argument("-f", + "--format", help="Format to output in", choices=("mp4", "png", "svg", "tiff"), default="png") - ap.add_argument("-n", "--cmap-name", + ap.add_argument("-n", + "--cmap-name", help="Color map name if not wanting to use default", default=None) - ap.add_argument("-s", "--stddev", + ap.add_argument("-s", + "--stddev", help="Plot the standard deviation from the ensemble", action="store_true", default=False) @@ -1564,8 +1576,8 @@ def plot_forecast(): logging.warning("No directory at: {}".format(output_path)) os.makedirs(output_path) elif os.path.isfile(output_path): - raise RuntimeError("{} should be a directory and not existent...". - format(output_path)) + raise RuntimeError( + "{} should be a directory and not existent...".format(output_path)) forecast_name = "{}.{}".format( os.path.splitext(os.path.basename(args.forecast_file))[0], @@ -1598,31 +1610,32 @@ def plot_forecast(): if "forecast_date" not in pred_da: forecast_dates = [ pd.Timestamp(args.forecast_date) + dt.timedelta(lt) - for lt in args.leadtimes] - pred_da = pred_da.assign_coords( - forecast_date=("leadtime", forecast_dates)) + for lt in args.leadtimes + ] + pred_da = pred_da.assign_coords(forecast_date=("leadtime", + forecast_dates)) pred_da = pred_da.drop("time").drop("leadtime").\ rename(leadtime="time", forecast_date="time").set_index(time="time") - anim_args = dict( - figsize=5 - ) + anim_args = dict(figsize=5) if not args.no_coastlines: logging.warning("Coastlines will not work with the current " "implementation of xarray_to_video") - output_filename = os.path.join(output_path, "{}.{}.{}{}".format( - forecast_name, - args.forecast_date.strftime("%Y%m%d"), - "" if not args.stddev else "stddev.", - args.format - )) - xarray_to_video(pred_da, fps=1, cmap=cmap, - imshow_kwargs=dict(vmin=0., vmax=vmax) - if not args.stddev else None, - video_path=output_filename, - **anim_args) + output_filename = os.path.join( + output_path, + "{}.{}.{}{}".format(forecast_name, + args.forecast_date.strftime("%Y%m%d"), + "" if not args.stddev else "stddev.", + args.format)) + xarray_to_video( + pred_da, + fps=1, + cmap=cmap, + imshow_kwargs=dict(vmin=0., vmax=vmax) if not args.stddev else None, + video_path=output_filename, + **anim_args) else: for leadtime in leadtimes: pred_da = fc.sel(leadtime=leadtime).isel(time=0) @@ -1640,7 +1653,10 @@ def plot_forecast(): bound_args.update(cmap=cmap) - im = show_img(ax, pred_da, **bound_args, vmax=vmax, + im = show_img(ax, + pred_da, + **bound_args, + vmax=vmax, do_coastlines=not args.no_coastlines) plt.colorbar(im, ax=ax) @@ -1648,13 +1664,12 @@ def plot_forecast(): ax.set_title("{:04d}/{:02d}/{:02d}".format(plot_date.year, plot_date.month, plot_date.day)) - output_filename = os.path.join(output_path, "{}.{}.{}{}".format( - forecast_name, - (args.forecast_date + dt.timedelta( - days=leadtime)).strftime("%Y%m%d"), - "" if not args.stddev else "stddev.", - args.format - )) + output_filename = os.path.join( + output_path, "{}.{}.{}{}".format( + forecast_name, + (args.forecast_date + + dt.timedelta(days=leadtime)).strftime("%Y%m%d"), + "" if not args.stddev else "stddev.", args.format)) logging.info("Saving to {}".format(output_filename)) plt.savefig(output_filename) @@ -1678,23 +1693,18 @@ def metric_plots(): """ Produces plot of requested metrics for forecasts. """ - ap = ( - ForecastPlotArgParser() - .allow_ecmwf() - .allow_metrics() - ) + ap = (ForecastPlotArgParser().allow_ecmwf().allow_metrics()) args = ap.parse_args() masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") - fc = get_forecast_ds(args.forecast_file, - args.forecast_date) - obs = get_obs_da(args.hemisphere, - pd.to_datetime(args.forecast_date) + - timedelta(days=1), - pd.to_datetime(args.forecast_date) + - timedelta(days=int(fc.leadtime.max()))) + fc = get_forecast_ds(args.forecast_file, args.forecast_date) + obs = get_obs_da( + args.hemisphere, + pd.to_datetime(args.forecast_date) + timedelta(days=1), + pd.to_datetime(args.forecast_date) + + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, args.forecast_date) metrics = parse_metrics_arg(args.metrics) @@ -1729,12 +1739,8 @@ def leadtime_avg_plots(): """ Produces plot of leadtime averaged metrics for forecasts. """ - ap = ( - ForecastPlotArgParser(forecast_date=False) - .allow_ecmwf() - .allow_threshold() - .allow_sie() - ) + ap = (ForecastPlotArgParser( + forecast_date=False).allow_ecmwf().allow_threshold().allow_sie()) ap.add_argument("-m", "--metric", help="Which metric to compute and plot", @@ -1759,11 +1765,12 @@ def leadtime_avg_plots(): help="What multiple of the standard deviation to plot", type=float, default=1.0) - ap.add_argument("-td", - "--target_date_average", - help="Averages metric over target date instead of init date", - action="store_true", - default=False) + ap.add_argument( + "-td", + "--target_date_average", + help="Averages metric over target date instead of init date", + action="store_true", + default=False) args = ap.parse_args() masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") @@ -1795,13 +1802,12 @@ def sic_error(): masks = Masks(north=args.hemisphere == "north", south=args.hemisphere == "south") - fc = get_forecast_ds(args.forecast_file, - args.forecast_date) - obs = get_obs_da(args.hemisphere, - pd.to_datetime(args.forecast_date) + - timedelta(days=1), - pd.to_datetime(args.forecast_date) + - timedelta(days=int(fc.leadtime.max()))) + fc = get_forecast_ds(args.forecast_file, args.forecast_date) + obs = get_obs_da( + args.hemisphere, + pd.to_datetime(args.forecast_date) + timedelta(days=1), + pd.to_datetime(args.forecast_date) + + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, args.forecast_date) if args.region: @@ -1818,24 +1824,17 @@ def sic_error_local(): Entry point for the icenet_plot_sic_error_local command """ - ap = ( - ForecastPlotArgParser() - .allow_probes() - ) + ap = (ForecastPlotArgParser().allow_probes()) args = ap.parse_args() - fc = get_forecast_ds(args.forecast_file, - args.forecast_date) - obs = get_obs_da(args.hemisphere, - pd.to_datetime(args.forecast_date) + - timedelta(days=1), - pd.to_datetime(args.forecast_date) + - timedelta(days=int(fc.leadtime.max()))) + fc = get_forecast_ds(args.forecast_file, args.forecast_date) + obs = get_obs_da( + args.hemisphere, + pd.to_datetime(args.forecast_date) + timedelta(days=1), + pd.to_datetime(args.forecast_date) + + timedelta(days=int(fc.leadtime.max()))) fc = filter_ds_by_obs(fc, obs, args.forecast_date) fc, obs = process_probes(args.probes, [fc, obs]) - sic_error_local_plots(fc, - obs, - args.output_path, - as_command=True) + sic_error_local_plots(fc, obs, args.output_path, as_command=True) diff --git a/icenet/plotting/utils.py b/icenet/plotting/utils.py index 74dd40a1..f08a2fc5 100644 --- a/icenet/plotting/utils.py +++ b/icenet/plotting/utils.py @@ -10,7 +10,6 @@ import pandas as pd import xarray as xr - from ibicus.debias import LinearScaling @@ -33,8 +32,8 @@ def broadcast_forecast(start_date: object, "Only one of datafiles and dataset can be set" if datafiles: - logging.info("Using {} to generate forecast through {} to {}". - format(", ".join(datafiles), start_date, end_date)) + logging.info("Using {} to generate forecast through {} to {}".format( + ", ".join(datafiles), start_date, end_date)) dataset = xr.open_mfdataset(datafiles, engine="netcdf4") dates = pd.date_range(start_date, end_date) @@ -46,8 +45,8 @@ def broadcast_forecast(start_date: object, while dataset.time.values[i + 1] < dates[0]: i += 1 - logging.info("Starting index will be {} for {} - {}". - format(i, dates[0], dates[-1])) + logging.info("Starting index will be {} for {} - {}".format( + i, dates[0], dates[-1])) dt_arr = [] for d in dates: @@ -66,10 +65,9 @@ def broadcast_forecast(start_date: object, i += 1 continue - logging.debug("Selecting date {} and lead {}". - format(pd.to_datetime( - dataset.time.values[i]).strftime("%D"), - d_lead)) + logging.debug("Selecting date {} and lead {}".format( + pd.to_datetime(dataset.time.values[i]).strftime("%D"), + d_lead)) arr = dataset.sel(time=dataset.time.values[i], leadtime=d_lead).\ @@ -90,8 +88,10 @@ def broadcast_forecast(start_date: object, return target_ds -def get_seas_forecast_init_dates(hemisphere: str, - source_path: object = os.path.join(".", "data", "mars.seas")) -> object: +def get_seas_forecast_init_dates( + hemisphere: str, + source_path: object = os.path.join(".", "data", "mars.seas") +) -> object: """ Obtains list of dates for which we have SEAS forecasts we have. @@ -101,20 +101,18 @@ def get_seas_forecast_init_dates(hemisphere: str, :return: list of dates """ # list the files in the path where SEAS forecasts are stored - filenames = os.listdir(os.path.join(source_path, - hemisphere, - "siconca")) + filenames = os.listdir(os.path.join(source_path, hemisphere, "siconca")) # obtain the dates from files with YYYYMMDD.nc format - return pd.to_datetime([x.split('.')[0] - for x in filenames - if re.search(r'^\d{8}\.nc$', x)]) + return pd.to_datetime( + [x.split('.')[0] for x in filenames if re.search(r'^\d{8}\.nc$', x)]) -def get_seas_forecast_da(hemisphere: str, - date: str, - bias_correct: bool = True, - source_path: object = os.path.join(".", "data", "mars.seas"), - ) -> tuple: +def get_seas_forecast_da( + hemisphere: str, + date: str, + bias_correct: bool = True, + source_path: object = os.path.join(".", "data", "mars.seas"), +) -> tuple: """ Atmospheric model Ensemble 15-day forecast (Set III - ENS) @@ -130,9 +128,7 @@ def get_seas_forecast_da(hemisphere: str, """ seas_file = os.path.join( - source_path, - hemisphere, - "siconca", + source_path, hemisphere, "siconca", "{}.nc".format(date.replace(day=1).strftime("%Y%m%d"))) if os.path.exists(seas_file): @@ -143,22 +139,18 @@ def get_seas_forecast_da(hemisphere: str, if bias_correct: # Let's have some maximum, though it's quite high - (start_date, end_date) = ( - date - dt.timedelta(days=10 * 365), - date + dt.timedelta(days=10 * 365) - ) + (start_date, end_date) = (date - dt.timedelta(days=10 * 365), + date + dt.timedelta(days=10 * 365)) obs_da = get_obs_da(hemisphere, start_date, end_date) - seas_hist_files = dict(sorted({os.path.abspath(el): - dt.datetime.strptime( - os.path.basename(el)[0:8], "%Y%m%d") - for el in - glob.glob(os.path.join(source_path, - hemisphere, - "siconca", - "*.nc")) - if re.search(r'^\d{8}\.nc$', - os.path.basename(el)) - and el != seas_file}.items())) + seas_hist_files = dict( + sorted({ + os.path.abspath(el): + dt.datetime.strptime(os.path.basename(el)[0:8], "%Y%m%d") + for el in glob.glob( + os.path.join(source_path, hemisphere, "siconca", "*.nc")) + if re.search(r'^\d{8}\.nc$', os.path.basename(el)) and + el != seas_file + }.items())) def strip_overlapping_time(ds): data_file = os.path.abspath(ds.encoding["source"]) @@ -166,8 +158,8 @@ def strip_overlapping_time(ds): try: idx = list(seas_hist_files.keys()).index(data_file) except ValueError: - logging.exception("\n{} not in \n\n{}".format(data_file, - seas_hist_files)) + logging.exception("\n{} not in \n\n{}".format( + data_file, seas_hist_files)) return None if idx < len(seas_hist_files) - 1: @@ -187,17 +179,16 @@ def strip_overlapping_time(ds): reasonable_physical_range=[0., 1.]) logging.info("Debiaser input ranges: obs {:.2f} - {:.2f}, " - "hist {:.2f} - {:.2f}, fut {:.2f} - {:.2f}". - format(float(obs_da.min()), float(obs_da.max()), - float(hist_da.min()), float(hist_da.max()), - float(seas_da.min()), float(seas_da.max()))) + "hist {:.2f} - {:.2f}, fut {:.2f} - {:.2f}".format( + float(obs_da.min()), float(obs_da.max()), + float(hist_da.min()), float(hist_da.max()), + float(seas_da.min()), float(seas_da.max()))) - seas_array = debiaser.apply(obs_da.values, - hist_da.values, + seas_array = debiaser.apply(obs_da.values, hist_da.values, seas_da.values) seas_da.values = seas_array - logging.info("Debiaser output range: {:.2f} - {:.2f}". - format(float(seas_da.min()), float(seas_da.max()))) + logging.info("Debiaser output range: {:.2f} - {:.2f}".format( + float(seas_da.min()), float(seas_da.max()))) logging.info("Returning SEAS data from {} from {}".format(seas_file, date)) @@ -206,23 +197,21 @@ def strip_overlapping_time(ds): date_location = list(seas_da.time.values).index(pd.Timestamp(date)) if date_location > 0: logging.warning("SEAS forecast started {} day before the requested " - "date {}, make sure you account for this!". - format(date_location, date)) + "date {}, make sure you account for this!".format( + date_location, date)) seas_da = seas_da.sel(time=slice(date, None)) logging.debug("SEAS data range: {} - {}, {} dates".format( pd.to_datetime(min(seas_da.time.values)).strftime("%Y-%m-%d"), pd.to_datetime(max(seas_da.time.values)).strftime("%Y-%m-%d"), - len(seas_da.time) - )) + len(seas_da.time))) return seas_da def get_forecast_ds(forecast_file: object, forecast_date: str, - stddev: bool = False - ) -> object: + stddev: bool = False) -> object: """ :param forecast_file: a path to a .nc file @@ -236,15 +225,12 @@ def get_forecast_ds(forecast_file: object, get_key = "sic_mean" if not stddev else "sic_stddev" forecast_ds = getattr( - forecast_ds.sel(time=slice(forecast_date, forecast_date)), - get_key) + forecast_ds.sel(time=slice(forecast_date, forecast_date)), get_key) return forecast_ds -def filter_ds_by_obs(ds: object, - obs_da: object, - forecast_date: str) -> object: +def filter_ds_by_obs(ds: object, obs_da: object, forecast_date: str) -> object: """ :param ds: @@ -253,10 +239,9 @@ def filter_ds_by_obs(ds: object, :return: """ forecast_date = pd.to_datetime(forecast_date) - (start_date, end_date) = ( - forecast_date + dt.timedelta(days=int(ds.leadtime.min())), - forecast_date + dt.timedelta(days=int(ds.leadtime.max())) - ) + (start_date, + end_date) = (forecast_date + dt.timedelta(days=int(ds.leadtime.min())), + forecast_date + dt.timedelta(days=int(ds.leadtime.max()))) if len(obs_da.time) < len(ds.leadtime): if len(obs_da.time) < 1: @@ -266,14 +251,11 @@ def filter_ds_by_obs(ds: object, logging.warning("Observational data not available for full range of " "forecast lead times: {}-{} vs {}-{}".format( - obs_da.time.to_series()[0].strftime("%D"), - obs_da.time.to_series()[-1].strftime("%D"), - start_date.strftime("%D"), - end_date.strftime("%D"))) - (start_date, end_date) = ( - obs_da.time.to_series()[0], - obs_da.time.to_series()[-1] - ) + obs_da.time.to_series()[0].strftime("%D"), + obs_da.time.to_series()[-1].strftime("%D"), + start_date.strftime("%D"), end_date.strftime("%D"))) + (start_date, end_date) = (obs_da.time.to_series()[0], + obs_da.time.to_series()[-1]) # We broadcast to get a nicely compatible dataset for plotting return broadcast_forecast(start_date=start_date, @@ -281,12 +263,12 @@ def filter_ds_by_obs(ds: object, dataset=ds) -def get_obs_da(hemisphere: str, - start_date: str, - end_date: str, - obs_source: object = - os.path.join(".", "data", "osisaf"), - ) -> object: +def get_obs_da( + hemisphere: str, + start_date: str, + end_date: str, + obs_source: object = os.path.join(".", "data", "osisaf"), +) -> object: """ :param hemisphere: string, typically either 'north' or 'south' @@ -296,14 +278,15 @@ def get_obs_da(hemisphere: str, :return: """ obs_years = pd.Series(pd.date_range(start_date, end_date)).dt.year.unique() - obs_dfs = [el for yr in obs_years for el in - glob.glob(os.path.join(obs_source, - hemisphere, - "siconca", "{}.nc".format(yr)))] + obs_dfs = [ + el for yr in obs_years for el in glob.glob( + os.path.join(obs_source, hemisphere, "siconca", "{}.nc".format(yr))) + ] if len(obs_dfs) < len(obs_years): - logging.warning("Cannot find all obs source files for {} - {} in {}". - format(start_date, end_date, obs_source)) + logging.warning( + "Cannot find all obs source files for {} - {} in {}".format( + start_date, end_date, obs_source)) logging.info("Got files: {}".format(obs_dfs)) obs_ds = xr.open_mfdataset(obs_dfs) @@ -312,10 +295,7 @@ def get_obs_da(hemisphere: str, return obs_ds.ice_conc -def calculate_extents(x1: int, - x2: int, - y1: int, - y2: int): +def calculate_extents(x1: int, x2: int, y1: int, y2: int): """ :param x1: @@ -442,8 +422,7 @@ def process_probes(probes, data) -> tuple: return data -def process_regions(region: tuple, - data: tuple) -> tuple: +def process_regions(region: tuple, data: tuple) -> tuple: """ :param region: diff --git a/icenet/plotting/video.py b/icenet/plotting/video.py index 8367a239..53f61b20 100644 --- a/icenet/plotting/video.py +++ b/icenet/plotting/video.py @@ -19,8 +19,7 @@ # TODO: This can be a plotting or analysis util function elsewhere -def get_dataarray_from_files(files: object, - numpy: bool = False) -> object: +def get_dataarray_from_files(files: object, numpy: bool = False) -> object: """ :param files: @@ -47,8 +46,8 @@ def get_dataarray_from_files(files: object, # TODO: error handling date_match = re.search(r"(\d{4})_(\d{1,2})_(\d{1,2})", nom) - dates.append(pd.to_datetime( - dt.date(*[int(s) for s in date_match.groups()]))) + dates.append( + pd.to_datetime(dt.date(*[int(s) for s in date_match.groups()]))) # FIXME: naive implementations abound path_comps = os.path.dirname(files[0]).split(os.sep) @@ -69,23 +68,23 @@ def get_dataarray_from_files(files: object, return da -def xarray_to_video(da: object, - fps: int, - video_path: object = None, - mask: object = None, - mask_type: str = 'contour', - clim: object = None, - crop: object = None, - data_type: str = 'abs', - video_dates: object = None, - cmap: object = "viridis", - figsize: int = 12, - dpi: int = 150, - imshow_kwargs: dict = None, - ax_init: object = None, - ax_extra: callable = None, - ) -> object: - +def xarray_to_video( + da: object, + fps: int, + video_path: object = None, + mask: object = None, + mask_type: str = 'contour', + clim: object = None, + crop: object = None, + data_type: str = 'abs', + video_dates: object = None, + cmap: object = "viridis", + figsize: int = 12, + dpi: int = 150, + imshow_kwargs: dict = None, + ax_init: object = None, + ax_extra: callable = None, +) -> object: """ Generate video of an xarray.DataArray. Optionally input a list of `video_dates` to show, otherwise the full set of time coordiantes @@ -138,8 +137,9 @@ def update(date): n_max = -n_min if video_dates is None: - video_dates = [pd.Timestamp(date).to_pydatetime() - for date in da.time.values] + video_dates = [ + pd.Timestamp(date).to_pydatetime() for date in da.time.values + ] if crop is not None: a = crop[0][0] @@ -179,9 +179,10 @@ def update(date): zorder=1, **imshow_kwargs if imshow_kwargs is not None else {}) - image_title = ax.set_title("{:04d}/{:02d}/{:02d}". - format(date.year, date.month, date.day), - fontsize="medium", zorder=2) + image_title = ax.set_title("{:04d}/{:02d}/{:02d}".format( + date.year, date.month, date.day), + fontsize="medium", + zorder=2) try: divider = make_axes_locatable(ax) @@ -193,10 +194,7 @@ def update(date): logging.info("Animating") # Investigated blitting, but it causes a few problems with masks/titles. - animation = FuncAnimation(fig, - update, - video_dates, - interval=1000/fps) + animation = FuncAnimation(fig, update, video_dates, interval=1000 / fps) plt.close() @@ -204,9 +202,7 @@ def update(date): logging.info("Not saving plot, will return animation") else: logging.info("Saving plot to {}".format(video_path)) - animation.save(video_path, - fps=fps, - extra_args=['-vcodec', 'libx264']) + animation.save(video_path, fps=fps, extra_args=['-vcodec', 'libx264']) return animation @@ -229,12 +225,13 @@ def recurse_data_folders(base_path: object, # TODO: should ideally use scandir for performance # TODO: naive hardcoded filtering of files logging.debug("CHILDREN: {} or LOOKUPS: {}".format(children, lookups)) - files = sorted( - [os.path.join(base_path, f) for f in os.listdir(base_path) - if os.path.splitext(f)[1] == ".{}".format(filetype) - and (re.match(r'^\d{4}\.nc$', f) - or - re.search(r'(abs|anom|linear_trend)\.nc$', f))]) + files = sorted([ + os.path.join(base_path, f) + for f in os.listdir(base_path) + if os.path.splitext(f)[1] == ".{}".format(filetype) and + (re.match(r'^\d{4}\.nc$', f) or + re.search(r'(abs|anom|linear_trend)\.nc$', f)) + ]) logging.debug("Files found: {}".format(", ".join(files))) if not len(files): @@ -250,22 +247,17 @@ def recurse_data_folders(base_path: object, if not len(lookups) or \ (len(lookups) and subdir in [str(s) for s in lookups]): subdir_files = recurse_data_folders( - new_path, - children[0] - if children is not None and len(children) > 0 else None, - children[1:] + new_path, children[0] if children is not None and + len(children) > 0 else None, children[1:] if children is not None and len(children) > 1 else None, - filetype - ) + filetype) if subdir_files: files.append(subdir_files) return files -def video_process(files: object, - numpy: object, - output_dir: object, +def video_process(files: object, numpy: object, output_dir: object, fps: int) -> object: """ @@ -302,7 +294,10 @@ def cli_args(): args.add_argument("-f", "--fps", default=15, type=int) args.add_argument("-n", "--numpy", action="store_true", default=False) - args.add_argument("-o", "--output-dir", dest="output_dir", type=str, + args.add_argument("-o", + "--output-dir", + dest="output_dir", + type=str, default="plot") args.add_argument("-p", "--path", default="data", type=str) args.add_argument("-w", "--workers", default=8, type=int) @@ -310,8 +305,10 @@ def cli_args(): args.add_argument("-v", "--verbose", action="store_true", default=False) args.add_argument("data", type=lambda s: s.split(",")) - args.add_argument("hemisphere", default=[], - choices=["north", "south"], nargs="?") + args.add_argument("hemisphere", + default=[], + choices=["north", "south"], + nargs="?") args.add_argument("--vars", default=[], type=lambda s: s.split(",")) args.add_argument("--years", default=[], type=lambda s: s.split(",")) @@ -329,24 +326,23 @@ def data_cli(): logging.info("Looking into {}".format(args.path)) path_children = [hemis, args.vars] - video_batches = recurse_data_folders(args.path, - args.data, - path_children, - filetype="nc" - if not args.numpy else "npy") + video_batches = recurse_data_folders( + args.path, + args.data, + path_children, + filetype="nc" if not args.numpy else "npy") logging.debug("Batches: {}".format(video_batches)) video_batches = [ - v_el for h_list in video_batches - for v_list in h_list - for v_el in v_list + v_el for h_list in video_batches for v_list in h_list for v_el in v_list ] if len(args.years) > 0: new_batches = [] for batch in video_batches: - batch = [el for el in batch - if os.path.basename(el)[0:4] in args.years] + batch = [ + el for el in batch if os.path.basename(el)[0:4] in args.years + ] if len(batch): new_batches.append(batch) video_batches = new_batches @@ -358,11 +354,9 @@ def data_cli(): futures = [] for batch in video_batches: - futures.append(executor.submit(video_process, - batch, - args.numpy, - args.output_dir, - args.fps)) + futures.append( + executor.submit(video_process, batch, args.numpy, + args.output_dir, args.fps)) for future in as_completed(futures): try: diff --git a/icenet/process/azure.py b/icenet/process/azure.py index db85e58d..b63bea56 100644 --- a/icenet/process/azure.py +++ b/icenet/process/azure.py @@ -60,8 +60,7 @@ def upload(): if len(ds.time) < 1: raise ValueError("No elements in {} for {}".format( - args.filename, args.date - )) + args.filename, args.date)) filename = destination_filename(tmpdir, args.filename, args.date) ds.to_netcdf(filename) @@ -76,8 +75,9 @@ def upload(): container_client = \ ContainerClient.\ from_connection_string(url, container_name=args.container) - container_client.upload_blob( - os.path.basename(filename), data, overwrite=args.overwrite) + container_client.upload_blob(os.path.basename(filename), + data, + overwrite=args.overwrite) finally: if args.date and not args.leave: logging.info("Removing {}".format(tmpdir)) diff --git a/icenet/process/forecasts.py b/icenet/process/forecasts.py index 820261e6..311c691d 100644 --- a/icenet/process/forecasts.py +++ b/icenet/process/forecasts.py @@ -16,8 +16,7 @@ from icenet.plotting.utils import broadcast_forecast, get_forecast_ds -def reproject_output(forecast_file: object, - proj_file: object, +def reproject_output(forecast_file: object, proj_file: object, save_file: object) -> object: """ @@ -34,8 +33,7 @@ def reproject_output(forecast_file: object, forecast_cube.coord('projection_y_coordinate').convert_units('meters') forecast_cube.coord('projection_x_coordinate').convert_units('meters') - logging.info("Attempting to reproject and save to {}". - format(save_file)) + logging.info("Attempting to reproject and save to {}".format(save_file)) latlon_cube = forecast_cube.regrid(gp, iris.analysis.Linear()) iris.save(latlon_cube, save_file) @@ -59,9 +57,7 @@ def broadcast_main(): """ args = broadcast_args() - broadcast_forecast(args.start_date, - args.end_date, - args.datafiles, + broadcast_forecast(args.start_date, args.end_date, args.datafiles, args.target) @@ -93,7 +89,8 @@ def geotiff_args() -> argparse.Namespace: """ ap = argparse.ArgumentParser() ap.add_argument("-o", "--output-path", default=".") - ap.add_argument("-s", "--stddev", + ap.add_argument("-s", + "--stddev", help="Plot the standard deviation from the ensemble", action="store_true", default=False) @@ -102,11 +99,13 @@ def geotiff_args() -> argparse.Namespace: ap.add_argument("forecast_date") ap.add_argument("leadtimes", help="Leadtimes to output, multiple as CSV, range as n..n", - type=lambda s: [int(i) for i in - list(s.split(",") if "," in s else - range(int(s.split("..")[0]), - int(s.split("..")[1]) + 1) if ".." in s else - [s])]) + type=lambda s: [ + int(i) for i in list( + s.split(",") + if "," in s else range(int(s.split("..")[0]), + int(s.split("..")[1]) + 1) + if ".." in s else [s]) + ]) args = ap.parse_args() return args @@ -119,12 +118,13 @@ def create_geotiff_output(): args = geotiff_args() if not os.path.isdir(args.output_path): - logging.warning("No directory at: {}, creating". - format(args.output_path)) + logging.warning("No directory at: {}, creating".format( + args.output_path)) os.makedirs(args.output_path) elif os.path.isfile(args.output_path): - raise RuntimeError("{} should be a directory and not existent...". - format(args.output_path)) + raise RuntimeError( + "{} should be a directory and not existent...".format( + args.output_path)) ds = get_forecast_ds(args.forecast_file, args.forecast_date, @@ -158,19 +158,19 @@ def create_geotiff_output(): os.path.splitext(os.path.basename(args.forecast_file))[0], args.forecast_date) - logging.info("Selecting and outputting files from {} for {}". - format(args.forecast_file, args.forecast_date)) + logging.info("Selecting and outputting files from {} for {}".format( + args.forecast_file, args.forecast_date)) for leadtime in leadtimes: pred_da = ds.sel(leadtime=leadtime) - output_filename = os.path.join(args.output_path, "{}.{}.{}tiff".format( - forecast_name, - (pd.to_datetime(args.forecast_date) + dt.timedelta( - days=leadtime)).strftime("%Y-%m-%d"), - "" if not args.stddev else "stddev." - )) + output_filename = os.path.join( + args.output_path, "{}.{}.{}tiff".format( + forecast_name, + (pd.to_datetime(args.forecast_date) + + dt.timedelta(days=leadtime)).strftime("%Y-%m-%d"), + "" if not args.stddev else "stddev.")) - logging.debug("Outputting leadtime {} to {}". - format(leadtime, output_filename)) + logging.debug("Outputting leadtime {} to {}".format( + leadtime, output_filename)) pred_da.rio.to_raster(output_filename) diff --git a/icenet/process/local.py b/icenet/process/local.py index 0e2d1285..a583e630 100644 --- a/icenet/process/local.py +++ b/icenet/process/local.py @@ -34,8 +34,8 @@ def upload(): logging.info("Local upload facility") if not os.path.isdir(args.destination): - raise RuntimeError("Destination {} does not exist". - format(args.destination)) + raise RuntimeError("Destination {} does not exist".format( + args.destination)) if args.date: ds = xr.open_dataset(args.filename) @@ -43,11 +43,9 @@ def upload(): if len(ds.time) < 1: raise ValueError("No elements in {} for {}".format( - args.filename, args.date - )) + args.filename, args.date)) - filename = destination_filename(args.destination, - args.filename, + filename = destination_filename(args.destination, args.filename, args.date) ds.to_netcdf(filename) ds.close() diff --git a/icenet/process/predict.py b/icenet/process/predict.py index 6e1ec678..9e6c2734 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -59,9 +59,7 @@ def get_refcube(north: bool = True, south: bool = False) -> object: return cube -def get_prediction_data(root: object, - name: object, - date: object) -> tuple: +def get_prediction_data(root: object, name: object, date: object) -> tuple: """ :param root: @@ -71,11 +69,7 @@ def get_prediction_data(root: object, """ logging.info("Post-processing {}".format(date)) - glob_str = os.path.join(root, - "results", - "predict", - name, - "*", + glob_str = os.path.join(root, "results", "predict", name, "*", date.strftime("%Y_%m_%d.npy")) np_files = glob.glob(glob_str) @@ -87,11 +81,11 @@ def get_prediction_data(root: object, data = np.array(data) ens_members = data.shape[0] - logging.debug("Data read from disk: {} from: {}".format(data.shape, np_files)) + logging.debug("Data read from disk: {} from: {}".format( + data.shape, np_files)) - return np.stack( - [data.mean(axis=0), data.std(axis=0)], - axis=-1).squeeze(), ens_members + return np.stack([data.mean(axis=0), data.std(axis=0)], + axis=-1).squeeze(), ens_members def date_arg(string: str) -> object: @@ -117,12 +111,20 @@ def get_args(): ap.add_argument("-m", "--mask", default=False, action="store_true") - ap.add_argument("--nan", help="Apply nans, not zeroes, to land mask", - default=False, action="store_true") - ap.add_argument("--no-agcm", help="No active grid cell masking", - default=True, action="store_false", dest="agcm") - ap.add_argument("--no-land", help="No land while masking", - default=True, action="store_false", dest="land") + ap.add_argument("--nan", + help="Apply nans, not zeroes, to land mask", + default=False, + action="store_true") + ap.add_argument("--no-agcm", + help="No active grid cell masking", + default=True, + action="store_false", + dest="agcm") + ap.add_argument("--no-land", + help="No land while masking", + default=True, + action="store_false", + dest="land") ap.add_argument("-o", "--output-dir", default=".") ap.add_argument("-r", "--root", type=str, default=".") @@ -145,11 +147,15 @@ def create_cf_output(): ref_sic = xr.open_dataset(get_refsic(ds.north, ds.south)) ref_cube = get_refcube(ds.north, ds.south) - dates = [dt.date(*[int(v) for v in s.split("-")]) - for s in args.datefile.read().split()] + dates = [ + dt.date(*[int(v) + for v in s.split("-")]) + for s in args.datefile.read().split() + ] args.datefile.close() - arr, ens_members = zip(*[get_prediction_data(args.root, args.name, date) for date in dates]) + arr, ens_members = zip( + *[get_prediction_data(args.root, args.name, date) for date in dates]) ens_members = list(ens_members) arr = np.array(arr) @@ -166,11 +172,14 @@ def create_cf_output(): for idx, forecast_date in enumerate(dates): for lead_idx in np.arange(0, arr.shape[3], 1): - lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) - logging.debug("Active grid cell mask start {} forecast date {}". - format(forecast_date, lead_dt)) - - grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) + lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + + 1) + logging.debug( + "Active grid cell mask start {} forecast date {}". + format(forecast_date, lead_dt)) + + grid_cell_mask = mask_gen.get_active_cell_mask( + lead_dt.month) sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 @@ -190,11 +199,11 @@ def create_cf_output(): sic_mean[mask] = 0 sic_stddev[mask] = 0 - lists_of_fcast_dates = [ - [pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) - for lead_idx in np.arange(1, arr.shape[3] + 1, 1)] - for date in dates + lists_of_fcast_dates = [[ + pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) + for lead_idx in np.arange(1, arr.shape[3] + 1, 1) ] + for date in dates] xarr = xr.Dataset( data_vars=dict( @@ -232,12 +241,12 @@ def create_cf_output(): history="{} - creation".format(dt.datetime.now()), id="IceNet {}".format(icenet_version), institution="British Antarctic Survey", - keywords="""'Earth Science > Cryosphere > Sea Ice > Sea Ice Concentration + keywords= + """'Earth Science > Cryosphere > Sea Ice > Sea Ice Concentration Earth Science > Oceans > Sea Ice > Sea Ice Concentration Earth Science > Climate Indicators > Cryospheric Indicators > Sea Ice Geographic Region > {} Hemisphere""".format( - "Northern" if ds.north else "Southern" - ), + "Northern" if ds.north else "Southern"), # TODO: check we're valid keywords_vocabulary="GCMD Science Keywords", # TODO: Double check this is good with PDC @@ -267,13 +276,16 @@ def create_cf_output(): """, # Use ISO 8601:2004 duration format, preferably the extended format # as recommended in the Attribute Content Guidance section. - time_coverage_start=min(set([item for row in lists_of_fcast_dates for item in row])).isoformat(), - time_coverage_end=max(set([item for row in lists_of_fcast_dates for item in row])).isoformat(), + time_coverage_start=min( + set([item for row in lists_of_fcast_dates for item in row + ])).isoformat(), + time_coverage_end=max( + set([item for row in lists_of_fcast_dates for item in row + ])).isoformat(), time_coverage_duration="P1D", time_coverage_resolution="P1D", title="Sea Ice Concentration Prediction", - ) - ) + )) xarr.time.attrs = dict( long_name=ref_cube.coord("time").long_name, @@ -315,7 +327,7 @@ def create_cf_output(): xarr.sic_mean.attrs = dict( long_name="mean sea ice area fraction across ensemble runs of icenet " - "model", + "model", standard_name="sea_ice_area_fraction", short_name="sic", valid_min=0, @@ -326,7 +338,8 @@ def create_cf_output(): ) xarr.sic_stddev.attrs = dict( - long_name="total uncertainty (one standard deviation) of concentration of sea ice", + long_name= + "total uncertainty (one standard deviation) of concentration of sea ice", standard_name="sea_ice_area_fraction standard_error", valid_min=0, valid_max=1, diff --git a/icenet/process/utils.py b/icenet/process/utils.py index 00845594..d5d6cb07 100644 --- a/icenet/process/utils.py +++ b/icenet/process/utils.py @@ -15,8 +15,7 @@ def date_arg(string: str) -> object: return dt.date(*[int(s) for s in d_match]) -def destination_filename(destination: object, - filename: str, +def destination_filename(destination: object, filename: str, date: object) -> object: """ @@ -25,11 +24,9 @@ def destination_filename(destination: object, :param date: :return: """ - return os.path.join(destination, - "{}.{}{}".format( - os.path.splitext( - os.path.basename(filename))[0], - date.strftime("%Y-%m-%d"), - os.path.splitext( - os.path.basename(filename))[1], - )) + return os.path.join( + destination, "{}.{}{}".format( + os.path.splitext(os.path.basename(filename))[0], + date.strftime("%Y-%m-%d"), + os.path.splitext(os.path.basename(filename))[1], + )) diff --git a/icenet/results/threshold.py b/icenet/results/threshold.py index 84e870e6..1abd60bd 100644 --- a/icenet/results/threshold.py +++ b/icenet/results/threshold.py @@ -23,8 +23,8 @@ def threshold_exceeds(da: object, logging.info("Checking thresholds for forecast(s)") if dimensions: - logging.debug("Selecting within given dimensions: {}". - format(dimensions)) + logging.debug( + "Selecting within given dimensions: {}".format(dimensions)) da = da.sel(**dimensions) thresh_arr = da > sic_thresh @@ -62,5 +62,6 @@ def threshold_main(): np.save(fh, threshold_map) logging.info("Saved to {}".format(args.output_file)) else: - logging.info("No output file provided: {} cells breached threshold". - format(len(threshold_map))) + logging.info( + "No output file provided: {} cells breached threshold".format( + len(threshold_map))) diff --git a/icenet/tests/test_entry_points.py b/icenet/tests/test_entry_points.py index 09e1d257..c0d440fa 100644 --- a/icenet/tests/test_entry_points.py +++ b/icenet/tests/test_entry_points.py @@ -3,7 +3,6 @@ from importlib_metadata import entry_points import pytest - icenet_entry_points = [ ep for ep in entry_points(group="console_scripts") if ep.module.startswith('icenet') diff --git a/icenet/tests/test_mod.py b/icenet/tests/test_mod.py index a4741073..084b4cdd 100644 --- a/icenet/tests/test_mod.py +++ b/icenet/tests/test_mod.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - """Tests for `icenet` package.""" import pytest diff --git a/icenet/utils.py b/icenet/utils.py index 8f442c52..f76f5b81 100644 --- a/icenet/utils.py +++ b/icenet/utils.py @@ -68,8 +68,8 @@ def run_command(command: str, dry: bool = False): ret = sp.run(command, shell=True) if ret.returncode < 0: - logging.warning("Child was terminated by signal: {}". - format(-ret.returncode)) + logging.warning( + "Child was terminated by signal: {}".format(-ret.returncode)) else: logging.info("Child returned: {}".format(-ret.returncode)) @@ -78,6 +78,7 @@ def run_command(command: str, dry: bool = False): def setup_logging(func, log_format="[%(asctime)-17s :%(levelname)-8s] - %(message)s"): + @wraps(func) def wrapper(*args, **kwargs): parsed_args = func(*args, **kwargs) @@ -100,4 +101,5 @@ def wrapper(*args, **kwargs): logging.getLogger("tensorflow").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) return parsed_args + return wrapper diff --git a/setup.py b/setup.py index 1d97928d..a5fb7732 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from setuptools import setup, find_packages import icenet - """Setup module for icenet """ @@ -17,9 +16,8 @@ def get_content(filename): author=icenet.__author__, author_email=icenet.__email__, description="Library for operational IceNet forecasting", - long_description="""{}\n---\n""". - format(get_content("README.md"), - get_content("HISTORY.rst")), + long_description="""{}\n---\n""".format(get_content("README.md"), + get_content("HISTORY.rst")), long_description_content_type="text/markdown", url="https://github.com/icenet-ai", packages=find_packages(), @@ -36,44 +34,35 @@ def get_content(filename): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - ], + ], entry_points={ "console_scripts": [ "icenet_data_masks = icenet.data.sic.mask:main", - "icenet_data_cmip = icenet.data.interfaces.esgf:main", "icenet_data_era5 = icenet.data.interfaces.cds:main", "icenet_data_oras5 = icenet.data.interfaces.cmems:main", "icenet_data_hres = icenet.data.interfaces.mars:hres_main", "icenet_data_seas = icenet.data.interfaces.mars:seas_main", "icenet_data_sic = icenet.data.sic.osisaf:main", - "icenet_data_reproc_monthly = " "icenet.data.interfaces.utils:reprocess_main", "icenet_data_add_time_dim = " "icenet.data.interfaces.utils:add_time_dim_main", - "icenet_process_cmip = icenet.data.processors.cmip:main", "icenet_process_era5 = icenet.data.processors.era5:main", "icenet_process_oras5 = icenet.data.processors.oras5:main", "icenet_process_hres = icenet.data.processors.hres:main", "icenet_process_sic = icenet.data.processors.osi:main", - "icenet_process_metadata = icenet.data.processors.meta:main", - "icenet_process_condense = " "icenet.data.processors.utils:condense_main", - "icenet_dataset_check = icenet.data.dataset:check_dataset", "icenet_dataset_create = icenet.data.loader:create", - "icenet_train = icenet.model.train:main", "icenet_predict = icenet.model.predict:main", "icenet_upload_azure = icenet.process.azure:upload", "icenet_upload_local = icenet.process.local:upload", - "icenet_plot_record = icenet.plotting.data:plot_tfrecord", - "icenet_plot_forecast = icenet.plotting.forecast:plot_forecast", "icenet_plot_input = icenet.plotting.data:plot_sample_cli", "icenet_plot_sic_error = icenet.plotting.forecast:sic_error", @@ -83,9 +72,7 @@ def get_content(filename): "icenet_plot_sie_error = icenet.plotting.forecast:sie_error", "icenet_plot_metrics = icenet.plotting.forecast:metric_plots", "icenet_plot_leadtime_avg = icenet.plotting.forecast:leadtime_avg_plots", - "icenet_video_data = icenet.plotting.video:data_cli", - "icenet_output = icenet.process.predict:create_cf_output", "icenet_output_geotiff = " "icenet.process.forecasts:create_geotiff_output", @@ -93,18 +80,17 @@ def get_content(filename): "icenet.process.forecasts:broadcast_main", "icenet_output_reproject = " "icenet.process.forecasts:reproject_main", - "icenet_result_threshold = " "icenet.results.threshold:threshold_main" ], - }, + }, python_requires='>=3.7, <4', install_requires=get_content("requirements.txt"), include_package_data=True, extras_require={ "dev": get_content("requirements_dev.txt"), "docs": get_content("docs/requirements.txt"), - }, + }, test_suite='tests', tests_require=['pytest>=3'], zip_safe=False, From 783fe38dcc1e252465058a01a466a5b31cd74ea8 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 14:25:34 +0000 Subject: [PATCH 22/61] Dev 193: pre-commit config updates --- .pre-commit-config.yaml | 15 ++++++++------- setup.cfg | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cad224ca..4911deff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,14 @@ repos: - id: trailing-whitespace language_version: python3 + # isort - Sorting imports + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--filter-files"] + + # yapf - Formatting - repo: https://github.com/google/yapf rev: v0.40.2 hooks: @@ -18,13 +26,6 @@ repos: args: ["--in-place", "--parallel"] exclude: "docs/" - # isort - Sorting imports - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--filter-files"] - # ruff - Linting - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.6 diff --git a/setup.cfg b/setup.cfg index b29f2b65..f35971a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ universal = 1 [isort] -profile=google +profile=black [flake8] exclude = docs From 2f06140da6a15c6265204a645ceb2b6c8ebda02f Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 16:49:46 +0000 Subject: [PATCH 23/61] Dev 193: Set ruff line-length --- .ruff.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.ruff.toml b/.ruff.toml index af4b09be..7b4c49eb 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,3 +1,5 @@ +line-length = 120 + [lint] select = ["E", "F"] ignore = ["E721"] From d5133b8e2765510488dafd3a5e6335dd9b98ce6f Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 16:50:19 +0000 Subject: [PATCH 24/61] Dev 193: isort options update --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index f35971a0..9809e2aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,9 @@ universal = 1 [isort] profile=black +line_length = 120 +atomic = True +lines_after_imports = 0 [flake8] exclude = docs From b9909856c52d2cc2ee2e983e2f374d247f82582c Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 20 Nov 2023 17:21:55 +0000 Subject: [PATCH 25/61] Dev 193: Config updates --- .pre-commit-config.yaml | 12 ++++++------ requirements_dev.txt | 1 + setup.cfg | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4911deff..a37dc202 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,12 +10,12 @@ repos: - id: trailing-whitespace language_version: python3 - # isort - Sorting imports - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--filter-files"] + # # isort - Sorting imports + # - repo: https://github.com/pycqa/isort + # rev: 5.12.0 + # hooks: + # - id: isort + # args: ["--filter-files"] # yapf - Formatting - repo: https://github.com/google/yapf diff --git a/requirements_dev.txt b/requirements_dev.txt index c256644d..52d55ce0 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -7,5 +7,6 @@ pytest black build importlib_metadata +yapf ruff pre-commit diff --git a/setup.cfg b/setup.cfg index 9809e2aa..34616f49 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ universal = 1 profile=black line_length = 120 atomic = True -lines_after_imports = 0 +# lines_after_imports = 1 [flake8] exclude = docs From 0429bdc120cc6748f483aad6295941574275e76b Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Thu, 23 Nov 2023 17:40:26 +0000 Subject: [PATCH 26/61] Unpin h5py --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ec9085fb..82df92c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ dask distributed eccodes ecmwf-api-client -h5py==2.10.0 +h5py ibicus matplotlib motuclient From c5480473160c2b63b43a91755dee86d5165aee6b Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Thu, 23 Nov 2023 18:39:22 +0000 Subject: [PATCH 27/61] Dev 193: Update CONTRIBUTING.rst --- CONTRIBUTING.rst | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 9df94d9a..4fb769de 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -76,22 +76,27 @@ Ready to contribute? Here's how to set up `icenet` for local development. Now you can make your changes locally. -5. When you're done making changes, check that your changes pass flake8 and the - tests, including testing other Python versions with tox:: +5. Install development packages + + $ pip install -r requirements.txt + +6. Set up pre-commit hooks to run automatically. This will run through linting checks, formatting, and pytest. It will by format new code using yapf and prevent code commit that does not pass linting or testing checks until fixed. + + $ pre-commit install + +7. Run through tox (currently omitted from pre-commit hook) to test other Python versions (Optionally, can replace with tox-conda, and run same command): - $ flake8 icenet tests - $ python setup.py test or pytest $ tox - To get flake8 and tox, just pip install them into your virtualenv. + To get tox, just pip install them into your virtualenv (or tox-conda for conda environment). -6. Commit your changes and push your branch to GitHub:: +8. Commit your changes and push your branch to GitHub:: $ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature -7. Submit a pull request through the GitHub website. +9. Submit a pull request through the GitHub website. Pull Request Guidelines ----------------------- From 71850dcac16da06902d39bbe85b040d7f647fa8f Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Thu, 23 Nov 2023 18:48:25 +0000 Subject: [PATCH 28/61] Dev 193: Update CONTRIBUTING.rst #2 --- CONTRIBUTING.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 4fb769de..7d7bf0fd 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -76,15 +76,15 @@ Ready to contribute? Here's how to set up `icenet` for local development. Now you can make your changes locally. -5. Install development packages +5. Install development packages:: $ pip install -r requirements.txt -6. Set up pre-commit hooks to run automatically. This will run through linting checks, formatting, and pytest. It will by format new code using yapf and prevent code commit that does not pass linting or testing checks until fixed. +6. Set up pre-commit hooks to run automatically. This will run through linting checks, formatting, and pytest. It will format new code using yapf and prevent code committing that does not pass linting or testing checks until fixed:: $ pre-commit install -7. Run through tox (currently omitted from pre-commit hook) to test other Python versions (Optionally, can replace with tox-conda, and run same command): +7. Run through tox (currently omitted from pre-commit hook) to test other Python versions (Optionally, can replace with tox-conda, and run same command):: $ tox @@ -96,6 +96,8 @@ Ready to contribute? Here's how to set up `icenet` for local development. $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature + Note: When committing, if pre-commit is installed, there might be formatting changes made by yapf and the commit prevented. In this case, add the file(s) modified by the formatter to the staging area and commit again. + 9. Submit a pull request through the GitHub website. Pull Request Guidelines From 3dd501ef8f505e17ab357536305a53fee3fa2797 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Thu, 23 Nov 2023 21:43:27 +0000 Subject: [PATCH 29/61] Dev #193: Omits setup.py in pre-commit runs --- .pre-commit-config.yaml | 2 +- setup.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a37dc202..43172b40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: (LICENSE|README.md) +exclude: (LICENSE|README.md|setup.py) repos: # General pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/setup.py b/setup.py index a5fb7732..333a4b30 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ from setuptools import setup, find_packages import icenet + """Setup module for icenet """ @@ -34,35 +35,44 @@ def get_content(filename): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - ], + ], entry_points={ "console_scripts": [ "icenet_data_masks = icenet.data.sic.mask:main", + "icenet_data_cmip = icenet.data.interfaces.esgf:main", "icenet_data_era5 = icenet.data.interfaces.cds:main", "icenet_data_oras5 = icenet.data.interfaces.cmems:main", "icenet_data_hres = icenet.data.interfaces.mars:hres_main", "icenet_data_seas = icenet.data.interfaces.mars:seas_main", "icenet_data_sic = icenet.data.sic.osisaf:main", + "icenet_data_reproc_monthly = " "icenet.data.interfaces.utils:reprocess_main", "icenet_data_add_time_dim = " "icenet.data.interfaces.utils:add_time_dim_main", + "icenet_process_cmip = icenet.data.processors.cmip:main", "icenet_process_era5 = icenet.data.processors.era5:main", "icenet_process_oras5 = icenet.data.processors.oras5:main", "icenet_process_hres = icenet.data.processors.hres:main", "icenet_process_sic = icenet.data.processors.osi:main", + "icenet_process_metadata = icenet.data.processors.meta:main", + "icenet_process_condense = " "icenet.data.processors.utils:condense_main", + "icenet_dataset_check = icenet.data.dataset:check_dataset", "icenet_dataset_create = icenet.data.loader:create", + "icenet_train = icenet.model.train:main", "icenet_predict = icenet.model.predict:main", "icenet_upload_azure = icenet.process.azure:upload", "icenet_upload_local = icenet.process.local:upload", + "icenet_plot_record = icenet.plotting.data:plot_tfrecord", + "icenet_plot_forecast = icenet.plotting.forecast:plot_forecast", "icenet_plot_input = icenet.plotting.data:plot_sample_cli", "icenet_plot_sic_error = icenet.plotting.forecast:sic_error", @@ -72,7 +82,9 @@ def get_content(filename): "icenet_plot_sie_error = icenet.plotting.forecast:sie_error", "icenet_plot_metrics = icenet.plotting.forecast:metric_plots", "icenet_plot_leadtime_avg = icenet.plotting.forecast:leadtime_avg_plots", + "icenet_video_data = icenet.plotting.video:data_cli", + "icenet_output = icenet.process.predict:create_cf_output", "icenet_output_geotiff = " "icenet.process.forecasts:create_geotiff_output", @@ -80,17 +92,18 @@ def get_content(filename): "icenet.process.forecasts:broadcast_main", "icenet_output_reproject = " "icenet.process.forecasts:reproject_main", + "icenet_result_threshold = " "icenet.results.threshold:threshold_main" ], - }, + }, python_requires='>=3.7, <4', install_requires=get_content("requirements.txt"), include_package_data=True, extras_require={ "dev": get_content("requirements_dev.txt"), "docs": get_content("docs/requirements.txt"), - }, + }, test_suite='tests', tests_require=['pytest>=3'], zip_safe=False, From 907eb95b56bde7ac745aa3fae3a78824d18baa09 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sun, 26 Nov 2023 18:44:06 +0000 Subject: [PATCH 30/61] Dev #20: Add myst_parser to parse md in rst docs --- docs/conf.py | 3 ++- docs/icenet.rst | 26 +++++++++++++------------- docs/readme.rst | 1 + docs/requirements.txt | 6 +++--- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 45e0514e..0d6ce49b 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,7 +32,8 @@ extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode'] + 'sphinx.ext.viewcode', + 'myst_parser'] # Standardising on napoleon_numpy_docstring = True diff --git a/docs/icenet.rst b/docs/icenet.rst index dc166bba..a9a19b31 100644 --- a/docs/icenet.rst +++ b/docs/icenet.rst @@ -5,13 +5,14 @@ Subpackages ----------- .. toctree:: + :maxdepth: 4 - icenet.data - icenet.model - icenet.plotting - icenet.process - icenet.results - icenet.tests + icenet.data + icenet.model + icenet.plotting + icenet.process + icenet.results + icenet.tests Submodules ---------- @@ -20,15 +21,14 @@ icenet.utils module ------------------- .. automodule:: icenet.utils - :members: - :undoc-members: - :show-inheritance: - + :members: + :undoc-members: + :show-inheritance: Module contents --------------- .. automodule:: icenet - :members: - :undoc-members: - :show-inheritance: + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/readme.rst b/docs/readme.rst index d3bab802..f7d909e8 100644 --- a/docs/readme.rst +++ b/docs/readme.rst @@ -2,3 +2,4 @@ IceNet README ============= .. include:: ../README.md + :parser: myst_parser.sphinx_ diff --git a/docs/requirements.txt b/docs/requirements.txt index 217b17e5..13c79e32 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,3 @@ -jinja2==3.0.3 -Sphinx==1.8.5 - +jinja2 +Sphinx +myst_parser From bf6903e3349addf6164e5c70506e2fad5bccafce Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sun, 26 Nov 2023 20:55:32 +0000 Subject: [PATCH 31/61] Dev #20: UML diag gen work --- docs/api.rst | 13 +++++++++++++ docs/conf.py | 8 +++++++- docs/index.rst | 1 + 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 docs/api.rst diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 00000000..4705441d --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,13 @@ +Class UML Diagram +================= + +.. uml:: icenet + :classes: + +.. .. mermaid:: ../classes_icenet.mmd +.. :zoom: + +.. .. plantuml:: ../classes_icenet.puml +.. :format: svg +.. :caption: Class UML Diagram +.. :align: center \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 0d6ce49b..bdb89c44 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,7 +33,11 @@ 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'myst_parser'] + 'myst_parser', + # 'sphinx_pyreverse', + # 'sphinxcontrib.mermaid' + # 'sphinxcontrib.plantuml', + ] # Standardising on napoleon_numpy_docstring = True @@ -165,4 +169,6 @@ ] +sphinx_pyreverse_output = "png" +# mermaid_output_format = "svg" \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 33d9e07a..c25f06f9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,6 +9,7 @@ Welcome to IceNet's documentation! installation usage modules + api contributing authors history From 517eb5633d6deae78b0eb174ab5e758dd4ddb68c Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sun, 26 Nov 2023 20:57:10 +0000 Subject: [PATCH 32/61] Dev #20: UML diag gen work +1 --- docs/api.rst | 10 +++++----- docs/requirements.txt | 3 +++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 4705441d..ba70a136 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,13 +1,13 @@ Class UML Diagram ================= -.. uml:: icenet - :classes: +.. .. uml:: icenet +.. :classes: -.. .. mermaid:: ../classes_icenet.mmd +.. .. mermaid:: ./classes_icenet.mmd .. :zoom: -.. .. plantuml:: ../classes_icenet.puml +.. .. plantuml:: ./classes_icenet.puml .. :format: svg .. :caption: Class UML Diagram -.. :align: center \ No newline at end of file +.. :align: center diff --git a/docs/requirements.txt b/docs/requirements.txt index 13c79e32..aca8e5e0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,6 @@ jinja2 Sphinx myst_parser +sphinx-pyreverse +sphinxcontrib-mermaid +sphinx-plantuml From 2c6b3086ac26af97c1c621c509fbdd4507392a2d Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sun, 26 Nov 2023 23:19:43 +0000 Subject: [PATCH 33/61] Dev #20: Add UML output to docs --- Makefile | 1 + docs/api.rst | 13 -- docs/classes_icenet.puml | 283 ++++++++++++++++++++++++++++++++++++++ docs/conf.py | 4 +- docs/index.rst | 2 +- docs/packages_icenet.puml | 228 ++++++++++++++++++++++++++++++ docs/requirements.txt | 5 +- docs/uml.rst | 4 + 8 files changed, 520 insertions(+), 20 deletions(-) delete mode 100644 docs/api.rst create mode 100644 docs/classes_icenet.puml create mode 100644 docs/packages_icenet.puml create mode 100644 docs/uml.rst diff --git a/Makefile b/Makefile index cb4376bc..978939bd 100644 --- a/Makefile +++ b/Makefile @@ -69,6 +69,7 @@ coverage: ## check code coverage quickly with the default Python docs: ## generate Sphinx HTML documentation, including API docs rm -f docs/icenet.rst rm -f docs/modules.rst + pyreverse --output puml --output-directory docs/ --project icenet icenet --colorized sphinx-apidoc -o docs/ icenet $(MAKE) -C docs clean $(MAKE) -C docs html diff --git a/docs/api.rst b/docs/api.rst deleted file mode 100644 index ba70a136..00000000 --- a/docs/api.rst +++ /dev/null @@ -1,13 +0,0 @@ -Class UML Diagram -================= - -.. .. uml:: icenet -.. :classes: - -.. .. mermaid:: ./classes_icenet.mmd -.. :zoom: - -.. .. plantuml:: ./classes_icenet.puml -.. :format: svg -.. :caption: Class UML Diagram -.. :align: center diff --git a/docs/classes_icenet.puml b/docs/classes_icenet.puml new file mode 100644 index 00000000..48136ef9 --- /dev/null +++ b/docs/classes_icenet.puml @@ -0,0 +1,283 @@ +@startuml classes_icenet +set namespaceSeparator none +class "BatchwiseModelCheckpoint" as icenet.model.callbacks.BatchwiseModelCheckpoint #44BB99 { + best + mode : object + model_path : object + monitor : object + sample_at_zero : object + save_frequency : object + on_train_batch_end(batch: object, logs: object) +} +class "CMIP6Downloader" as icenet.data.interfaces.esgf.CMIP6Downloader #99DDFF { + ESGF_NODES : tuple + GRID_MAP : dict + TABLE_MAP : dict + additional_regrid_processing(datafile: str, cube_ease: object) + convert_cube(cube: object) -> object +} +class "ClimateDownloader" as icenet.data.interfaces.downloader.ClimateDownloader #99DDFF { + dates + delete + download_method + group_dates_by + levels + pregrid_prefix + sic_ease_cube + var_names + {abstract}additional_regrid_processing(datafile: str, cube_ease: object) + convert_cube(cube: object) + download() + get_req_filenames(var_folder: str, req_date: object, date_format: str) + postprocess(var, download_path) + regrid(files: object, rotate_wind: bool) + rotate_wind_data(apply_to: object, manual_files: object) + save_temporal_files(var, da, date_format, freq) +} +class "ConstructLeadtimeAccuracy" as icenet.model.metrics.ConstructLeadtimeAccuracy #44BB99 { + single_forecast_leadtime_idx : Optional[object] + use_all_forecast_months : bool + from_config(config: object) + get_config() + result() + update_state(y_true: object, y_pred: object, sample_weight: object) +} +class "DaskBaseDataLoader" as icenet.data.loaders.dask.DaskBaseDataLoader #99DDFF { + {abstract}client_generate(client: object, dates_override: object, pickup: bool) + generate() +} +class "DaskMultiSharingWorkerLoader" as icenet.data.loaders.dask.DaskMultiSharingWorkerLoader #99DDFF { + {abstract}client_generate(client: object, dates_override: object, pickup: bool) + {abstract}generate_sample(date: object, prediction: bool) +} +class "DaskMultiWorkerLoader" as icenet.data.loaders.dask.DaskMultiWorkerLoader #99DDFF { + client_generate(client: object, dates_override: object, pickup: bool) + generate_sample(date: object, prediction: bool) +} +class "DaskWrapper" as icenet.data.sic.osisaf.DaskWrapper #99DDFF { + dask_process() +} +class "DataCollection" as icenet.data.producers.DataCollection #99DDFF { + base_path + identifier +} +class "DataProducer" as icenet.data.producers.DataProducer #99DDFF { + dry : bool + overwrite : bool + get_data_var_folder(var: str, append: object, hemisphere: object, missing_error: bool) -> str +} +class "Downloader" as icenet.data.producers.Downloader #99DDFF { + {abstract}download() +} +class "ERA5Downloader" as icenet.data.interfaces.cds.ERA5Downloader #99DDFF { + CDI_MAP : dict + client : Client + download_method + additional_regrid_processing(datafile: str, cube_ease: object) + postprocess(var: str, download_path: object) +} +class "ForecastPlotArgParser" as icenet.plotting.forecast.ForecastPlotArgParser #BBCC33 { + allow_ecmwf() + allow_metrics() + allow_probes() + allow_sie() + allow_threshold() + parse_args() +} +class "Generator" as icenet.data.producers.Generator #99DDFF { + {abstract}generate() +} +class "HRESDownloader" as icenet.data.interfaces.mars.HRESDownloader #99DDFF { + HRES_PARAMS : dict + MARS_TEMPLATE : str + PARAM_TABLE : int + mars_template + param_table + params + additional_regrid_processing(datafile: str, cube_ease: object) + download() +} +class "Hemisphere" as icenet.utils.Hemisphere #77AADD { + name +} +class "HemisphereMixin" as icenet.utils.HemisphereMixin #77AADD { + both + hemisphere + hemisphere_loc + hemisphere_str + north + south +} +class "IceNetBaseDataLoader" as icenet.data.loaders.base.IceNetBaseDataLoader #99DDFF { + channel_names + config + dates_override + num_channels + pickup + workers + {abstract}generate_sample(date: object, prediction: bool) + get_sample_files() -> object + write_dataset_config_only() +} +class "IceNetCMIPPreProcessor" as icenet.data.processors.cmip.IceNetCMIPPreProcessor #99DDFF { + pre_normalisation(var_name: str, da: object) +} +class "IceNetDataLoader" as icenet.data.loaders.stdlib.IceNetDataLoader #99DDFF { + {abstract}generate() + {abstract}generate_sample(date: object, prediction: bool) +} +class "IceNetDataLoaderFactory" as icenet.data.loaders.IceNetDataLoaderFactory #99DDFF { + loader_map + add_data_loader(loader_name: str, loader_impl: object) + create_data_loader(loader_name) +} +class "IceNetDataSet" as icenet.data.dataset.IceNetDataSet #99DDFF { + channels + counts + loader_config + get_data_loader(n_forecast_days, generate_workers) +} +class "IceNetDataWarning" as icenet.data.loaders.utils.IceNetDataWarning #99DDFF { +} +class "IceNetERA5PreProcessor" as icenet.data.processors.era5.IceNetERA5PreProcessor #99DDFF { +} +class "IceNetHRESPreProcessor" as icenet.data.processors.hres.IceNetHRESPreProcessor #99DDFF { +} +class "IceNetMetaPreProcessor" as icenet.data.processors.meta.IceNetMetaPreProcessor #99DDFF { + {abstract}init_source_data(lag_days: object, lead_days: object) + process() +} +class "IceNetORAS5PreProcessor" as icenet.data.processors.oras5.IceNetORAS5PreProcessor #99DDFF { +} +class "IceNetOSIPreProcessor" as icenet.data.processors.osi.IceNetOSIPreProcessor #99DDFF { + missing_dates : list + pre_normalisation(var_name: str, da: object) +} +class "IceNetPreProcessor" as icenet.data.process.IceNetPreProcessor #99DDFF { + DATE_FORMAT : str + missing_dates + mean_and_std(array: object) + post_normalisation(var_name: str, da: object) + pre_normalisation(var_name: str, da: object) + process() + update_loader_config() +} +class "IceNetPreTrainingEvaluator" as icenet.model.callbacks.IceNetPreTrainingEvaluator #44BB99 { + sample_at_zero : bool + val_dataloader + validation_frequency + on_train_batch_end(batch: object, logs: object) +} +class "Masks" as icenet.data.sic.mask.Masks #99DDFF { + LAND_MASK_FILENAME : str + POLARHOLE_DATES : tuple + POLARHOLE_RADII : tuple + generate(year: int, save_land_mask: bool, save_polarhole_masks: bool, remove_temp_files: bool) + get_active_cell_da(src_da: object) -> object + get_active_cell_mask(month: object) -> object + get_blank_mask() -> object + get_land_mask(land_mask_filename: str) -> object + get_polarhole_mask(date: object) -> object + init_params() + reset_region() +} +class "MergedIceNetDataSet" as icenet.data.dataset.MergedIceNetDataSet #99DDFF { + channels + counts + {abstract}check_dataset(split: str) + get_data_loader() +} +class "ORAS5Downloader" as icenet.data.interfaces.cmems.ORAS5Downloader #99DDFF { + ENDPOINTS : dict + VAR_MAP : dict + download_method + additional_regrid_processing(datafile: object, cube_ease: object) -> object + postprocess(var: str, download_path: object) +} +class "Processor" as icenet.data.producers.Processor #99DDFF { + dates + lead_time + processed_files + source_data + init_source_data(lag_days: object) + {abstract}process() + save_processed_file(var_name: str, name: str, data: object) +} +class "SEASDownloader" as icenet.data.interfaces.mars.SEASDownloader #99DDFF { + MARS_TEMPLATE : str + save_temporal_files(var, da, date_format, freq) +} +class "SICDownloader" as icenet.data.sic.osisaf.SICDownloader #99DDFF { + download() + missing_dates() +} +class "SplittingMixin" as icenet.data.datasets.utils.SplittingMixin #99DDFF { + batch_size + dtype + n_forecast_days + num_channels + shape + shuffling + test_fns : list + train_fns : list + val_fns : list + add_records(base_path: str, hemi: str) + check_dataset(split: str) + get_split_datasets(ratio: object) +} +class "TemperatureScale" as icenet.model.models.TemperatureScale #44BB99 { + temp + call(inputs: object) + get_config() +} +class "WeightedBinaryAccuracy" as icenet.model.metrics.WeightedBinaryAccuracy #44BB99 { + get_config() + result() + update_state(y_true: object, y_pred: object, sample_weight: object) +} +class "WeightedMAE" as icenet.model.metrics.WeightedMAE #44BB99 { + result() + update_state(y_true: object, y_pred: object, sample_weight: object) +} +class "WeightedMSE" as icenet.model.metrics.WeightedMSE #44BB99 { + result() + update_state(y_true: object, y_pred: object, sample_weight: object) +} +class "WeightedMSE" as icenet.model.losses.WeightedMSE #44BB99 { +} +class "WeightedRMSE" as icenet.model.metrics.WeightedRMSE #44BB99 { + result() + update_state(y_true: object, y_pred: object, sample_weight: object) +} +icenet.data.dataset.IceNetDataSet --|> icenet.data.datasets.utils.SplittingMixin +icenet.data.dataset.IceNetDataSet --|> icenet.data.producers.DataCollection +icenet.data.dataset.MergedIceNetDataSet --|> icenet.data.datasets.utils.SplittingMixin +icenet.data.dataset.MergedIceNetDataSet --|> icenet.data.producers.DataCollection +icenet.data.interfaces.cds.ERA5Downloader --|> icenet.data.interfaces.downloader.ClimateDownloader +icenet.data.interfaces.cmems.ORAS5Downloader --|> icenet.data.interfaces.downloader.ClimateDownloader +icenet.data.interfaces.downloader.ClimateDownloader --|> icenet.data.producers.Downloader +icenet.data.interfaces.esgf.CMIP6Downloader --|> icenet.data.interfaces.downloader.ClimateDownloader +icenet.data.interfaces.mars.HRESDownloader --|> icenet.data.interfaces.downloader.ClimateDownloader +icenet.data.interfaces.mars.SEASDownloader --|> icenet.data.interfaces.mars.HRESDownloader +icenet.data.loaders.base.IceNetBaseDataLoader --|> icenet.data.producers.Generator +icenet.data.loaders.dask.DaskBaseDataLoader --|> icenet.data.loaders.base.IceNetBaseDataLoader +icenet.data.loaders.dask.DaskMultiSharingWorkerLoader --|> icenet.data.loaders.dask.DaskBaseDataLoader +icenet.data.loaders.dask.DaskMultiWorkerLoader --|> icenet.data.loaders.dask.DaskBaseDataLoader +icenet.data.loaders.stdlib.IceNetDataLoader --|> icenet.data.loaders.base.IceNetBaseDataLoader +icenet.data.process.IceNetPreProcessor --|> icenet.data.producers.Processor +icenet.data.processors.cmip.IceNetCMIPPreProcessor --|> icenet.data.process.IceNetPreProcessor +icenet.data.processors.era5.IceNetERA5PreProcessor --|> icenet.data.process.IceNetPreProcessor +icenet.data.processors.hres.IceNetHRESPreProcessor --|> icenet.data.process.IceNetPreProcessor +icenet.data.processors.meta.IceNetMetaPreProcessor --|> icenet.data.process.IceNetPreProcessor +icenet.data.processors.oras5.IceNetORAS5PreProcessor --|> icenet.data.process.IceNetPreProcessor +icenet.data.processors.osi.IceNetOSIPreProcessor --|> icenet.data.process.IceNetPreProcessor +icenet.data.producers.DataCollection --|> icenet.utils.HemisphereMixin +icenet.data.producers.DataProducer --|> icenet.data.producers.DataCollection +icenet.data.producers.Downloader --|> icenet.data.producers.DataProducer +icenet.data.producers.Generator --|> icenet.data.producers.DataProducer +icenet.data.producers.Processor --|> icenet.data.producers.DataProducer +icenet.data.sic.mask.Masks --|> icenet.data.producers.Generator +icenet.data.sic.osisaf.SICDownloader --|> icenet.data.producers.Downloader +icenet.data.sic.mask.Masks --* icenet.data.interfaces.downloader.ClimateDownloader : _masks +icenet.data.sic.mask.Masks --* icenet.data.sic.osisaf.SICDownloader : _masks +@enduml diff --git a/docs/conf.py b/docs/conf.py index bdb89c44..2850f1f8 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,9 +34,7 @@ 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'myst_parser', - # 'sphinx_pyreverse', - # 'sphinxcontrib.mermaid' - # 'sphinxcontrib.plantuml', + 'sphinxcontrib.kroki' ] # Standardising on diff --git a/docs/index.rst b/docs/index.rst index c25f06f9..39816f5f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,7 +9,7 @@ Welcome to IceNet's documentation! installation usage modules - api + uml contributing authors history diff --git a/docs/packages_icenet.puml b/docs/packages_icenet.puml new file mode 100644 index 00000000..c783999b --- /dev/null +++ b/docs/packages_icenet.puml @@ -0,0 +1,228 @@ +@startuml packages_icenet +set namespaceSeparator none +package "icenet" as icenet #77AADD { +} +package "icenet.data" as icenet.data #99DDFF { +} +package "icenet.data.cli" as icenet.data.cli #99DDFF { +} +package "icenet.data.dataset" as icenet.data.dataset #99DDFF { +} +package "icenet.data.datasets" as icenet.data.datasets #99DDFF { +} +package "icenet.data.datasets.utils" as icenet.data.datasets.utils #99DDFF { +} +package "icenet.data.interfaces" as icenet.data.interfaces #99DDFF { +} +package "icenet.data.interfaces.cds" as icenet.data.interfaces.cds #99DDFF { +} +package "icenet.data.interfaces.cmems" as icenet.data.interfaces.cmems #99DDFF { +} +package "icenet.data.interfaces.downloader" as icenet.data.interfaces.downloader #99DDFF { +} +package "icenet.data.interfaces.esgf" as icenet.data.interfaces.esgf #99DDFF { +} +package "icenet.data.interfaces.mars" as icenet.data.interfaces.mars #99DDFF { +} +package "icenet.data.interfaces.utils" as icenet.data.interfaces.utils #99DDFF { +} +package "icenet.data.loader" as icenet.data.loader #99DDFF { +} +package "icenet.data.loaders" as icenet.data.loaders #99DDFF { +} +package "icenet.data.loaders.base" as icenet.data.loaders.base #99DDFF { +} +package "icenet.data.loaders.dask" as icenet.data.loaders.dask #99DDFF { +} +package "icenet.data.loaders.stdlib" as icenet.data.loaders.stdlib #99DDFF { +} +package "icenet.data.loaders.utils" as icenet.data.loaders.utils #99DDFF { +} +package "icenet.data.process" as icenet.data.process #99DDFF { +} +package "icenet.data.processors" as icenet.data.processors #99DDFF { +} +package "icenet.data.processors.cmip" as icenet.data.processors.cmip #99DDFF { +} +package "icenet.data.processors.era5" as icenet.data.processors.era5 #99DDFF { +} +package "icenet.data.processors.hres" as icenet.data.processors.hres #99DDFF { +} +package "icenet.data.processors.meta" as icenet.data.processors.meta #99DDFF { +} +package "icenet.data.processors.oras5" as icenet.data.processors.oras5 #99DDFF { +} +package "icenet.data.processors.osi" as icenet.data.processors.osi #99DDFF { +} +package "icenet.data.processors.utils" as icenet.data.processors.utils #99DDFF { +} +package "icenet.data.producers" as icenet.data.producers #99DDFF { +} +package "icenet.data.sic" as icenet.data.sic #99DDFF { +} +package "icenet.data.sic.mask" as icenet.data.sic.mask #99DDFF { +} +package "icenet.data.sic.osisaf" as icenet.data.sic.osisaf #99DDFF { +} +package "icenet.data.sic.utils" as icenet.data.sic.utils #99DDFF { +} +package "icenet.data.utils" as icenet.data.utils #99DDFF { +} +package "icenet.model" as icenet.model #44BB99 { +} +package "icenet.model.callbacks" as icenet.model.callbacks #44BB99 { +} +package "icenet.model.losses" as icenet.model.losses #44BB99 { +} +package "icenet.model.metrics" as icenet.model.metrics #44BB99 { +} +package "icenet.model.models" as icenet.model.models #44BB99 { +} +package "icenet.model.predict" as icenet.model.predict #44BB99 { +} +package "icenet.model.train" as icenet.model.train #44BB99 { +} +package "icenet.model.utils" as icenet.model.utils #44BB99 { +} +package "icenet.plotting" as icenet.plotting #BBCC33 { +} +package "icenet.plotting.data" as icenet.plotting.data #BBCC33 { +} +package "icenet.plotting.forecast" as icenet.plotting.forecast #BBCC33 { +} +package "icenet.plotting.trend" as icenet.plotting.trend #BBCC33 { +} +package "icenet.plotting.utils" as icenet.plotting.utils #BBCC33 { +} +package "icenet.plotting.video" as icenet.plotting.video #BBCC33 { +} +package "icenet.process" as icenet.process #AAAA00 { +} +package "icenet.process.azure" as icenet.process.azure #AAAA00 { +} +package "icenet.process.forecasts" as icenet.process.forecasts #AAAA00 { +} +package "icenet.process.local" as icenet.process.local #AAAA00 { +} +package "icenet.process.predict" as icenet.process.predict #AAAA00 { +} +package "icenet.process.train" as icenet.process.train #AAAA00 { +} +package "icenet.process.utils" as icenet.process.utils #AAAA00 { +} +package "icenet.results" as icenet.results #EEDD88 { +} +package "icenet.results.metrics" as icenet.results.metrics #EEDD88 { +} +package "icenet.results.threshold" as icenet.results.threshold #EEDD88 { +} +package "icenet.tests" as icenet.tests #EE8866 { +} +package "icenet.tests.test_entry_points" as icenet.tests.test_entry_points #EE8866 { +} +package "icenet.tests.test_mod" as icenet.tests.test_mod #EE8866 { +} +package "icenet.utils" as icenet.utils #77AADD { +} +icenet.data.cli --> icenet.utils +icenet.data.dataset --> icenet.data.datasets.utils +icenet.data.dataset --> icenet.data.loader +icenet.data.dataset --> icenet.data.producers +icenet.data.dataset --> icenet.utils +icenet.data.interfaces.cds --> icenet.data.cli +icenet.data.interfaces.cds --> icenet.data.interfaces.downloader +icenet.data.interfaces.cmems --> icenet.data.cli +icenet.data.interfaces.cmems --> icenet.data.interfaces.downloader +icenet.data.interfaces.cmems --> icenet.utils +icenet.data.interfaces.downloader --> icenet.data.interfaces.utils +icenet.data.interfaces.downloader --> icenet.data.producers +icenet.data.interfaces.downloader --> icenet.data.sic.mask +icenet.data.interfaces.downloader --> icenet.data.sic.utils +icenet.data.interfaces.downloader --> icenet.data.utils +icenet.data.interfaces.downloader --> icenet.utils +icenet.data.interfaces.esgf --> icenet.data.cli +icenet.data.interfaces.esgf --> icenet.data.interfaces.downloader +icenet.data.interfaces.esgf --> icenet.data.utils +icenet.data.interfaces.mars --> icenet.data.cli +icenet.data.interfaces.mars --> icenet.data.interfaces.downloader +icenet.data.interfaces.mars --> icenet.data.interfaces.utils +icenet.data.interfaces.utils --> icenet.utils +icenet.data.loader --> icenet.data.cli +icenet.data.loader --> icenet.data.loaders +icenet.data.loader --> icenet.utils +icenet.data.loaders --> icenet.data.loaders.base +icenet.data.loaders --> icenet.data.loaders.dask +icenet.data.loaders.base --> icenet.data.process +icenet.data.loaders.base --> icenet.data.producers +icenet.data.loaders.dask --> icenet.data.loaders.base +icenet.data.loaders.dask --> icenet.data.loaders.utils +icenet.data.loaders.dask --> icenet.data.process +icenet.data.loaders.dask --> icenet.data.sic.mask +icenet.data.loaders.stdlib --> icenet.data.loaders.base +icenet.data.process --> icenet.data.producers +icenet.data.process --> icenet.data.sic.mask +icenet.data.process --> icenet.model.models +icenet.data.processors.cmip --> icenet.data.cli +icenet.data.processors.cmip --> icenet.data.process +icenet.data.processors.cmip --> icenet.data.processors.utils +icenet.data.processors.cmip --> icenet.data.sic.mask +icenet.data.processors.era5 --> icenet.data.cli +icenet.data.processors.era5 --> icenet.data.process +icenet.data.processors.hres --> icenet.data.cli +icenet.data.processors.hres --> icenet.data.process +icenet.data.processors.meta --> icenet.data.cli +icenet.data.processors.meta --> icenet.data.process +icenet.data.processors.meta --> icenet.data.sic.mask +icenet.data.processors.oras5 --> icenet.data.cli +icenet.data.processors.oras5 --> icenet.data.process +icenet.data.processors.osi --> icenet.data.cli +icenet.data.processors.osi --> icenet.data.process +icenet.data.processors.osi --> icenet.data.processors.utils +icenet.data.processors.osi --> icenet.data.sic.mask +icenet.data.processors.utils --> icenet.data.producers +icenet.data.processors.utils --> icenet.utils +icenet.data.producers --> icenet.utils +icenet.data.sic.mask --> icenet.data.cli +icenet.data.sic.mask --> icenet.data.producers +icenet.data.sic.mask --> icenet.data.sic.utils +icenet.data.sic.mask --> icenet.utils +icenet.data.sic.osisaf --> icenet.data.cli +icenet.data.sic.osisaf --> icenet.data.producers +icenet.data.sic.osisaf --> icenet.data.sic.mask +icenet.data.sic.osisaf --> icenet.data.sic.utils +icenet.data.sic.osisaf --> icenet.utils +icenet.model.predict --> icenet.data.dataset +icenet.model.predict --> icenet.data.loader +icenet.model.predict --> icenet.model.models +icenet.model.predict --> icenet.utils +icenet.model.train --> icenet.data.dataset +icenet.model.train --> icenet.model.losses +icenet.model.train --> icenet.model.metrics +icenet.model.train --> icenet.model.models +icenet.model.train --> icenet.model.utils +icenet.model.train --> icenet.utils +icenet.plotting.data --> icenet.data.cli +icenet.plotting.data --> icenet.data.dataset +icenet.plotting.data --> icenet.data.datasets.utils +icenet.plotting.data --> icenet.utils +icenet.plotting.forecast --> icenet +icenet.plotting.forecast --> icenet.data.cli +icenet.plotting.forecast --> icenet.data.sic.mask +icenet.plotting.forecast --> icenet.plotting.utils +icenet.plotting.forecast --> icenet.plotting.video +icenet.plotting.video --> icenet.process.predict +icenet.plotting.video --> icenet.utils +icenet.process.azure --> icenet.process.utils +icenet.process.azure --> icenet.utils +icenet.process.forecasts --> icenet.plotting.utils +icenet.process.forecasts --> icenet.process.utils +icenet.process.forecasts --> icenet.utils +icenet.process.local --> icenet.process.utils +icenet.process.local --> icenet.utils +icenet.process.predict --> icenet +icenet.process.predict --> icenet.data.dataset +icenet.process.predict --> icenet.data.sic.mask +icenet.process.predict --> icenet.utils +icenet.results.threshold --> icenet.data.cli +icenet.results.threshold --> icenet.utils +@enduml diff --git a/docs/requirements.txt b/docs/requirements.txt index aca8e5e0..3d4d5d4d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,5 @@ jinja2 Sphinx myst_parser -sphinx-pyreverse -sphinxcontrib-mermaid -sphinx-plantuml +pylint +sphinxcontrib-kroki diff --git a/docs/uml.rst b/docs/uml.rst new file mode 100644 index 00000000..8a4bad6b --- /dev/null +++ b/docs/uml.rst @@ -0,0 +1,4 @@ +Class UML Diagram +================= + +.. kroki:: ./classes_icenet.puml svg From 7a1cf9a5c5828d8b1ea9547a673266f2425c1898 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 27 Nov 2023 11:08:08 +0000 Subject: [PATCH 34/61] Dev #20: Update 'docs/Makefile' for UML generation --- docs/Makefile | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index c3bf474f..be888215 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -12,7 +12,12 @@ BUILDDIR = _build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +.PHONY: help html Makefile + +# When running "make html": Run pyreverse to generate UML diagram if outputting html docs. +html: + pyreverse --output puml --output-directory ./ --project icenet ../icenet --colorized + @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). From 3c47a16b7d3f8b78cc9c363805c8ee07e43261c7 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 22:31:24 +0000 Subject: [PATCH 35/61] Dev #20: Docstrings for Hemisphere + HemisphereMixin --- icenet/utils.py | 69 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/icenet/utils.py b/icenet/utils.py index 8f442c52..fa5acbe6 100644 --- a/icenet/utils.py +++ b/icenet/utils.py @@ -6,8 +6,12 @@ class Hemisphere(Flag): - """ + """Representation of hemispheres & both with bitwise operations. + An enum.Flag derived class representing the different hemispheres + (north, south, or both), providing methods to check which hemisphere + is selected via bitwise operations: + & (AND), | (OR), ^ (XOR), and ~ (INVERT) """ NONE = 0 @@ -17,53 +21,90 @@ class Hemisphere(Flag): class HemisphereMixin: - """ + """A mixin relating to Hemisphere checking. + Attributes: + _hemisphere: Represents the bitmask value of the hemisphere. + Defaults to Hemisphere.NONE (i.e., 0). """ - _hemisphere = Hemisphere.NONE + _hemisphere: int = Hemisphere.NONE @property - def hemisphere(self): + def hemisphere(self) -> int: + """Get the bitmask value representing the hemisphere. + + Returns: + The bitmask value representing the hemisphere. + """ return self._hemisphere @property - def hemisphere_str(self): + def hemisphere_str(self) -> list: + """Get a list of strings representing the selected hemispheres. + + Returns: + A list of strings representing the hemisphere. + """ return ["north"] if self.north else \ ["south"] if self.south else \ ["north", "south"] @property - def hemisphere_loc(self): + def hemisphere_loc(self) -> list: + """Get a list of latitude and longitude extent representing the hemisphere's location. + + Returns: + A list of latitude and longitude extent representing the hemisphere's location. + [north lat, west lon, south lat, east lon] + """ return [90, -180, 0, 180] if self.north else \ [0, -180, -90, 180] if self.south else \ [90, -180, -90, 180] @property - def north(self): + def north(self) -> bool: + """Get flag if `_hemisphere` is north. + + Returns: + True if the hemisphere is north, False otherwise. + """ return (self._hemisphere & Hemisphere.NORTH) == Hemisphere.NORTH @property - def south(self): + def south(self) -> bool: + """Check flag if `_hemisphere` is south. + + Returns: + True if the hemisphere is south, False otherwise. + """ return (self._hemisphere & Hemisphere.SOUTH) == Hemisphere.SOUTH @property - def both(self): + def both(self) -> int: + """Get the bitmask value representing both hemispheres. + + Returns: + The bitmask value representing both hemispheres. + """ return self._hemisphere & Hemisphere.BOTH -def run_command(command: str, dry: bool = False): +def run_command(command: str, dry: bool = False) -> object: """Run a shell command A wrapper in case we want some additional handling to go in here - :param command: - :param dry: - :return: + Args: + command: Command to run in shell. + dry (optional): Whether to do a dry run or to run actual command. + Default is False. + Returns: + subprocess.CompletedProcess return of the executed command. """ if dry: - logging.info("Skipping dry commaand: {}".format(command)) + logging.info("Skipping dry command: {}".format(command)) return 0 ret = sp.run(command, shell=True) From 2efafa6996f36941af845b16a8a4cc8c3f666a01 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 22:43:38 +0000 Subject: [PATCH 36/61] Dev #20: Docstrings for producer classes --- icenet/data/producers.py | 211 ++++++++++++++++++++++++++++----------- 1 file changed, 153 insertions(+), 58 deletions(-) diff --git a/icenet/data/producers.py b/icenet/data/producers.py index 5238b01b..9349da5b 100644 --- a/icenet/data/producers.py +++ b/icenet/data/producers.py @@ -13,12 +13,12 @@ class DataCollection(HemisphereMixin, metaclass=ABCMeta): - """ + """An Abstract base class with common interface for data collection classes. - :param identifier: - :param north: - :param south: - :param path: + Attributes: + _identifier: The identifier of the data collection. + _path: The base path of the data collection. + _hemisphere: The hemisphere(s) of the data collection. """ @abstractmethod @@ -27,37 +27,75 @@ def __init__(self, *args, north: bool = True, south: bool = False, path: str = os.path.join(".", "data"), - **kwargs): - self._identifier = identifier - self._path = os.path.join(path, identifier) - self._hemisphere = (Hemisphere.NORTH if north else Hemisphere.NONE) | \ + **kwargs) -> None: + """Initialises DataCollection class. + + Args: + identifier: An identifier/label for the data collection. + Defaults to None. + north (optional): A flag indicating if the data collection is in the northern hemisphere. + Defaults to True. + south (optional): A flag indicating if the data collection is in the southern hemisphere. + Defaults to False. + path (optional): The base path of the data collection. + Defaults to `./data`. + + Raises: + AssertionError: Raised if identifier is not specified, or no hemispheres are selected. + """ + self._identifier: object = identifier + self._path: str = os.path.join(path, identifier) + self._hemisphere: Hemisphere = (Hemisphere.NORTH if north else Hemisphere.NONE) | \ (Hemisphere.SOUTH if south else Hemisphere.NONE) assert self._identifier, "No identifier supplied" assert self._hemisphere != Hemisphere.NONE, "No hemispheres selected" @property - def base_path(self): + def base_path(self) -> str: + """Get the base path of the data collection. + + Returns: + The base path of the data collection. + """ return self._path @base_path.setter - def base_path(self, path): + def base_path(self, path: str) -> None: self._path = path @property - def identifier(self): + def identifier(self) -> object: + """Get the identifier (label) for this data collection. + + Returns: + The identifier/label of the data collection. + """ return self._identifier class DataProducer(DataCollection): + """Manages the creation and organisation of data files. + + Attributes: + dry: Flag specifying whether the data producer should be in dry run mode or not. + overwrite: Flag specifying whether existing files should be overwritten or not. """ - :param dry: - :param overwrite: - """ + def __init__(self, *args, dry: bool = False, overwrite: bool = False, - **kwargs): + **kwargs) -> None: + """Initialises the DataProducer instance. + + Creates the base path of the data collection if it does not exist. + + Args: + dry (optional): Flag specifying whether the data producer should be in dry run mode or not. + Defaults to False + overwrite (optional): Flag specifying whether existing files should be overwritten or not. + Defaults to False + """ super(DataProducer, self).__init__(*args, **kwargs) self.dry = dry @@ -82,13 +120,19 @@ def get_data_var_folder(self, append: object = None, hemisphere: object = None, missing_error: bool = False) -> str: - """ + """Returns the path for a specific data variable. + + Appends additional folders to the path if specified in the `append` parameter. + + Args: + var: The data variable. + append (optional): Additional folders to append to the path. Defaults to None. + hemisphere (optional): The hemisphere. Defaults to None. + missing_error (optional): Flag to specify if missing directories should be treated as an error. + Defaults to False. - :param var: - :param append: - :param hemisphere: - :param missing_error: - :return: + Returns: + str: The path for the specific data variable. """ if not append: append = [] @@ -113,47 +157,46 @@ def get_data_var_folder(self, class Downloader(DataProducer): - """ - + """Abstract base class for a downloader. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @abstractmethod def download(self): - """Abstract download method for this downloader - + """Abstract download method for this downloader: Must be implemented by subclasses. """ raise NotImplementedError("{}.download is abstract". format(__class__.__name__)) class Generator(DataProducer): - """ - + """Abstract base class for a generator. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @abstractmethod def generate(self): - """ - + """Abstract generate method for this generator: Must be implemented by subclasses. """ raise NotImplementedError("{}.generate is abstract". format(__class__.__name__)) class Processor(DataProducer): - """ - - :param identifier: - :param source_data: - :param *args: - :param file_filters: - :param test_dates: - :param train_dates: - :param val_dates: + """An abstract base class for data processing classes. + + Provides methods for initialising source data, processing the data, and + saving the processed data to standard netCDF files. + + Attributes: + _file_filters: List of file filters to exclude certain files during data processing. + _lead_time: Forecast/lead time used in the data processing. + source_data: Path to the source data directory. + _var_files: Dictionary storing variable files organised by variable name. + _processed_files: Dictionary storing the processed files organised by variable name. + _dates: Named tuple that stores the dates used for training, validation, and testing. """ def __init__(self, identifier: str, @@ -164,7 +207,22 @@ def __init__(self, test_dates: object = (), train_dates: object = (), val_dates: object = (), - **kwargs): + **kwargs) -> None: + """Initialise Processor class. + + Args: + identifier: The identifier for the processor. + source_data: The source data directory. + *args: Additional positional arguments. + file_filters (optional): List of file filters to exclude certain files + during data processing. Defaults to (). + lead_time (optional): The forecast/lead time used in the data processing. + Defaults to 93. + test_dates (optional): Dates used for testing. Defaults to (). + train_dates (optional): Dates used for training. Defaults to (). + val_dates (optional): Dates used for validation. Defaults to (). + **kargs: Additional keyword arguments. + """ super().__init__(*args, identifier=identifier, **kwargs) @@ -177,17 +235,29 @@ def __init__(self, self._var_files = dict() self._processed_files = dict() - # TODO: better as a mixin? + # TODO: better as a mixin? or maybe a Python data class instead? Dates = collections.namedtuple("Dates", ["train", "val", "test"]) self._dates = Dates(train=list(train_dates), val=list(val_dates), test=list(test_dates)) def init_source_data(self, - lag_days: object = None): - """ + lag_days: object = None) -> None: + """Initialises source data by globbing the files and organising based on date. + + Adds previous n days of `lag_days` if not already in `self._dates` + if lag_days>0. + Adds next n days of `self._lead_time` if not already in `self._dates` + if `self._lead_time`>0. - :param lag_days: + Args: + lag_days: The number of lag days to include in the data processing. + + Returns: + None. The method updates the `_var_files` attribute of the `Processor` object. + + Raises: + OSError: If the source data directory does not exist. """ if not os.path.exists(self.source_data): @@ -290,8 +360,7 @@ def init_source_data(self, @abstractmethod def process(self): - """ - + """Abstract method defining data processing: Must be implemented by subclasses. """ raise NotImplementedError("{}.process is abstract". format(__class__.__name__)) @@ -299,14 +368,18 @@ def process(self): def save_processed_file(self, var_name: str, name: str, - data: object, **kwargs): - """ - - :param var_name: - :param name: - :param data: - :param kwargs: - :return: + data: object, **kwargs) -> str: + """Save processed data to netCDF file. + + Args: + var_name: The name of the variable. + name: The name of the file. + data: The data to be saved. + **kwargs: Additional keyword arguments to be passed to the + `get_data_var_folder` method. + + Returns: + The path of the saved netCDF file. """ file_path = os.path.join( self.get_data_var_folder(var_name, **kwargs), name) @@ -324,17 +397,39 @@ def save_processed_file(self, return file_path @property - def dates(self): + def dates(self) -> object: + """Get the dates used for training, validation, and testing in this class. + + Returns: + A named collections.tuple containing the dates used for training, + validation, and testing accessible as attributes. + E.g: self._dates.train, self._dates.val, self._dates.test. + """ return self._dates @property - def lead_time(self): + def lead_time(self) -> int: + """Get the lead time used in the data processing. + + Returns: + The lead time used in the data processing. + """ return self._lead_time @property - def processed_files(self): + def processed_files(self) -> dict: + """Get dictionary of processed files. + + Returns: + Dict with the processed files organized by variable name. + """ return self._processed_files @property - def source_data(self): + def source_data(self) -> str: + """Get the source data directory. + + Returns: + The source data directory as a string. + """ return self._source_data From 1b0dacd04f72cee0aa8f555811c38b25c800693d Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 22:50:49 +0000 Subject: [PATCH 37/61] Dev #20: Docstrings for Masks class --- icenet/data/sic/mask.py | 99 ++++++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 30 deletions(-) diff --git a/icenet/data/sic/mask.py b/icenet/data/sic/mask.py index 197780e3..d3ad6ec7 100644 --- a/icenet/data/sic/mask.py +++ b/icenet/data/sic/mask.py @@ -18,12 +18,9 @@ class Masks(Generator): - """ + """Masking of regions to include/omit in dataset. - :param polarhole_dates: - :param polarhole_radii: - :param data_shape: - :param dtype: + TODO: Add example usage. """ LAND_MASK_FILENAME = "land_mask.npy" @@ -41,6 +38,14 @@ def __init__(self, *args, data_shape: object = (432, 432), dtype: object = np.float32, **kwargs): + """Initialises Masks across specified hemispheres. + + Args: + polarhole_dates: Dates for polar hole (missing data) in data. + polarhole_radii: Radii of polar hole. + data_shape: Shape of input dataset. + dtype: Store mask as this type. + """ super().__init__(*args, identifier="masks", **kwargs) self._polarhole_dates = polarhole_dates @@ -52,9 +57,14 @@ def __init__(self, *args, self.init_params() def init_params(self): - """ + """Initialises the parameters of the Masks class. + This method will create a `masks.params` file if it does not exist. + And, stores the polar_radii and polar_dates instance variables into it. + If it already exists, it will read and store the values to the instance + variables """ + params_path = os.path.join( self.get_data_var_folder("masks"), "masks.params" @@ -80,12 +90,17 @@ def generate(self, save_land_mask: bool = True, save_polarhole_masks: bool = True, remove_temp_files: bool = False): - """Generate a set of data masks - - :param year: - :param save_land_mask: - :param save_polarhole_masks: - :param remove_temp_files: + """Generate a set of data masks. + + Args: + year (optional): Which year to use for generate masks from. + Defaults to 2000. + save_land_mask (optional): Whether to output land mask. + Defaults to True. + save_polarhole_masks (optional): Whether to output polar hole masks. + Defaults to True. + remove_temp_files (optional): Whether to remove temporary directory. + Defaults to False. """ siconca_folder = self.get_data_var_folder("siconca") @@ -179,10 +194,16 @@ def generate(self, def get_active_cell_mask(self, month: object) -> object: - """ + """Check if a mask file exists for input month, and raise an error if it does not. + + Args: + month: Month index representing the month for which the mask file is being checked. - :param month: - :return: + Returns: + Active cell mask boolean(s) for corresponding month and pre-defined self._region. + + Raises: + RuntimeError: If the mask file for the input month does not exist. """ mask_path = os.path.join(self.get_data_var_folder("masks"), "active_grid_cell_mask_{:02d}.npy". @@ -198,11 +219,17 @@ def get_active_cell_mask(self, def get_active_cell_da(self, src_da: object) -> object: - """ + """Generate an xarray.DataArray object containing the active cell masks + for each timestamp in a given source DataArray. - :param src_da: - """ + Args: + src_da: Source xarray.DataArray object containing time, xc, yc + coordinates. + Returns: + An xarray.DataArray containing active cell masks for each time + in source DataArray. + """ return xr.DataArray( [self.get_active_cell_mask(pd.to_datetime(date).month) for date in src_da.time.values], @@ -216,10 +243,16 @@ def get_active_cell_da(self, def get_land_mask(self, land_mask_filename: str = LAND_MASK_FILENAME) -> object: - """ + """Generate an xarray.DataArray object containing the active cell masks + for each timestamp in a given source DataArray. - :param land_mask_filename: - :return: + Args: + land_mask_filename (optional): Land mask output filename. + Defaults to Masks.LAND_MASK_FILENAME. + + Returns: + An numpy array of land mask flag(s) for corresponding month and + pre-defined `self._region`. """ mask_path = os.path.join(self.get_data_var_folder("masks"), land_mask_filename) @@ -234,10 +267,11 @@ def get_land_mask(self, def get_polarhole_mask(self, date: object) -> object: - """ + """Get mask of polar hole region. - :param date: - :return: + TODO: + Explain date literals as class instance for POLARHOLE_DATES + and POLARHOLE_RADII """ if self.south: return None @@ -252,32 +286,37 @@ def get_polarhole_mask(self, return None def get_blank_mask(self) -> object: - """ + """Returns an empty mask. - :return: + Returns: + A numpy array of flags set to false for pre-defined `self._region` + of shape `self._shape` (the `data_shape` instance initialisation + value). """ return np.full(self._shape, False)[self._region] def __getitem__(self, item): - """ + """Sets slice of region wanted for masking, and allows method chaining. This might be a semantically dodgy thing to do, but it works for the mo - :param item: + Args: + item: Index/slice to extract. """ logging.info("Mask region set to: {}".format(item)) self._region = item return self def reset_region(self): - """ - + """Resets the mask region and logs a message indicating that the whole mask will be returned. """ logging.info("Mask region reset, whole mask will be returned") self._region = (slice(None, None), slice(None, None)) def main(): + """Entry point of Masks class - used to create executable that calls it. + """ args = download_args(dates=False, var_specs=False) north = args.hemisphere == "north" From 421efcecffec01c0dfa34d45ea76c717c0c7032b Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 22:57:37 +0000 Subject: [PATCH 38/61] Dev #20: Docstrings for SplittingMixin and get_decoder --- icenet/data/datasets/utils.py | 135 ++++++++++++++++++++++++++++------ 1 file changed, 111 insertions(+), 24 deletions(-) diff --git a/icenet/data/datasets/utils.py b/icenet/data/datasets/utils.py index 99e7a508..75b46326 100644 --- a/icenet/data/datasets/utils.py +++ b/icenet/data/datasets/utils.py @@ -11,14 +11,18 @@ def get_decoder(shape: object, forecasts: object, num_vars: int = 1, dtype: str = "float32") -> object: - """ - - :param shape: - :param channels: - :param forecasts: - :param num_vars: - :param dtype: - :return: + """Returns a decoder function used for parsing and decoding data from tfrecord protocol buffer. + + Args: + shape: The shape of the input data. + channels: The number of channels in the input data. + forecasts: The number of days to forecast in prediction + num_vars (optional): The number of variables in the input data. Defaults to 1. + dtype (optional): The data type of the input data. Defaults to "float32". + + Returns: + A function that can be used to parse and decode data. It takes in a protocol buffer + (tfrecord) as input and returns the parsed and decoded data. """ xf = tf.io.FixedLenFeature( [*shape, channels], getattr(tf, dtype)) @@ -44,10 +48,19 @@ def decode_item(proto): # TODO: define a decent interface and sort the inheritance architecture out, as # this will facilitate the new datasets in #35 class SplittingMixin: - """ + """Read train, val, test datasets from tfrecord protocol buffer files. - """ + Split and shuffle data if specified as well. + + Example: + This mixin is not to be used directly, but to give an idea of its use: + # Initialise SplittingMixin + split_dataset = SplittingMixin() + + # Add file paths to the train, validation, and test datasets + split_dataset.add_records(base_path="./network_datasets/notebook_data/", hemi="south") + """ _batch_size: int _dtype: object _num_channels: int @@ -59,11 +72,18 @@ class SplittingMixin: test_fns = [] val_fns = [] - def add_records(self, base_path: str, hemi: str): - """ + def add_records(self, base_path: str, hemi: str) -> None: + """Add list of paths to train, val, test *.tfrecord(s) to relevant instance attributes. + + Add sorted list of file paths to train, validation, and test datasets in SplittingMixin. - :param base_path: - :param hemi: + Args: + base_path (str): The base path where the datasets are located. + hemi (str): The hemisphere the datasets correspond to. + + Returns: + None. Updates `self.train_fns`, `self.val_fns`, `self.test_fns` with list + of *.tfrecord files. """ train_path = os.path.join(base_path, hemi, "train") val_path = os.path.join(base_path, hemi, "val") @@ -77,10 +97,22 @@ def add_records(self, base_path: str, hemi: str): self.test_fns += sorted(glob.glob("{}/*.tfrecord".format(test_path))) def get_split_datasets(self, ratio: object = None): - """ + """Retrieves train, val, and test datasets from corresponding attributes of SplittingMixin. + + Retrieves the train, validation, and test datasets from the file paths stored in the + `train_fns`, `val_fns`, and `test_fns` attributes of SplittingMixin. + + Args: + ratio (optional): A float representing the truncated list of datasets to be used. + If not specified, all datasets will be used. + Defaults to None. + + Returns: + tuple: A tuple containing the train, validation, and test datasets. - :param ratio: - :return: + Raises: + RuntimeError: If no files have been found in the train, validation, and test datasets. + RuntimeError: If the ratio is greater than 1. """ if not (len(self.train_fns) + len(self.val_fns) + len(self.test_fns)): raise RuntimeError("No files have been found, abandoning. This is " @@ -91,6 +123,7 @@ def get_split_datasets(self, ratio: object = None): logging.info("Datasets: {} train, {} val and {} test filenames".format( len(self.train_fns), len(self.val_fns), len(self.test_fns))) + # If ratio is specified, truncate file paths for train, val, test using the ratio. if ratio: if ratio > 1.0: raise RuntimeError("Ratio cannot be more than 1") @@ -111,6 +144,7 @@ def get_split_datasets(self, ratio: object = None): logging.info("Reduced: {} train, {} val and {} test filenames".format( len(self.train_fns), len(self.val_fns), len(self.test_fns))) + # Loads from files as bytes exactly as written. Must parse and decode it. train_ds, val_ds, test_ds = \ tf.data.TFRecordDataset(self.train_fns, num_parallel_reads=self.batch_size), \ @@ -135,6 +169,8 @@ def get_split_datasets(self, ratio: object = None): train_ds = train_ds.shuffle( min(int(len(self.train_fns) * self.batch_size), 366)) + # Since TFRecordDataset does not parse or decode the dataset from bytes, + # use custom decoder function with map to do so. train_ds = train_ds.\ map(decoder, num_parallel_calls=self.batch_size).\ batch(self.batch_size) @@ -152,7 +188,14 @@ def get_split_datasets(self, ratio: object = None): test_ds.prefetch(tf.data.AUTOTUNE) def check_dataset(self, - split: str = "train"): + split: str = "train") -> None: + """Check the dataset for NaN, log debugging info regarding dataset shape and bounds. + + Also logs a warning if any NaN are found. + + Args: + split: The split of the dataset to check. Default is "train". + """ logging.debug("Checking dataset {}".format(split)) decoder = get_decoder(self.shape, @@ -207,25 +250,69 @@ def check_dataset(self, # We don't except any non-tensorflow errors to prevent progression @property - def batch_size(self): + def batch_size(self) -> int: + """Get dataset's batch size. + + Set in subclass, not in SplittingMixin. + + Returns: + self._batch_size: Batch size set for dataset. + """ return self._batch_size @property - def dtype(self): + def dtype(self) -> str: + """Get dataset's data type. + + Set in subclass, not in SplittingMixin. + + Returns: + self._dtype: Data type of dataset. + """ return self._dtype @property - def n_forecast_days(self): + def n_forecast_days(self) -> int: + """Get number of days to forecast in prediction. + + Set in subclass, not in SplittingMixin. + + Returns: + self._n_forecast_days: Number of days to forecast. + """ return self._n_forecast_days @property - def num_channels(self): + def num_channels(self) -> int: + """Get number of channels in dataset. + + Corresponds to number of variables. + Set in subclass, not in SplittingMixin. + + Returns: + self._num_channels: Number of channels in dataset. + """ return self._num_channels @property - def shape(self): + def shape(self) -> object: + """Get shape of dataset. + + Set in subclass, not in SplittingMixin. + + Returns: + self._shape: Tuple/List of dataset shape. + """ return self._shape @property - def shuffling(self): + def shuffling(self) -> bool: + """Get flag for whether training dataset(s) are marked to be shuffled. + + Set in subclass, not in SplittingMixin. + + Returns: + self._shuffling: A flag if training dataset(s) marked to + be shuffled. + """ return self._shuffling From dcd086a7c31d542659c3d26211cc3858a65bcf71 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 23:03:30 +0000 Subject: [PATCH 39/61] Dev #20: Docstrings for IceNetDataLoaderFactory --- icenet/data/loaders/__init__.py | 56 +++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/icenet/data/loaders/__init__.py b/icenet/data/loaders/__init__.py index 4e165e30..ff7b9431 100644 --- a/icenet/data/loaders/__init__.py +++ b/icenet/data/loaders/__init__.py @@ -7,43 +7,65 @@ class IceNetDataLoaderFactory: - """ + """A factory class for managing a map of loader names and their corresponding implementation classes. + Attributes: + _loader_map: A dictionary holding loader names against their implementation classes. """ + def __init__(self): + """Initialises the IceNetDataLoaderFactory instance and sets up the initial loader map. + """ self._loader_map = dict( dask=icenet.data.loaders.dask.DaskMultiWorkerLoader, dask_shared=icenet.data.loaders.dask.DaskMultiSharingWorkerLoader, standard=icenet.data.loaders.stdlib.IceNetDataLoader, ) - def add_data_loader(self, loader_name: str, loader_impl: object): - """ + def add_data_loader(self, loader_name: str, loader_impl: object) -> None: + """Adds a new loader to the loader map with the given name and implementation class. + + Args: + loader_name: The name of the loader. + loader_impl: The implementation class of the loader. + + Returns: + None. Updates `_loader_map` attribute in IceNetDataLoaderFactory with specified + loader name and implementation. - :param loader_name: - :param loader_impl: + Raises: + RuntimeError: If the loader name already exists or if the implementation + class is not a descendant of IceNetBaseDataLoader. """ if loader_name not in self._loader_map: if IceNetBaseDataLoader in inspect.getmro(loader_impl): self._loader_map[loader_name] = loader_impl else: - raise RuntimeError("{} is not descended from " - "IceNetBaseDataLoader". - format(loader_impl.__name__)) + raise RuntimeError("{} is not descended from IceNetBaseDataLoader".format(loader_impl.__name__)) else: - raise RuntimeError("Cannot add {} as already in loader map". - format(loader_name)) + raise RuntimeError("Cannot add {} as already in loader map".format(loader_name)) - def create_data_loader(self, loader_name, *args, **kwargs): - """ + def create_data_loader(self, loader_name, *args, **kwargs) -> object: + """Creates an instance of a loader based on specified name from the `_loader_map` dict attribute. + + Args: + loader_name: The name of the loader. + *args: Additional positional arguments, is passed to the loader constructor. + **kwargs: Additional keyword arguments, is passed to the loader constructor. + + Returns: + An instance of the loader class. - :param loader_name: - :param args: - :param kwargs: - :return: + Raises: + KeyError: If the loader name does not exist in `_loader_map`. """ return self._loader_map[loader_name](*args, **kwargs) @property - def loader_map(self): + def loader_map(self) -> dict: + """Get the loader map dictionary. + + Returns: + The loader map. + """ return self._loader_map From 019e408683b8fdf13484229625cff95e79e82bea Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 23:10:42 +0000 Subject: [PATCH 40/61] Dev #20: Docstrings for IceNetDataSet + more --- icenet/data/dataset.py | 91 ++++++++++++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 20 deletions(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 596ba243..f57636c0 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -20,20 +20,26 @@ class IceNetDataSet(SplittingMixin, DataCollection): - """ - - :param configuration_path: - :param batch_size: - :param path: - """ - def __init__(self, configuration_path: str, *args, batch_size: int = 4, path: str = os.path.join(".", "network_datasets"), shuffling: bool = False, - **kwargs): + **kwargs) -> None: + """Initialises an instance of the IceNetDataSet class. + + Args: + configuration_path: The path to the JSON configuration file. + *args: Additional positional arguments. + batch_size (optional): The batch size for the data loader. Defaults to 4. + path (optional): The path to the directory where the processed tfrecord + protocol buffer files will be stored. Defaults to './network_datasets'. + shuffling (optional): Flag indicating whether to shuffle the data. + Defaults to False. + *args: Additional keyword arguments. + """ + self._config = dict() self._configuration_path = configuration_path self._load_configuration(configuration_path) @@ -62,6 +68,8 @@ def __init__(self, else: path_attr = "dataset_path" + # Check JSON config has attribute for path to tfrecord datasets, and + # that the path exists. if self._config[path_attr] and \ os.path.exists(self._config[path_attr]): hemi = self.hemisphere_str[0] @@ -70,10 +78,14 @@ def __init__(self, logging.warning("Running in configuration only mode, tfrecords " "were not generated for this dataset") - def _load_configuration(self, path: str): - """ + def _load_configuration(self, path: str) -> None: + """Load the JSON configuration file and update the `_config` attribute of `IceNetDataSet` class. - :param path: + Args: + path: The path to the JSON configuration file. + + Raises: + OSError: If the specified configuration file is not found. """ if os.path.exists(path): logging.info("Loading configuration {}".format(path)) @@ -85,17 +97,24 @@ def _load_configuration(self, path: str): else: raise OSError("{} not found".format(path)) - def get_data_loader(self, n_forecast_days = None, generate_workers = None): - """ + def get_data_loader(self, n_forecast_days = None, generate_workers = None) -> object: + """Create an instance of the IceNetDataLoader class. - :return: + Args: + n_forecast_days (optional): The number of forecast days to be used by the data loader. + If not provided, defaults to the value specified in the configuration file. + generate_workers (optional): A flag indicating whether to generate workers for parallel processing. + If not provided, defaults to the value specified in the configuration file. + + Returns: + An instance of the DaskMultiWorkerLoader class configured with the specified parameters. """ if n_forecast_days is None: n_forecast_days = self._config["n_forecast_days"] if generate_workers is None: generate_workers = self._config["generate_workers"] loader = IceNetDataLoaderFactory().create_data_loader( - "dask", + "dask", # This will load the `DaskMultiWorkerLoader` class. self.loader_config, self.identifier, self._config["var_lag"], @@ -115,15 +134,32 @@ def get_data_loader(self, n_forecast_days = None, generate_workers = None): return loader @property - def loader_config(self): + def loader_config(self) -> str: + """Get path to the JSON loader configuration file stored in the dataset config file. + + E.g. `/path/to/loader.{identifier}.json` + + Returns: + Path to the loader config file from the dataset config file. + """ return self._loader_config @property - def channels(self): + def channels(self) -> list: + """Get the list of channels (variable names) specified in the dataset config file. + + Returns: + List of channels (variable names) in the dataset config file. + """ return self._config["channels"] @property - def counts(self): + def counts(self) -> dict: + """Get a dict of no. of items in train, val, test. + + Returns: + Dict with number of elements in train, val, test in the config file. + """ return self._config["counts"] @@ -278,7 +314,20 @@ def counts(self): @setup_logging -def get_args(): +def get_args() -> object: + """Parse command line arguments using the argparse module. + + Returns: + An object containing the parsed command line arguments. + + Example: + Assuming CLI arguments provided. + + args = get_args() + print(args.dataset) + print(args.split) + print(args.verbose) + """ ap = argparse.ArgumentParser() ap.add_argument("dataset") ap.add_argument("-s", "--split", @@ -288,7 +337,9 @@ def get_args(): return args -def check_dataset(): +def check_dataset() -> None: + """Check the dataset for a specific split. + """ args = get_args() ds = IceNetDataSet(args.dataset) ds.check_dataset(args.split) From 4e5be81bfbf844a3e9c104e9cb3d01874756928c Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 23:33:20 +0000 Subject: [PATCH 41/61] Dev #20: Fix duplicate running of pyreverse when make docs --- Makefile | 1 - 1 file changed, 1 deletion(-) diff --git a/Makefile b/Makefile index 978939bd..cb4376bc 100644 --- a/Makefile +++ b/Makefile @@ -69,7 +69,6 @@ coverage: ## check code coverage quickly with the default Python docs: ## generate Sphinx HTML documentation, including API docs rm -f docs/icenet.rst rm -f docs/modules.rst - pyreverse --output puml --output-directory docs/ --project icenet icenet --colorized sphinx-apidoc -o docs/ icenet $(MAKE) -C docs clean $(MAKE) -C docs html From 39be36da7bb22e6d2f1ffc8862fb2465031f6306 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 29 Nov 2023 23:33:48 +0000 Subject: [PATCH 42/61] Dev #20: Update run make docs --- docs/classes_icenet.puml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/classes_icenet.puml b/docs/classes_icenet.puml index 48136ef9..28029fea 100644 --- a/docs/classes_icenet.puml +++ b/docs/classes_icenet.puml @@ -128,14 +128,14 @@ class "IceNetDataLoader" as icenet.data.loaders.stdlib.IceNetDataLoader #99DDFF } class "IceNetDataLoaderFactory" as icenet.data.loaders.IceNetDataLoaderFactory #99DDFF { loader_map - add_data_loader(loader_name: str, loader_impl: object) - create_data_loader(loader_name) + add_data_loader(loader_name: str, loader_impl: object) -> None + create_data_loader(loader_name) -> object } class "IceNetDataSet" as icenet.data.dataset.IceNetDataSet #99DDFF { channels counts loader_config - get_data_loader(n_forecast_days, generate_workers) + get_data_loader(n_forecast_days, generate_workers) -> object } class "IceNetDataWarning" as icenet.data.loaders.utils.IceNetDataWarning #99DDFF { } @@ -199,9 +199,9 @@ class "Processor" as icenet.data.producers.Processor #99DDFF { lead_time processed_files source_data - init_source_data(lag_days: object) + init_source_data(lag_days: object) -> None {abstract}process() - save_processed_file(var_name: str, name: str, data: object) + save_processed_file(var_name: str, name: str, data: object) -> str } class "SEASDownloader" as icenet.data.interfaces.mars.SEASDownloader #99DDFF { MARS_TEMPLATE : str @@ -221,8 +221,8 @@ class "SplittingMixin" as icenet.data.datasets.utils.SplittingMixin #99DDFF { test_fns : list train_fns : list val_fns : list - add_records(base_path: str, hemi: str) - check_dataset(split: str) + add_records(base_path: str, hemi: str) -> None + check_dataset(split: str) -> None get_split_datasets(ratio: object) } class "TemperatureScale" as icenet.model.models.TemperatureScale #44BB99 { @@ -280,4 +280,5 @@ icenet.data.sic.mask.Masks --|> icenet.data.producers.Generator icenet.data.sic.osisaf.SICDownloader --|> icenet.data.producers.Downloader icenet.data.sic.mask.Masks --* icenet.data.interfaces.downloader.ClimateDownloader : _masks icenet.data.sic.mask.Masks --* icenet.data.sic.osisaf.SICDownloader : _masks +icenet.utils.Hemisphere --* icenet.data.producers.DataCollection : _hemisphere @enduml From 483c202058a44fb705ef31544d8251284a76bea7 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Fri, 1 Dec 2023 19:06:21 +0000 Subject: [PATCH 43/61] Dev #20: Update pytest config to include doctests --- icenet/data/datasets/utils.py | 4 ++-- setup.cfg | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/icenet/data/datasets/utils.py b/icenet/data/datasets/utils.py index dacbed22..8c52a432 100644 --- a/icenet/data/datasets/utils.py +++ b/icenet/data/datasets/utils.py @@ -55,10 +55,10 @@ class SplittingMixin: This mixin is not to be used directly, but to give an idea of its use: # Initialise SplittingMixin - split_dataset = SplittingMixin() + >>> split_dataset = SplittingMixin() # Add file paths to the train, validation, and test datasets - split_dataset.add_records(base_path="./network_datasets/notebook_data/", hemi="south") + >>> split_dataset.add_records(base_path="./network_datasets/notebook_data/", hemi="south") """ _batch_size: int _dtype: object diff --git a/setup.cfg b/setup.cfg index 34616f49..5ff36522 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,6 +15,7 @@ based_on_style = google [tool:pytest] collect_ignore = ['setup.py'] +addopts = --doctest-modules [metadata] # This includes the license file(s) in the wheel. From 5e63d052f551530cfac81d6f7c2afe78ba3508c0 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sat, 2 Dec 2023 15:21:27 +0000 Subject: [PATCH 44/61] Dev #20: Update docs for IceNetDataSet class --- icenet/data/dataset.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 5c3ce9a1..694d9231 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -19,6 +19,24 @@ class IceNetDataSet(SplittingMixin, DataCollection): + """Initialises and configures a dataset. + + It loads a JSON configuration file, updates the `_config` attribute with the + result, creates a data loader, and methods to access the dataset. + + Attributes: + _config: A dict used to store configuration loaded from JSON file. + _configuration_path: The path to the JSON configuration file. + _batch_size: The batch size for the data loader. + _counts: A dict with number of elements in train, val, test. + _dtype: The type of the dataset. + _loader_config: The path to the data loader configuration file. + _generate_workers: An integer representing number of workers for parallel processing with Dask. + _n_forecast_days: An integer representing number of days to forecast for. + _num_channels: An integer representing number of channels (input variables) in the dataset. + _shape: The shape of the dataset. + _shuffling: A flag indicating whether to shuffle the data or not. + """ def __init__(self, configuration_path: str, @@ -98,15 +116,16 @@ def _load_configuration(self, path: str) -> None: raise OSError("{} not found".format(path)) def get_data_loader(self, - n_forecast_days=None, - generate_workers=None) -> object: + n_forecast_days: object = None, + generate_workers: object = None) -> object: """Create an instance of the IceNetDataLoader class. Args: n_forecast_days (optional): The number of forecast days to be used by the data loader. If not provided, defaults to the value specified in the configuration file. - generate_workers (optional): A flag indicating whether to generate workers for parallel processing. - If not provided, defaults to the value specified in the configuration file. + generate_workers (optional): An integer representing number of workers to use for + parallel processing with Dask. If not provided, defaults to the value specified in + the configuration file. Returns: An instance of the DaskMultiWorkerLoader class configured with the specified parameters. From 867468048583b3d766217f0e0deec67a25a62d09 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sat, 2 Dec 2023 15:26:32 +0000 Subject: [PATCH 45/61] Dev #20: Update pyreverse uml output --- docs/classes_icenet.puml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/classes_icenet.puml b/docs/classes_icenet.puml index 28029fea..3a6f4be3 100644 --- a/docs/classes_icenet.puml +++ b/docs/classes_icenet.puml @@ -135,7 +135,7 @@ class "IceNetDataSet" as icenet.data.dataset.IceNetDataSet #99DDFF { channels counts loader_config - get_data_loader(n_forecast_days, generate_workers) -> object + get_data_loader(n_forecast_days: object, generate_workers: object) -> object } class "IceNetDataWarning" as icenet.data.loaders.utils.IceNetDataWarning #99DDFF { } From abcb2d72081e0194e271c3ccef2699f5fae21264 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 5 Dec 2023 14:49:21 +0000 Subject: [PATCH 46/61] Dev #20: Update docs - getters --- icenet/data/dataset.py | 23 ++++----------- icenet/data/datasets/utils.py | 50 ++++----------------------------- icenet/data/loaders/__init__.py | 6 +--- icenet/data/producers.py | 38 ++++--------------------- icenet/utils.py | 39 ++++++------------------- 5 files changed, 26 insertions(+), 130 deletions(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 694d9231..b0be5698 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -32,7 +32,7 @@ class IceNetDataSet(SplittingMixin, DataCollection): _dtype: The type of the dataset. _loader_config: The path to the data loader configuration file. _generate_workers: An integer representing number of workers for parallel processing with Dask. - _n_forecast_days: An integer representing number of days to forecast for. + _n_forecast_days: An integer representing number of days to predict for. _num_channels: An integer representing number of channels (input variables) in the dataset. _shape: The shape of the dataset. _shuffling: A flag indicating whether to shuffle the data or not. @@ -152,31 +152,18 @@ def get_data_loader(self, @property def loader_config(self) -> str: - """Get path to the JSON loader configuration file stored in the dataset config file. - - E.g. `/path/to/loader.{identifier}.json` - - Returns: - Path to the loader config file from the dataset config file. - """ + """The path to the JSON loader configuration file stored in the dataset config file.""" + # E.g. `/path/to/loader.{identifier}.json` return self._loader_config @property def channels(self) -> list: - """Get the list of channels (variable names) specified in the dataset config file. - - Returns: - List of channels (variable names) in the dataset config file. - """ + """The list of channels (variable names) specified in the dataset config file.""" return self._config["channels"] @property def counts(self) -> dict: - """Get a dict of no. of items in train, val, test. - - Returns: - Dict with number of elements in train, val, test in the config file. - """ + """A dict with number of elements in train, val, test in the config file.""" return self._config["counts"] diff --git a/icenet/data/datasets/utils.py b/icenet/data/datasets/utils.py index 8c52a432..6f8a799a 100644 --- a/icenet/data/datasets/utils.py +++ b/icenet/data/datasets/utils.py @@ -247,68 +247,30 @@ def check_dataset(self, split: str = "train") -> None: @property def batch_size(self) -> int: - """Get dataset's batch size. - - Set in subclass, not in SplittingMixin. - - Returns: - self._batch_size: Batch size set for dataset. - """ + """The dataset's batch size.""" return self._batch_size @property def dtype(self) -> str: - """Get dataset's data type. - - Set in subclass, not in SplittingMixin. - - Returns: - self._dtype: Data type of dataset. - """ + """The dataset's data type.""" return self._dtype @property def n_forecast_days(self) -> int: - """Get number of days to forecast in prediction. - - Set in subclass, not in SplittingMixin. - - Returns: - self._n_forecast_days: Number of days to forecast. - """ + """The number of days to forecast in prediction.""" return self._n_forecast_days @property def num_channels(self) -> int: - """Get number of channels in dataset. - - Corresponds to number of variables. - Set in subclass, not in SplittingMixin. - - Returns: - self._num_channels: Number of channels in dataset. - """ + """The number of channels in dataset.""" return self._num_channels @property def shape(self) -> object: - """Get shape of dataset. - - Set in subclass, not in SplittingMixin. - - Returns: - self._shape: Tuple/List of dataset shape. - """ + """The shape of dataset.""" return self._shape @property def shuffling(self) -> bool: - """Get flag for whether training dataset(s) are marked to be shuffled. - - Set in subclass, not in SplittingMixin. - - Returns: - self._shuffling: A flag if training dataset(s) marked to - be shuffled. - """ + """A flag for whether training dataset(s) are marked to be shuffled.""" return self._shuffling diff --git a/icenet/data/loaders/__init__.py b/icenet/data/loaders/__init__.py index 2daf2232..bf59a849 100644 --- a/icenet/data/loaders/__init__.py +++ b/icenet/data/loaders/__init__.py @@ -65,9 +65,5 @@ def create_data_loader(self, loader_name, *args, **kwargs) -> object: @property def loader_map(self) -> dict: - """Get the loader map dictionary. - - Returns: - The loader map. - """ + """The loader map dictionary.""" return self._loader_map diff --git a/icenet/data/producers.py b/icenet/data/producers.py index 460d6f73..defe9b29 100644 --- a/icenet/data/producers.py +++ b/icenet/data/producers.py @@ -54,11 +54,7 @@ def __init__(self, @property def base_path(self) -> str: - """Get the base path of the data collection. - - Returns: - The base path of the data collection. - """ + """The base path of the data collection.""" return self._path @base_path.setter @@ -67,11 +63,7 @@ def base_path(self, path: str) -> None: @property def identifier(self) -> object: - """Get the identifier (label) for this data collection. - - Returns: - The identifier/label of the data collection. - """ + """The identifier (label) for this data collection.""" return self._identifier @@ -395,38 +387,20 @@ def save_processed_file(self, var_name: str, name: str, data: object, @property def dates(self) -> object: - """Get the dates used for training, validation, and testing in this class. - - Returns: - A named collections.tuple containing the dates used for training, - validation, and testing accessible as attributes. - E.g: self._dates.train, self._dates.val, self._dates.test. - """ + """The dates used for training, validation, and testing in this class as a named collections.tuple.""" return self._dates @property def lead_time(self) -> int: - """Get the lead time used in the data processing. - - Returns: - The lead time used in the data processing. - """ + """The lead time used in the data processing.""" return self._lead_time @property def processed_files(self) -> dict: - """Get dictionary of processed files. - - Returns: - Dict with the processed files organized by variable name. - """ + """A dict with the processed files organised by variable name.""" return self._processed_files @property def source_data(self) -> str: - """Get the source data directory. - - Returns: - The source data directory as a string. - """ + """The source data directory as a string.""" return self._source_data diff --git a/icenet/utils.py b/icenet/utils.py index c0f31ec0..6223f7f1 100644 --- a/icenet/utils.py +++ b/icenet/utils.py @@ -32,61 +32,38 @@ class HemisphereMixin: @property def hemisphere(self) -> int: - """Get the bitmask value representing the hemisphere. - - Returns: - The bitmask value representing the hemisphere. - """ + """The bitmask value representing the hemisphere.""" return self._hemisphere @property def hemisphere_str(self) -> list: - """Get a list of strings representing the selected hemispheres. - - Returns: - A list of strings representing the hemisphere. - """ + """A list of strings representing the selected hemispheres.""" return ["north"] if self.north else \ ["south"] if self.south else \ ["north", "south"] @property def hemisphere_loc(self) -> list: - """Get a list of latitude and longitude extent representing the hemisphere's location. - - Returns: - A list of latitude and longitude extent representing the hemisphere's location. - [north lat, west lon, south lat, east lon] - """ + """Get a list of latitude and longitude extent representing the hemisphere's location.""" + # A list of latitude and longitude extent representing the hemisphere's location. + # [north lat, west lon, south lat, east lon] return [90, -180, 0, 180] if self.north else \ [0, -180, -90, 180] if self.south else \ [90, -180, -90, 180] @property def north(self) -> bool: - """Get flag if `_hemisphere` is north. - - Returns: - True if the hemisphere is north, False otherwise. - """ + """A flag indicating if `_hemisphere` is north. True if the hemisphere is north, False otherwise.""" return (self._hemisphere & Hemisphere.NORTH) == Hemisphere.NORTH @property def south(self) -> bool: - """Check flag if `_hemisphere` is south. - - Returns: - True if the hemisphere is south, False otherwise. - """ + """A flag indicating if `_hemisphere` is south. True if the hemisphere is south, False otherwise.""" return (self._hemisphere & Hemisphere.SOUTH) == Hemisphere.SOUTH @property def both(self) -> int: - """Get the bitmask value representing both hemispheres. - - Returns: - The bitmask value representing both hemispheres. - """ + """The bitmask value representing both hemispheres.""" return self._hemisphere & Hemisphere.BOTH From 6ab50a7e8442aeaab1fdc4150b718fb58e08ec52 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Tue, 26 Dec 2023 23:36:49 +0000 Subject: [PATCH 47/61] Dev #20: Update docs for dask, mask --- icenet/data/loader.py | 10 +++++++--- icenet/data/loaders/dask.py | 38 ++++++++++++++++++++++++++++--------- icenet/data/sic/mask.py | 6 +++++- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/icenet/data/loader.py b/icenet/data/loader.py index b962c889..a46a834a 100644 --- a/icenet/data/loader.py +++ b/icenet/data/loader.py @@ -13,10 +13,14 @@ @setup_logging -def create_get_args(): - """ +def create_get_args() -> object: + """Converts input data creation argument strings to objects, and assigns them as attributes to the namespace. + + The args added in this function relate to the dataloader creation process. - :return: + Returns: + An argparse.ArgumentParser object with all arguments added via `add_argument` accessible + as object attributes. """ implementations = list(IceNetDataLoaderFactory().loader_map) diff --git a/icenet/data/loaders/dask.py b/icenet/data/loaders/dask.py index b12394bd..723a9aaf 100644 --- a/icenet/data/loaders/dask.py +++ b/icenet/data/loaders/dask.py @@ -27,22 +27,37 @@ class DaskBaseDataLoader(IceNetBaseDataLoader): + """A subclass of IceNetBaseDataLoader that provides functionality for loading data using Dask. + + Attributes: + _dashboard_port: The port number for the Dask dashboard. + _timeout: The timeout value for Dask communication. + _tmp_dir: The temporary directory for Dask. + """ def __init__(self, *args, dask_port: int = 8888, dask_timeouts: int = 60, dask_tmp_dir: object = "/tmp", - **kwargs): + **kwargs) -> None: + """Initialises the DaskBaseDataLoader object with the specified port, timeouts, and temp directory. + + Args: + dask_port: The port number for the Dask dashboard. Defaults to 8888. + dask_timeouts: The timeout value for Dask communication. Defaults to 60. + dask_tmp_dir: The temporary directory for Dask. Defaults to `/tmp`. + """ super().__init__(*args, **kwargs) self._dashboard_port = dask_port self._timeout = dask_timeouts self._tmp_dir = dask_tmp_dir - def generate(self): + def generate(self) -> None: """ - + Generates data using Dask client by setting up a Dask cluster and client, + and calling client_generate method. """ dashboard = "localhost:{}".format(self._dashboard_port) @@ -68,12 +83,17 @@ def generate(self): def client_generate(self, client: object, dates_override: object = None, - pickup: bool = False): - """ + pickup: bool = False) -> None: + """Generates data using the Dask client. This method needs to be implemented in subclasses. - :param client: - :param dates_override: - :param pickup: + Args: + client: The Dask client. + dates_override (optional): A dict with keys `train`, `val`, `test`, each with a list of + continuous dates for that purpose. Defaults to None. + pickup (optional): TODO. Defaults to False. + + Raises: + NotImplementedError: If generate is called without being implemented as a subclass of DaskBaseDataLoader. """ raise NotImplementedError("generate called on non-implementation") @@ -108,7 +128,7 @@ def generate_sample(self, date: object, prediction: bool = False): class DaskMultiWorkerLoader(DaskBaseDataLoader): - def __init__(self, *args, futures_per_worker: int = 2, **kwargs): + def __init__(self, *args, futures_per_worker: int = 2, **kwargs) -> None: super().__init__(*args, **kwargs) masks = Masks(north=self.north, south=self.south) diff --git a/icenet/data/sic/mask.py b/icenet/data/sic/mask.py index f62ed069..7e1a23bf 100644 --- a/icenet/data/sic/mask.py +++ b/icenet/data/sic/mask.py @@ -141,6 +141,7 @@ def generate(self, binary = np.unpackbits(status_flag, axis=1).\ reshape(*self._shape, 8) + #TODO: Add source/explanation for these magic numbers (index slicing nos.). # Mask out: land, lake, and 'outside max climatology' (open sea) max_extent_mask = np.sum(binary[:, :, [7, 6, 0]], axis=2).reshape(*self._shape) >= 1 @@ -148,6 +149,7 @@ def generate(self, # FIXME: Remove Caspian and Black seas - should we do this sh? if self.north: + # TODO: Add source/explanation for these indices. max_extent_mask[325:386, 317:380] = False mask_path = os.path.join( @@ -195,7 +197,9 @@ def generate(self, np.save(polarhole_path, polarhole) def get_active_cell_mask(self, month: object) -> object: - """Check if a mask file exists for input month, and raise an error if it does not. + """Loads an active grid cell mask from numpy file. + + Also, checks if a mask file exists for input month, and raises an error if it does not. Args: month: Month index representing the month for which the mask file is being checked. From 9840f03a7cd83382cc665d035729ea686e618508 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Wed, 27 Dec 2023 22:04:32 +0000 Subject: [PATCH 48/61] Fixes #207: Make parallel read an optional arg Leave it as True for backward compatibility --- icenet/data/loaders/dask.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/icenet/data/loaders/dask.py b/icenet/data/loaders/dask.py index b12394bd..4fc81907 100644 --- a/icenet/data/loaders/dask.py +++ b/icenet/data/loaders/dask.py @@ -223,7 +223,7 @@ def batch(batch_dates, num): np.average(exec_times))) self._write_dataset_config(counts) - def generate_sample(self, date: object, prediction: bool = False): + def generate_sample(self, date: object, prediction: bool = False, parallel=True): """ :param date: @@ -234,7 +234,7 @@ def generate_sample(self, date: object, prediction: bool = False): ds_kwargs = dict( chunks=dict(time=1, yc=self._shape[0], xc=self._shape[1]), drop_variables=["month", "plev", "level", "realization"], - parallel=True, + parallel=parallel, ) var_files = self.get_sample_files() @@ -242,6 +242,7 @@ def generate_sample(self, date: object, prediction: bool = False): v for k, v in var_files.items() if k not in self._meta_channels and not k.endswith("linear_trend") ], **ds_kwargs) + var_ds = var_ds.transpose("yc", "xc", "time") trend_files = \ From 5549c038249cd10f6626414095eaf549482af78e Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sun, 31 Dec 2023 20:20:13 +0000 Subject: [PATCH 49/61] Fixes 209: Update channel size n_filters_factor calc for all vals --- icenet/model/models.py | 53 ++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/icenet/model/models.py b/icenet/model/models.py index a8e9fa63..be014f52 100644 --- a/icenet/model/models.py +++ b/icenet/model/models.py @@ -66,12 +66,19 @@ def unet_batchnorm(input_shape: object, """ inputs = Input(shape=input_shape) - conv1 = Conv2D(int(64 * n_filters_factor), + start_out_channels = 64 + reduced_channels = int(start_out_channels*n_filters_factor) + channels = {} + for pow in range(4): + value = reduced_channels*2**pow + channels[start_out_channels*2**pow] = value + + conv1 = Conv2D(channels[64], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) - conv1 = Conv2D(int(64 * n_filters_factor), + conv1 = Conv2D(channels[64], filter_size, activation='relu', padding='same', @@ -79,12 +86,12 @@ def unet_batchnorm(input_shape: object, bn1 = BatchNormalization(axis=-1)(conv1) pool1 = MaxPooling2D(pool_size=(2, 2))(bn1) - conv2 = Conv2D(int(128 * n_filters_factor), + conv2 = Conv2D(channels[128], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) - conv2 = Conv2D(int(128 * n_filters_factor), + conv2 = Conv2D(channels[128], filter_size, activation='relu', padding='same', @@ -92,12 +99,12 @@ def unet_batchnorm(input_shape: object, bn2 = BatchNormalization(axis=-1)(conv2) pool2 = MaxPooling2D(pool_size=(2, 2))(bn2) - conv3 = Conv2D(int(256 * n_filters_factor), + conv3 = Conv2D(channels[256], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) - conv3 = Conv2D(int(256 * n_filters_factor), + conv3 = Conv2D(channels[256], filter_size, activation='relu', padding='same', @@ -105,12 +112,12 @@ def unet_batchnorm(input_shape: object, bn3 = BatchNormalization(axis=-1)(conv3) pool3 = MaxPooling2D(pool_size=(2, 2))(bn3) - conv4 = Conv2D(int(256 * n_filters_factor), + conv4 = Conv2D(channels[256], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) - conv4 = Conv2D(int(256 * n_filters_factor), + conv4 = Conv2D(channels[256], filter_size, activation='relu', padding='same', @@ -118,19 +125,19 @@ def unet_batchnorm(input_shape: object, bn4 = BatchNormalization(axis=-1)(conv4) pool4 = MaxPooling2D(pool_size=(2, 2))(bn4) - conv5 = Conv2D(int(512 * n_filters_factor), + conv5 = Conv2D(channels[512], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) - conv5 = Conv2D(int(512 * n_filters_factor), + conv5 = Conv2D(channels[512], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv5) bn5 = BatchNormalization(axis=-1)(conv5) - up6 = Conv2D(int(256 * n_filters_factor), + up6 = Conv2D(channels[256], 2, activation='relu', padding='same', @@ -138,57 +145,57 @@ def unet_batchnorm(input_shape: object, size=(2, 2), interpolation='nearest')(bn5)) merge6 = concatenate([bn4, up6], axis=3) - conv6 = Conv2D(int(256 * n_filters_factor), + conv6 = Conv2D(channels[256], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) - conv6 = Conv2D(int(256 * n_filters_factor), + conv6 = Conv2D(channels[256], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv6) bn6 = BatchNormalization(axis=-1)(conv6) - up7 = Conv2D(int(256 * n_filters_factor), + up7 = Conv2D(channels[256], 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D( size=(2, 2), interpolation='nearest')(bn6)) merge7 = concatenate([bn3, up7], axis=3) - conv7 = Conv2D(int(256 * n_filters_factor), + conv7 = Conv2D(channels[256], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) - conv7 = Conv2D(int(256 * n_filters_factor), + conv7 = Conv2D(channels[256], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv7) bn7 = BatchNormalization(axis=-1)(conv7) - up8 = Conv2D(int(128 * n_filters_factor), + up8 = Conv2D(channels[128], 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D( size=(2, 2), interpolation='nearest')(bn7)) merge8 = concatenate([bn2, up8], axis=3) - conv8 = Conv2D(int(128 * n_filters_factor), + conv8 = Conv2D(channels[128], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) - conv8 = Conv2D(int(128 * n_filters_factor), + conv8 = Conv2D(channels[128], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv8) bn8 = BatchNormalization(axis=-1)(conv8) - up9 = Conv2D(int(64 * n_filters_factor), + up9 = Conv2D(channels[64], 2, activation='relu', padding='same', @@ -197,17 +204,17 @@ def unet_batchnorm(input_shape: object, merge9 = concatenate([conv1, up9], axis=3) - conv9 = Conv2D(int(64 * n_filters_factor), + conv9 = Conv2D(channels[64], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(merge9) - conv9 = Conv2D(int(64 * n_filters_factor), + conv9 = Conv2D(channels[64], filter_size, activation='relu', padding='same', kernel_initializer='he_normal')(conv9) - conv9 = Conv2D(int(64 * n_filters_factor), + conv9 = Conv2D(channels[64], filter_size, activation='relu', padding='same', From b821a4bd3d18a76a1358b5b21c0f248c396458b7 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Sun, 31 Dec 2023 20:31:21 +0000 Subject: [PATCH 50/61] Fixes 209: Refactor fix --- icenet/model/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/icenet/model/models.py b/icenet/model/models.py index be014f52..9447cd52 100644 --- a/icenet/model/models.py +++ b/icenet/model/models.py @@ -67,11 +67,11 @@ def unet_batchnorm(input_shape: object, inputs = Input(shape=input_shape) start_out_channels = 64 - reduced_channels = int(start_out_channels*n_filters_factor) - channels = {} - for pow in range(4): - value = reduced_channels*2**pow - channels[start_out_channels*2**pow] = value + reduced_channels = int(start_out_channels * n_filters_factor) + channels = { + start_out_channels * 2**pow: reduced_channels * 2**pow + for pow in range(4) + } conv1 = Conv2D(channels[64], filter_size, From a537f202739354a33b4784c210c80b0c23837a31 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Thu, 18 Jan 2024 23:45:14 +0000 Subject: [PATCH 51/61] Fixes #184: Invalid zero day SIC length OSI-SAF file ingress --- icenet/data/sic/osisaf.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/icenet/data/sic/osisaf.py b/icenet/data/sic/osisaf.py index 1a2196ac..a60cea34 100644 --- a/icenet/data/sic/osisaf.py +++ b/icenet/data/sic/osisaf.py @@ -502,8 +502,19 @@ def download(self): "for {}".format(date_str)) continue - with open(temp_path, "wb") as fh: - ftp.retrbinary("RETR {}".format(ftp_files[0]), fh.write) + # Check if remote file size is too small, if so, render date invalid + # and continue. + file_size = ftp.size(ftp_files[0]) + + # Check remote file size in bytes + if file_size < 100: + logging.warning( + f"Date {el} is in invalid list, as file size too small") + self._invalid_dates.append(el) + continue + else: + with open(temp_path, "wb") as fh: + ftp.retrbinary("RETR {}".format(ftp_files[0]), fh.write) logging.debug("Downloaded {}".format(temp_path)) data_files.append(temp_path) From 1d1502171b442623a464301e857d86fbe79051ae Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Fri, 19 Jan 2024 00:44:45 +0000 Subject: [PATCH 52/61] Fixes #214: Force h5py update from prev pinned to newer version Make sure env updates h5py, so is newer than v2.10.0, if reusing previous icenet env --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 82df92c4..899b469e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ dask distributed eccodes ecmwf-api-client -h5py +h5py>2.10 ibicus matplotlib motuclient From 95809e61de3ad082664a8a12a863fd172cf9d2f9 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Fri, 19 Jan 2024 16:53:25 +0000 Subject: [PATCH 53/61] Fixes #217: Omits dates already processed from OSI-SAF download --- icenet/data/sic/osisaf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/icenet/data/sic/osisaf.py b/icenet/data/sic/osisaf.py index 1a2196ac..49610dde 100644 --- a/icenet/data/sic/osisaf.py +++ b/icenet/data/sic/osisaf.py @@ -411,7 +411,9 @@ def download(self): if len(extant_paths) > 0: extant_ds = xr.open_mfdataset(extant_paths) - exclude_dates = pd.to_datetime(extant_ds.time.values) + exclude_dates = [ + pd.to_datetime(date).date() for date in extant_ds.time.values + ] logging.info("Excluding {} dates already existing from {} dates " "requested.".format(len(exclude_dates), len(dt_arr))) From 683e1bea57e10c4529452d6b7265c9639505a0d9 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Fri, 19 Jan 2024 22:25:21 +0000 Subject: [PATCH 54/61] Fixes #184: Updated to recheck file sizes previously 0 byte --- icenet/data/sic/osisaf.py | 51 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/icenet/data/sic/osisaf.py b/icenet/data/sic/osisaf.py index a60cea34..8684802d 100644 --- a/icenet/data/sic/osisaf.py +++ b/icenet/data/sic/osisaf.py @@ -1,4 +1,5 @@ import copy +import csv import fnmatch import ftplib import logging @@ -381,6 +382,20 @@ def __init__(self, for month in np.arange(1, 12 + 1) } + # Load dates that previously had a file size of zero. + # To recheck they haven't been fixed since last download. + zero_dates_path = os.path.join(self.get_data_var_folder("siconca"), + "zero_size_days.csv") + + self._zero_dates_path = zero_dates_path + self._zero_dates = [] + if os.path.exists(zero_dates_path): + with open(zero_dates_path, "r") as fh: + self._zero_dates = [ + pd.to_datetime("-".join(date)).date() + for date in csv.reader(fh) + ] + def download(self): """ @@ -411,7 +426,13 @@ def download(self): if len(extant_paths) > 0: extant_ds = xr.open_mfdataset(extant_paths) - exclude_dates = pd.to_datetime(extant_ds.time.values) + exclude_dates = [ + pd.to_datetime(date).date() for date in extant_ds.time.values + ] + + # Do not exclude dates that previously had a file size of 0 + exclude_dates = set(exclude_dates).difference(self._zero_dates) + logging.info("Excluding {} dates already existing from {} dates " "requested.".format(len(exclude_dates), len(dt_arr))) @@ -510,15 +531,30 @@ def download(self): if file_size < 100: logging.warning( f"Date {el} is in invalid list, as file size too small") + self._zero_dates.append(el) self._invalid_dates.append(el) continue else: + # Removing missing date file if it was created for a file with zero size before + if el in self._zero_dates: + self._zero_dates.remove(el) + fpath = os.path.join( + self.get_data_var_folder( + "siconca", + append=[str(pd.to_datetime(el).year)]), + "missing.{}.nc".format(date_str)) + if os.path.exists(fpath): + os.unlink(fpath) + with open(temp_path, "wb") as fh: ftp.retrbinary("RETR {}".format(ftp_files[0]), fh.write) logging.debug("Downloaded {}".format(temp_path)) data_files.append(temp_path) + self._zero_dates = set(self._zero_dates) + self.zero_dates() + if ftp: ftp.quit() @@ -596,6 +632,19 @@ def download(self): for fpath in data_files: os.unlink(fpath) + def zero_dates(self): + """ + Write out any dates that have zero file size on the ftp server to csv + """ + if not self._zero_dates and os.path.exists(self._zero_dates_path): + os.unlink(self._zero_dates_path) + elif self._zero_dates: + logging.info(f"Processing {len(self._zero_dates)} zero dates") + with open(self._zero_dates_path, "w") as fh: + for date in self._zero_dates: + # FIXME: slightly unusual format for Ymd dates + fh.write(date.strftime("%Y,%m,%d\n")) + def missing_dates(self): """ From b3d208f620f779b0e45769c3db345ab819e4bbea Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 11:08:05 +0000 Subject: [PATCH 55/61] Ensure dev requirements get installed for contributions --- CONTRIBUTING.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 7d7bf0fd..e2aa67af 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -78,7 +78,7 @@ Ready to contribute? Here's how to set up `icenet` for local development. 5. Install development packages:: - $ pip install -r requirements.txt + $ pip install -r requirements.txt -r requirements_dev.txt 6. Set up pre-commit hooks to run automatically. This will run through linting checks, formatting, and pytest. It will format new code using yapf and prevent code committing that does not pass linting or testing checks until fixed:: From d72495e92f94a19bc812ed1c0bb85bd10955b118 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 11:09:05 +0000 Subject: [PATCH 56/61] Documentation dependency solving --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 3d4d5d4d..03d1fe8e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -jinja2 +jinja2>=3.1.3 Sphinx myst_parser pylint From 4f287cc7acf4fcb117506f26092b2f2de6c2ab3a Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 11:34:32 +0000 Subject: [PATCH 57/61] Adding requirements for tox tests --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 022b2de4..52fa1d96 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,8 @@ include CONTRIBUTING.rst include HISTORY.rst include LICENSE include README.md +include requirements*.txt +include docs/requirements.txt recursive-include tests * recursive-exclude * __pycache__ From 52d8c39142cbdfcf096010360df22f231d359c56 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 12:03:54 +0000 Subject: [PATCH 58/61] Removing tox configuration for 0.2.* releases (present in 0.3 --- tox.ini | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 tox.ini diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 253c2caa..00000000 --- a/tox.ini +++ /dev/null @@ -1,24 +0,0 @@ -[tox] -envlist = py37, py38, flake8 - -[travis] -python = - 3.8: py38 - 3.7: py37 - -[testenv:flake8] -basepython = python -deps = flake8 -commands = flake8 icenet tests - -[testenv] -setenv = - PYTHONPATH = {toxinidir} -deps = - -r{toxinidir}/requirements_dev.txt -; If you want to make tox run the tests with the same versions, create a -; requirements.txt with the pinned versions and uncomment the following line: -; -r{toxinidir}/requirements.txt -commands = - pip install -U pip - pytest --basetemp={envtmpdir} From c4ef2da43405c6fc4f04db5efa1cad0d847d0410 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 12:35:37 +0000 Subject: [PATCH 59/61] Dealing with linting changes and configuration --- .pre-commit-config.yaml | 17 +++++----- icenet/data/loaders/dask.py | 8 +++-- icenet/data/process.py | 36 +++++++++++--------- icenet/data/producers.py | 12 ++++--- icenet/data/sic/mask.py | 6 ++-- icenet/model/models.py | 3 +- icenet/model/train.py | 18 +++++----- icenet/plotting/forecast.py | 55 ++++++++++++++++++------------- icenet/process/predict.py | 17 +++++----- icenet/tests/test_entry_points.py | 2 +- setup.cfg | 2 +- 11 files changed, 100 insertions(+), 76 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43172b40..08a9e96a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,14 +17,15 @@ repos: # - id: isort # args: ["--filter-files"] - # yapf - Formatting - - repo: https://github.com/google/yapf - rev: v0.40.2 - hooks: - - id: yapf - name: "yapf" - args: ["--in-place", "--parallel"] - exclude: "docs/" + # yapf - is doing my head in with making modifications, so removing as it make non-pep8 + # compliant changes + #- repo: https://github.com/google/yapf + # rev: v0.40.2 + # hooks: + # - id: yapf + # name: "yapf" + # args: ["--in-place", "--parallel"] + # exclude: "docs/" # ruff - Linting - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/icenet/data/loaders/dask.py b/icenet/data/loaders/dask.py index e909e627..82ec57f0 100644 --- a/icenet/data/loaders/dask.py +++ b/icenet/data/loaders/dask.py @@ -243,7 +243,10 @@ def batch(batch_dates, num): np.average(exec_times))) self._write_dataset_config(counts) - def generate_sample(self, date: object, prediction: bool = False, parallel=True): + def generate_sample(self, + date: object, + prediction: bool = False, + parallel=True): """ :param date: @@ -416,8 +419,7 @@ def generate_sample(forecast_date: object, for leadtime_idx in range(n_forecast_days): forecast_day = forecast_date + dt.timedelta(days=leadtime_idx) - if any([forecast_day == missing_date for missing_date in missing_dates - ]): + if any([forecast_day == missing_date for missing_date in missing_dates]): sample_weight = da.zeros(shape, dtype) else: # Zero loss outside of 'active grid cells' diff --git a/icenet/data/process.py b/icenet/data/process.py index 6221a066..224a2fc9 100644 --- a/icenet/data/process.py +++ b/icenet/data/process.py @@ -201,7 +201,8 @@ def _serialize(x): configuration = {"sources": {}} if os.path.exists(self._update_loader): - logging.info("Loading configuration {}".format(self._update_loader)) + logging.info("Loading configuration {}".format( + self._update_loader)) with open(self._update_loader, "r") as fh: obj = json.load(fh) configuration.update(obj) @@ -254,8 +255,9 @@ def _save_variable(self, var_name: str, var_suffix: str): clim_path = os.path.join(self._refdir, "params", "climatology.{}".format(var_name)) else: - clim_path = os.path.join(self.get_data_var_folder("params"), - "climatology.{}".format(var_name)) + clim_path = os.path.join( + self.get_data_var_folder("params"), + "climatology.{}".format(var_name)) if not os.path.exists(clim_path): logging.info("Generating climatology {}".format(clim_path)) @@ -279,11 +281,12 @@ def _save_variable(self, var_name: str, var_suffix: str): logging.warning( "We don't have a full climatology ({}) " "compared with data ({})".format( - ",".join([str(i) for i in climatology.month.values - ]), ",".join([ - str(i) for i in da.groupby( - "time.month").all().month.values - ]))) + ",".join( + [str(i) for i in climatology.month.values]), + ",".join([ + str(i) for i in da.groupby( + "time.month").all().month.values + ]))) da = da - climatology.mean() else: da = da.groupby("time.month") - climatology @@ -303,9 +306,10 @@ def _save_variable(self, var_name: str, var_suffix: str): ref_da = None if self._refdir: - logging.info("We have a reference {}, so will load " - "and supply abs from that for linear trend of " - "{}".format(self._refdir, var_name)) + logging.info( + "We have a reference {}, so will load " + "and supply abs from that for linear trend of " + "{}".format(self._refdir, var_name)) ref_da = xr.open_dataarray( os.path.join(self._refdir, var_name, "{}_{}.nc".format(var_name, var_suffix))) @@ -389,7 +393,8 @@ def _open_dataarray_from_files(self, var_name: str): # transferring logging.warning("Data selection failed, likely not daily sampled " "data so will give that a try") - da = da.resample(time="1D").mean().sel(time=da_dates).sortby("time") + da = da.resample(time="1D").mean().sel( + time=da_dates).sortby("time") logging.info("Filtered to {} units long based on configuration " "requirements".format(len(da.time))) @@ -412,8 +417,8 @@ def mean_and_std(array: object): mean = np.nanmean(array) std = np.nanstd(array) - logging.info("Mean: {:.3f}, std: {:.3f}".format(mean.item(), - std.item())) + logging.info("Mean: {:.3f}, std: {:.3f}".format( + mean.item(), std.item())) return mean, std @@ -446,7 +451,8 @@ def _normalise_array_mean(self, var_name: str, da: object): logging.debug( "Loading norm-average mean-std from {}".format(mean_path)) mean, std = tuple([ - self._dtype(el) for el in open(mean_path, "r").read().split(",") + self._dtype(el) + for el in open(mean_path, "r").read().split(",") ]) elif self._dates.train: logging.debug("Generating norm-average mean-std from {} training " diff --git a/icenet/data/producers.py b/icenet/data/producers.py index defe9b29..6fd763d7 100644 --- a/icenet/data/producers.py +++ b/icenet/data/producers.py @@ -47,7 +47,7 @@ def __init__(self, self._identifier: object = identifier self._path: str = os.path.join(path, identifier) self._hemisphere: Hemisphere = (Hemisphere.NORTH if north else Hemisphere.NONE) | \ - (Hemisphere.SOUTH if south else Hemisphere.NONE) + (Hemisphere.SOUTH if south else Hemisphere.NONE) assert self._identifier, "No identifier supplied" assert self._hemisphere != Hemisphere.NONE, "No hemispheres selected" @@ -307,7 +307,8 @@ def init_source_data(self, lag_days: object = None) -> None: ] dt_series = pd.Series(dfs, index=data_dates) - logging.debug("Create structure of {} files".format(len(dt_series))) + logging.debug("Create structure of {} files".format( + len(dt_series))) # Ensure we're ordered, it has repercussions for xarray for date in sorted(dates): @@ -344,11 +345,12 @@ def init_source_data(self, lag_days: object = None) -> None: # TODO: allow option to ditch dates from train/val/test for missing # var files self._var_files = { - var: var_files[var] for var in sorted(var_files.keys()) + var: var_files[var] + for var in sorted(var_files.keys()) } for var in self._var_files.keys(): - logging.info("Got {} files for {}".format(len(self._var_files[var]), - var)) + logging.info("Got {} files for {}".format( + len(self._var_files[var]), var)) @abstractmethod def process(self): diff --git a/icenet/data/sic/mask.py b/icenet/data/sic/mask.py index 7e1a23bf..2014fb11 100644 --- a/icenet/data/sic/mask.py +++ b/icenet/data/sic/mask.py @@ -127,8 +127,8 @@ def generate(self, if not os.path.exists(month_path): run_command( - retrieve_cmd_template_osi450.format(siconca_folder, year, - month, filename_osi450)) + retrieve_cmd_template_osi450.format( + siconca_folder, year, month, filename_osi450)) else: logging.info( "siconca {} already exists".format(filename_osi450)) @@ -141,7 +141,7 @@ def generate(self, binary = np.unpackbits(status_flag, axis=1).\ reshape(*self._shape, 8) - #TODO: Add source/explanation for these magic numbers (index slicing nos.). + # TODO: Add source/explanation for these magic numbers (index slicing nos.). # Mask out: land, lake, and 'outside max climatology' (open sea) max_extent_mask = np.sum(binary[:, :, [7, 6, 0]], axis=2).reshape(*self._shape) >= 1 diff --git a/icenet/model/models.py b/icenet/model/models.py index 9447cd52..b4f50d16 100644 --- a/icenet/model/models.py +++ b/icenet/model/models.py @@ -243,7 +243,8 @@ def linear_trend_forecast( da: object, mask: object, missing_dates: object = (), - shape: object = (432, 432)) -> object: + shape: object = (432, 432) +) -> object: """ :param usable_selector: diff --git a/icenet/model/train.py b/icenet/model/train.py index 7a543522..57bfffe0 100644 --- a/icenet/model/train.py +++ b/icenet/model/train.py @@ -244,7 +244,8 @@ def evaluate_model(model_path: object, metrics.WeightedRMSE, ] metrics_list = [ - cls(leadtime_idx=lt - 1) for lt in lead_times for cls in metrics_classes + cls(leadtime_idx=lt - 1) for lt in lead_times + for cls in metrics_classes ] network.compile(weighted_metrics=metrics_list) @@ -283,7 +284,10 @@ def get_args(): ap.add_argument("-b", "--batch-size", type=int, default=4) ap.add_argument("-ca", "--checkpoint-mode", default="min", type=str) - ap.add_argument("-cm", "--checkpoint-monitor", default="val_rmse", type=str) + ap.add_argument("-cm", + "--checkpoint-monitor", + default="val_rmse", + type=str) ap.add_argument("-ds", "--additional-dataset", dest="additional", @@ -394,9 +398,8 @@ def main(): name="{}.{}".format(args.run_name, args.seed), notes="{}: run at {}{}".format( args.run_name, - dt.datetime.now().strftime("%D %T"), - "" if not args.preload is not None else " preload {}".format( - args.preload)), + dt.datetime.now().strftime("%D %T"), "" if args.preload is None + else " preload {}".format(args.preload)), entity=args.wandb_user, config=dict( seed=args.seed, @@ -465,9 +468,8 @@ def main(): if using_wandb: logging.info("Updating wandb run with evaluation metrics") - metric_vals = [ - [results[f'{name}{lt}'] for lt in leads] for name in metric_names - ] + metric_vals = [[results[f'{name}{lt}'] for lt in leads] + for name in metric_names] table_data = list(zip(leads, *metric_vals)) table = wandb.Table(data=table_data, columns=['leadtime', *metric_names]) diff --git a/icenet/plotting/forecast.py b/icenet/plotting/forecast.py index c5f26888..8c25c293 100644 --- a/icenet/plotting/forecast.py +++ b/icenet/plotting/forecast.py @@ -495,8 +495,8 @@ def compute_metric_as_dataframe(metric: object, masks: object, target_dayofyear = target_date.dt.dayofyear # obtain day of year using same method above to avoid any leap-year issues target_dayofyear = pd.Series([ - 59 if d.strftime("%m-%d") == "02-29" else d.replace(year=2001).dayofyear - for d in target_date + 59 if d.strftime("%m-%d") == "02-29" else d.replace( + year=2001).dayofyear for d in target_date ]) target_month = target_date.dt.month return pd.concat([ @@ -623,7 +623,8 @@ def compute_metrics_leadtime_avg(metric: str, except OSError: # don't break if not successful, still return dataframe logging.info( - "Save not successful! Make sure the data_path directory exists") + "Save not successful! Make sure the data_path directory exists" + ) return fc_metric_df.reset_index(drop=True) @@ -739,7 +740,7 @@ def standard_deviation_heatmap(metric: str, # compute standard deviation of metric fc_std_metric = metrics_df.groupby([groupby_col, "leadtime"]).std(numeric_only=True).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ - sort_values(groupby_col, ascending=True) + sort_values(groupby_col, ascending=True) n_forecast_days = fc_std_metric.shape[1] # set ylabel (if average_over == "all"), or legend label (otherwise) @@ -977,14 +978,14 @@ def plot_metrics_leadtime_avg(metric: str, # compute metric by first grouping the dataframe by groupby_col and leadtime fc_avg_metric = fc_metric_df.groupby([groupby_col, "leadtime"]).mean(metric).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ - sort_values(groupby_col, ascending=True) + sort_values(groupby_col, ascending=True) n_forecast_days = fc_avg_metric.shape[1] if seas_metric_df is not None: # compute the difference in leadtime average to SEAS forecast seas_avg_metric = seas_metric_df.groupby([groupby_col, "leadtime"]).mean(metric).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ - sort_values(groupby_col, ascending=True) + sort_values(groupby_col, ascending=True) heatmap_df_diff = fc_avg_metric - seas_avg_metric max = np.nanmax(np.abs(heatmap_df_diff.values)) @@ -1058,10 +1059,10 @@ def plot_metrics_leadtime_avg(metric: str, # compute the standard deviation of the metric for both the forecast and SEAS5 fc_std_metric = fc_metric_df.groupby([groupby_col, "leadtime"]).std(numeric_only=True).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ - sort_values(groupby_col, ascending=True) + sort_values(groupby_col, ascending=True) seas_std_metric = seas_metric_df.groupby([groupby_col, "leadtime"]).std(numeric_only=True).\ reset_index().pivot(index=groupby_col, columns="leadtime", values=metric).\ - sort_values(groupby_col, ascending=True) + sort_values(groupby_col, ascending=True) # compute the maximum standard deviation to obtain a common scale vmax = np.nanmax([ np.nanmax(fc_std_metric.values), @@ -1118,10 +1119,13 @@ def sic_error_video(fc_da: object, obs_da: object, land_mask: object, obs_plot = obs_da.isel(time=leadtime).to_numpy() diff_plot = diff.isel(time=leadtime).to_numpy() - upper_bound = np.max([np.abs(np.min(diff_plot)), np.abs(np.max(diff_plot))]) + upper_bound = np.max( + [np.abs(np.min(diff_plot)), + np.abs(np.max(diff_plot))]) diff_vmin = -upper_bound diff_vmax = upper_bound - logging.debug("Bounds of differences: {} - {}".format(diff_vmin, diff_vmax)) + logging.debug("Bounds of differences: {} - {}".format( + diff_vmin, diff_vmax)) sic_cmap = mpl.cm.get_cmap("Blues_r", 20) contour_kwargs = dict(vmin=0, vmax=1, cmap=sic_cmap) @@ -1203,15 +1207,18 @@ def sic_error_local_header_data(da: xr.DataArray): return { "probe array index": { i_probe: (f"{da.xi.values[i_probe]}," - f"{da.yi.values[i_probe]}") for i_probe in range(n_probe) + f"{da.yi.values[i_probe]}") + for i_probe in range(n_probe) }, "probe location (EASE)": { i_probe: (f"{da.xc.values[i_probe]}," - f"{da.yc.values[i_probe]}") for i_probe in range(n_probe) + f"{da.yc.values[i_probe]}") + for i_probe in range(n_probe) }, "probe location (lat, lon)": { i_probe: (f"{da.lat.values[i_probe]}," - f"{da.lon.values[i_probe]}") for i_probe in range(n_probe) + f"{da.lon.values[i_probe]}") + for i_probe in range(n_probe) }, "obs_kind": { 0: "forecast", @@ -1267,7 +1274,8 @@ def sic_error_local_write_fig(combined_da: xr.DataArray, output_prefix: str): # dims: (obs_kind, time, probe) ax.plot(plot_series.loc[OBS_KIND_FC, :, i_probe], label="IceNet") - ax.plot(plot_series.loc[OBS_KIND_OBS, :, i_probe], label="Observed") + ax.plot(plot_series.loc[OBS_KIND_OBS, :, i_probe], + label="Observed") ax.legend() plt.setp(ax.get_xticklabels(), rotation=45, ha='right') @@ -1367,7 +1375,10 @@ def __init__(self, *args, forecast_date: bool = True, **kwargs): self.add_argument("forecast_date", type=date_arg) self.add_argument("-o", "--output-path", type=str, default=None) - self.add_argument("-v", "--verbose", action="store_true", default=False) + self.add_argument("-v", + "--verbose", + action="store_true", + default=False) self.add_argument("-r", "--region", default=None, @@ -1629,13 +1640,13 @@ def plot_forecast(): args.forecast_date.strftime("%Y%m%d"), "" if not args.stddev else "stddev.", args.format)) - xarray_to_video( - pred_da, - fps=1, - cmap=cmap, - imshow_kwargs=dict(vmin=0., vmax=vmax) if not args.stddev else None, - video_path=output_filename, - **anim_args) + xarray_to_video(pred_da, + fps=1, + cmap=cmap, + imshow_kwargs=dict(vmin=0., vmax=vmax) + if not args.stddev else None, + video_path=output_filename, + **anim_args) else: for leadtime in leadtimes: pred_da = fc.sel(leadtime=leadtime).isel(time=0) diff --git a/icenet/process/predict.py b/icenet/process/predict.py index 9e6c2734..985f0c32 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -27,7 +27,8 @@ def get_refsic(north: bool = True, south: bool = False) -> object: str = "nh" if north else "sh" - sic_day_fname = 'ice_conc_{}_ease2-250_cdr-v2p0_197901021200.nc'.format(str) + sic_day_fname = 'ice_conc_{}_ease2-250_cdr-v2p0_197901021200.nc'.format( + str) sic_day_path = os.path.join(".", "_sicfile") if not os.path.exists(os.path.join(sic_day_path, sic_day_fname)): @@ -148,8 +149,7 @@ def create_cf_output(): ref_cube = get_refcube(ds.north, ds.south) dates = [ - dt.date(*[int(v) - for v in s.split("-")]) + dt.date(*[int(v) for v in s.split("-")]) for s in args.datefile.read().split() ] args.datefile.close() @@ -202,8 +202,7 @@ def create_cf_output(): lists_of_fcast_dates = [[ pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) for lead_idx in np.arange(1, arr.shape[3] + 1, 1) - ] - for date in dates] + ] for date in dates] xarr = xr.Dataset( data_vars=dict( @@ -277,11 +276,11 @@ def create_cf_output(): # Use ISO 8601:2004 duration format, preferably the extended format # as recommended in the Attribute Content Guidance section. time_coverage_start=min( - set([item for row in lists_of_fcast_dates for item in row - ])).isoformat(), + set([item for row in lists_of_fcast_dates + for item in row])).isoformat(), time_coverage_end=max( - set([item for row in lists_of_fcast_dates for item in row - ])).isoformat(), + set([item for row in lists_of_fcast_dates + for item in row])).isoformat(), time_coverage_duration="P1D", time_coverage_resolution="P1D", title="Sea Ice Concentration Prediction", diff --git a/icenet/tests/test_entry_points.py b/icenet/tests/test_entry_points.py index c0d440fa..07224199 100644 --- a/icenet/tests/test_entry_points.py +++ b/icenet/tests/test_entry_points.py @@ -5,7 +5,7 @@ icenet_entry_points = [ ep for ep in entry_points(group="console_scripts") - if ep.module.startswith('icenet') + if ep.module.startswith("icenet") ] diff --git a/setup.cfg b/setup.cfg index 5ff36522..0327510d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,7 +11,7 @@ atomic = True exclude = docs [yapf] -based_on_style = google +based_on_style = pep8 [tool:pytest] collect_ignore = ['setup.py'] From 8cc48623567f9c6f7b2a2ac8b87f68d690690bd0 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 12:47:10 +0000 Subject: [PATCH 60/61] Updating contribution, history and build commands --- CONTRIBUTING.rst | 27 +++++++++++++-------------- HISTORY.rst | 5 +++-- Makefile | 2 +- docs/classes_icenet.puml | 11 ++++++----- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index e2aa67af..135be7d4 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -68,7 +68,7 @@ Ready to contribute? Here's how to set up `icenet` for local development. $ mkvirtualenv icenet $ cd icenet/ - $ python setup.py develop + $ pip install -e . 4. Create a branch for local development:: @@ -78,17 +78,16 @@ Ready to contribute? Here's how to set up `icenet` for local development. 5. Install development packages:: - $ pip install -r requirements.txt -r requirements_dev.txt + $ pip install -r requirements_dev.txt 6. Set up pre-commit hooks to run automatically. This will run through linting checks, formatting, and pytest. It will format new code using yapf and prevent code committing that does not pass linting or testing checks until fixed:: $ pre-commit install -7. Run through tox (currently omitted from pre-commit hook) to test other Python versions (Optionally, can replace with tox-conda, and run same command):: +7. When you're done making changes, check that your changes pass flake8 and the tests:: - $ tox - - To get tox, just pip install them into your virtualenv (or tox-conda for conda environment). + $ make lint + $ pytest 8. Commit your changes and push your branch to GitHub:: @@ -96,7 +95,7 @@ Ready to contribute? Here's how to set up `icenet` for local development. $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature - Note: When committing, if pre-commit is installed, there might be formatting changes made by yapf and the commit prevented. In this case, add the file(s) modified by the formatter to the staging area and commit again. + Note: When committing, if pre-commit is installed, the commit might be prevented if there are problems with formatting. In this case, deal with the file(s) and commit again. 9. Submit a pull request through the GitHub website. @@ -119,12 +118,12 @@ TODO Deploying --------- -A reminder for the maintainers on how to deploy. -Make sure all your changes are committed (including an entry in HISTORY.rst). -Then run:: +A reminder for the maintainers on how to deploy:: -# Update icenet/__init__.py -$ git push -$ git push --tags +$ make clean +$ make lint # Ignore black moaning at present +$ make docs +$ make install +$ make release -Travis will then deploy to PyPI if tests pass. +If anything looks really wrong, abandon and fix! diff --git a/HISTORY.rst b/HISTORY.rst index 94dd09be..2a295a09 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,13 +2,14 @@ History ======= -0.2.* (2022-11-22) +0.2.* (2022-2024) ------------------ * First releases to PyPI * Still under development but usable codebase for training and prediction * Use alongside icenet-pipeline and icenet-notebooks -* Proper releases will follow >=0.2.1, otherwise alpha versioned for the mo +* Alpha only, beta releases will follow 0.2 +* 0.2 is maintained for the sake of adoption, with 0.3 to be a beta version 0.1.* (2021-2022) ----------------- diff --git a/Makefile b/Makefile index cb4376bc..ffcb1642 100644 --- a/Makefile +++ b/Makefile @@ -86,4 +86,4 @@ dist: clean ## builds source and wheel package ls -l dist install: clean ## install the package to the active Python's site-packages - python setup.py install + pip install -e . diff --git a/docs/classes_icenet.puml b/docs/classes_icenet.puml index 3a6f4be3..07c57692 100644 --- a/docs/classes_icenet.puml +++ b/docs/classes_icenet.puml @@ -43,8 +43,8 @@ class "ConstructLeadtimeAccuracy" as icenet.model.metrics.ConstructLeadtimeAccur update_state(y_true: object, y_pred: object, sample_weight: object) } class "DaskBaseDataLoader" as icenet.data.loaders.dask.DaskBaseDataLoader #99DDFF { - {abstract}client_generate(client: object, dates_override: object, pickup: bool) - generate() + {abstract}client_generate(client: object, dates_override: object, pickup: bool) -> None + generate() -> None } class "DaskMultiSharingWorkerLoader" as icenet.data.loaders.dask.DaskMultiSharingWorkerLoader #99DDFF { {abstract}client_generate(client: object, dates_override: object, pickup: bool) @@ -52,7 +52,7 @@ class "DaskMultiSharingWorkerLoader" as icenet.data.loaders.dask.DaskMultiSharin } class "DaskMultiWorkerLoader" as icenet.data.loaders.dask.DaskMultiWorkerLoader #99DDFF { client_generate(client: object, dates_override: object, pickup: bool) - generate_sample(date: object, prediction: bool) + generate_sample(date: object, prediction: bool, parallel) } class "DaskWrapper" as icenet.data.sic.osisaf.DaskWrapper #99DDFF { dask_process() @@ -210,6 +210,7 @@ class "SEASDownloader" as icenet.data.interfaces.mars.SEASDownloader #99DDFF { class "SICDownloader" as icenet.data.sic.osisaf.SICDownloader #99DDFF { download() missing_dates() + zero_dates() } class "SplittingMixin" as icenet.data.datasets.utils.SplittingMixin #99DDFF { batch_size @@ -239,12 +240,12 @@ class "WeightedMAE" as icenet.model.metrics.WeightedMAE #44BB99 { result() update_state(y_true: object, y_pred: object, sample_weight: object) } +class "WeightedMSE" as icenet.model.losses.WeightedMSE #44BB99 { +} class "WeightedMSE" as icenet.model.metrics.WeightedMSE #44BB99 { result() update_state(y_true: object, y_pred: object, sample_weight: object) } -class "WeightedMSE" as icenet.model.losses.WeightedMSE #44BB99 { -} class "WeightedRMSE" as icenet.model.metrics.WeightedRMSE #44BB99 { result() update_state(y_true: object, y_pred: object, sample_weight: object) From 3d374f0d3e42f0a1e873ab4e8f6352a70415df42 Mon Sep 17 00:00:00 2001 From: James Byrne Date: Sat, 20 Jan 2024 12:48:05 +0000 Subject: [PATCH 61/61] 0.2.7 version release --- icenet/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/__init__.py b/icenet/__init__.py index dbd51d00..53890a69 100644 --- a/icenet/__init__.py +++ b/icenet/__init__.py @@ -4,4 +4,4 @@ __copyright__ = "British Antarctic Survey" __email__ = "jambyr@bas.ac.uk" __license__ = "MIT" -__version__ = "0.2.7a1" +__version__ = "0.2.7"