Skip to content

Commit

Permalink
introduce neural-lam config
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Oct 2, 2024
1 parent 7e46194 commit b57bc7a
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 25 deletions.
123 changes: 123 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Standard library
import dataclasses
from pathlib import Path
from typing import Dict, Union

# Third-party
import dataclass_wizard

# Local
from .datastore import (
DATASTORES,
MDPDatastore,
NpyFilesDatastoreMEPS,
init_datastore,
)


class DatastoreKindStr(str):
VALID_KINDS = DATASTORES.keys()

def __new__(cls, value):
if value not in cls.VALID_KINDS:
raise ValueError(f"Invalid datastore kind: {value}")
return super().__new__(cls, value)


@dataclasses.dataclass
class DatastoreSelection:
"""
Configuration for selecting a datastore to use with neural-lam.
Attributes
----------
kind : DatastoreKindStr
The kind of datastore to use, currently `mdp` or `npyfilesmeps` are
implemented.
config_path : str
The path to the configuration file for the selected datastore, this is
assumed to be relative to the configuration file for neural-lam.
"""

kind: DatastoreKindStr
config_path: str


@dataclasses.dataclass
class TrainingConfig:
"""
Configuration related to training neural-lam
Attributes
----------
state_feature_weights : Dict[str, float]
The weights for each state feature in the datastore to use in the loss
function during training.
"""

state_feature_weights: Dict[str, float]


@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.YAMLWizard):
"""
Dataclass for Neural-LAM configuration. This class is used to load and
store the configuration for using Neural-LAM.
Attributes
----------
datastore : DatastoreSelection
The configuration for the datastore to use.
training : TrainingConfig
The configuration for training the model.
"""

datastore: DatastoreSelection
training: TrainingConfig


def load_config_and_datastore(
config_path: str,
) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]:
"""
Load the neural-lam configuration and the datastore specified in the
configuration.
Parameters
----------
config_path : str
Path to the Neural-LAM configuration file.
Returns
-------
tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]
The Neural-LAM configuration and the loaded datastore.
"""
config = NeuralLAMConfig.from_yaml_file(config_path)
# datastore config is assumed to be relative to the config file
datastore_config_path = (
Path(config_path).parent / config.datastore.config_path
)
datastore = init_datastore(
datastore_kind=config.datastore.kind, config_path=datastore_config_path
)

# TODO: This check should maybe be moved somewhere else, but I'm not sure
# where right now... check that the config state feature weights include a
# weight for each state feature
state_feature_names = datastore.get_vars_names(category="state")
named_feature_weights = config.training.state_feature_weights.keys()

if set(named_feature_weights) != set(state_feature_names):
additional_features = set(named_feature_weights) - set(
state_feature_names
)
missing_features = set(state_feature_names) - set(named_feature_weights)
raise ValueError(
f"State feature weights must be provided for each state feature in "
f"the datastore ({state_feature_names}). {missing_features} are "
"missing and weights are defined for the features "
f"{additional_features} which are not in the datastore."
)

return config, datastore
16 changes: 5 additions & 11 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch_geometric.utils.convert import from_networkx

# Local
from .datastore import DATASTORES
from .config import load_config_and_datastore
from .datastore.base import BaseCartesianDatastore


Expand Down Expand Up @@ -551,15 +551,9 @@ def create_graph_from_datastore(
def cli(input_args=None):
parser = ArgumentParser(description="Graph generation arguments")
parser.add_argument(
"datastore",
"--config",
type=str,
choices=DATASTORES.keys(),
help="kind of data store to use",
)
parser.add_argument(
"datastore_config_path",
type=str,
help="path to the data store config",
default="tests/datastore_examples/mdp/config.yaml",
)
parser.add_argument(
"--name",
Expand All @@ -586,8 +580,8 @@ def cli(input_args=None):
)
args = parser.parse_args(input_args)

DatastoreClass = DATASTORES[args.datastore]
datastore = DatastoreClass(config_path=args.datastore_config_path)
# Load neural-lam configuration and datastore to use
_, datastore = load_config_and_datastore(config_path=args.config)

create_graph_from_datastore(
datastore=datastore,
Expand Down
25 changes: 11 additions & 14 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Local
from . import utils
from .datastore import DATASTORES, init_datastore
from .config import load_config_and_datastore
from .models import GraphLAM, HiLAM, HiLAMParallel
from .weather_dataset import WeatherDataModule

Expand All @@ -28,15 +28,9 @@ def main(input_args=None):
description="Train or evaluate NeurWP models for LAM"
)
parser.add_argument(
"datastore_kind",
"--config",
type=str,
choices=DATASTORES.keys(),
help="Kind of datastore to use",
)
parser.add_argument(
"datastore_config_path",
type=str,
help="Path for the datastore config",
default="tests/datastore_examples/mdp/config.yaml",
)
parser.add_argument(
"--model",
Expand Down Expand Up @@ -226,11 +220,14 @@ def main(input_args=None):
# Set seed
seed.seed_everything(args.seed)

# Create datastore
datastore = init_datastore(
datastore_kind=args.datastore_kind,
config_path=args.datastore_config_path,
)
# Load neural-lam configuration and datastore to use
config, datastore = load_config_and_datastore(config_path=args.config)
# TODO: config.training.state_feature_weights need passing in somewhere,
# probably to ARModel, so that it can be used in the loss function
assert (
config.training.state_feature_weights
), "No state feature weights found in config"

# Create datamodule
data_module = WeatherDataModule(
datastore=datastore,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ known_first_party = [
# Add first-party modules that may be misclassified by isort
"neural_lam",
]
line_length = 80

[tool.flake8]
max-line-length = 80
Expand Down
7 changes: 7 additions & 0 deletions tests/datastore_examples/mdp/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
datastore:
kind: mdp
config_path: danra.example.yaml
training:
state_feature_weights:
u100m: 1.0
v100m: 1.0

0 comments on commit b57bc7a

Please sign in to comment.