diff --git a/create_mesh.py b/create_mesh.py index f04b4d4b..41557a97 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -153,7 +153,7 @@ def prepend_node_index(graph, new_index): return networkx.relabel_nodes(graph, to_mapping, copy=True) -def main(): +def main(input_args=None): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( "--data_config", @@ -186,7 +186,7 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) - args = parser.parse_args() + args = parser.parse_args(input_args) # Load grid positions config_loader = config.Config.from_file(args.data_config) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 2b6abf15..8c9ca77c 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,7 +87,7 @@ def plot_prediction( 1, 2, figsize=(13, 7), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) # Plot pred and target @@ -136,7 +136,7 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None): fig, ax = plt.subplots( figsize=(5, 4.8), - subplot_kw={"projection": data_config.coords_projection()}, + subplot_kw={"projection": data_config.coords_projection}, ) ax.coastlines() # Add coastline outlines diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py index 0dd454bd..bd638c78 100644 --- a/tests/test_mllam_dataset.py +++ b/tests/test_mllam_dataset.py @@ -1,15 +1,19 @@ # Standard library -from pathlib import Path - -# Third-party -import numpy as np -import weather_model_graphs as wmg +import os # First-party +from create_mesh import main as create_mesh from neural_lam.config import Config from neural_lam.utils import load_static_data from neural_lam.weather_dataset import WeatherDataset -from train_model import main +from train_model import main as train_model + +# from pathlib import Path +# import numpy as np +# import weather_model_graphs as wmg + + +os.environ["WANDB_DISABLED"] = "true" def test_load_reduced_meps_dataset(): @@ -83,49 +87,59 @@ def test_load_reduced_meps_dataset(): assert set(static_data.keys()) == required_props -def test_create_graph_reduced_meps_dataset(): - dataset_name = "meps_example_reduced" - static_dir_path = Path("data", dataset_name, "static") - graph_dir_path = Path("graphs", "hierarchial") - - # -- Static grid node features -- - xy_grid = np.load(static_dir_path / "nwp_xy.npy") - - # create the full graph - graph = wmg.create.archetype.create_oscarsson_hierarchical_graph( - xy_grid=xy_grid - ) - - # split the graph by component - graph_components = wmg.split_graph_by_edge_attribute( - graph=graph, attr="component" - ) +# def test_create_wmg_graph_reduced_meps_dataset(): +# dataset_name = "meps_example_reduced" +# static_dir_path = Path("data", dataset_name, "static") +# graph_dir_path = Path("graphs", "hierarchial") + +# # -- Static grid node features -- +# xy_grid = np.load(static_dir_path / "nwp_xy.npy") + +# # create the full graph +# graph = wmg.create.archetype.create_oscarsson_hierarchical_graph( +# xy_grid=xy_grid +# ) + +# # split the graph by component +# graph_components = wmg.split_graph_by_edge_attribute( +# graph=graph, attr="component" +# ) + +# m2m_graph = graph_components.pop("m2m") +# m2m_graph_components = wmg.split_graph_by_edge_attribute( +# graph=m2m_graph, attr="direction" +# ) +# m2m_graph_components = { +# f"m2m_{name}": graph for name, graph in m2m_graph_components.items() +# } +# graph_components.update(m2m_graph_components) + +# # save the graph components to disk in pytorch-geometric format +# for component_name, graph_component in graph_components.items(): +# kwargs = {} +# wmg.save.to_pyg( +# graph=graph_component, +# name=component_name, +# output_directory=graph_dir_path, +# **kwargs, +# ) - m2m_graph = graph_components.pop("m2m") - m2m_graph_components = wmg.split_graph_by_edge_attribute( - graph=m2m_graph, attr="direction" - ) - m2m_graph_components = { - f"m2m_{name}": graph for name, graph in m2m_graph_components.items() - } - graph_components.update(m2m_graph_components) - # save the graph components to disk in pytorch-geometric format - for component_name, graph_component in graph_components.items(): - kwargs = {} - wmg.save.to_pyg( - graph=graph_component, - name=component_name, - output_directory=graph_dir_path, - **kwargs, - ) +def test_create_graph_reduced_meps_dataset(): + args = [ + "--graph=hierarchical", + "--hierarchical=1", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--levels=2", + ] + create_mesh(args) def test_train_model_reduced_meps_dataset(): args = [ "--model=hi_lam", "--data_config=data/meps_example_reduced/data_config.yaml", - "--n_workers=1", + "--n_workers=4", "--epochs=1", "--graph=hierarchical", "--hidden_dim=16", @@ -133,6 +147,9 @@ def test_train_model_reduced_meps_dataset(): "--processor_layers=1", "--ar_steps=1", "--eval=val", - "--wandb_project=None", + "--n_example_pred=0", ] - main(args) + train_model(args) + + +test_train_model_reduced_meps_dataset()