Skip to content

Commit

Permalink
one round of formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
phinate committed Jul 22, 2024
1 parent 4e3edf9 commit 662103f
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 304 deletions.
29 changes: 14 additions & 15 deletions scripts/validate_persistence_model_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from cloudcasting.validation import (
validate_model,
AbstractValidationModel
)

from cloudcasting.validation import AbstractValidationModel, validate_model

# -------------------------------------------------
# User settings
Expand All @@ -12,43 +10,44 @@
project = "sat_pred"
run_name = "persistence"

logged_params = {"persistence-method": "last input frame",}
logged_params = {
"persistence-method": "last input frame",
}


class PersistenceModel(AbstractValidationModel):
def __init__(self, forecast_frames: int):
self.forecast_frames = forecast_frames

def forward(self, X: np.ndarray):
"""Predict the latest frame of the input for all future steps
Args:
X: Either a batch or a sample of the most recent satelllite data. X can will be 4 or 5
dimensional. X has shape [(batch), channels, time, height, width]
Returns
np.ndarray: The models predictions of future satellite data
"""
latest_frame = X[..., -1:, :, :].copy()

# The NaN values in the input data are filled with -1. Clip these to zero
latest_frame = latest_frame.clip(0, 1)

y_hat = np.repeat(latest_frame, self.forecast_frames, axis=-3)
return y_hat


