From 4741c798c25ec28345777b576133268f58b9b440 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Thu, 14 Nov 2024 10:24:52 +0000 Subject: [PATCH 01/13] First pass at restricting validation area --- src/cloudcasting/constants.py | 3 +++ src/cloudcasting/validation.py | 36 +++++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index ff2dc3b..82ccaaf 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -3,6 +3,7 @@ "DATA_INTERVAL_SPACING_MINUTES", "NUM_FORECAST_STEPS", "NUM_CHANNELS", + "CUTOUT_COORDS", ) # These constants were locked as part of the project specification @@ -16,3 +17,5 @@ NUM_CHANNELS = 11 # Image size (height, width) IMAGE_SIZE_TUPLE = (372, 614) +# Cutout coords (min lat, max lat, min lon, max lon) +CUTOUT_COORDS = (49, 60, -6, 2) diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index fa3967a..c28b118 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -31,6 +31,7 @@ import cloudcasting from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed from cloudcasting.constants import ( + CUTOUT_COORDS, DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES, IMAGE_SIZE_TUPLE, @@ -45,7 +46,7 @@ SampleOutputArray, TimeArray, ) -from cloudcasting.utils import numpy_validation_collate_fn +from cloudcasting.utils import lon_lat_to_geostationary_area_coords, numpy_validation_collate_fn logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -218,6 +219,27 @@ def score_model_on_all_metrics( "(please make an issue on github if you see this!!!!)" ) + # calculate the cutout indices for the dataset + lat_min, lat_max, lon_min, lon_max = CUTOUT_COORDS + (x_min, x_max), (y_min, y_max) = lon_lat_to_geostationary_area_coords( + [lon_min, lon_max], + [lat_min, lat_max], + valid_dataset.ds.data, + ) + + y_vals = np.where( + np.logical_and( + valid_dataset.ds.coords["y_geostationary"] <= y_max, + valid_dataset.ds.coords["y_geostationary"] >= y_min, + ) + )[0] + x_vals = np.where( + np.logical_and( + valid_dataset.ds.coords["x_geostationary"] <= x_max, + valid_dataset.ds.coords["x_geostationary"] >= x_min, + ) + )[0] + valid_dataloader = DataLoader( valid_dataset, batch_size=batch_size, @@ -269,25 +291,29 @@ def get_pix_function( for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps): y_hat = model(X) + # cutout the GB area + y_cutout = y[..., x_vals, y_vals] + y_hat = y_hat[..., x_vals, y_vals] + # assert shapes are the same - assert y.shape == y_hat.shape, f"{y.shape=} != {y_hat.shape=}" + assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}" # If nan_to_num is used in the dataset, the model will output -1 for NaNs. We need to # convert these back to NaNs for the metrics - y[y == -1] = np.nan + y_cutout[y_cutout == -1] = np.nan # pix accepts arrays of shape [batch, height, width, channels]. # our arrays are of shape [batch, channels, time, height, width]. # channel dim would be reduced; we add a new axis to satisfy the shape reqs. # we then reshape to squash batch, channels, and time into the leading axis, # where the vmap in metrics.py will broadcast over the leading dim. - y_jax = jnp.array(y).reshape(-1, *y.shape[-2:])[..., np.newaxis] + y_jax = jnp.array(y_cutout).reshape(-1, *y_cutout.shape[-2:])[..., np.newaxis] y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis] for metric_name, metric_func in metric_funcs.items(): # we reshape the result back into [batch, channels, time], # then take the mean over the batch - metric_res = metric_func(y_hat_jax, y_jax).reshape(*y.shape[:3]) + metric_res = metric_func(y_hat_jax, y_jax).reshape(*y_cutout.shape[:3]) batch_reduced_metric = jnp.nanmean(metric_res, axis=0) metrics[metric_name].append(batch_reduced_metric) From f74a4dd844bc06e04ed3dc7efb2483044d99a053 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Fri, 15 Nov 2024 10:49:54 +0000 Subject: [PATCH 02/13] Fixing indexing issue --- src/cloudcasting/validation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index c28b118..0d07122 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -240,6 +240,8 @@ def score_model_on_all_metrics( ) )[0] + ix = np.ix_(x_vals, y_vals) + valid_dataloader = DataLoader( valid_dataset, batch_size=batch_size, @@ -292,8 +294,8 @@ def get_pix_function( y_hat = model(X) # cutout the GB area - y_cutout = y[..., x_vals, y_vals] - y_hat = y_hat[..., x_vals, y_vals] + y_cutout = y[..., ix[1], ix[0]] + y_hat = y_hat[..., ix[1], ix[0]] # assert shapes are the same assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}" From ca15eb87e311caa7b7037398c1b62ffe19137917 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Wed, 20 Nov 2024 11:26:42 +0000 Subject: [PATCH 03/13] Create a cutout mask based on pixel values --- src/cloudcasting/constants.py | 12 ++++++++--- src/cloudcasting/utils.py | 25 +++++++++++++++++++++++ src/cloudcasting/validation.py | 37 +++++++++------------------------- tests/test_validation.py | 6 ++++++ 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index 82ccaaf..4112319 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -3,9 +3,11 @@ "DATA_INTERVAL_SPACING_MINUTES", "NUM_FORECAST_STEPS", "NUM_CHANNELS", - "CUTOUT_COORDS", + "CUTOUT_MASK", ) +from cloudcasting.utils import create_cutout_mask + # These constants were locked as part of the project specification # 3 hour horecast horizon FORECAST_HORIZON_MINUTES = 180 @@ -17,5 +19,9 @@ NUM_CHANNELS = 11 # Image size (height, width) IMAGE_SIZE_TUPLE = (372, 614) -# Cutout coords (min lat, max lat, min lon, max lon) -CUTOUT_COORDS = (49, 60, -6, 2) +# # Cutout coords (min lat, max lat, min lon, max lon) +# CUTOUT_COORDS = (49, 60, -6, 2) +# Cutout mask (min x, max x, min y, max y) +CUTOUT_MASK_BOUNDARY = (127, 394, 104, 290) +# Create cutout mask +CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) diff --git a/src/cloudcasting/utils.py b/src/cloudcasting/utils.py index 19ce71f..8a51d05 100644 --- a/src/cloudcasting/utils.py +++ b/src/cloudcasting/utils.py @@ -13,6 +13,7 @@ import pyproj import pyresample import xarray as xr +from numpy.typing import NDArray from cloudcasting.types import ( BatchInputArray, @@ -152,3 +153,27 @@ def numpy_validation_collate_fn( X_all[i] = X y_all[i] = y return X_all, y_all + + +def create_cutout_mask( + mask_size: tuple[int, int, int, int], + image_size: tuple[int, int], +) -> NDArray[np.int8]: + """Create a mask with a cutout in the center. + Args: + x: x-coordinate of the center of the cutout + y: y-coordinate of the center of the cutout + width: Width of the mask + height: Height of the mask + mask_size: Size of the cutout + mask_value: Value to fill the mask with + Returns: + np.ndarray: The mask + """ + height, width = image_size + min_x, max_x, min_y, max_y = mask_size + + mask = np.empty((height, width), dtype=np.int8) + mask[:] = np.nan + mask[min_y:max_y, min_x:max_x] = 1 + return mask diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index 0d07122..fd4e2dd 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -18,6 +18,7 @@ from jax import tree from jaxtyping import Array, Float32 from matplotlib.colors import Normalize # type: ignore[import-not-found] +from numpy.typing import NDArray from torch.utils.data import DataLoader from tqdm import tqdm @@ -31,7 +32,7 @@ import cloudcasting from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed from cloudcasting.constants import ( - CUTOUT_COORDS, + CUTOUT_MASK, DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES, IMAGE_SIZE_TUPLE, @@ -46,7 +47,7 @@ SampleOutputArray, TimeArray, ) -from cloudcasting.utils import lon_lat_to_geostationary_area_coords, numpy_validation_collate_fn +from cloudcasting.utils import numpy_validation_collate_fn logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) @@ -191,6 +192,7 @@ def score_model_on_all_metrics( batch_limit: int | None = None, metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"), metric_kwargs: dict[str, dict[str, Any]] | None = None, + mask: NDArray[np.int8] = CUTOUT_MASK, ) -> tuple[dict[str, MetricArray], list[str]]: """Calculate the scoreboard metrics for the given model on the validation dataset. @@ -205,6 +207,7 @@ def score_model_on_all_metrics( in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim"). metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in cloudcasting.metrics. Defaults to None. + mask (np.ndarray, optional): The mask to apply to the data. Defaults to CUTOUT_MASK. Returns: tuple[dict[str, MetricArray], list[str]]: @@ -219,29 +222,6 @@ def score_model_on_all_metrics( "(please make an issue on github if you see this!!!!)" ) - # calculate the cutout indices for the dataset - lat_min, lat_max, lon_min, lon_max = CUTOUT_COORDS - (x_min, x_max), (y_min, y_max) = lon_lat_to_geostationary_area_coords( - [lon_min, lon_max], - [lat_min, lat_max], - valid_dataset.ds.data, - ) - - y_vals = np.where( - np.logical_and( - valid_dataset.ds.coords["y_geostationary"] <= y_max, - valid_dataset.ds.coords["y_geostationary"] >= y_min, - ) - )[0] - x_vals = np.where( - np.logical_and( - valid_dataset.ds.coords["x_geostationary"] <= x_max, - valid_dataset.ds.coords["x_geostationary"] >= x_min, - ) - )[0] - - ix = np.ix_(x_vals, y_vals) - valid_dataloader = DataLoader( valid_dataset, batch_size=batch_size, @@ -294,8 +274,9 @@ def get_pix_function( y_hat = model(X) # cutout the GB area - y_cutout = y[..., ix[1], ix[0]] - y_hat = y_hat[..., ix[1], ix[0]] + mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :] + y_cutout = y * mask_full + y_hat = y_hat * mask_full # assert shapes are the same assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}" @@ -387,6 +368,7 @@ def validate( batch_size: int = 1, num_workers: int = 0, batch_limit: int | None = None, + mask: NDArray[np.int8] = CUTOUT_MASK, ) -> None: """Run the full validation procedure on the model and log the results to wandb. @@ -433,6 +415,7 @@ def validate( batch_size=batch_size, num_workers=num_workers, batch_limit=batch_limit, + mask=mask, ) # Calculate the mean of each metric reduced over forecast horizon and channels diff --git a/tests/test_validation.py b/tests/test_validation.py index 06793b3..66d3fa5 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -8,6 +8,7 @@ NUM_FORECAST_STEPS, ) from cloudcasting.dataset import ValidationSatelliteDataset +from cloudcasting.utils import create_cutout_mask from cloudcasting.validation import ( calc_mean_metrics, score_model_on_all_metrics, @@ -15,6 +16,8 @@ validate_from_config, ) +test_mask = create_cutout_mask((2, 6, 1, 7), (9, 8)) + @pytest.fixture() def model(): @@ -47,6 +50,7 @@ def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num): batch_limit=3, metric_names=metric_names, metric_kwargs=metric_kwargs, + mask=test_mask, ) # Check all the expected keys are there @@ -98,6 +102,7 @@ def test_validate(model, val_sat_zarr_path, mocker): batch_size=2, num_workers=0, batch_limit=4, + mask=test_mask, ) @@ -141,6 +146,7 @@ def hyperparameters_dict(self): batch_size: 2 num_workers: 0 batch_limit: 4 + mask: {test_mask} """ ) From 727b349df2c78adbc78581163ca1c7af22d77817 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Wed, 20 Nov 2024 11:34:44 +0000 Subject: [PATCH 04/13] Correcting type --- src/cloudcasting/utils.py | 4 ++-- src/cloudcasting/validation.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cloudcasting/utils.py b/src/cloudcasting/utils.py index 8a51d05..2abf46c 100644 --- a/src/cloudcasting/utils.py +++ b/src/cloudcasting/utils.py @@ -158,7 +158,7 @@ def numpy_validation_collate_fn( def create_cutout_mask( mask_size: tuple[int, int, int, int], image_size: tuple[int, int], -) -> NDArray[np.int8]: +) -> NDArray[np.float64]: """Create a mask with a cutout in the center. Args: x: x-coordinate of the center of the cutout @@ -173,7 +173,7 @@ def create_cutout_mask( height, width = image_size min_x, max_x, min_y, max_y = mask_size - mask = np.empty((height, width), dtype=np.int8) + mask = np.empty((height, width), dtype=np.float64) mask[:] = np.nan mask[min_y:max_y, min_x:max_x] = 1 return mask diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index fd4e2dd..b859b20 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -192,7 +192,7 @@ def score_model_on_all_metrics( batch_limit: int | None = None, metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"), metric_kwargs: dict[str, dict[str, Any]] | None = None, - mask: NDArray[np.int8] = CUTOUT_MASK, + mask: NDArray[np.float64] = CUTOUT_MASK, ) -> tuple[dict[str, MetricArray], list[str]]: """Calculate the scoreboard metrics for the given model on the validation dataset. @@ -368,7 +368,7 @@ def validate( batch_size: int = 1, num_workers: int = 0, batch_limit: int | None = None, - mask: NDArray[np.int8] = CUTOUT_MASK, + mask: NDArray[np.float64] = CUTOUT_MASK, ) -> None: """Run the full validation procedure on the model and log the results to wandb. From 86224bbd8365892e9e2503a86ed5c7de2d846b6a Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Thu, 21 Nov 2024 12:08:17 +0000 Subject: [PATCH 05/13] Working out how to handle test_mask in cli --- tests/test_validation.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_validation.py b/tests/test_validation.py index 66d3fa5..9342ecb 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -16,7 +16,10 @@ validate_from_config, ) -test_mask = create_cutout_mask((2, 6, 1, 7), (9, 8)) + +@pytest.fixture() +def test_mask(): + return create_cutout_mask((2, 6, 1, 7), (9, 8)) @pytest.fixture() @@ -25,7 +28,7 @@ def model(): @pytest.mark.parametrize("nan_to_num", [True, False]) -def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num): +def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_num): # Create valid dataset valid_dataset = ValidationSatelliteDataset( zarr_path=val_sat_zarr_path, @@ -84,7 +87,7 @@ def test_calc_mean_metrics(): assert mean_metrics_dict["mse"] == 5 -def test_validate(model, val_sat_zarr_path, mocker): +def test_validate(model, val_sat_zarr_path, test_mask, mocker): # Mock the wandb functions so they aren't run in testing mocker.patch("wandb.login") mocker.patch("wandb.init") @@ -106,7 +109,7 @@ def test_validate(model, val_sat_zarr_path, mocker): ) -def test_validate_cli(val_sat_zarr_path, mocker): +def test_validate_cli(val_sat_zarr_path, test_mask, mocker): # write out an example model.py file with open("model.py", "w") as f: f.write( @@ -139,14 +142,14 @@ def hyperparameters_dict(self): history_steps: 1 sigma: 0.1 validation: - data_path: {val_sat_zarr_path} + data_path: "{val_sat_zarr_path}" wandb_project_name: cloudcasting-pytest wandb_run_name: test_validate nan_to_num: False batch_size: 2 num_workers: 0 batch_limit: 4 - mask: {test_mask} + mask: "{test_mask}" """ ) From f7c641374ea65c1f7a4dbcabcea9328b75fbd547 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Thu, 21 Nov 2024 12:26:44 +0000 Subject: [PATCH 06/13] Adding typeguard --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 73fa130..05e0358 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dev = [ "scipy", "pytest-mock", "scikit-image", + "typeguard", ] [tool.setuptools.package-data] From 2a01efae0de64612074a25bfa52f5042f63a99ee Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Mon, 2 Dec 2024 10:17:33 +0000 Subject: [PATCH 07/13] New attempt at a validation cutout --- src/cloudcasting/constants.py | 12 ++++++++++-- src/cloudcasting/validation.py | 15 ++++++++++----- tests/test_validation.py | 15 +++------------ 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index 4112319..784ee9b 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -17,11 +17,19 @@ NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES # for all 11 low resolution channels NUM_CHANNELS = 11 + +# Constants for the larger (original) image # Image size (height, width) IMAGE_SIZE_TUPLE = (372, 614) -# # Cutout coords (min lat, max lat, min lon, max lon) -# CUTOUT_COORDS = (49, 60, -6, 2) # Cutout mask (min x, max x, min y, max y) CUTOUT_MASK_BOUNDARY = (127, 394, 104, 290) # Create cutout mask CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) + +# Constants for the smaller (cropped) image +# Cropped image size (height, width) +CROPPED_IMAGE_SIZE_TUPLE = (278, 385) +# Cropped cutout mask (min x, max x, min y, max y) +CROPPED_CUTOUT_MASK_BOUNDARY = (70, 337, 59, 245) +# Create cropped cutout mask +CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index b859b20..e646a09 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -18,7 +18,6 @@ from jax import tree from jaxtyping import Array, Float32 from matplotlib.colors import Normalize # type: ignore[import-not-found] -from numpy.typing import NDArray from torch.utils.data import DataLoader from tqdm import tqdm @@ -32,6 +31,8 @@ import cloudcasting from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed from cloudcasting.constants import ( + CROPPED_CUTOUT_MASK, + CROPPED_IMAGE_SIZE_TUPLE, CUTOUT_MASK, DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES, @@ -192,7 +193,6 @@ def score_model_on_all_metrics( batch_limit: int | None = None, metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"), metric_kwargs: dict[str, dict[str, Any]] | None = None, - mask: NDArray[np.float64] = CUTOUT_MASK, ) -> tuple[dict[str, MetricArray], list[str]]: """Calculate the scoreboard metrics for the given model on the validation dataset. @@ -207,7 +207,6 @@ def score_model_on_all_metrics( in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim"). metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in cloudcasting.metrics. Defaults to None. - mask (np.ndarray, optional): The mask to apply to the data. Defaults to CUTOUT_MASK. Returns: tuple[dict[str, MetricArray], list[str]]: @@ -273,6 +272,14 @@ def get_pix_function( for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps): y_hat = model(X) + # identify the correct mask / create a mask if necessary + if X.shape[-2:] == IMAGE_SIZE_TUPLE: + mask = CUTOUT_MASK + elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE: + mask = CROPPED_CUTOUT_MASK + else: + mask = np.ones(X.shape[-2:], dtype=np.float64) + # cutout the GB area mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :] y_cutout = y * mask_full @@ -368,7 +375,6 @@ def validate( batch_size: int = 1, num_workers: int = 0, batch_limit: int | None = None, - mask: NDArray[np.float64] = CUTOUT_MASK, ) -> None: """Run the full validation procedure on the model and log the results to wandb. @@ -415,7 +421,6 @@ def validate( batch_size=batch_size, num_workers=num_workers, batch_limit=batch_limit, - mask=mask, ) # Calculate the mean of each metric reduced over forecast horizon and channels diff --git a/tests/test_validation.py b/tests/test_validation.py index 9342ecb..5e44a68 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -8,7 +8,6 @@ NUM_FORECAST_STEPS, ) from cloudcasting.dataset import ValidationSatelliteDataset -from cloudcasting.utils import create_cutout_mask from cloudcasting.validation import ( calc_mean_metrics, score_model_on_all_metrics, @@ -17,18 +16,13 @@ ) -@pytest.fixture() -def test_mask(): - return create_cutout_mask((2, 6, 1, 7), (9, 8)) - - @pytest.fixture() def model(): return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS) @pytest.mark.parametrize("nan_to_num", [True, False]) -def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_num): +def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num): # Create valid dataset valid_dataset = ValidationSatelliteDataset( zarr_path=val_sat_zarr_path, @@ -53,7 +47,6 @@ def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_ batch_limit=3, metric_names=metric_names, metric_kwargs=metric_kwargs, - mask=test_mask, ) # Check all the expected keys are there @@ -87,7 +80,7 @@ def test_calc_mean_metrics(): assert mean_metrics_dict["mse"] == 5 -def test_validate(model, val_sat_zarr_path, test_mask, mocker): +def test_validate(model, val_sat_zarr_path, mocker): # Mock the wandb functions so they aren't run in testing mocker.patch("wandb.login") mocker.patch("wandb.init") @@ -105,11 +98,10 @@ def test_validate(model, val_sat_zarr_path, test_mask, mocker): batch_size=2, num_workers=0, batch_limit=4, - mask=test_mask, ) -def test_validate_cli(val_sat_zarr_path, test_mask, mocker): +def test_validate_cli(val_sat_zarr_path, mocker): # write out an example model.py file with open("model.py", "w") as f: f.write( @@ -149,7 +141,6 @@ def hyperparameters_dict(self): batch_size: 2 num_workers: 0 batch_limit: 4 - mask: "{test_mask}" """ ) From 799f3a1836da9da3d42a4fc64c1b0ec658bd5826 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Mon, 2 Dec 2024 10:33:21 +0000 Subject: [PATCH 08/13] Removing double quotes in test_validate_cli --- tests/test_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_validation.py b/tests/test_validation.py index 5e44a68..06793b3 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -134,7 +134,7 @@ def hyperparameters_dict(self): history_steps: 1 sigma: 0.1 validation: - data_path: "{val_sat_zarr_path}" + data_path: {val_sat_zarr_path} wandb_project_name: cloudcasting-pytest wandb_run_name: test_validate nan_to_num: False From 1d212f6b1bad539b01fcacad65a1f7f008ac9b10 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Tue, 3 Dec 2024 10:33:38 +0000 Subject: [PATCH 09/13] Correcting dimensions of CROPPED_CUTOUT_MASK --- src/cloudcasting/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index 784ee9b..c489a11 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -32,4 +32,4 @@ # Cropped cutout mask (min x, max x, min y, max y) CROPPED_CUTOUT_MASK_BOUNDARY = (70, 337, 59, 245) # Create cropped cutout mask -CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) +CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, CROPPED_IMAGE_SIZE_TUPLE) From 8fba5a86c668b4d6fc2d5e8c8cad2d219de8b568 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Tue, 3 Dec 2024 12:24:12 +0000 Subject: [PATCH 10/13] Updating the cutout box boundaries --- src/cloudcasting/constants.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index c489a11..8891ad5 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -22,7 +22,7 @@ # Image size (height, width) IMAGE_SIZE_TUPLE = (372, 614) # Cutout mask (min x, max x, min y, max y) -CUTOUT_MASK_BOUNDARY = (127, 394, 104, 290) +CUTOUT_MASK_BOUNDARY = (166, 336, 107, 289) # Create cutout mask CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) @@ -30,6 +30,6 @@ # Cropped image size (height, width) CROPPED_IMAGE_SIZE_TUPLE = (278, 385) # Cropped cutout mask (min x, max x, min y, max y) -CROPPED_CUTOUT_MASK_BOUNDARY = (70, 337, 59, 245) +CROPPED_CUTOUT_MASK_BOUNDARY = (109, 279, 62, 244) # Create cropped cutout mask CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, CROPPED_IMAGE_SIZE_TUPLE) From 6d8330a8891135eea64122a5bae973ad6aa4133e Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Wed, 4 Dec 2024 12:15:34 +0000 Subject: [PATCH 11/13] Adding a bounding box to the video --- src/cloudcasting/validation.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index e646a09..1f9d8eb 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -32,8 +32,10 @@ from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed from cloudcasting.constants import ( CROPPED_CUTOUT_MASK, + CROPPED_CUTOUT_MASK_BOUNDARY, CROPPED_IMAGE_SIZE_TUPLE, CUTOUT_MASK, + CUTOUT_MASK_BOUNDARY, DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES, IMAGE_SIZE_TUPLE, @@ -148,6 +150,30 @@ def log_prediction_video_to_wandb( y[mask] = 0 y_hat[mask] = 0 + create_box = True + + # create a boundary box for the crop + if y.shape[-2:] == IMAGE_SIZE_TUPLE: + boxl, boxr, boxb, boxt = CUTOUT_MASK_BOUNDARY + bsize = IMAGE_SIZE_TUPLE + elif y.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE: + boxl, boxr, boxb, boxt = CROPPED_CUTOUT_MASK_BOUNDARY + bsize = CROPPED_IMAGE_SIZE_TUPLE + else: + create_box = False + + if create_box: + # box mask + maskb = np.ones(bsize, dtype=np.float64) + maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge + maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge + maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge + maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge + maskb = maskb[np.newaxis, np.newaxis, :, :] + + y = y * maskb + y_hat = y_hat * maskb + # Tranpose the arrays so time is the first dimension and select the channel # Then flip the frames so they are in the correct orientation for the video y_frames = y.transpose(1, 2, 3, 0)[:, ::-1, ::-1, channel_ind : channel_ind + 1] @@ -171,6 +197,14 @@ def log_prediction_video_to_wandb( # combine add difference to the video array video_array = np.concatenate([video_array, diff_ccmap], axis=2) + + # Set bounding box to a colour so it is visible + if create_box: + video_array[:, :, :, 0][np.isnan(video_array[:, :, :, 0])] = 250 + video_array[:, :, :, 1][np.isnan(video_array[:, :, :, 1])] = 40 + video_array[:, :, :, 2][np.isnan(video_array[:, :, :, 2])] = 10 + video_array[:, :, :, 3][np.isnan(video_array[:, :, :, 3])] = 255 + video_array = video_array.transpose(0, 3, 1, 2) video_array = video_array.astype(np.uint8) From 47fdb2ff67f66c8409b3d0e4c3749be8063fd247 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Thu, 5 Dec 2024 15:23:48 +0000 Subject: [PATCH 12/13] Temporary change of preshuffle to check code is working as expected --- src/cloudcasting/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudcasting/dataset.py b/src/cloudcasting/dataset.py index ad6fa68..5380f0d 100644 --- a/src/cloudcasting/dataset.py +++ b/src/cloudcasting/dataset.py @@ -220,7 +220,7 @@ def __init__( history_mins=history_mins, forecast_mins=forecast_mins, sample_freq_mins=sample_freq_mins, - preshuffle=True, + preshuffle=False, nan_to_num=nan_to_num, ) From ad5943989baca961ae5e88f95a928ccb894bd7c1 Mon Sep 17 00:00:00 2001 From: Isabel Fenton Date: Fri, 6 Dec 2024 12:30:49 +0000 Subject: [PATCH 13/13] Change of preshuffle back to True after test --- src/cloudcasting/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cloudcasting/dataset.py b/src/cloudcasting/dataset.py index 5380f0d..ad6fa68 100644 --- a/src/cloudcasting/dataset.py +++ b/src/cloudcasting/dataset.py @@ -220,7 +220,7 @@ def __init__( history_mins=history_mins, forecast_mins=forecast_mins, sample_freq_mins=sample_freq_mins, - preshuffle=False, + preshuffle=True, nan_to_num=nan_to_num, )