From 5a65ce9c2684f8a426b9cb248a310d38e0857224 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Mon, 4 Nov 2024 18:18:34 +0100 Subject: [PATCH] Start building graphs with wmg --- neural_lam/build_graph.py | 153 ++++++++++++++++++++++++++++++++++++++ neural_lam/utils.py | 18 ++++- plot_graph.py | 6 +- pyproject.toml | 1 + 4 files changed, 172 insertions(+), 6 deletions(-) create mode 100644 neural_lam/build_graph.py diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py new file mode 100644 index 00000000..034f82cd --- /dev/null +++ b/neural_lam/build_graph.py @@ -0,0 +1,153 @@ +# Standard library +import argparse +import os + +# Third-party +import numpy as np +import weather_model_graphs as wmg + +# Local +from . import config, utils + +WMG_ARCHETYPES = { + "keisler": wmg.create.archetype.create_keisler_graph, + "graphcast": wmg.create.archetype.create_graphcast_graph, + "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, +} + + +def main(): + parser = argparse.ArgumentParser( + description="Graph generation using WMG", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Inputs and outputs + parser.add_argument( + "--data_config", + type=str, + default="neural_lam/data_config.yaml", + help="Path to data config file", + ) + parser.add_argument( + "--output_dir", + type=str, + default="graphs", + help="Directory to save graph to", + ) + + # Graph structure + parser.add_argument( + "--archetype", + type=str, + default="keisler", + help="Archetype to use to create graph (keisler/graphcast/hierarchical)", + ) + parser.add_argument( + "--mesh_node_distance", + type=float, + default=3.0, + help="Distance between created mesh nodes", + ) + parser.add_argument( + "--level_refinement_factor", + type=float, + default=3, + help="Refinement factor between grid points and bottom level of mesh hierarchy", + ) + parser.add_argument( + "--max_num_levels", + type=int, + help="Limit multi-scale mesh to given number of levels, " + "from bottom up", + ) + parser.add_argument( + "--hierarchical", + action="store_true", + help="Generate hierarchical mesh graph (default: False)", + ) + args = parser.parse_args() + + # Load grid positions + config_loader = config.Config.from_file(args.data_config) + + coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy() + # (num_nodes_full, 2) + + # Construct mask + static_data = utils.load_static_data(config_loader.dataset.name) + decode_mask = np.concatenate( + ( + np.ones(static_data["grid_static_features"].shape[0], dtype=bool), + np.zeros( + static_data["boundary_static_features"].shape[0], dtype=bool + ), + ), + axis=0, + ) + + # Build graph + assert ( + args.archetype in WMG_ARCHETYPES + ), f"Unknown archetype: {args.archetype}" + archetype_create_func = WMG_ARCHETYPES[args.archetype] + + create_kwargs = { + "coords": coords, + "mesh_node_distance": args.mesh_node_distance, + "projection": None, + "decode_mask": decode_mask, + } + if args.archetype != "keisler": + # Add additional multi-level kwargs + create_kwargs.update( + { + "level_refinement_factor": args.level_refinement_factor, + "max_num_levels": args.max_num_levels, + } + ) + + graph = archetype_create_func(**create_kwargs) + graph_comp = wmg.split_graph_by_edge_attribute(graph, attr="component") + + print("Created graph:") + for name, subgraph in graph_comp.items(): + print(f"{name}: {subgraph}") + + # Save graph + os.makedirs(args.output_dir, exist_ok=True) + for component, graph in graph_comp.items(): + # TODO This is all hack, saving in wmg needs to be consistent with nl + if component == "m2m": + if args.archetype == "hierarchical": + # Split by direction + m2m_direction_comp = wmg.split_graph_by_edge_attribute( + graph, attr="direction" + ) + for direction, graph in m2m_direction_comp.items(): + wmg.save.to_pyg( + graph=graph, + name=f"mesh_{direction}", + list_from_attribute="level", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + list_from_attribute="dummy", + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + else: + wmg.save.to_pyg( + graph=graph, + name=component, + edge_features=["len", "vdiff"], + output_directory=args.output_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/neural_lam/utils.py b/neural_lam/utils.py index a414e357..2891e9b0 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -154,7 +154,7 @@ def loads_file(fn): # Load static node features mesh_static_features = loads_file( - "mesh_features.pt" + "m2m_node_features.pt" ) # List of (N_mesh[l], d_mesh_static) # Some checks for consistency @@ -281,3 +281,19 @@ def init_wandb_metrics(wandb_logger, val_steps): experiment.define_metric("val_mean_loss", summary="min") for step in val_steps: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +def get_reordered_grid_pos(dataset_name, device="cpu"): + """ + Interior nodes first, then boundary + """ + static_data = load_static_data(dataset_name) + + return torch.cat( + ( + static_data["grid_static_features"][:, :2], + static_data["boundary_static_features"][:, :2], + ), + dim=0, + ) + # (num_total_grid_nodes, 2) diff --git a/plot_graph.py b/plot_graph.py index e47e62c0..46a63a7f 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -58,12 +58,8 @@ def main(): ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(config_loader.dataset.name)[ - "grid_static_features" - ] - # Extract values needed, turn to numpy - grid_pos = grid_static_features[:, :2].numpy() + grid_pos = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy() # Add in z-dimension z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) grid_pos = np.concatenate( diff --git a/pyproject.toml b/pyproject.toml index 14b7e69a..22a1cafe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "plotly>=5.15.0", "torch>=2.3.0", "torch-geometric==2.3.1", + "weather-model-graphs>=0.2.0" ] requires-python = ">=3.9"