-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding code for model validation (#49)
* add pair programming results * start tests * add validation dataset * add model files + validation changes * full valid loop * draft validation test * spurious changes * update validation suite, remove ssim for now Co-authored-by: James Fulton <[email protected]> Co-authored-by: Isabel Fenton <[email protected]> * pre-commit fixes * Validation wandb logging (#15) * convert VariableHorizonModel to use history_steps rather than history_mins * add function to upload horizon metric plot to wandb * add skeleton higher level validation function * update function docs * bug fix, linting, and types * update notebook * fix notebook for ruff * Adding wandb as a dependency * Adding mean metric calculation and bar chart * Adding mean metric plotting loop * set wandb entity so all runs go to correct team * remove scratch notebook * add validation progress bar (#30) * add tests for validation (#32) * add tests for validation * fix types * linting * add data to pyproject.toml * add tqdm stubs * move tqdm-stubs * try to fix package data * try to fix package data * add video logging (#42) * add pair programming results * start tests * add validation dataset * add model files + validation changes * full valid loop * draft validation test * spurious changes * update validation suite, remove ssim for now Co-authored-by: James Fulton <[email protected]> Co-authored-by: Isabel Fenton <[email protected]> * pre-commit fixes * Validation wandb logging (#15) * convert VariableHorizonModel to use history_steps rather than history_mins * add function to upload horizon metric plot to wandb * add skeleton higher level validation function * update function docs * bug fix, linting, and types * update notebook * fix notebook for ruff * Adding wandb as a dependency * Adding mean metric calculation and bar chart * Adding mean metric plotting loop * set wandb entity so all runs go to correct team * remove scratch notebook * add validation progress bar (#30) * add tests for validation (#32) * add tests for validation * fix types * linting * add data to pyproject.toml * add tqdm stubs * move tqdm-stubs * try to fix package data * try to fix package data * add script to calculate 2022 t0 val times * linting * linting * update for new 2022 t0 test times * add video logging (#42) * update to new validation names * add wandb video deps * add lock file for uv * Update name from valid to test * align test_2022 set names * silence deprecation warning and limit numpy * filter keyword warning * macOS specific dependency * macOS dep * update ci for macos * try reducing macos deps * add validation demo notebook * make limited batch validation more representative * do not check notebooks with ruff * fix validation to work when nan_to_num=True * add macOS install instruction * linting * Adding details about how to access the test_2022 dataset * Clarifying a couple of bits of text * keep formatting for notebooks * Testing fix for windows Overflow error * Clarifying test vs verify datasets for validation * Remove unused import Co-authored-by: James Robinson <[email protected]> * tidy and encapsulate functions * minor tidy * more explicit type naming * more explicit memory management in batch compile func * skip tests * add constants * linting * bump version number and update README --------- Co-authored-by: Nathan Simpson <[email protected]> Co-authored-by: James Fulton <[email protected]> Co-authored-by: James Fulton <[email protected]> Co-authored-by: James Robinson <[email protected]>
- Loading branch information
1 parent
c1cdd2e
commit b652aca
Showing
22 changed files
with
3,855 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""This script finds the 2022 test set t0 times and saves them to the cloudcasting package.""" | ||
|
||
import importlib.util | ||
import os | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
|
||
from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES | ||
from cloudcasting.dataset import find_valid_t0_times | ||
from cloudcasting.download import _get_sat_public_dataset_path | ||
|
||
# Set a max history length which we will support in the validation process | ||
# We will not be able to fairly score models which require a longer history than this | ||
# But by setting this too long, we will reduce the samples we have to score on | ||
|
||
# The current FORECAST_HORIZON_MINUTES is 3 hours so we'll set this conservatively to 6 hours | ||
MAX_HISTORY_MINUTES = 6 * 60 | ||
|
||
# Open the 2022 dataset | ||
ds = xr.open_zarr(_get_sat_public_dataset_path(2022, is_hrv=False)) | ||
|
||
# Filter to defined time frequency | ||
mask = np.mod(ds.time.dt.minute, DATA_INTERVAL_SPACING_MINUTES) == 0 | ||
ds = ds.sel(time=mask) | ||
|
||
# Mask to the odd fortnights - i.e. the 2022 test set | ||
mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 1 | ||
ds = ds.sel(time=mask) | ||
|
||
|
||
# Find the valid t0 times | ||
valid_t0_times = find_valid_t0_times( | ||
datetimes=pd.DatetimeIndex(ds.time), | ||
history_mins=MAX_HISTORY_MINUTES, | ||
forecast_mins=FORECAST_HORIZON_MINUTES, | ||
sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, | ||
) | ||
|
||
# Print the valid t0 times to sanity check | ||
print(f"Number of available t0 times: {len(valid_t0_times)}") | ||
print(f"Actual available t0 times: {valid_t0_times}") | ||
|
||
|
||
# Find the path of the cloudcasting package so we can save the valid times into it | ||
spec = importlib.util.find_spec("cloudcasting") | ||
if spec and spec.origin: | ||
package_path = os.path.dirname(spec.origin) | ||
else: | ||
msg = "Path of package `cloudcasting` can not be found" | ||
raise ModuleNotFoundError(msg) | ||
|
||
# Save the valid t0 times | ||
filename = "test_2022_t0_times.csv" | ||
df = pd.DataFrame(valid_t0_times, columns=["t0_time"]).set_index("t0_time") | ||
df.to_csv( | ||
f"{package_path}/data/{filename}.zip", | ||
compression={ | ||
"method": "zip", | ||
"archive_name": filename, | ||
}, | ||
) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" | |
|
||
[project] | ||
name = "cloudcasting" | ||
version = "0.1.0" | ||
version = "0.2.0" | ||
authors = [ | ||
{ name = "cloudcasting Maintainers", email = "[email protected]" }, | ||
] | ||
|
@@ -37,20 +37,27 @@ dependencies = [ | |
"ocf-blosc2>=0.0.10", # for no-import codec register | ||
"typer", | ||
"lightning", | ||
"torch>=2.3.0", # needed for numpy 2.0 | ||
"torch>=2.3.0", # needed for numpy 2.0 | ||
"scikit-image", | ||
"jaxtyping", | ||
"numpy", | ||
"matplotlib", | ||
"wandb", | ||
"tqdm", | ||
"moviepy>=1.0.3", | ||
"imageio>=2.35.1", | ||
"numpy <2.1.0", # https://github.com/wandb/wandb/issues/8166 | ||
] | ||
[project.optional-dependencies] | ||
dev = [ | ||
"pytest >=6", | ||
"pytest-cov >=3", | ||
"pre-commit", | ||
"scipy", | ||
"pytest-mock", | ||
] | ||
|
||
[tool.setuptools.package-data] | ||
"cloudcasting" = ["data/*.zip"] | ||
|
||
[project.scripts] | ||
cloudcasting = "cloudcasting.cli:app" | ||
|
||
|
@@ -62,13 +69,20 @@ Changelog = "https://github.com/climetrend/cloudcasting/releases" | |
|
||
[tool.pytest.ini_options] | ||
minversion = "6.0" | ||
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] | ||
addopts = [ | ||
"-ra", | ||
"--showlocals", | ||
"--strict-markers", | ||
"--strict-config" | ||
] | ||
xfail_strict = true | ||
filterwarnings = [ | ||
"error", | ||
"ignore:pkg_resources:DeprecationWarning", # lightning | ||
"ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", # lightning | ||
"ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping | ||
"ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping | ||
"ignore:`newshape` keyword argument is deprecated:DeprecationWarning", # wandb using numpy 2.1.0 | ||
"ignore:The keyword `fps` is no longer supported:DeprecationWarning", # wandb.Video | ||
] | ||
log_cli_level = "INFO" | ||
testpaths = [ | ||
|
@@ -114,13 +128,14 @@ disallow_untyped_calls = false | |
|
||
[tool.ruff] | ||
src = ["src"] | ||
exclude = [] | ||
exclude = ["notebooks/*.ipynb"] | ||
line-length = 100 # how long you want lines to be | ||
|
||
[tool.ruff.format] | ||
docstring-code-format = true # code snippets in docstrings will be formatted | ||
|
||
[tool.ruff.lint] | ||
exclude = ["notebooks/*.ipynb"] | ||
select = [ | ||
"E", "F", "W", # flake8 | ||
"B", # flake8-bugbear | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
__all__ = ( | ||
"FORECAST_HORIZON_MINUTES", | ||
"DATA_INTERVAL_SPACING_MINUTES", | ||
"NUM_FORECAST_STEPS", | ||
"NUM_CHANNELS", | ||
) | ||
|
||
# These constants were locked as part of the project specification | ||
# 3 hour horecast horizon | ||
FORECAST_HORIZON_MINUTES = 180 | ||
# at 15 minute intervals | ||
DATA_INTERVAL_SPACING_MINUTES = 15 | ||
# gives 12 forecast steps | ||
NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES | ||
# for all 11 low resolution channels | ||
NUM_CHANNELS = 11 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.