Skip to content

Commit

Permalink
Start building graphs with wmg
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 4, 2024
1 parent 365675d commit 5a65ce9
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 6 deletions.
153 changes: 153 additions & 0 deletions neural_lam/build_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Standard library
import argparse
import os

# Third-party
import numpy as np
import weather_model_graphs as wmg

# Local
from . import config, utils

WMG_ARCHETYPES = {
"keisler": wmg.create.archetype.create_keisler_graph,
"graphcast": wmg.create.archetype.create_graphcast_graph,
"hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph,
}


def main():
parser = argparse.ArgumentParser(
description="Graph generation using WMG",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# Inputs and outputs
parser.add_argument(
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
help="Path to data config file",
)
parser.add_argument(
"--output_dir",
type=str,
default="graphs",
help="Directory to save graph to",
)

# Graph structure
parser.add_argument(
"--archetype",
type=str,
default="keisler",
help="Archetype to use to create graph (keisler/graphcast/hierarchical)",
)
parser.add_argument(
"--mesh_node_distance",
type=float,
default=3.0,
help="Distance between created mesh nodes",
)
parser.add_argument(
"--level_refinement_factor",
type=float,
default=3,
help="Refinement factor between grid points and bottom level of mesh hierarchy",
)
parser.add_argument(
"--max_num_levels",
type=int,
help="Limit multi-scale mesh to given number of levels, "
"from bottom up",
)
parser.add_argument(
"--hierarchical",
action="store_true",
help="Generate hierarchical mesh graph (default: False)",
)
args = parser.parse_args()

# Load grid positions
config_loader = config.Config.from_file(args.data_config)

coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy()
# (num_nodes_full, 2)

# Construct mask
static_data = utils.load_static_data(config_loader.dataset.name)
decode_mask = np.concatenate(
(
np.ones(static_data["grid_static_features"].shape[0], dtype=bool),
np.zeros(
static_data["boundary_static_features"].shape[0], dtype=bool
),
),
axis=0,
)

# Build graph
assert (
args.archetype in WMG_ARCHETYPES
), f"Unknown archetype: {args.archetype}"
archetype_create_func = WMG_ARCHETYPES[args.archetype]

create_kwargs = {
"coords": coords,
"mesh_node_distance": args.mesh_node_distance,
"projection": None,
"decode_mask": decode_mask,
}
if args.archetype != "keisler":
# Add additional multi-level kwargs
create_kwargs.update(
{
"level_refinement_factor": args.level_refinement_factor,
"max_num_levels": args.max_num_levels,
}
)

graph = archetype_create_func(**create_kwargs)
graph_comp = wmg.split_graph_by_edge_attribute(graph, attr="component")

print("Created graph:")
for name, subgraph in graph_comp.items():
print(f"{name}: {subgraph}")

# Save graph
os.makedirs(args.output_dir, exist_ok=True)
for component, graph in graph_comp.items():
# TODO This is all hack, saving in wmg needs to be consistent with nl
if component == "m2m":
if args.archetype == "hierarchical":
# Split by direction
m2m_direction_comp = wmg.split_graph_by_edge_attribute(
graph, attr="direction"
)
for direction, graph in m2m_direction_comp.items():
wmg.save.to_pyg(
graph=graph,
name=f"mesh_{direction}",
list_from_attribute="level",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)
else:
wmg.save.to_pyg(
graph=graph,
name=component,
list_from_attribute="dummy",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)
else:
wmg.save.to_pyg(
graph=graph,
name=component,
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)


if __name__ == "__main__":
main()
18 changes: 17 additions & 1 deletion neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def loads_file(fn):

# Load static node features
mesh_static_features = loads_file(
"mesh_features.pt"
"m2m_node_features.pt"
) # List of (N_mesh[l], d_mesh_static)

# Some checks for consistency
Expand Down Expand Up @@ -281,3 +281,19 @@ def init_wandb_metrics(wandb_logger, val_steps):
experiment.define_metric("val_mean_loss", summary="min")
for step in val_steps:
experiment.define_metric(f"val_loss_unroll{step}", summary="min")


def get_reordered_grid_pos(dataset_name, device="cpu"):
"""
Interior nodes first, then boundary
"""
static_data = load_static_data(dataset_name)

return torch.cat(
(
static_data["grid_static_features"][:, :2],
static_data["boundary_static_features"][:, :2],
),
dim=0,
)
# (num_total_grid_nodes, 2)
6 changes: 1 addition & 5 deletions plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ def main():
)
mesh_static_features = graph_ldict["mesh_static_features"]

grid_static_features = utils.load_static_data(config_loader.dataset.name)[
"grid_static_features"
]

# Extract values needed, turn to numpy
grid_pos = grid_static_features[:, :2].numpy()
grid_pos = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy()
# Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
grid_pos = np.concatenate(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
"weather-model-graphs>=0.2.0"
]
requires-python = ">=3.9"

Expand Down

0 comments on commit 5a65ce9

Please sign in to comment.