diff --git a/.gitignore b/.gitignore index 8cd4e45..fdb51d3 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,7 @@ tags # macos .DS_Store +__MACOSX # pdm (https://pdm-project.org/en/stable/) .pdm-python diff --git a/README.md b/README.md index 8462e38..2561e03 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,10 @@ training: v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data store and the path to its config, and 2) the weighting of different features in the loss function. +For now the neural-lam config only defines two things: 1) the kind of data +store and the path to its config, and 2) the weighting of different features in +the loss function. If you don't define the state feature weighting it will default +to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -270,11 +273,88 @@ Graphs used in the initial paper are also available for download at the same lin Note that this is far too little data to train any useful models, but all pre-processing and training steps can be run with it. It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues. +The following datastore configuration works with MEPS dataset: + ```yaml # meps.datastore.yaml +dataset: + name: meps_example + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + - t_isobaricInhPa_500_instant + - t_isobaricInhPa_850_instant + - u_hybrid_65_instant + - u_isobaricInhPa_850_instant + - v_hybrid_65_instant + - v_isobaricInhPa_850_instant + - wvint_entireAtmosphere_0_instant + - z_isobaricInhPa_1000_instant + - z_isobaricInhPa_500_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + - t_500 + - t_850 + - u_65 + - u_850 + - v_65 + - v_850 + - wvint_0 + - z_1000 + - z_500 + var_units: + - Pa + - Pa + - W/m\textsuperscript{2} + - W/m\textsuperscript{2} + - "-" + - "-" + - K + - K + - K + - K + - m/s + - m/s + - m/s + - m/s + - kg/m\textsuperscript{2} + - m\textsuperscript{2}/s\textsuperscript{2} + - m\textsuperscript{2}/s\textsuperscript{2} + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 + remove_state_features_with_index: [15] +grid_shape_state: +- 268 +- 238 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 ``` +Which you can then use in a neural-lam configuration file like this: + ```yaml +# config.yaml datastore: kind: npyfilesmeps config_path: meps.datastore.yaml @@ -286,43 +366,23 @@ training: v100m: 1.0 ``` -## Pre-processing - -There are two main steps in the pre-processing pipeline: creating the graph and creating additional features/normalisation/boundary-masks. - -The amount of pre-processing required will depend on what kind of datastore you will be using for training. - -### Additional inputs - -#### MultiZarr Datastore - -* `python -m neural_lam.create_boundary_mask` -* `python -m neural_lam.create_datetime_forcings` -* `python -m neural_lam.create_norm` +For npy-file based datastores you must separately run the command that creates the variables used for standardization: -#### NpyFiles Datastore - -#### MDP (mllam-data-prep) Datastore +```bash +python -m neural_lam.datastore.npyfilesmeps.compute_standardization_stats +``` -An overview of how the different pre-processing steps, training and files depend on each other is given in this figure: -

- -

