Skip to content

Commit

Permalink
Adding code for model validation (#49)
Browse files Browse the repository at this point in the history
* add pair programming results

* start tests

* add validation dataset

* add model files + validation changes

* full valid loop

* draft validation test

* spurious changes

* update validation suite, remove ssim for now

Co-authored-by: James Fulton <[email protected]>
Co-authored-by: Isabel Fenton <[email protected]>

* pre-commit fixes

* Validation wandb logging (#15)

* convert VariableHorizonModel to use history_steps rather than history_mins

* add function to upload horizon metric plot to wandb

* add skeleton higher level validation function

* update function docs

* bug fix, linting, and types

* update notebook

* fix notebook for ruff

* Adding wandb as a dependency

* Adding mean metric calculation and bar chart

* Adding mean metric plotting loop

* set wandb entity so all runs go to correct team

* remove scratch notebook

* add validation progress bar (#30)

* add tests for validation (#32)

* add tests for validation

* fix types

* linting

* add data to pyproject.toml

* add tqdm stubs

* move tqdm-stubs

* try to fix package data

* try to fix package data

* add video logging (#42)

* add pair programming results

* start tests

* add validation dataset

* add model files + validation changes

* full valid loop

* draft validation test

* spurious changes

* update validation suite, remove ssim for now

Co-authored-by: James Fulton <[email protected]>
Co-authored-by: Isabel Fenton <[email protected]>

* pre-commit fixes

* Validation wandb logging (#15)

* convert VariableHorizonModel to use history_steps rather than history_mins

* add function to upload horizon metric plot to wandb

* add skeleton higher level validation function

* update function docs

* bug fix, linting, and types

* update notebook

* fix notebook for ruff

* Adding wandb as a dependency

* Adding mean metric calculation and bar chart

* Adding mean metric plotting loop

* set wandb entity so all runs go to correct team

* remove scratch notebook

* add validation progress bar (#30)

* add tests for validation (#32)

* add tests for validation

* fix types

* linting

* add data to pyproject.toml

* add tqdm stubs

* move tqdm-stubs

* try to fix package data

* try to fix package data

* add script to calculate 2022 t0 val times

* linting

* linting

* update for new 2022 t0 test times

* add video logging (#42)

* update to new validation names

* add wandb video deps

* add lock file for uv

* Update name from valid to test

* align test_2022 set names

* silence deprecation warning and limit numpy

* filter  keyword warning

* macOS specific dependency

* macOS dep

* update ci for macos

* try reducing macos deps

* add validation demo notebook

* make limited batch validation more representative

* do not check notebooks with ruff

* fix validation to work when nan_to_num=True

* add macOS install instruction

* linting

* Adding details about how to access the test_2022 dataset

* Clarifying a couple of bits of text

* keep formatting for notebooks

* Testing fix for windows Overflow error

* Clarifying test vs verify datasets for validation

* Remove unused import

Co-authored-by: James Robinson <[email protected]>

* tidy and encapsulate functions

* minor tidy

* more explicit type naming

* more explicit memory management in batch compile func

* skip tests

* add constants

* linting

* bump version number and update README

---------

Co-authored-by: Nathan Simpson <[email protected]>
Co-authored-by: James Fulton <[email protected]>
Co-authored-by: James Fulton <[email protected]>
Co-authored-by: James Robinson <[email protected]>
  • Loading branch information
5 people authored Sep 2, 2024
1 parent c1cdd2e commit b652aca
Show file tree
Hide file tree
Showing 22 changed files with 3,855 additions and 89 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ jobs:
python-version: ${{ matrix.python-version }}
allow-prereleases: true

- name: Install ffmpeg on macOS
if: runner.os == 'macOS'
run: |
brew install ffmpeg
- name: Install package
run: python -m pip install .[dev]

Expand Down
7 changes: 3 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.5.2"
rev: "v0.5.6"
hooks:
# first, lint + autofix
- id: ruff
Expand All @@ -30,7 +30,7 @@ repos:
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.10.1"
rev: "v1.11.1"
hooks:
- id: mypy
args: []
Expand All @@ -45,5 +45,4 @@ repos:
- lightning
- torch
- jaxtyping
- scikit-image
- numpy
- types-tqdm
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ Tooling and infrastructure to enable cloud nowcasting.
## Installation

From source (development mode):

On macOS you first need to install `ffmpeg` with the following command. On other platforms this is
not necessary.

```bash
brew install ffmpeg
```

Clone and install the repo.

```bash
git clone https://github.com/climetrend/cloudcasting
cd cloudcasting
Expand All @@ -28,7 +38,7 @@ For making changes, see the [guidance on development](https://github.com/alan-tu
Example:

```bash
cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/my/dir/data.zarr"
cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir"
```

Full options:
Expand Down
63 changes: 63 additions & 0 deletions examples/find_test_2022_t0_times.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""This script finds the 2022 test set t0 times and saves them to the cloudcasting package."""

import importlib.util
import os

import numpy as np
import pandas as pd
import xarray as xr

from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES
from cloudcasting.dataset import find_valid_t0_times
from cloudcasting.download import _get_sat_public_dataset_path

# Set a max history length which we will support in the validation process
# We will not be able to fairly score models which require a longer history than this
# But by setting this too long, we will reduce the samples we have to score on

# The current FORECAST_HORIZON_MINUTES is 3 hours so we'll set this conservatively to 6 hours
MAX_HISTORY_MINUTES = 6 * 60

# Open the 2022 dataset
ds = xr.open_zarr(_get_sat_public_dataset_path(2022, is_hrv=False))

# Filter to defined time frequency
mask = np.mod(ds.time.dt.minute, DATA_INTERVAL_SPACING_MINUTES) == 0
ds = ds.sel(time=mask)

# Mask to the odd fortnights - i.e. the 2022 test set
mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 1
ds = ds.sel(time=mask)


# Find the valid t0 times
valid_t0_times = find_valid_t0_times(
datetimes=pd.DatetimeIndex(ds.time),
history_mins=MAX_HISTORY_MINUTES,
forecast_mins=FORECAST_HORIZON_MINUTES,
sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
)

# Print the valid t0 times to sanity check
print(f"Number of available t0 times: {len(valid_t0_times)}")
print(f"Actual available t0 times: {valid_t0_times}")


# Find the path of the cloudcasting package so we can save the valid times into it
spec = importlib.util.find_spec("cloudcasting")
if spec and spec.origin:
package_path = os.path.dirname(spec.origin)
else:
msg = "Path of package `cloudcasting` can not be found"
raise ModuleNotFoundError(msg)

# Save the valid t0 times
filename = "test_2022_t0_times.csv"
df = pd.DataFrame(valid_t0_times, columns=["t0_time"]).set_index("t0_time")
df.to_csv(
f"{package_path}/data/{filename}.zip",
compression={
"method": "zip",
"archive_name": filename,
},
)
458 changes: 458 additions & 0 deletions notebooks/03-score_model_demo.ipynb

Large diffs are not rendered by default.

29 changes: 22 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "cloudcasting"
version = "0.1.0"
version = "0.2.0"
authors = [
{ name = "cloudcasting Maintainers", email = "[email protected]" },
]
Expand Down Expand Up @@ -37,20 +37,27 @@ dependencies = [
"ocf-blosc2>=0.0.10", # for no-import codec register
"typer",
"lightning",
"torch>=2.3.0", # needed for numpy 2.0
"torch>=2.3.0", # needed for numpy 2.0
"scikit-image",
"jaxtyping",
"numpy",
"matplotlib",
"wandb",
"tqdm",
"moviepy>=1.0.3",
"imageio>=2.35.1",
"numpy <2.1.0", # https://github.com/wandb/wandb/issues/8166
]
[project.optional-dependencies]
dev = [
"pytest >=6",
"pytest-cov >=3",
"pre-commit",
"scipy",
"pytest-mock",
]

[tool.setuptools.package-data]
"cloudcasting" = ["data/*.zip"]

[project.scripts]
cloudcasting = "cloudcasting.cli:app"

Expand All @@ -62,13 +69,20 @@ Changelog = "https://github.com/climetrend/cloudcasting/releases"

[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
addopts = [
"-ra",
"--showlocals",
"--strict-markers",
"--strict-config"
]
xfail_strict = true
filterwarnings = [
"error",
"ignore:pkg_resources:DeprecationWarning", # lightning
"ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", # lightning
"ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping
"ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping
"ignore:`newshape` keyword argument is deprecated:DeprecationWarning", # wandb using numpy 2.1.0
"ignore:The keyword `fps` is no longer supported:DeprecationWarning", # wandb.Video
]
log_cli_level = "INFO"
testpaths = [
Expand Down Expand Up @@ -114,13 +128,14 @@ disallow_untyped_calls = false

[tool.ruff]
src = ["src"]
exclude = []
exclude = ["notebooks/*.ipynb"]
line-length = 100 # how long you want lines to be

[tool.ruff.format]
docstring-code-format = true # code snippets in docstrings will be formatted

[tool.ruff.lint]
exclude = ["notebooks/*.ipynb"]
select = [
"E", "F", "W", # flake8
"B", # flake8-bugbear
Expand Down
11 changes: 9 additions & 2 deletions src/cloudcasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
# dataclasses, meaning that they will be type-checked
# (and therefore shape-checked via jaxtyping) at runtime.
with install_import_hook("cloudcasting", "typeguard.typechecked"):
from cloudcasting import metrics
from cloudcasting import metrics, models

from cloudcasting import cli, dataset, download

__all__ = ("__version__", "download", "cli", "dataset", "metrics")
__all__ = (
"__version__",
"download",
"cli",
"dataset",
"metrics",
"models",
)
__version__ = version(__name__)
16 changes: 16 additions & 0 deletions src/cloudcasting/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
__all__ = (
"FORECAST_HORIZON_MINUTES",
"DATA_INTERVAL_SPACING_MINUTES",
"NUM_FORECAST_STEPS",
"NUM_CHANNELS",
)

# These constants were locked as part of the project specification
# 3 hour horecast horizon
FORECAST_HORIZON_MINUTES = 180
# at 15 minute intervals
DATA_INTERVAL_SPACING_MINUTES = 15
# gives 12 forecast steps
NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES
# for all 11 low resolution channels
NUM_CHANNELS = 11
Binary file added src/cloudcasting/data/test_2022_t0_times.csv.zip
Binary file not shown.
100 changes: 98 additions & 2 deletions src/cloudcasting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
__all__ = (
"SatelliteDataModule",
"SatelliteDataset",
"ValidationSatelliteDataset",
)

import io
import pkgutil
from datetime import datetime, timedelta
from typing import TypedDict

Expand All @@ -15,6 +18,7 @@
from numpy.typing import NDArray
from torch.utils.data import DataLoader, Dataset

from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES
from cloudcasting.utils import find_contiguous_t0_time_periods, find_contiguous_time_periods


Expand Down Expand Up @@ -101,7 +105,7 @@ def __init__(
"""A torch Dataset for loading past and future satellite data
Args:
zarr_path: Path the satellite data. Can be a string or list
zarr_path: Path to the satellite data. Can be a string or list
start_time: The satellite data is filtered to exclude timestamps before this
end_time: The satellite data is filtered to exclude timestamps after this
history_mins: How many minutes of history will be used as input features
Expand All @@ -128,7 +132,7 @@ def __init__(

# Find the valid t0 times for the available data. This avoids trying to take samples where
# there would be a missing timestamp in the sat data required for the sample
self.t0_times = find_valid_t0_times(
self.t0_times = self._find_t0_times(
pd.DatetimeIndex(self.ds.time), history_mins, forecast_mins, sample_freq_mins
)

Expand All @@ -140,6 +144,12 @@ def __init__(
self.sample_freq_mins = sample_freq_mins
self.nan_to_num = nan_to_num

@staticmethod
def _find_t0_times(
date_range: pd.DatetimeIndex, history_mins: int, forecast_mins: int, sample_freq_mins: int
) -> pd.DatetimeIndex:
return find_valid_t0_times(date_range, history_mins, forecast_mins, sample_freq_mins)

def __len__(self) -> int:
return len(self.t0_times)

Expand Down Expand Up @@ -183,6 +193,92 @@ def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.f
return self._get_datetime(t0)


class ValidationSatelliteDataset(SatelliteDataset):
def __init__(
self,
zarr_path: list[str] | str,
history_mins: int,
forecast_mins: int = FORECAST_HORIZON_MINUTES,
sample_freq_mins: int = DATA_INTERVAL_SPACING_MINUTES,
nan_to_num: bool = False,
):
"""A torch Dataset used only in the validation proceedure.
Args:
zarr_path: Path to the satellite data for validation. Can be a string or list
history_mins: How many minutes of history will be used as input features
forecast_mins: How many minutes of future will be used as target features
sample_freq_mins: The sample frequency to use for the satellite data
nan_to_num: Whether to convert NaNs to -1.
"""

super().__init__(
zarr_path=zarr_path,
start_time=None,
end_time=None,
history_mins=history_mins,
forecast_mins=forecast_mins,
sample_freq_mins=sample_freq_mins,
preshuffle=True,
nan_to_num=nan_to_num,
)

@staticmethod
def _find_t0_times(
date_range: pd.DatetimeIndex, history_mins: int, forecast_mins: int, sample_freq_mins: int
) -> pd.DatetimeIndex:
# Find the valid t0 times for the available data. This avoids trying to take samples where
# there would be a missing timestamp in the sat data required for the sample
available_t0_times = find_valid_t0_times(
date_range, history_mins, forecast_mins, sample_freq_mins
)

# Get the required 2022 test dataset t0 times
val_t0_times_from_csv = ValidationSatelliteDataset._get_test_2022_t0_times()

# Find the intersection of the available t0 times and the required validation t0 times
val_time_available = val_t0_times_from_csv.isin(available_t0_times)

# Make sure all of the required validation times are available in the data
if not val_time_available.all():
msg = (
"The following validation t0 times are not available in the satellite data: \n"
f"{val_t0_times_from_csv[~val_time_available]}\n"
"The validation proceedure requires these t0 times to be available."
)
raise ValueError(msg)

return val_t0_times_from_csv

@staticmethod
def _get_t0_times(path: str) -> pd.DatetimeIndex:
"""Load the required validation t0 times from library path"""

# Load the zipped csv file as a byte stream
data = pkgutil.get_data("cloudcasting", path)
if data is not None:
byte_stream = io.BytesIO(data)
else:
# Handle the case where data is None
msg = f"No data found for path: {path}"
raise ValueError(msg)

# Load the times into pandas
df = pd.read_csv(byte_stream, encoding="utf8", compression="zip")

return pd.DatetimeIndex(df.t0_time)

@staticmethod
def _get_test_2022_t0_times() -> pd.DatetimeIndex:
"""Load the required 2022 test dataset t0 times from their location in the library"""
return ValidationSatelliteDataset._get_t0_times("data/test_2022_t0_times.csv.zip")

@staticmethod
def _get_verify_2023_t0_times() -> pd.DatetimeIndex:
msg = "The required 2023 verification dataset t0 times are not available"
raise NotImplementedError(msg)


class SatelliteDataModule(LightningDataModule):
"""A lightning DataModule for loading past and future satellite data"""

Expand Down
Loading

0 comments on commit b652aca

Please sign in to comment.