Skip to content

Commit

Permalink
wip to make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 11, 2024
1 parent aeb5403 commit 42be03f
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 53 deletions.
35 changes: 21 additions & 14 deletions neural_lam/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
}


def main():
def main(input_args=None):
parser = argparse.ArgumentParser(
description="Graph generation using WMG",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
Expand Down Expand Up @@ -61,16 +61,12 @@ def main():
help="Limit multi-scale mesh to given number of levels, "
"from bottom up",
)
parser.add_argument(
"--hierarchical",
action="store_true",
help="Generate hierarchical mesh graph (default: False)",
)
args = parser.parse_args()
args = parser.parse_args(input_args)

# Load grid positions
config_loader = config.Config.from_file(args.data_config)

# TODO Do not get normalised positions
coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy()
# (num_nodes_full, 2)

Expand Down Expand Up @@ -126,13 +122,24 @@ def main():
graph, attr="direction"
)
for direction, graph in m2m_direction_comp.items():
wmg.save.to_pyg(
graph=graph,
name=f"mesh_{direction}",
list_from_attribute="level",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)
if direction == "same":
# Name just m2m to be consistent with non-hierarchical
wmg.save.to_pyg(
graph=graph,
name="m2m",
list_from_attribute="level",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)
else:
# up and down directions
wmg.save.to_pyg(
graph=graph,
name=f"mesh_{direction}",
list_from_attribute="levels",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)
else:
wmg.save.to_pyg(
graph=graph,
Expand Down
6 changes: 3 additions & 3 deletions neural_lam/interaction_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
"""
Create a new InteractionNet
edge_index: (2,M), Edges in pyg format
edge_index: (2,M), Edges in pyg format, with boeth sender and receiver
node indices starting at 0
input_dim: Dimensionality of input representations,
for both nodes and edges
update_edges: If new edge representations should be computed
Expand All @@ -52,8 +53,7 @@ def __init__(
# Default to input dim if not explicitly given
hidden_dim = input_dim

# Make both sender and receiver indices of edge_index start at 0
edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0]
# any edge_index used here must start sender and rec. nodes at index 0
# Store number of receiver nodes according to edge_index
self.num_rec = edge_index[1].max() + 1
edge_index[0] = (
Expand Down
26 changes: 23 additions & 3 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def __iter__(self):
return (self[i] for i in range(len(self)))


def zero_index_edge_index(edge_index):
"""
Make both sender and receiver indices of edge_index start at 0
"""
return edge_index - edge_index.min(dim=1, keepdim=True)[0]


def load_graph(graph_name, device="cpu"):
"""
Load all tensors representing the graph
Expand All @@ -128,11 +135,16 @@ def loads_file(fn):

# Load edges (edge_index)
m2m_edge_index = BufferList(
loads_file("m2m_edge_index.pt"), persistent=False
[zero_index_edge_index(ei) for ei in loads_file("m2m_edge_index.pt")],
persistent=False,
) # List of (2, M_m2m[l])
g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m)
m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g)

# Change first indices to 0
g2m_edge_index = zero_index_edge_index(g2m_edge_index)
m2g_edge_index = zero_index_edge_index(m2g_edge_index)

n_levels = len(m2m_edge_index)
hierarchical = n_levels > 1 # Nor just single level mesh graph

Expand Down Expand Up @@ -168,10 +180,18 @@ def loads_file(fn):
if hierarchical:
# Load up and down edges and features
mesh_up_edge_index = BufferList(
loads_file("mesh_up_edge_index.pt"), persistent=False
[
zero_index_edge_index(ei)
for ei in loads_file("mesh_up_edge_index.pt")
],
persistent=False,
) # List of (2, M_up[l])
mesh_down_edge_index = BufferList(
loads_file("mesh_down_edge_index.pt"), persistent=False
[
zero_index_edge_index(ei)
for ei in loads_file("mesh_down_edge_index.pt")
],
persistent=False,
) # List of (2, M_down[l])

mesh_up_features = loads_file(
Expand Down
115 changes: 91 additions & 24 deletions plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def main():

# Load graph data
hierarchical, graph_ldict = utils.load_graph(args.graph)
(g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
(
g2m_edge_index,
m2g_edge_index,
m2m_edge_index,
) = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
graph_ldict["m2m_edge_index"],
Expand All @@ -66,11 +70,9 @@ def main():
(grid_pos, np.expand_dims(z_grid, axis=1)), axis=1
)

# List of edges to plot, (edge_index, color, line_width, label)
edge_plot_list = [
(m2g_edge_index.numpy(), "black", 0.4, "M2G"),
(g2m_edge_index.numpy(), "black", 0.4, "G2M"),
]
# List of edges to plot, (edge_index, from_pos, to_pos, color,
# line_width, label)
edge_plot_list = []

# Mesh positioning and edges to plot differ if we have a hierarchical graph
if hierarchical:
Expand All @@ -89,24 +91,80 @@ def main():
mesh_static_features, start=1
)
]
mesh_pos = np.concatenate(mesh_level_pos, axis=0)
all_mesh_pos = np.concatenate(mesh_level_pos, axis=0)
grid_con_mesh_pos = mesh_level_pos[0]

# Add inter-level mesh edges
edge_plot_list += [
(level_ei.numpy(), "blue", 1, f"M2M Level {level}")
for level, level_ei in enumerate(m2m_edge_index)
(
level_ei.numpy(),
level_pos,
level_pos,
"blue",
1,
f"M2M Level {level}",
)
for level, (level_ei, level_pos) in enumerate(
zip(m2m_edge_index, mesh_level_pos)
)
]

