Skip to content

Commit

Permalink
Merge pull request #78 from alan-turing-institute/validation-cropping
Browse files Browse the repository at this point in the history
Restricting the validation area to GB
  • Loading branch information
IFenton authored Dec 9, 2024
2 parents 40b6c0a + ad59439 commit b6f0c97
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dev = [
"scipy",
"pytest-mock",
"scikit-image",
"typeguard",
]

[tool.setuptools.package-data]
Expand Down
17 changes: 17 additions & 0 deletions src/cloudcasting/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
"DATA_INTERVAL_SPACING_MINUTES",
"NUM_FORECAST_STEPS",
"NUM_CHANNELS",
"CUTOUT_MASK",
)

from cloudcasting.utils import create_cutout_mask

# These constants were locked as part of the project specification
# 3 hour horecast horizon
FORECAST_HORIZON_MINUTES = 180
Expand All @@ -14,5 +17,19 @@
NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES
# for all 11 low resolution channels
NUM_CHANNELS = 11

# Constants for the larger (original) image
# Image size (height, width)
IMAGE_SIZE_TUPLE = (372, 614)
# Cutout mask (min x, max x, min y, max y)
CUTOUT_MASK_BOUNDARY = (166, 336, 107, 289)
# Create cutout mask
CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE)

# Constants for the smaller (cropped) image
# Cropped image size (height, width)
CROPPED_IMAGE_SIZE_TUPLE = (278, 385)
# Cropped cutout mask (min x, max x, min y, max y)
CROPPED_CUTOUT_MASK_BOUNDARY = (109, 279, 62, 244)
# Create cropped cutout mask
CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, CROPPED_IMAGE_SIZE_TUPLE)
25 changes: 25 additions & 0 deletions src/cloudcasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyproj
import pyresample
import xarray as xr
from numpy.typing import NDArray

from cloudcasting.types import (
BatchInputArray,
Expand Down Expand Up @@ -152,3 +153,27 @@ def numpy_validation_collate_fn(
X_all[i] = X
y_all[i] = y
return X_all, y_all


def create_cutout_mask(
mask_size: tuple[int, int, int, int],
image_size: tuple[int, int],
) -> NDArray[np.float64]:
"""Create a mask with a cutout in the center.
Args:
x: x-coordinate of the center of the cutout
y: y-coordinate of the center of the cutout
width: Width of the mask
height: Height of the mask
mask_size: Size of the cutout
mask_value: Value to fill the mask with
Returns:
np.ndarray: The mask
"""
height, width = image_size
min_x, max_x, min_y, max_y = mask_size

mask = np.empty((height, width), dtype=np.float64)
mask[:] = np.nan
mask[min_y:max_y, min_x:max_x] = 1
return mask
58 changes: 54 additions & 4 deletions src/cloudcasting/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
import cloudcasting
from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed
from cloudcasting.constants import (
CROPPED_CUTOUT_MASK,
CROPPED_CUTOUT_MASK_BOUNDARY,
CROPPED_IMAGE_SIZE_TUPLE,
CUTOUT_MASK,
CUTOUT_MASK_BOUNDARY,
DATA_INTERVAL_SPACING_MINUTES,
FORECAST_HORIZON_MINUTES,
IMAGE_SIZE_TUPLE,
Expand Down Expand Up @@ -145,6 +150,30 @@ def log_prediction_video_to_wandb(
y[mask] = 0
y_hat[mask] = 0

create_box = True

# create a boundary box for the crop
if y.shape[-2:] == IMAGE_SIZE_TUPLE:
boxl, boxr, boxb, boxt = CUTOUT_MASK_BOUNDARY
bsize = IMAGE_SIZE_TUPLE
elif y.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE:
boxl, boxr, boxb, boxt = CROPPED_CUTOUT_MASK_BOUNDARY
bsize = CROPPED_IMAGE_SIZE_TUPLE
else:
create_box = False

if create_box:
# box mask
maskb = np.ones(bsize, dtype=np.float64)
maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge
maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge
maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge
maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge
maskb = maskb[np.newaxis, np.newaxis, :, :]

y = y * maskb
y_hat = y_hat * maskb

# Tranpose the arrays so time is the first dimension and select the channel
# Then flip the frames so they are in the correct orientation for the video
y_frames = y.transpose(1, 2, 3, 0)[:, ::-1, ::-1, channel_ind : channel_ind + 1]
Expand All @@ -168,6 +197,14 @@ def log_prediction_video_to_wandb(

# combine add difference to the video array
video_array = np.concatenate([video_array, diff_ccmap], axis=2)

# Set bounding box to a colour so it is visible
if create_box:
video_array[:, :, :, 0][np.isnan(video_array[:, :, :, 0])] = 250
video_array[:, :, :, 1][np.isnan(video_array[:, :, :, 1])] = 40
video_array[:, :, :, 2][np.isnan(video_array[:, :, :, 2])] = 10
video_array[:, :, :, 3][np.isnan(video_array[:, :, :, 3])] = 255

video_array = video_array.transpose(0, 3, 1, 2)
video_array = video_array.astype(np.uint8)

Expand Down Expand Up @@ -269,25 +306,38 @@ def get_pix_function(
for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps):
y_hat = model(X)

# identify the correct mask / create a mask if necessary
if X.shape[-2:] == IMAGE_SIZE_TUPLE:
mask = CUTOUT_MASK
elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE:
mask = CROPPED_CUTOUT_MASK
else:
mask = np.ones(X.shape[-2:], dtype=np.float64)

# cutout the GB area
mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :]
y_cutout = y * mask_full
y_hat = y_hat * mask_full

# assert shapes are the same
assert y.shape == y_hat.shape, f"{y.shape=} != {y_hat.shape=}"
assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}"

# If nan_to_num is used in the dataset, the model will output -1 for NaNs. We need to
# convert these back to NaNs for the metrics
y[y == -1] = np.nan
y_cutout[y_cutout == -1] = np.nan

# pix accepts arrays of shape [batch, height, width, channels].
# our arrays are of shape [batch, channels, time, height, width].
# channel dim would be reduced; we add a new axis to satisfy the shape reqs.
# we then reshape to squash batch, channels, and time into the leading axis,
# where the vmap in metrics.py will broadcast over the leading dim.
y_jax = jnp.array(y).reshape(-1, *y.shape[-2:])[..., np.newaxis]
y_jax = jnp.array(y_cutout).reshape(-1, *y_cutout.shape[-2:])[..., np.newaxis]
y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis]

for metric_name, metric_func in metric_funcs.items():
# we reshape the result back into [batch, channels, time],
# then take the mean over the batch
metric_res = metric_func(y_hat_jax, y_jax).reshape(*y.shape[:3])
metric_res = metric_func(y_hat_jax, y_jax).reshape(*y_cutout.shape[:3])
batch_reduced_metric = jnp.nanmean(metric_res, axis=0)
metrics[metric_name].append(batch_reduced_metric)

Expand Down

0 comments on commit b6f0c97

Please sign in to comment.