From b57bc7ac0c10d9467f04f2500a078543c9b310a1 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 2 Oct 2024 09:54:39 +0200 Subject: [PATCH] introduce neural-lam config --- neural_lam/config.py | 123 +++++++++++++++++++++++ neural_lam/create_graph.py | 16 +-- neural_lam/train_model.py | 25 ++--- pyproject.toml | 1 + tests/datastore_examples/mdp/config.yaml | 7 ++ 5 files changed, 147 insertions(+), 25 deletions(-) create mode 100644 neural_lam/config.py create mode 100644 tests/datastore_examples/mdp/config.yaml diff --git a/neural_lam/config.py b/neural_lam/config.py new file mode 100644 index 00000000..7524e3bd --- /dev/null +++ b/neural_lam/config.py @@ -0,0 +1,123 @@ +# Standard library +import dataclasses +from pathlib import Path +from typing import Dict, Union + +# Third-party +import dataclass_wizard + +# Local +from .datastore import ( + DATASTORES, + MDPDatastore, + NpyFilesDatastoreMEPS, + init_datastore, +) + + +class DatastoreKindStr(str): + VALID_KINDS = DATASTORES.keys() + + def __new__(cls, value): + if value not in cls.VALID_KINDS: + raise ValueError(f"Invalid datastore kind: {value}") + return super().__new__(cls, value) + + +@dataclasses.dataclass +class DatastoreSelection: + """ + Configuration for selecting a datastore to use with neural-lam. + + Attributes + ---------- + kind : DatastoreKindStr + The kind of datastore to use, currently `mdp` or `npyfilesmeps` are + implemented. + config_path : str + The path to the configuration file for the selected datastore, this is + assumed to be relative to the configuration file for neural-lam. + """ + + kind: DatastoreKindStr + config_path: str + + +@dataclasses.dataclass +class TrainingConfig: + """ + Configuration related to training neural-lam + + Attributes + ---------- + state_feature_weights : Dict[str, float] + The weights for each state feature in the datastore to use in the loss + function during training. + """ + + state_feature_weights: Dict[str, float] + + +@dataclasses.dataclass +class NeuralLAMConfig(dataclass_wizard.YAMLWizard): + """ + Dataclass for Neural-LAM configuration. This class is used to load and + store the configuration for using Neural-LAM. + + Attributes + ---------- + datastore : DatastoreSelection + The configuration for the datastore to use. + training : TrainingConfig + The configuration for training the model. + """ + + datastore: DatastoreSelection + training: TrainingConfig + + +def load_config_and_datastore( + config_path: str, +) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]: + """ + Load the neural-lam configuration and the datastore specified in the + configuration. + + Parameters + ---------- + config_path : str + Path to the Neural-LAM configuration file. + + Returns + ------- + tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]] + The Neural-LAM configuration and the loaded datastore. + """ + config = NeuralLAMConfig.from_yaml_file(config_path) + # datastore config is assumed to be relative to the config file + datastore_config_path = ( + Path(config_path).parent / config.datastore.config_path + ) + datastore = init_datastore( + datastore_kind=config.datastore.kind, config_path=datastore_config_path + ) + + # TODO: This check should maybe be moved somewhere else, but I'm not sure + # where right now... check that the config state feature weights include a + # weight for each state feature + state_feature_names = datastore.get_vars_names(category="state") + named_feature_weights = config.training.state_feature_weights.keys() + + if set(named_feature_weights) != set(state_feature_names): + additional_features = set(named_feature_weights) - set( + state_feature_names + ) + missing_features = set(state_feature_names) - set(named_feature_weights) + raise ValueError( + f"State feature weights must be provided for each state feature in " + f"the datastore ({state_feature_names}). {missing_features} are " + "missing and weights are defined for the features " + f"{additional_features} which are not in the datastore." + ) + + return config, datastore diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index 0b267f67..4f656d05 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -13,7 +13,7 @@ from torch_geometric.utils.convert import from_networkx # Local -from .datastore import DATASTORES +from .config import load_config_and_datastore from .datastore.base import BaseCartesianDatastore @@ -551,15 +551,9 @@ def create_graph_from_datastore( def cli(input_args=None): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( - "datastore", + "--config", type=str, - choices=DATASTORES.keys(), - help="kind of data store to use", - ) - parser.add_argument( - "datastore_config_path", - type=str, - help="path to the data store config", + default="tests/datastore_examples/mdp/config.yaml", ) parser.add_argument( "--name", @@ -586,8 +580,8 @@ def cli(input_args=None): ) args = parser.parse_args(input_args) - DatastoreClass = DATASTORES[args.datastore] - datastore = DatastoreClass(config_path=args.datastore_config_path) + # Load neural-lam configuration and datastore to use + _, datastore = load_config_and_datastore(config_path=args.config) create_graph_from_datastore( datastore=datastore, diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index a1918994..e2700bc0 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -11,7 +11,7 @@ # Local from . import utils -from .datastore import DATASTORES, init_datastore +from .config import load_config_and_datastore from .models import GraphLAM, HiLAM, HiLAMParallel from .weather_dataset import WeatherDataModule @@ -28,15 +28,9 @@ def main(input_args=None): description="Train or evaluate NeurWP models for LAM" ) parser.add_argument( - "datastore_kind", + "--config", type=str, - choices=DATASTORES.keys(), - help="Kind of datastore to use", - ) - parser.add_argument( - "datastore_config_path", - type=str, - help="Path for the datastore config", + default="tests/datastore_examples/mdp/config.yaml", ) parser.add_argument( "--model", @@ -226,11 +220,14 @@ def main(input_args=None): # Set seed seed.seed_everything(args.seed) - # Create datastore - datastore = init_datastore( - datastore_kind=args.datastore_kind, - config_path=args.datastore_config_path, - ) + # Load neural-lam configuration and datastore to use + config, datastore = load_config_and_datastore(config_path=args.config) + # TODO: config.training.state_feature_weights need passing in somewhere, + # probably to ARModel, so that it can be used in the loss function + assert ( + config.training.state_feature_weights + ), "No state feature weights found in config" + # Create datamodule data_module = WeatherDataModule( datastore=datastore, diff --git a/pyproject.toml b/pyproject.toml index da6664cf..349e459d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ known_first_party = [ # Add first-party modules that may be misclassified by isort "neural_lam", ] +line_length = 80 [tool.flake8] max-line-length = 80 diff --git a/tests/datastore_examples/mdp/config.yaml b/tests/datastore_examples/mdp/config.yaml new file mode 100644 index 00000000..44a87ca4 --- /dev/null +++ b/tests/datastore_examples/mdp/config.yaml @@ -0,0 +1,7 @@ +datastore: + kind: mdp + config_path: danra.example.yaml +training: + state_feature_weights: + u100m: 1.0 + v100m: 1.0