-In order to start training models at least three pre-processing steps have to be run: +### Graph creation -### Create graph Run `python -m neural_lam.create_mesh` with suitable options to generate the graph you want to use (see `python neural_lam.create_mesh --help` for a list of options). The graphs used for the different models in the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling) can be created as: -* **GC-LAM**: `python -m neural_lam.create_mesh --graph multiscale` -* **Hi-LAM**: `python -m neural_lam.create_mesh --graph hierarchical --hierarchical` (also works for Hi-LAM-Parallel) -* **L1-LAM**: `python -m neural_lam.create_mesh --graph 1level --levels 1` +* **GC-LAM**: `python -m neural_lam.create_mesh --graph multiscale` +* **Hi-LAM**: `python -m neural_lam.create_mesh --graph hierarchical --hierarchical` (also works for Hi-LAM-Parallel) +* **L1-LAM**: `python -m neural_lam.create_mesh --graph 1level --levels 1` The graph-related files are stored in a directory called `graphs`. -### Create remaining static features -To create the remaining static files run `python -m neural_lam.create_grid_features` and `python -m neural_lam.create_parameter_weights`. - ## Weights & Biases Integration The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. @@ -340,12 +400,11 @@ wandb off ``` ## Train Models -Models can be trained using `python -m neural_lam.train_model `. +Models can be trained using `python -m neural_lam.train_model `. Run `python neural_lam.train_model --help` for a full list of training options. A few of the key ones are outlined below: -* ``: The kind of datastore that you are using (should be one of `npyfiles`, `multizarr` or `mllam`) -* ``: Path to the data store configuration file +* ``: Path to the configuration for neural-lam (for example in `data/myexperiment/config.yaml`). * `--model`: Which model to train * `--graph`: Which graph to use with the model * `--epochs`: Number of epochs to train for diff --git a/neural_lam/config.py b/neural_lam/config.py index 33393c5..d3e0969 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -51,11 +51,11 @@ class ManualStateFeatureWeighting: Attributes ---------- - values : Dict[str, float] + weights : Dict[str, float] Manual weights for the state features. """ - values: Dict[str, float] + weights: Dict[str, float] @dataclasses.dataclass @@ -123,6 +123,17 @@ class _(dataclass_wizard.JSONWizard.Meta): tag_key = "__config_class__" auto_assign_tags = True + # ensure that all parts of the loaded configuration match the + # dataclasses used + # TODO: this should be enabled once + # https://github.com/rnag/dataclass-wizard/issues/137 is fixed, but + # currently cannot be used together with `auto_assign_tags` due to a + # bug it seems + # raise_on_unknown_json_key = True + + +class InvalidConfigError(Exception): + pass def load_config_and_datastore( @@ -142,7 +153,13 @@ def load_config_and_datastore( tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]] The Neural-LAM configuration and the loaded datastore. """ - config = NeuralLAMConfig.from_yaml_file(config_path) + try: + config = NeuralLAMConfig.from_yaml_file(config_path) + except dataclass_wizard.errors.UnknownJSONKey as ex: + raise InvalidConfigError( + "There was an error loading the configuration file at " + f"{config_path}. " + ) from ex # datastore config is assumed to be relative to the config file datastore_config_path = ( Path(config_path).parent / config.datastore.config_path diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index 885d1ae..21daa34 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -2,9 +2,9 @@ import os import subprocess from argparse import ArgumentParser +from pathlib import Path # Third-party -import numpy as np import torch import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler @@ -102,6 +102,10 @@ def save_stats( mean = torch.mean(means, dim=0) # (d_features,) second_moment = torch.mean(squares, dim=0) # (d_features,) std = torch.sqrt(second_moment - mean**2) # (d_features,) + print( + f"Saving {filename_prefix} mean and std.-dev. to " + f"{filename_prefix}_mean.pt and {filename_prefix}_std.pt" + ) torch.save( mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt") ) @@ -120,52 +124,42 @@ def save_stats( flux_mean = torch.mean(flux_means) # (,) flux_second_moment = torch.mean(flux_squares) # (,) flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + print("Saving flux mean and std.-dev. to flux_stats.pt") torch.save( torch.stack((flux_mean, flux_std)).cpu(), os.path.join(static_dir_path, "flux_stats.pt"), ) -def main(): +def main( + datastore_config_path, batch_size, step_length, n_workers, distributed +): """ Pre-compute parameter weights to be used in loss function + + Arguments + --------- + datastore_config_path : str + Path to datastore config file + batch_size : int + Batch size when iterating over the dataset + step_length : int + Step length in hours to consider single time step + n_workers : int + Number of workers in data loader + distributed : bool + Run the script in distributed """ - parser = ArgumentParser(description="Training arguments") - parser.add_argument( - "--datastore_config_path", type=str, help="Path to data config file" - ) - parser.add_argument( - "--batch_size", - type=int, - default=32, - help="Batch size when iterating over the dataset", - ) - parser.add_argument( - "--step_length", - type=int, - default=3, - help="Step length in hours to consider single time step (default: 3)", - ) - parser.add_argument( - "--n_workers", - type=int, - default=4, - help="Number of workers in data loader (default: 4)", - ) - parser.add_argument( - "--distributed", - action="store_true", - help="Run the script in distributed mode (default: False)", - ) - args = parser.parse_args() - distributed = bool(args.distributed) rank = get_rank() world_size = get_world_size() datastore = init_datastore( - datastore_kind="npyfilesmeps", config_path=args.datastore_config_path + datastore_kind="npyfilesmeps", config_path=datastore_config_path ) + static_dir_path = Path(datastore_config_path).parent / "static" + os.makedirs(static_dir_path, exist_ok=True) + if distributed: setup(rank, world_size) device = torch.device( @@ -173,44 +167,21 @@ def main(): ) torch.cuda.set_device(device) if torch.cuda.is_available() else None - if rank == 0: - static_dir_path = os.path.join(datastore.root_path, "static") - # Create parameter weights based on height - # based on fig A.1 in graph cast paper - w_dict = { - "2": 1.0, - "0": 0.1, - "65": 0.065, - "1000": 0.1, - "850": 0.05, - "500": 0.03, - } - w_list = np.array( - [ - w_dict[par.split("_")[-2]] - for par in datastore.get_vars_long_names(category="state") - ] - ) - print("Saving parameter weights...") - np.save( - os.path.join(static_dir_path, "parameter_weights.npy"), - w_list.astype("float32"), - ) - - # XXX: is this correct? - ar_steps = 61 + # XXX (lcd@dmi.dk): I don't quite understand why, but below fails with the + # MEPS example dataset if I just use `datastore._num_timesteps - 2` which + # would assume would be ok + ar_steps = datastore._num_timesteps - 10 ds = WeatherDataset( datastore=datastore, split="train", ar_steps=ar_steps, - # pred_length=63, standardize=False, ) if distributed: ds = PaddedWeatherDataset( ds, world_size, - args.batch_size, + batch_size, ) sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=False @@ -219,9 +190,9 @@ def main(): sampler = None loader = torch.utils.data.DataLoader( ds, - args.batch_size, + batch_size, shuffle=False, - num_workers=args.n_workers, + # num_workers=args.n_workers, sampler=sampler, ) @@ -311,7 +282,7 @@ def main(): ds_standard = PaddedWeatherDataset( ds_standard, world_size, - args.batch_size, + batch_size, ) sampler_standard = DistributedSampler( ds_standard, num_replicas=world_size, rank=rank, shuffle=False @@ -320,12 +291,12 @@ def main(): sampler_standard = None loader_standard = torch.utils.data.DataLoader( ds_standard, - args.batch_size, + batch_size, shuffle=False, - num_workers=args.n_workers, + num_workers=n_workers, sampler=sampler_standard, ) - used_subsample_len = (65 // args.step_length) * args.step_length + used_subsample_len = (65 // step_length) * step_length diff_means, diff_squares = [], [] @@ -341,13 +312,13 @@ def main(): # Note: batch contains only 1h-steps stepped_batch = torch.cat( [ - batch[:, ss_i : used_subsample_len : args.step_length] - for ss_i in range(args.step_length) + batch[:, ss_i:used_subsample_len:step_length] + for ss_i in range(step_length) ], dim=0, ) # (N_batch', N_t, N_grid, d_features), - # N_batch' = args.step_length*N_batch + # N_batch' = step_length*N_batch batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1] # (N_batch', N_t-1, N_grid, d_features) diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu()) @@ -378,7 +349,7 @@ def main(): ), ) original_indices = ds_standard.get_original_window_indices( - args.step_length + step_length ) diff_means, diff_squares = ( [diff_means_gathered[i] for i in original_indices], @@ -395,5 +366,45 @@ def main(): dist.destroy_process_group() +def cli(): + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "datastore_config_path", type=str, help="Path to data config file" + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size when iterating over the dataset", + ) + parser.add_argument( + "--step_length", + type=int, + default=3, + help="Step length in hours to consider single time step (default: 3)", + ) + parser.add_argument( + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", + ) + parser.add_argument( + "--distributed", + action="store_true", + help="Run the script in distributed mode (default: False)", + ) + args = parser.parse_args() + distributed = bool(args.distributed) + + main( + datastore_config_path=args.datastore_config_path, + batch_size=args.batch_size, + step_length=args.step_length, + n_workers=args.n_workers, + distributed=distributed, + ) + + if __name__ == "__main__": main() diff --git a/neural_lam/datastore/npyfilesmeps/config.py b/neural_lam/datastore/npyfilesmeps/config.py index a4cf5b1..1a9d729 100644 --- a/neural_lam/datastore/npyfilesmeps/config.py +++ b/neural_lam/datastore/npyfilesmeps/config.py @@ -1,5 +1,5 @@ # Standard library -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List # Third-party @@ -46,7 +46,7 @@ class Dataset: num_timesteps: int step_length: int num_ensemble_members: int - remove_state_features_with_index: List[int] + remove_state_features_with_index: List[int] = field(default_factory=list) @dataclass diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index eb524d3..ffa70dc 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -6,6 +6,7 @@ # Standard library import functools import re +import warnings from functools import cached_property from pathlib import Path from typing import List @@ -719,8 +720,17 @@ def load_pickled_tensor(fn): if category == "state": mean_values = load_pickled_tensor("parameter_mean.pt") std_values = load_pickled_tensor("parameter_std.pt") - mean_diff_values = load_pickled_tensor("diff_mean.pt") - std_diff_values = load_pickled_tensor("diff_std.pt") + try: + mean_diff_values = load_pickled_tensor("diff_mean.pt") + std_diff_values = load_pickled_tensor("diff_std.pt") + except FileNotFoundError: + warnings.warn(f"Could not load diff mean/std for {category}") + # XXX: this is a hack, but when running + # compute_standardization_stats the diff mean/std files are + # created, but require the std and mean files + mean_diff_values = np.empty_like(mean_values) + std_diff_values = np.empty_like(std_values) + elif category == "forcing": flux_stats = load_pickled_tensor("flux_stats.pt") # (2,) flux_mean, flux_std = flux_stats diff --git a/neural_lam/loss_weighting.py b/neural_lam/loss_weighting.py index e4238f1..c842b20 100644 --- a/neural_lam/loss_weighting.py +++ b/neural_lam/loss_weighting.py @@ -27,7 +27,7 @@ def get_manual_state_feature_weights( List of floats containing the state feature weights. """ state_feature_names = datastore.get_vars_names(category="state") - feature_weight_names = weighting_config.keys() + feature_weight_names = weighting_config.weights.keys() # Check that the state_feature_weights dictionary has a weight for each # state feature in the datastore. @@ -44,7 +44,7 @@ def get_manual_state_feature_weights( ) state_feature_weights = [ - weighting_config.values[feature] for feature in state_feature_names + weighting_config.weights[feature] for feature in state_feature_names ] return state_feature_weights diff --git a/tests/conftest.py b/tests/conftest.py index 9b5364d..84f6fb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,9 @@ # First-party from neural_lam.datastore import DATASTORES, init_datastore +from neural_lam.datastore.npyfilesmeps import ( + compute_standardization_stats as compute_standardization_stats_meps, +) # Local from .dummy_datastore import DummyDatastore @@ -20,17 +23,20 @@ # Initializing variables for the s3 client S3_BUCKET_NAME = "mllam-testdata" -S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" -S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip" +# S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" +S3_ENDPOINT_URL = "http://localhost:8000" +# S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip" +# TODO: I will upload this to AWS S3 once I have located the credentials... +S3_FILE_PATH = "meps_example_reduced.v0.2.0.zip" S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH]) TEST_DATA_KNOWN_HASH = ( - "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7" + "7ff2e07e04cfcd77631115f800c9d49188bb2a7c2a2777da3cea219f926d0c86" ) def download_meps_example_reduced_dataset(): # Download and unzip test data into data/meps_example_reduced - root_path = DATASTORE_EXAMPLES_ROOT_PATH / "npy" + root_path = DATASTORE_EXAMPLES_ROOT_PATH / "npyfilesmeps" dataset_path = root_path / "meps_example_reduced" pooch.retrieve( @@ -41,7 +47,7 @@ def download_meps_example_reduced_dataset(): fname="meps_example_reduced.zip", ) - config_path = dataset_path / "data_config.yaml" + config_path = dataset_path / "meps_example_reduced.datastore.yaml" with open(config_path, "r") as f: config = yaml.safe_load(f) @@ -58,11 +64,25 @@ def download_meps_example_reduced_dataset(): with open(config_path, "w") as f: yaml.dump(config, f) + # create parameters + compute_standardization_stats_meps.main( + datastore_config_path=config_path, + batch_size=8, + step_length=3, + n_workers=1, + distributed=False, + ) + return config_path DATASTORES_EXAMPLES = dict( - mdp=(DATASTORE_EXAMPLES_ROOT_PATH / "mdp" / "danra.datastore.yaml"), + mdp=( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "danra_100m_winds" + / "danra.datastore.yaml" + ), npyfilesmeps=download_meps_example_reduced_dataset(), dummydata=None, ) diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index 82c481f..e84e649 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,2 @@ -npy/*.zip -npy/meps_example_reduced/ +npyfilesmeps/*.zip +npyfilesmeps/meps_example_reduced/ diff --git a/tests/datastore_examples/mdp/.gitignore b/tests/datastore_examples/mdp/danra_100m_winds/.gitignore similarity index 100% rename from tests/datastore_examples/mdp/.gitignore rename to tests/datastore_examples/mdp/danra_100m_winds/.gitignore diff --git a/tests/datastore_examples/mdp/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml similarity index 93% rename from tests/datastore_examples/mdp/config.yaml rename to tests/datastore_examples/mdp/danra_100m_winds/config.yaml index 8696755..0bb5c5e 100644 --- a/tests/datastore_examples/mdp/config.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -4,6 +4,6 @@ datastore: training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 diff --git a/tests/datastore_examples/mdp/danra.datastore.yaml b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml similarity index 100% rename from tests/datastore_examples/mdp/danra.datastore.yaml rename to tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml diff --git a/tests/datastore_examples/npy/config_meps.yaml b/tests/datastore_examples/npy/config_meps.yaml deleted file mode 100644 index 21259fd..0000000 --- a/tests/datastore_examples/npy/config_meps.yaml +++ /dev/null @@ -1,15 +0,0 @@ -datastore: - kind: npyfilesmeps - config_path: meps_example/data_config.yaml -training: - state_feature_weighting: - __config_class__: ManualStateFeatureWeighting - values: - nlwrs_0: 1.0 - nswrs_0: 1.0 - pres_0g: 1.0 - pres_0s: 1.0 - r_2: 1.0 - r_65: 1.0 - t_2: 1.0 - t_65: 1.0 diff --git a/tests/test_config.py b/tests/test_config.py index 7d8357c..4bb7c1c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,7 +8,9 @@ @pytest.mark.parametrize( "state_weighting_config", [ - nlconfig.ManualStateFeatureWeighting(values=dict(u100m=1.0, v100m=0.5)), + nlconfig.ManualStateFeatureWeighting( + weights=dict(u100m=1.0, v100m=0.5) + ), nlconfig.UniformFeatureWeighting(), ], ) @@ -53,7 +55,7 @@ def test_config_serialization(state_weighting_config): datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""), training=nlconfig.TrainingConfig( state_feature_weighting=nlconfig.ManualStateFeatureWeighting( - values=dict(u100m=1.0, v100m=1.0) + weights=dict(u100m=1.0, v100m=1.0) ) ), )