Skip to content

Commit

Permalink
Add Graph support in Dataset and Dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Oct 23, 2024
1 parent 7dee77e commit 5970eee
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 48 deletions.
8 changes: 3 additions & 5 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
from ..graph import Graph


class BaseDataset(Dataset):
Expand Down Expand Up @@ -47,16 +48,14 @@ def __init__(self, problem, device):
for name, data in collector.data_collections.items():
keys = []
for k, v in data.items():
if isinstance(v, LabelTensor):
if isinstance(v, LabelTensor) or (isinstance(v, list) and all(isinstance(i, Graph) for i in v)):
keys.append(k)
if sorted(self.__slots__) == sorted(keys):

for slot in self.__slots__:
current_list = getattr(self, slot)
current_list.append(data[slot])
self.condition_names[idx] = name
idx += 1

if len(getattr(self, self.__slots__[0])) > 0:
input_list = getattr(self, self.__slots__[0])
self.condition_indices = torch.cat(
Expand Down Expand Up @@ -89,11 +88,10 @@ def __getattribute__(self, item):
def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx).to(self.device)

if isinstance(idx, slice):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
to_return_list.append(getattr(self, i)[idx].to(self.device))
return to_return_list

if isinstance(idx, (tuple, list)):
Expand Down
3 changes: 2 additions & 1 deletion pina/data/pina_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

class Batch:
"""
Implementation of the Batch class used during training to perform SGD optimization.
Implementation of the Batch class used during training to perform SGD
optimization.
"""

def __init__(self, dataset_dict, idx_dict):
Expand Down
7 changes: 6 additions & 1 deletion pina/data/pina_subset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module for PinaSubset class
"""
from pina import LabelTensor


class PinaSubset:
Expand All @@ -23,4 +24,8 @@ def __len__(self):
return len(self.indices)

def __getattr__(self, name):
return self.dataset.__getattribute__(name)
tensor = self.dataset.__getattribute__(name)
if isinstance(tensor, LabelTensor):
return tensor[self.indices]
if isinstance(tensor, list):
return [tensor[i] for i in self.indices]
3 changes: 2 additions & 1 deletion pina/data/supervised_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

class SupervisedDataset(BaseDataset):
"""
This class extends the BaseDataset to handle datasets that consist of input-output pairs.
This class extends the BaseDataset to handle datasets that consist of
input-output pairs.
"""
data_type = 'supervised'
__slots__ = ['input_points', 'output_points']
110 changes: 70 additions & 40 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import math
import torch
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset
from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, \
UnsupervisedDataset
from pina.data import PinaDataLoader
from pina import LabelTensor, Condition
from pina.equation import Equation
from pina.domain import CartesianDomain
from pina.problem import SpatialProblem
from pina.problem import SpatialProblem, AbstractProblem
from pina.operators import laplacian
from pina.equation.equation_factory import FixedValue
from pina.graph import Graph



def laplace_equation(input_, output_):
Expand All @@ -30,49 +33,49 @@ class Poisson(SpatialProblem):

conditions = {
'gamma1':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 1
}),
equation=FixedValue(0.0)),
'gamma2':
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': [0, 1],
'y': 0
}),
equation=FixedValue(0.0)),
'gamma3':
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 1,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'gamma4':
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
Condition(domain=CartesianDomain({
'x': 0,
'y': [0, 1]
}),
equation=FixedValue(0.0)),
'D':
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
Condition(input_points=LabelTensor(torch.rand(size=(100, 2)),
['x', 'y']),
equation=my_laplace),
'data':
Condition(input_points=in_, output_points=out_),
Condition(input_points=in_, output_points=out_),
'data2':
Condition(input_points=in2_, output_points=out2_),
Condition(input_points=in2_, output_points=out2_),
'unsupervised':
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
Condition(
input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(45, 1)),
['alpha']),
),
'unsupervised2':
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
Condition(
input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']),
conditional_variables=LabelTensor(torch.ones(size=(90, 1)),
['alpha']),
)
}


Expand All @@ -98,8 +101,8 @@ def test_data():
assert dataset.input_points.shape == (61, 2)
assert dataset['input_points'].labels == ['x', 'y']
assert dataset.input_points.labels == ['x', 'y']
assert dataset['input_points', 3:].shape == (58, 2)
assert dataset[3:][1].labels == ['u']
assert dataset.input_points[3:].shape == (58, 2)
assert dataset.output_points[:3].labels == ['u']
assert dataset.output_points.shape == (61, 1)
assert dataset.output_points.labels == ['u']
assert dataset.condition_indices.dtype == torch.uint8
Expand Down Expand Up @@ -192,5 +195,32 @@ def test_loader():
assert i.physics.input_points.requires_grad == True
assert i.unsupervised.input_points.requires_grad == True

coordinates = LabelTensor(torch.rand((100, 100, 2)), labels=['x', 'y'])
data = LabelTensor(torch.rand((100, 100, 3)), labels=['ux', 'uy', 'p'])
class GraphProblem(AbstractProblem):

output = LabelTensor(torch.rand((100, 3)), labels=['ux', 'uy', 'p'])
input = [Graph.build('radius',
nodes_coordinates=coordinates[i,:,:],
nodes_data=data[i, :, :], radius=0.2)
for i in
range(100)]
output_variables = ['u']

conditions = {
'graph_data': Condition(input_points=input, output_points=output)
}


graph_problem = GraphProblem()


def test_loader_graph():
data_module = PinaDataModule(graph_problem, device='cpu', batch_size=10)
data_module.setup()
loader = data_module.train_dataloader()
for i in loader:
assert len(i) <= 10
assert isinstance(i.supervised.input_points, list)
assert all(isinstance(x, Graph) for x in i.supervised.input_points)

test_loader()

0 comments on commit 5970eee

Please sign in to comment.