Skip to content

Commit

Permalink
made create_mesh callable as python function with arguments.
Browse files Browse the repository at this point in the history
Fixed error in plotting where non-callable cartopy projection from Config was called
used current mesh generation from neural-lam instead of weather-model-graphs
finished test of training call
  • Loading branch information
SimonKamuk committed May 23, 2024
1 parent 7fa7cdd commit 569d061
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 47 deletions.
4 changes: 2 additions & 2 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def prepend_node_index(graph, new_index):
return networkx.relabel_nodes(graph, to_mapping, copy=True)


def main():
def main(input_args=None):
parser = ArgumentParser(description="Graph generation arguments")
parser.add_argument(
"--data_config",
Expand Down Expand Up @@ -186,7 +186,7 @@ def main():
default=0,
help="Generate hierarchical mesh graph (default: 0, no)",
)
args = parser.parse_args()
args = parser.parse_args(input_args)

# Load grid positions
config_loader = config.Config.from_file(args.data_config)
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

# Plot pred and target
Expand Down Expand Up @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):

fig, ax = plt.subplots(
figsize=(5, 4.8),
subplot_kw={"projection": data_config.coords_projection()},
subplot_kw={"projection": data_config.coords_projection},
)

ax.coastlines() # Add coastline outlines
Expand Down
103 changes: 60 additions & 43 deletions tests/test_mllam_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# Standard library
from pathlib import Path

# Third-party
import numpy as np
import weather_model_graphs as wmg
import os

# First-party
from create_mesh import main as create_mesh
from neural_lam.config import Config
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
from train_model import main
from train_model import main as train_model

# from pathlib import Path
# import numpy as np
# import weather_model_graphs as wmg


os.environ["WANDB_DISABLED"] = "true"


def test_load_reduced_meps_dataset():
Expand Down Expand Up @@ -83,56 +87,69 @@ def test_load_reduced_meps_dataset():
assert set(static_data.keys()) == required_props


def test_create_graph_reduced_meps_dataset():
dataset_name = "meps_example_reduced"
static_dir_path = Path("data", dataset_name, "static")
graph_dir_path = Path("graphs", "hierarchial")

# -- Static grid node features --
xy_grid = np.load(static_dir_path / "nwp_xy.npy")

# create the full graph
graph = wmg.create.archetype.create_oscarsson_hierarchical_graph(
xy_grid=xy_grid
)

# split the graph by component
graph_components = wmg.split_graph_by_edge_attribute(
graph=graph, attr="component"
)
# def test_create_wmg_graph_reduced_meps_dataset():
# dataset_name = "meps_example_reduced"
# static_dir_path = Path("data", dataset_name, "static")
# graph_dir_path = Path("graphs", "hierarchial")

# # -- Static grid node features --
# xy_grid = np.load(static_dir_path / "nwp_xy.npy")

# # create the full graph
# graph = wmg.create.archetype.create_oscarsson_hierarchical_graph(
# xy_grid=xy_grid
# )

# # split the graph by component
# graph_components = wmg.split_graph_by_edge_attribute(
# graph=graph, attr="component"
# )

# m2m_graph = graph_components.pop("m2m")
# m2m_graph_components = wmg.split_graph_by_edge_attribute(
# graph=m2m_graph, attr="direction"
# )
# m2m_graph_components = {
# f"m2m_{name}": graph for name, graph in m2m_graph_components.items()
# }
# graph_components.update(m2m_graph_components)

# # save the graph components to disk in pytorch-geometric format
# for component_name, graph_component in graph_components.items():
# kwargs = {}
# wmg.save.to_pyg(
# graph=graph_component,
# name=component_name,
# output_directory=graph_dir_path,
# **kwargs,
# )

m2m_graph = graph_components.pop("m2m")
m2m_graph_components = wmg.split_graph_by_edge_attribute(
graph=m2m_graph, attr="direction"
)
m2m_graph_components = {
f"m2m_{name}": graph for name, graph in m2m_graph_components.items()
}
graph_components.update(m2m_graph_components)

# save the graph components to disk in pytorch-geometric format
for component_name, graph_component in graph_components.items():
kwargs = {}
wmg.save.to_pyg(
graph=graph_component,
name=component_name,
output_directory=graph_dir_path,
**kwargs,
)
def test_create_graph_reduced_meps_dataset():
args = [
"--graph=hierarchical",
"--hierarchical=1",
"--data_config=data/meps_example_reduced/data_config.yaml",
"--levels=2",
]
create_mesh(args)


def test_train_model_reduced_meps_dataset():
args = [
"--model=hi_lam",
"--data_config=data/meps_example_reduced/data_config.yaml",
"--n_workers=1",
"--n_workers=4",
"--epochs=1",
"--graph=hierarchical",
"--hidden_dim=16",
"--hidden_layers=1",
"--processor_layers=1",
"--ar_steps=1",
"--eval=val",
"--wandb_project=None",
"--n_example_pred=0",
]
main(args)
train_model(args)


test_train_model_reduced_meps_dataset()

0 comments on commit 569d061

Please sign in to comment.