Skip to content

Commit

Permalink
Refactor Graph class
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Dec 5, 2024
1 parent e0e9364 commit e1bfadc
Showing 1 changed file with 43 additions and 105 deletions.
148 changes: 43 additions & 105 deletions pina/graph.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,56 @@
""" Module for Loss class """

import logging
from torch_geometric.nn import MessagePassing, InstanceNorm, radius_graph
from torch_geometric.data import Data
import torch
from . import LabelTensor
from torch_geometric.nn import radius_graph
from torch_geometric.data import Data

class Graph:
"""
PINA Graph managing the PyG Data class.
"""
def __init__(self, data):
self.data = data

@staticmethod
def _build_triangulation(**kwargs):
logging.debug("Creating graph with triangulation mode.")

# check for mandatory arguments
if "nodes_coordinates" not in kwargs:
raise ValueError("Nodes coordinates must be provided in the kwargs.")
if "nodes_data" not in kwargs:
raise ValueError("Nodes data must be provided in the kwargs.")
if "triangles" not in kwargs:
raise ValueError("Triangles must be provided in the kwargs.")

nodes_coordinates = kwargs["nodes_coordinates"]
nodes_data = kwargs["nodes_data"]
triangles = kwargs["triangles"]

def __init__(self, x=None, pos=None, edge_index=None, edge_attr=None, **kwargs):
if isinstance(x, torch.Tensor):
self.size_x = x.size(0)

if isinstance(pos, torch.Tensor):
self.size_pos = pos.size(0)
self.data = None
if x is not None and pos is not None:
self.build_graphs_list(x, pos, **kwargs)

def less_first(a, b):
return [a, b] if a < b else [b, a]

list_of_edges = []

for triangle in triangles:
for e1, e2 in [[0, 1], [1, 2], [2, 0]]:
list_of_edges.append(less_first(triangle[e1],triangle[e2]))

array_of_edges = torch.unique(torch.Tensor(list_of_edges), dim=0) # remove duplicates
array_of_edges = array_of_edges.t().contiguous()
print(array_of_edges)

# list_of_lengths = []

# for p1,p2 in array_of_edges:
# x1, y1 = tri.points[p1]
# x2, y2 = tri.points[p2]
# list_of_lengths.append((x1-x2)**2 + (y1-y2)**2)

# array_of_lengths = np.sqrt(np.array(list_of_lengths))

# return array_of_edges, array_of_lengths
def build_graphs_list(self, x, pos, method='radius',
build_edge_attr=False, **kwargs):
"""
Build the graph from the node features and the node positions.
"""
if isinstance(x, list) and isinstance(pos, list):
if len(x) != len(pos):
raise ValueError("The number of node features and node positions"
" must be the same.")
if isinstance(x, (torch.Tensor, LabelTensor)) and isinstance(
pos, list):
x = [x] * len(pos) # Copy just the reference
if isinstance(pos, (torch.Tensor, LabelTensor)):
edge_idx = [self._build_edge_index(pos, method, **kwargs)] * len(x)
else:
edge_idx = [self._build_edge_index(p, method, **kwargs) for p in pos]
if build_edge_attr is not None:
edge_attr = [self._build_edge_attr(p, e) for p, e in zip(pos, edge_idx)]
else:
edge_attr = [None] * len(x)

return Data(
x=nodes_data,
pos=nodes_coordinates.T,

edge_index=array_of_edges,
)
graphs = []
for i in range(len(x)):
graphs.append(Data(x=x[i], pos=pos[i], edge_index=edge_idx[i],
edge_attr=edge_attr[i]))
self.data = graphs

@staticmethod
def _build_radius(**kwargs):
logging.debug("Creating graph with radius mode.")

# check for mandatory arguments
if "nodes_coordinates" not in kwargs:
raise ValueError("Nodes coordinates must be provided in the kwargs.")
if "nodes_data" not in kwargs:
raise ValueError("Nodes data must be provided in the kwargs.")
if "radius" not in kwargs:
raise ValueError("Radius must be provided in the kwargs.")

nodes_coordinates = kwargs["nodes_coordinates"]
nodes_data = kwargs["nodes_data"]
radius = kwargs["radius"]

edges_data = kwargs.get("edge_data", None)
loop = kwargs.get("loop", False)
batch = kwargs.get("batch", None)

logging.debug(f"radius: {radius}, loop: {loop}, "
f"batch: {batch}")

edge_index = radius_graph(
x=nodes_coordinates.tensor,
r=radius,
loop=loop,
batch=batch,
)

logging.debug(f"edge_index computed")
return Data(
x=nodes_data.tensor,
pos=nodes_coordinates.tensor,
edge_index=edge_index,
edge_attr=edges_data,
)
def _build_edge_index(pos, method, **kwargs):
if method == 'radius':
return radius_graph(pos, **kwargs)
else:
raise ValueError("The method must be 'radius'.")

@staticmethod
def build(mode, **kwargs):
"""
Constructor for the `Graph` class.
"""
if mode == "radius":
graph = Graph._build_radius(**kwargs)
elif mode == "triangulation":
graph = Graph._build_triangulation(**kwargs)
else:
raise ValueError(f"Mode {mode} not recognized")

return Graph(graph)
def _build_edge_attr(pos, edge_index,):
return torch.norm((pos[edge_index[0]] - pos[edge_index[1]]), dim=-1)


def __repr__(self):
return f"Graph(data={self.data})"

0 comments on commit e1bfadc

Please sign in to comment.