Skip to content

Commit

Permalink
Create a cutout mask based on pixel values
Browse files Browse the repository at this point in the history
  • Loading branch information
IFenton committed Nov 20, 2024
1 parent f74a4dd commit ca15eb8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 30 deletions.
12 changes: 9 additions & 3 deletions src/cloudcasting/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"DATA_INTERVAL_SPACING_MINUTES",
"NUM_FORECAST_STEPS",
"NUM_CHANNELS",
"CUTOUT_COORDS",
"CUTOUT_MASK",
)

from cloudcasting.utils import create_cutout_mask

# These constants were locked as part of the project specification
# 3 hour horecast horizon
FORECAST_HORIZON_MINUTES = 180
Expand All @@ -17,5 +19,9 @@
NUM_CHANNELS = 11
# Image size (height, width)
IMAGE_SIZE_TUPLE = (372, 614)
# Cutout coords (min lat, max lat, min lon, max lon)
CUTOUT_COORDS = (49, 60, -6, 2)
# # Cutout coords (min lat, max lat, min lon, max lon)
# CUTOUT_COORDS = (49, 60, -6, 2)
# Cutout mask (min x, max x, min y, max y)
CUTOUT_MASK_BOUNDARY = (127, 394, 104, 290)
# Create cutout mask
CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE)
25 changes: 25 additions & 0 deletions src/cloudcasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyproj
import pyresample
import xarray as xr
from numpy.typing import NDArray

from cloudcasting.types import (
BatchInputArray,
Expand Down Expand Up @@ -152,3 +153,27 @@ def numpy_validation_collate_fn(
X_all[i] = X
y_all[i] = y
return X_all, y_all


def create_cutout_mask(
mask_size: tuple[int, int, int, int],
image_size: tuple[int, int],
) -> NDArray[np.int8]:
"""Create a mask with a cutout in the center.
Args:
x: x-coordinate of the center of the cutout
y: y-coordinate of the center of the cutout
width: Width of the mask
height: Height of the mask
mask_size: Size of the cutout
mask_value: Value to fill the mask with
Returns:
np.ndarray: The mask
"""
height, width = image_size
min_x, max_x, min_y, max_y = mask_size

mask = np.empty((height, width), dtype=np.int8)
mask[:] = np.nan
mask[min_y:max_y, min_x:max_x] = 1
return mask
37 changes: 10 additions & 27 deletions src/cloudcasting/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jax import tree
from jaxtyping import Array, Float32
from matplotlib.colors import Normalize # type: ignore[import-not-found]
from numpy.typing import NDArray
from torch.utils.data import DataLoader
from tqdm import tqdm

Expand All @@ -31,7 +32,7 @@
import cloudcasting
from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed
from cloudcasting.constants import (
CUTOUT_COORDS,
CUTOUT_MASK,
DATA_INTERVAL_SPACING_MINUTES,
FORECAST_HORIZON_MINUTES,
IMAGE_SIZE_TUPLE,
Expand All @@ -46,7 +47,7 @@
SampleOutputArray,
TimeArray,
)
from cloudcasting.utils import lon_lat_to_geostationary_area_coords, numpy_validation_collate_fn
from cloudcasting.utils import numpy_validation_collate_fn

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -191,6 +192,7 @@ def score_model_on_all_metrics(
batch_limit: int | None = None,
metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"),
metric_kwargs: dict[str, dict[str, Any]] | None = None,
mask: NDArray[np.int8] = CUTOUT_MASK,
) -> tuple[dict[str, MetricArray], list[str]]:
"""Calculate the scoreboard metrics for the given model on the validation dataset.
Expand All @@ -205,6 +207,7 @@ def score_model_on_all_metrics(
in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim").
metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in
cloudcasting.metrics. Defaults to None.
mask (np.ndarray, optional): The mask to apply to the data. Defaults to CUTOUT_MASK.
Returns:
tuple[dict[str, MetricArray], list[str]]:
Expand All @@ -219,29 +222,6 @@ def score_model_on_all_metrics(
"(please make an issue on github if you see this!!!!)"
)

# calculate the cutout indices for the dataset
lat_min, lat_max, lon_min, lon_max = CUTOUT_COORDS
(x_min, x_max), (y_min, y_max) = lon_lat_to_geostationary_area_coords(
[lon_min, lon_max],
[lat_min, lat_max],
valid_dataset.ds.data,
)

y_vals = np.where(
np.logical_and(
valid_dataset.ds.coords["y_geostationary"] <= y_max,
valid_dataset.ds.coords["y_geostationary"] >= y_min,
)
)[0]
x_vals = np.where(
np.logical_and(
valid_dataset.ds.coords["x_geostationary"] <= x_max,
valid_dataset.ds.coords["x_geostationary"] >= x_min,
)
)[0]

ix = np.ix_(x_vals, y_vals)

valid_dataloader = DataLoader(
valid_dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -294,8 +274,9 @@ def get_pix_function(
y_hat = model(X)

# cutout the GB area
y_cutout = y[..., ix[1], ix[0]]
y_hat = y_hat[..., ix[1], ix[0]]
mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :]
y_cutout = y * mask_full
y_hat = y_hat * mask_full

# assert shapes are the same
assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}"
Expand Down Expand Up @@ -387,6 +368,7 @@ def validate(
batch_size: int = 1,
num_workers: int = 0,
batch_limit: int | None = None,
mask: NDArray[np.int8] = CUTOUT_MASK,
) -> None:
"""Run the full validation procedure on the model and log the results to wandb.
Expand Down Expand Up @@ -433,6 +415,7 @@ def validate(
batch_size=batch_size,
num_workers=num_workers,
batch_limit=batch_limit,
mask=mask,
)

# Calculate the mean of each metric reduced over forecast horizon and channels
Expand Down
6 changes: 6 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
NUM_FORECAST_STEPS,
)
from cloudcasting.dataset import ValidationSatelliteDataset
from cloudcasting.utils import create_cutout_mask
from cloudcasting.validation import (
calc_mean_metrics,
score_model_on_all_metrics,
validate,
validate_from_config,
)

test_mask = create_cutout_mask((2, 6, 1, 7), (9, 8))


@pytest.fixture()
def model():
Expand Down Expand Up @@ -47,6 +50,7 @@ def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num):
batch_limit=3,
metric_names=metric_names,
metric_kwargs=metric_kwargs,
mask=test_mask,
)

# Check all the expected keys are there
Expand Down Expand Up @@ -98,6 +102,7 @@ def test_validate(model, val_sat_zarr_path, mocker):
batch_size=2,
num_workers=0,
batch_limit=4,
mask=test_mask,
)


Expand Down Expand Up @@ -141,6 +146,7 @@ def hyperparameters_dict(self):
batch_size: 2
num_workers: 0
batch_limit: 4
mask: {test_mask}
"""
)

Expand Down

0 comments on commit ca15eb8

Please sign in to comment.