model = PersistenceModel(forecast_frames=forecast_mins//sample_freq_mins)
model = PersistenceModel(forecast_frames=forecast_mins // sample_freq_mins)


if __name__=="__main__":
if __name__ == "__main__":
validate_model(
model,
model,
project=project,
run_name=run_name,
batch_size=4,
batch_size=4,
num_workers=0,
val_zarr_path="path/to/test/2023/satellite.zarr",
fast_dev_run=True,
)

122 changes: 55 additions & 67 deletions src/cloudcasting/dataset.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
"""Dataset and DataModule for past and future satellite data"""

from typing import Union
from datetime import datetime, timedelta
from typing import Union

import numpy as np
import pandas as pd
import xarray as xr

import torch

from torch.utils.data import Dataset, DataLoader
import xarray as xr
from lightning.pytorch import LightningDataModule

from ocf_datapipes.load.satellite import _get_single_sat_data
from ocf_datapipes.select.find_contiguous_t0_time_periods import (
find_contiguous_time_periods, find_contiguous_t0_time_periods
find_contiguous_t0_time_periods,
find_contiguous_time_periods,
)
from torch.utils.data import DataLoader, Dataset


def minutes(m: int):
Expand All @@ -25,7 +23,7 @@ def minutes(m: int):

def load_satellite_zarrs(zarr_path):
"""Load the satellite data"""

if isinstance(zarr_path, (list, tuple)):
ds = xr.combine_nested(
[_get_single_sat_data(path) for path in zarr_path],
Expand All @@ -35,7 +33,7 @@ def load_satellite_zarrs(zarr_path):
)
else:
ds = _get_single_sat_data(zarr_path)

return ds


Expand All @@ -45,11 +43,11 @@ def find_valid_t0_times(ds, history_mins, forecast_mins, sample_freq_mins):
# Find periods where we have contiguous time steps
contiguous_time_periods = find_contiguous_time_periods(
datetimes=pd.DatetimeIndex(ds.time),
min_seq_length=int((history_mins + forecast_mins) / sample_freq_mins) + 1,
min_seq_length=int((history_mins + forecast_mins) / sample_freq_mins) + 1,
max_gap_duration=minutes(sample_freq_mins),
)

# Find periods of valid init-times
# Find periods of valid init-times
contiguous_t0_periods = find_contiguous_t0_time_periods(
contiguous_time_periods=contiguous_time_periods,
history_duration=minutes(history_mins),
Expand All @@ -59,31 +57,27 @@ def find_valid_t0_times(ds, history_mins, forecast_mins, sample_freq_mins):
valid_t0_times = []
for _, row in contiguous_t0_periods.iterrows():
valid_t0_times.append(
pd.date_range(
row["start_dt"],
row["end_dt"],
freq=f"{sample_freq_mins}min"
)
pd.date_range(row["start_dt"], row["end_dt"], freq=f"{sample_freq_mins}min")
)

valid_t0_times = pd.to_datetime(np.concatenate(valid_t0_times))

return valid_t0_times


class SatelliteDataset(Dataset):
def __init__(
self,
zarr_path: Union[list, str],
self,
zarr_path: list | str,
start_time: str,
end_time: str,
history_mins: int,
forecast_mins: int,
end_time: str,
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
preshuffle: bool = False
preshuffle: bool = False,
):
"""A torch Dataset for loading past and future satellite data
Args:
zarr_path: Path the satellite data. Can be a string or list
start_time: The satellite data is filtered to exclude timestamps before this
Expand All @@ -93,57 +87,52 @@ def __init__(
sample_freq_mins: The sample frequency to use for the satellite data
preshuffle: Whether to shuffle the data - useful for validation
"""

# Load the sat zarr file or list of files and slice the data to the given period
self.ds = load_satellite_zarrs(zarr_path).sel(time=slice(start_time, end_time))

# Convert the satellite data to the given time frequency by selection
mask = np.mod(self.ds.time.dt.minute, sample_freq_mins)==0
mask = np.mod(self.ds.time.dt.minute, sample_freq_mins) == 0
self.ds = self.ds.sel(time=mask)
# Find the valid t0 times for the available data. This avoids trying to take samples where

# Find the valid t0 times for the available data. This avoids trying to take samples where
# there would be a missing timestamp in the sat data required for the sample
self.t0_times = find_valid_t0_times(self.ds, history_mins, forecast_mins, sample_freq_mins)

if preshuffle:
self.t0_times = pd.to_datetime(np.random.permutation(self.t0_times))

self.history_mins = history_mins
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins



def __len__(self):
return len(self.t0_times)

def _get_datetime(self, t0: datetime):
ds_sel = self.ds.sel(
time=slice(
t0-minutes(self.history_mins),
t0+minutes(self.forecast_mins)
)
time=slice(t0 - minutes(self.history_mins), t0 + minutes(self.forecast_mins))
)
# Load the data eagerly so that the same chunks aren't loaded multiple times after we split

# Load the data eagerly so that the same chunks aren't loaded multiple times after we split
# further
ds_sel = ds_sel.compute(scheduler="single-threaded")

# Reshape to (channel, time, height, width)
ds_sel = ds_sel.transpose("variable", "time", "y_geostationary", "x_geostationary")

ds_input = ds_sel.sel(time=slice(None, t0))
ds_target = ds_sel.sel(time=slice(t0+minutes(self.sample_freq_mins), None))
ds_target = ds_sel.sel(time=slice(t0 + minutes(self.sample_freq_mins), None))

# Convert to arrays
X = ds_input.data.values
y = ds_target.data.values

X = np.nan_to_num(X, nan=-1)
y = np.nan_to_num(y, nan=-1)

return X.astype(np.float32), y.astype(np.float32)


def __getitem__(self, idx):
if isinstance(idx, (str)):
t0 = pd.Timestamp(idx)
Expand All @@ -153,50 +142,49 @@ def __getitem__(self, idx):
else:
raise ValueError(f"Unrecognised type {type(idx)}")
return self._get_datetime(t0)


class ValidationSatelliteDataset(SatelliteDataset):
def __init__(
self,
zarr_path: Union[list, str],
self,
zarr_path: list | str,
t0_times: list[datetime],
history_mins: int,
forecast_mins: int,
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
):
"""A torch Dataset for loading past and future satellite data
Args:
zarr_path: Path the satellite data. Can be a string or list
t0_times: Array-like of the t0 times used for validation
history_mins: How many minutes of history will be used as input features
forecast_mins: How many minutes of future will be used as target features
sample_freq_mins: The sample frequency to use for the satellite data
"""

# Load the sat zarr file or list of files and slice the data to the given period
self.ds = load_satellite_zarrs(zarr_path)

# Convert the satellite data to the given time frequency by selection
mask = np.mod(self.ds.time.dt.minute, sample_freq_mins)==0
mask = np.mod(self.ds.time.dt.minute, sample_freq_mins) == 0
self.ds = self.ds.sel(time=mask)

self.t0_times = t0_times

self.history_mins = history_mins
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins



class SatelliteDataModule(LightningDataModule):
"""A lightning DataModule for loading past and future satellite data"""

def __init__(
self,
zarr_path: Union[list, str],
history_mins: int,
forecast_mins: int,
zarr_path: list | str,
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
batch_size=16,
num_workers=0,
Expand All @@ -220,7 +208,7 @@ def __init__(
test_period: Date range filter for test dataloader.
"""
super().__init__()

self.zarr_path = zarr_path
self.history_mins = history_mins
self.forecast_mins = forecast_mins
Expand All @@ -230,7 +218,7 @@ def __init__(
self.test_period = test_period

self._common_dataloader_kwargs = dict(
batch_size=batch_size,
batch_size=batch_size,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
Expand All @@ -247,13 +235,13 @@ def _make_dataset(self, start_date, end_date, preshuffle=False):
self.zarr_path,
start_date,
end_date,
self.history_mins,
self.forecast_mins,
self.history_mins,
self.forecast_mins,
self.sample_freq_mins,
preshuffle=preshuffle,
)
return dataset

def train_dataloader(self):
"""Construct train dataloader"""
dataset = self._make_dataset(*self.train_period)
Expand All @@ -268,4 +256,4 @@ def val_dataloader(self):
def test_dataloader(self):
"""Construct test dataloader"""
dataset = self._make_dataset(*self.test_period)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
Loading

0 comments on commit 662103f

Please sign in to comment.