# Add intra-level mesh edges
up_edges_ei = np.concatenate(
[level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1
up_edges_ei = [
level_up_ei.numpy() for level_up_ei in mesh_up_edge_index
]
down_edges_ei = [
level_down_ei.numpy() for level_down_ei in mesh_down_edge_index
]
# Add up edges
for level_i, (up_ei, from_pos, to_pos) in enumerate(
zip(up_edges_ei, mesh_level_pos[:-1], mesh_level_pos[1:])
):
edge_plot_list.append(
(
up_ei,
from_pos,
to_pos,
"green",
1,
f"Mesh up {level_i}-{level_i+1}",
)
)
# Add down edges
for level_i, (down_ei, from_pos, to_pos) in enumerate(
zip(down_edges_ei, mesh_level_pos[1:], mesh_level_pos[:-1])
):
edge_plot_list.append(
(
down_ei,
from_pos,
to_pos,
"green",
1,
f"Mesh down {level_i+1}-{level_i}",
)
)

edge_plot_list.append(
(
m2g_edge_index.numpy(),
grid_con_mesh_pos,
grid_pos,
"black",
0.4,
"M2G",
)
)
down_edges_ei = np.concatenate(
[level_down_ei.numpy() for level_down_ei in mesh_down_edge_index],
axis=1,
edge_plot_list.append(
(
g2m_edge_index.numpy(),
grid_pos,
grid_con_mesh_pos,
"black",
0.4,
"G2M",
)
)
edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up"))
edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down"))

mesh_node_size = 2.5
else:
Expand All @@ -120,21 +178,30 @@ def main():
(mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1
)

edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M"))
edge_plot_list.append(
(m2m_edge_index.numpy(), mesh_pos, mesh_pos, "blue", 1, "M2M")
)
edge_plot_list.append(
(m2g_edge_index.numpy(), mesh_pos, grid_pos, "black", 0.4, "M2G")
)
edge_plot_list.append(
(g2m_edge_index.numpy(), grid_pos, mesh_pos, "black", 0.4, "G2M")
)

# All node positions in one array
node_pos = np.concatenate((mesh_pos, grid_pos), axis=0)
all_mesh_pos = mesh_pos

# Add edges
data_objs = []
for (
ei,
from_pos,
to_pos,
col,
width,
label,
) in edge_plot_list:
edge_start = node_pos[ei[0]] # (M, 2)
edge_end = node_pos[ei[1]] # (M, 2)
edge_start = from_pos[ei[0]] # (M, 2)
edge_end = to_pos[ei[1]] # (M, 2)
n_edges = edge_start.shape[0]

x_edges = np.stack(
Expand Down Expand Up @@ -171,9 +238,9 @@ def main():
)
data_objs.append(
go.Scatter3d(
x=mesh_pos[:, 0],
y=mesh_pos[:, 1],
z=mesh_pos[:, 2],
x=all_mesh_pos[:, 0],
y=all_mesh_pos[:, 1],
z=all_mesh_pos[:, 2],
mode="markers",
marker={"color": "blue", "size": mesh_node_size},
name="Mesh nodes",
Expand Down
26 changes: 17 additions & 9 deletions tests/test_mllam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# First-party
from neural_lam.config import Config
from neural_lam.create_mesh import main as create_mesh
from neural_lam.build_graph import main as build_graph
from neural_lam.train_model import main as train_model
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
Expand Down Expand Up @@ -66,14 +66,15 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
n_state_features = len(var_names)
n_prediction_timesteps = dataset.sample_length - n_input_steps

nx, ny = config.values["grid_shape_state"]
n_grid = nx * ny
static_data = load_static_data(dataset_name)
n_grid = static_data["interior_mask"].sum().item()
n_boundary = static_data["boundary_mask"].sum().item()

# check that the dataset is not empty
assert len(dataset) > 0

# get the first item
init_states, target_states, forcing = dataset[0]
init_states, target_states, forcing, boundary_forcing = dataset[0]

# check that the shapes of the tensors are correct
assert init_states.shape == (n_input_steps, n_grid, n_state_features)
Expand All @@ -87,6 +88,11 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
n_grid,
n_forcing_features,
)
assert boundary_forcing.shape == (
n_prediction_timesteps,
n_boundary,
2 * n_grid + n_forcing_features, # TODO Adjust dimensionality
)

static_data = load_static_data(dataset_name=dataset_name)

Expand Down Expand Up @@ -117,12 +123,14 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):

def test_create_graph_reduced_meps_dataset():
args = [
"--graph=hierarchical",
"--hierarchical",
"--output_dir=graphs/reduced_meps_hierarchical",
"--archetype=hierarchical",
"--data_config=data/meps_example_reduced/data_config.yaml",
"--levels=2",
"--max_num_levels=2",
"--mesh_node_distance=0.05",
# Distance for normalized data, might need adjustment
]
create_mesh(args)
build_graph(args)


def test_train_model_reduced_meps_dataset():
Expand All @@ -131,7 +139,7 @@ def test_train_model_reduced_meps_dataset():
"--data_config=data/meps_example_reduced/data_config.yaml",
"--n_workers=4",
"--epochs=1",
"--graph=hierarchical",
"--graph=reduced_meps_hierarchical",
"--hidden_dim=16",
"--hidden_layers=1",
"--processor_layers=1",
Expand Down

0 comments on commit 42be03f

Please sign in to comment.