Skip to content

Commit

Permalink
New attempt at a validation cutout
Browse files Browse the repository at this point in the history
  • Loading branch information
IFenton committed Dec 2, 2024
1 parent 90fdfa5 commit 2a01efa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
12 changes: 10 additions & 2 deletions src/cloudcasting/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@
NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES
# for all 11 low resolution channels
NUM_CHANNELS = 11

# Constants for the larger (original) image
# Image size (height, width)
IMAGE_SIZE_TUPLE = (372, 614)
# # Cutout coords (min lat, max lat, min lon, max lon)
# CUTOUT_COORDS = (49, 60, -6, 2)
# Cutout mask (min x, max x, min y, max y)
CUTOUT_MASK_BOUNDARY = (127, 394, 104, 290)
# Create cutout mask
CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE)

# Constants for the smaller (cropped) image
# Cropped image size (height, width)
CROPPED_IMAGE_SIZE_TUPLE = (278, 385)
# Cropped cutout mask (min x, max x, min y, max y)
CROPPED_CUTOUT_MASK_BOUNDARY = (70, 337, 59, 245)
# Create cropped cutout mask
CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE)
15 changes: 10 additions & 5 deletions src/cloudcasting/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from jax import tree
from jaxtyping import Array, Float32
from matplotlib.colors import Normalize # type: ignore[import-not-found]
from numpy.typing import NDArray
from torch.utils.data import DataLoader
from tqdm import tqdm

Expand All @@ -32,6 +31,8 @@
import cloudcasting
from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed
from cloudcasting.constants import (
CROPPED_CUTOUT_MASK,
CROPPED_IMAGE_SIZE_TUPLE,
CUTOUT_MASK,
DATA_INTERVAL_SPACING_MINUTES,
FORECAST_HORIZON_MINUTES,
Expand Down Expand Up @@ -192,7 +193,6 @@ def score_model_on_all_metrics(
batch_limit: int | None = None,
metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"),
metric_kwargs: dict[str, dict[str, Any]] | None = None,
mask: NDArray[np.float64] = CUTOUT_MASK,
) -> tuple[dict[str, MetricArray], list[str]]:
"""Calculate the scoreboard metrics for the given model on the validation dataset.
Expand All @@ -207,7 +207,6 @@ def score_model_on_all_metrics(
in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim").
metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in
cloudcasting.metrics. Defaults to None.
mask (np.ndarray, optional): The mask to apply to the data. Defaults to CUTOUT_MASK.
Returns:
tuple[dict[str, MetricArray], list[str]]:
Expand Down Expand Up @@ -273,6 +272,14 @@ def get_pix_function(
for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps):
y_hat = model(X)

# identify the correct mask / create a mask if necessary
if X.shape[-2:] == IMAGE_SIZE_TUPLE:
mask = CUTOUT_MASK
elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE:
mask = CROPPED_CUTOUT_MASK
else:
mask = np.ones(X.shape[-2:], dtype=np.float64)

# cutout the GB area
mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :]
y_cutout = y * mask_full
Expand Down Expand Up @@ -368,7 +375,6 @@ def validate(
batch_size: int = 1,
num_workers: int = 0,
batch_limit: int | None = None,
mask: NDArray[np.float64] = CUTOUT_MASK,
) -> None:
"""Run the full validation procedure on the model and log the results to wandb.
Expand Down Expand Up @@ -415,7 +421,6 @@ def validate(
batch_size=batch_size,
num_workers=num_workers,
batch_limit=batch_limit,
mask=mask,
)

# Calculate the mean of each metric reduced over forecast horizon and channels
Expand Down
15 changes: 3 additions & 12 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
NUM_FORECAST_STEPS,
)
from cloudcasting.dataset import ValidationSatelliteDataset
from cloudcasting.utils import create_cutout_mask
from cloudcasting.validation import (
calc_mean_metrics,
score_model_on_all_metrics,
Expand All @@ -17,18 +16,13 @@
)


@pytest.fixture()
def test_mask():
return create_cutout_mask((2, 6, 1, 7), (9, 8))


@pytest.fixture()
def model():
return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS)


@pytest.mark.parametrize("nan_to_num", [True, False])
def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_num):
def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num):
# Create valid dataset
valid_dataset = ValidationSatelliteDataset(
zarr_path=val_sat_zarr_path,
Expand All @@ -53,7 +47,6 @@ def test_score_model_on_all_metrics(model, val_sat_zarr_path, test_mask, nan_to_
batch_limit=3,
metric_names=metric_names,
metric_kwargs=metric_kwargs,
mask=test_mask,
)

# Check all the expected keys are there
Expand Down Expand Up @@ -87,7 +80,7 @@ def test_calc_mean_metrics():
assert mean_metrics_dict["mse"] == 5


def test_validate(model, val_sat_zarr_path, test_mask, mocker):
def test_validate(model, val_sat_zarr_path, mocker):
# Mock the wandb functions so they aren't run in testing
mocker.patch("wandb.login")
mocker.patch("wandb.init")
Expand All @@ -105,11 +98,10 @@ def test_validate(model, val_sat_zarr_path, test_mask, mocker):
batch_size=2,
num_workers=0,
batch_limit=4,
mask=test_mask,
)


def test_validate_cli(val_sat_zarr_path, test_mask, mocker):
def test_validate_cli(val_sat_zarr_path, mocker):
# write out an example model.py file
with open("model.py", "w") as f:
f.write(
Expand Down Expand Up @@ -149,7 +141,6 @@ def hyperparameters_dict(self):
batch_size: 2
num_workers: 0
batch_limit: 4
mask: "{test_mask}"
"""
)

Expand Down

0 comments on commit 2a01efa

Please sign in to comment.