diff --git a/src/cloudcasting/constants.py b/src/cloudcasting/constants.py index 4112319..784ee9b 100644 --- a/src/cloudcasting/constants.py +++ b/src/cloudcasting/constants.py @@ -17,11 +17,19 @@ NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES # for all 11 low resolution channels NUM_CHANNELS = 11 + +# Constants for the larger (original) image # Image size (height, width) IMAGE_SIZE_TUPLE = (372, 614) -# # Cutout coords (min lat, max lat, min lon, max lon) -# CUTOUT_COORDS = (49, 60, -6, 2) # Cutout mask (min x, max x, min y, max y) CUTOUT_MASK_BOUNDARY = (127, 394, 104, 290) # Create cutout mask CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) + +# Constants for the smaller (cropped) image +# Cropped image size (height, width) +CROPPED_IMAGE_SIZE_TUPLE = (278, 385) +# Cropped cutout mask (min x, max x, min y, max y) +CROPPED_CUTOUT_MASK_BOUNDARY = (70, 337, 59, 245) +# Create cropped cutout mask +CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) diff --git a/src/cloudcasting/validation.py b/src/cloudcasting/validation.py index b859b20..e646a09 100644 --- a/src/cloudcasting/validation.py +++ b/src/cloudcasting/validation.py @@ -18,7 +18,6 @@ from jax import tree from jaxtyping import Array, Float32 from matplotlib.colors import Normalize # type: ignore[import-not-found] -from numpy.typing import NDArray from torch.utils.data import DataLoader from tqdm import tqdm @@ -32,6 +31,8 @@ import cloudcasting from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed from cloudcasting.constants import ( + CROPPED_CUTOUT_MASK, + CROPPED_IMAGE_SIZE_TUPLE, CUTOUT_MASK, DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES, @@ -192,7 +193,6 @@ def score_model_on_all_metrics( batch_limit: int | None = None, metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"), metric_kwargs: dict[str, dict[str, Any]] | None = None, - mask: NDArray[np.float64] = CUTOUT_MASK, ) -> tuple[dict[str, MetricArray], list[str]]: """Calculate the scoreboard metrics for the given model on the validation dataset. @@ -207,7 +207,6 @@ def score_model_on_all_metrics( in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim"). metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in cloudcasting.metrics. Defaults to None. - mask (np.ndarray, optional): The mask to apply to the data. Defaults to CUTOUT_MASK. Returns: tuple[dict[str, MetricArray], list[str]]: @@ -273,6 +272,14 @@ def get_pix_function( for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps): y_hat = model(X) + # identify the correct mask / create a mask if necessary + if X.shape[-2:] == IMAGE_SIZE_TUPLE: + mask = CUTOUT_MASK + elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE: + mask = CROPPED_CUTOUT_MASK + else: + mask = np.ones(X.shape[-2:], dtype=np.float64) + # cutout the GB area mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :] y_cutout = y * mask_full @@ -368,7 +375,6 @@ def validate( batch_size: int = 1, num_workers: int = 0, batch_limit: int | None = None, - mask: NDArray[np.float64] = CUTOUT_MASK, ) -> None: """Run the full validation procedure on the model and log the results to wandb. @@ -415,7 +421,6 @@ def validate( batch_size=batch_size, num_workers=num_workers, batch_limit=batch_limit, - mask=mask, ) # Calculate the mean of each metric reduced over forecast horizon and channels diff --git a/tests/test_validation.py b/tests/test_validation.py index 9342ecb..5e44a68 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -8,7 +8,6 @@ NUM_FORECAST_STEPS, ) from cloudcasting.dataset import ValidationSatelliteDataset -from cloudcasting.utils import create_cutout_mask from cloudcasting.validation import ( calc_mean_metrics, score_model_on_all_metrics, @@ -17,18 +16,13 @@ ) -@pytest.fixture() -def test_mask(): - return create_cutout_mask((2, 6, 1, 7), (9, 8)) - - @pytest.fixture() def model(): return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS) @pytest.mark.parametrize("nan_to_num", [True, False]) -def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_num): +def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num): # Create valid dataset valid_dataset = ValidationSatelliteDataset( zarr_path=val_sat_zarr_path, @@ -53,7 +47,6 @@ def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_ batch_limit=3, metric_names=metric_names, metric_kwargs=metric_kwargs, - mask=test_mask, ) # Check all the expected keys are there @@ -87,7 +80,7 @@ def test_calc_mean_metrics(): assert mean_metrics_dict["mse"] == 5 -def test_validate(model, val_sat_zarr_path, test_mask, mocker): +def test_validate(model, val_sat_zarr_path, mocker): # Mock the wandb functions so they aren't run in testing mocker.patch("wandb.login") mocker.patch("wandb.init") @@ -105,11 +98,10 @@ def test_validate(model, val_sat_zarr_path, test_mask, mocker): batch_size=2, num_workers=0, batch_limit=4, - mask=test_mask, ) -def test_validate_cli(val_sat_zarr_path, test_mask, mocker): +def test_validate_cli(val_sat_zarr_path, mocker): # write out an example model.py file with open("model.py", "w") as f: f.write( @@ -149,7 +141,6 @@ def hyperparameters_dict(self): batch_size: 2 num_workers: 0 batch_limit: 4 - mask: "{test_mask}" """ )