Skip to content

Commit

Permalink
Implementation of DataLoader and DataModule (#383)
Browse files Browse the repository at this point in the history
Refactoring for 0.2
* Data module, data loader and dataset
* Refactor LabelTensor
* Refactor solvers 

Co-authored-by: dario-coscia <[email protected]>
  • Loading branch information
FilippoOlivo and dario-coscia authored Nov 27, 2024
1 parent dbb5476 commit 8f458d2
Show file tree
Hide file tree
Showing 35 changed files with 814 additions and 1,792 deletions.
8 changes: 4 additions & 4 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph',
]

from .meta import *
Expand All @@ -9,9 +9,9 @@
from .trainer import Trainer
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset

from .data import PinaDataModule
from .data import PinaDataLoader

from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
18 changes: 12 additions & 6 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors


Expand Down Expand Up @@ -66,9 +67,12 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
condition_domain = condition.domain
if isinstance(condition_domain, str):
condition_domain = self.problem.domains[condition_domain]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
Expand All @@ -84,10 +88,11 @@ def store_sample_domains(self, n, mode, variables, sample_locations):

# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
condition_domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
Expand All @@ -110,5 +115,6 @@ def add_points(self, new_points_dict):
if not self._is_conditions_ready[k]:
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
self.data_collections[k]['input_points'] = LabelTensor.vstack(
[self.data_collections[k][
'input_points'], v])
3 changes: 1 addition & 2 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ class DataConditionInterface(ConditionInterface):

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ class DomainEquationCondition(ConditionInterface):

def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
check_consistency(value, (DomainInterface))
check_consistency(value, (DomainInterface, str))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ class InputPointsEquationCondition(ConditionInterface):

def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'input_points':
Expand Down
6 changes: 3 additions & 3 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch_geometric

from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
Expand All @@ -16,16 +17,15 @@ class InputOutputPointsCondition(ConditionInterface):

def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
13 changes: 5 additions & 8 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
Import data classes
"""
__all__ = [
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
'PinaDataModule',
'PinaDataset'
]

from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch


from .data_module import PinaDataModule
from .base_dataset import BaseDataset
from .dataset import PinaDataset
157 changes: 0 additions & 157 deletions pina/data/base_dataset.py

This file was deleted.

Loading

0 comments on commit 8f458d2

Please sign in to comment.